In [None]:
from pathlib import Path
import pickle
import random
import re

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from keras.utils import to_categorical
from cnn_classification import classification_model
from sklearn.model_selection import train_test_split

In [None]:
tf.test.gpu_device_name()

In [None]:
SEED = 2
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [None]:
# Load the data

# Just load OT2_SC for now
DATA_PATH = Path("/Users/sylvi/topo_data/hariborings/dna_manual_tags/OT2_SC/")
DATA_PATH_IMAGES = DATA_PATH / "images"
assert DATA_PATH_IMAGES.exists()

NUM_CLASSES = 3
label_to_int = {
    "churro": 0,
    "dorito": 1,
    "pasty": 2,
}

tagged_grain_dict_file = DATA_PATH / "OT2_SC_tagged_grains.pkl"
assert tagged_grain_dict_file.exists()

with open(tagged_grain_dict_file, "rb") as f:
    tagged_grain_dict = pickle.load(f)

print(tagged_grain_dict.keys())
NUM_IMAGES = len(tagged_grain_dict)

In [None]:
def zoom_and_shift(image: np.ndarray, ground_truth: np.ndarray, max_zoom_percentage: float = 0.1) -> np.ndarray:
    """Zooms in on the image by a random amount between 0 and max_zoom_percentage,
    then shifts the image by a random amount up to the number of zoomed pixels.
    """

    # Choose a zoom percentage and caluculate the number of pixels to zoom in
    zoom = np.random.uniform(0, 0.1)
    zoom_pixels = int(image.shape[0] * zoom)

    # If there is zoom, choose a random shift
    if int(zoom_pixels) > 0:
        shift_x = np.random.randint(int(-zoom_pixels), int(zoom_pixels))
        shift_y = np.random.randint(int(-zoom_pixels), int(zoom_pixels))

        # Zoom and shift the image
        zoomed_and_shifted_image = image[
            zoom_pixels + shift_x : -zoom_pixels + shift_x,
            zoom_pixels + shift_y : -zoom_pixels + shift_y,
        ]
        zoomed_and_shifted_ground_truth = ground_truth[
            zoom_pixels + shift_x : -zoom_pixels + shift_x,
            zoom_pixels + shift_y : -zoom_pixels + shift_y,
        ]
    else:
        # Do nothing
        shift_x = 0
        shift_y = 0

        zoomed_and_shifted_image = image
        zoomed_and_shifted_ground_truth = ground_truth

    return zoomed_and_shifted_image, zoomed_and_shifted_ground_truth


# An image generator that loads images as they are needed
def image_generator(image_indexes, batch_size=4, return_index=False):
    while True:
        # Select files (paths/indices) for the batch
        batch_image_indexes = np.random.choice(a=image_indexes, size=batch_size)
        batch_input = []
        batch_output = []
        real_image_indexes = []

        # Load the image and ground truth
        for index in batch_image_indexes:
            # Get the nth image after sorting the files
            image_path = list(DATA_PATH_IMAGES.glob("*.npy"))
            image_path = sorted(image_path)[index]
            image = np.load(image_path)

            # Get the real image index from the file name
            image_index = int(re.search(r"image_(\d+)_", image_path.stem).group(1))

            # Get the label, which is the text just after "image_{number}_" and before ".npy"
            file_name = image_path.stem
            label = re.search(r"image_\d+_(\w+)", file_name).group(1)

            # Convert the label to an integer
            label = label_to_int[label]
            label = to_categorical(label, num_classes=NUM_CLASSES)

            # Pad the image to be square
            if image.shape[0] != image.shape[1]:
                # Find the difference between the two dimensions
                diff = abs(image.shape[0] - image.shape[1])
                # If the first dimension is smaller, pad the first dimension
                if image.shape[0] < image.shape[1]:
                    image = np.pad(image, ((diff // 2, diff // 2), (0, 0)), mode="constant")
                # If the second dimension is smaller, pad the second dimension
                else:
                    image = np.pad(image, ((0, 0), (diff // 2, diff // 2)), mode="constant")

            # Randomly zoom and shift the image
            image, _ = zoom_and_shift(image=image, ground_truth=image, max_zoom_percentage=0.15)

            # Resize to 256x256
            image = Image.fromarray(image)
            image = image.resize((256, 256), resample=Image.BILINEAR)
            image = np.array(image)

            # Normalise the image
            image = image - np.min(image)
            image = image / np.max(image)
            # image = np.clip(image, NORM_LOWER_BOUND, NORM_UPPER_BOUND)
            # image = image - NORM_LOWER_BOUND
            # image = image / (NORM_UPPER_BOUND - NORM_LOWER_BOUND)

            # Augment the image
            # Flip the images 50% of the time
            if random.choice([0, 1]) == 1:
                image = np.flip(image, axis=1)
            # Rotate the images by either 0, 90, 180, or 270 degrees
            rotation = random.choice([0, 1, 2, 3])
            image = np.rot90(image, rotation)

            batch_input.append(image)
            batch_output.append(label)
            real_image_indexes.append(image_index)

        batch_x = np.array(batch_input).astype(np.float32)
        batch_y = np.array(batch_output).astype(np.float32)

        if return_index:
            yield (batch_x, batch_y, real_image_indexes)
        yield (batch_x, batch_y)

In [None]:
# Check that the generator is doing the right thing
# Get a list of available indexes to use from
batch_generator = image_generator([0, 1, 2, 3, 4], batch_size=4, return_index=True)
(batch_x, batch_y, real_image_indexes) = next(batch_generator)
for image, label, real_image_index in zip(batch_x, batch_y, real_image_indexes):
    print(f"image shape: {image.shape}")
    print(f"image max: {np.max(image)}")
    print(f"image min: {np.min(image)}")
    print(f"label: {label}")
    print(f"real_image_index: {real_image_index}")

    fig, ax = plt.subplots(1, 1, figsize=(10, 30))
    ax.imshow(image)
    # Get the label from the label_to_int dictionary
    text_label = list(label_to_int.keys())[list(label).index(1)]
    ax.set_title(f"label: {label}  {text_label}")
    plt.show()

In [None]:
IMG_SIZE = 256
BATCH_SIZE = 5
EPOCHS = 25
AUGMENTATION_FACTOR = 8
LEARNING_RATE = 0.001
TRAIN_TEST_SPLIT = 0.1

In [None]:
# Split what images are used for training and validation
train_image_indexes, validation_image_indexes = train_test_split(
    range(0, NUM_IMAGES), test_size=TRAIN_TEST_SPLIT, random_state=SEED
)

print(f"Number of training images: {len(train_image_indexes)}")
print(f"Number of validation images: {len(validation_image_indexes)}")

print(f"Training image indexes: {train_image_indexes}")
print(f"Validation image indexes: {validation_image_indexes}")

# Create the generators
train_generator = image_generator(train_image_indexes, batch_size=BATCH_SIZE)
validation_generator = image_generator(validation_image_indexes, batch_size=BATCH_SIZE)

In [None]:
class_model = classification_model(img_size=IMG_SIZE, classes=NUM_CLASSES, learning_rate=LEARNING_RATE)
class_model.summary()

In [None]:
history = class_model.fit(
    train_generator,
    steps_per_epoch=len(train_image_indexes) // BATCH_SIZE,
    validation_data=validation_generator,
    validation_steps=len(validation_image_indexes) // BATCH_SIZE,
    epochs=EPOCHS,
    verbose=1,
)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(30, 8))

loss = history.history["loss"]
val_loss = history.history["val_loss"]
epochs = range(1, len(loss) + 1)
ax[0].plot(epochs, loss, "y", label="Training loss")
ax[0].plot(epochs, val_loss, "r", label="Valdation loss")
ax[0].set_title("Training and validation loss")
ax[0].set_xlabel("Epochs")
ax[0].set_ylabel("Loss")
ax[0].legend()


print(history.history.keys())
# acc = history.history["mean_io_u"]
# val_acc = history.history["val_mean_io_u"]

acc = history.history["accuracy"]
val_acc = history.history["val_accuracy"]

ax[1].plot(epochs, acc, "y", label="Training acc")
ax[1].plot(epochs, val_acc, "r", label="Validation acc")
ax[1].set_title("Training and validation accuracy")
ax[1].set_xlabel("Epochs")
ax[1].set_ylabel("Accuracy")

if "iou" in history.history:
    ax[1].plot(epochs, history.history["iou"], "b", label="Training iou")
    ax[1].plot(epochs, history.history["val_iou"], "g", label="Validation iou")

ax[1].legend()

In [None]:
# Show the results of the model on the testing set using the validation generator
# Get the next batch from the generator
(batch_x, batch_y) = next(validation_generator)
# Predict the classes
predicted_class = class_model.predict(batch_x)

print(predicted_class)
print(batch_x.shape)

# Plot the results
fig, ax = plt.subplots(len(batch_x), 1, figsize=(10, 30))
for index, (image, ground_truth_label, predicted_label) in enumerate(zip(batch_x, batch_y, predicted_class)):
    ax[index].imshow(image)
    # Get the label from the label_to_int dictionary
    text_label = list(label_to_int.keys())[list(ground_truth_label).index(1)]
    text_predicted = list(label_to_int.keys())[np.argmax(predicted_label)]
    title = f"predicted: {text_predicted.upper()} {predicted_label[0]:.2f} {predicted_label[1]:.2f} {predicted_label[2]:.2f} \nground truth: {text_label.upper()} {ground_truth_label}"
    if text_label == text_predicted:
        # Set the title to green if the prediction is correct
        ax[index].set_title(title, color="green")
    else:
        # Set the title to red if the prediction is wrong
        ax[index].set_title(title, color="red")
plt.show()