In [None]:
from keras.utils import normalize, to_categorical
import os
import cv2
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
from unet import unet_model
import random
from pathlib import Path
import logging
import itertools

LOGGER = logging.getLogger()
from datetime import datetime

from skimage.morphology import binary_erosion
from skimage.morphology import skeletonize
from skimage.filters import hessian
from skimage.feature import hessian_matrix, hessian_matrix_eigvals
from skimage.morphology import label
from skimage.measure import regionprops
from skimage.color import label2rgb

import tensorflow as tf
from sklearn.model_selection import train_test_split

for key, value in os.environ.items():
    print(f"{key} : {value}")

In [None]:
# Ensure that your GPU is working
tf.test.gpu_device_name()

In [None]:
# Set the random seeds
SEED = 0
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [None]:
# MacOS
ORIGINAL_IMAGE_DIR = Path("/Users/sylvi/topo_data/cats/training_data/cropped/cropped_images")
MASK_DIR = Path("/Users/sylvi/topo_data/cats/training_data/cropped/cropped_labels")

# Linux

MODEL_SAVE_DIR = Path("/Users/sylvi/topo_data/cats/saved_models/")
CHANNELS = 1

BATCH_SIZE = 5

# Get the number of .png images
NUM_IMAGES = len(list(ORIGINAL_IMAGE_DIR.glob("image_*.npy")))
NUM_MASKS = len(list(MASK_DIR.glob("mask_*.npy")))
print(f"Number of images: {NUM_IMAGES}, Number of masks: {NUM_MASKS}")

In [None]:
img = np.load(ORIGINAL_IMAGE_DIR / "image_0.npy")
plt.imshow(img, cmap="gray")
plt.show()
print(img.shape)

msk = np.load(MASK_DIR / "mask_0.npy")
plt.imshow(msk, cmap="gray")
print(msk.shape)

In [None]:
# An image generator that loads images as they are needed
def image_generator(image_indexes, batch_size=4):
    while True:
        # Select files (paths/indices) for the batch
        batch_image_indexes = np.random.choice(a=image_indexes, size=batch_size)
        batch_input = []
        batch_output = []

        # Load the image and ground truth
        for index in batch_image_indexes:
            # Find the index as the only number in the filename
            image = np.load(ORIGINAL_IMAGE_DIR / f"image_{index}.npy")
            image = Image.fromarray(image)
            image = image.resize((512, 512))
            image = np.array(image)
            # Normalise the image
            image = image - np.min(image)
            image = image / np.max(image)

            ground_truth = np.load(MASK_DIR / f"mask_{index}.npy")
            # print(f"ground truth unique values: {np.unique(ground_truth)}")
            ground_truth = ground_truth.astype(bool)
            ground_truth = Image.fromarray(ground_truth.astype(np.uint8))
            # No interpolation
            ground_truth = ground_truth.resize((512, 512), resample=Image.NEAREST)
            ground_truth = np.array(ground_truth).astype(int)

            # Augment the images
            # Flip the images 50% of the time
            if random.choice([0, 1]) == 1:
                image = np.flip(image, axis=1)
                ground_truth = np.flip(ground_truth, axis=1)

            # Rotate the images by either 0, 90, 180, or 270 degrees
            rotation = random.choice([0, 1, 2, 3])
            image = np.rot90(image, rotation)
            ground_truth = np.rot90(ground_truth, rotation)

            batch_input.append(image)
            batch_output.append(ground_truth)

        batch_x = np.array(batch_input).astype(np.float32)
        batch_y = np.array(batch_output).astype(np.float32)

        yield (batch_x, batch_y)

In [None]:
# Check that the generator is doing the right thing
batch_generator = image_generator([0, 1, 2, 3, 4], batch_size=4)
(batch_x, batch_y) = next(batch_generator)
for image, mask in zip(batch_x, batch_y):
    plt.imshow(image)
    print(f"image shape: {image.shape}")
    print(f"image max: {np.max(image)}")
    print(f"image min: {np.min(image)}")
    plt.show()
    print(f"mask shape: {mask.shape}")
    plt.imshow(mask)
    print(f"mask shape: {mask.shape}")
    print(f"mask unique: {np.unique(mask)}")
    plt.show()

    print(f"mask dtype: {mask.dtype}")

In [None]:
# Split what images are used for training and validation
train_image_indexes, validation_image_indexes = train_test_split(range(0, NUM_IMAGES), test_size=0.2, random_state=SEED)

print(f"Number of training images: {len(train_image_indexes)}")
print(f"Number of validation images: {len(validation_image_indexes)}")

print(f"Training image indexes: {train_image_indexes}")
print(f"Validation image indexes: {validation_image_indexes}")

# Create the generators
train_generator = image_generator(train_image_indexes, batch_size=4)
validation_generator = image_generator(validation_image_indexes, batch_size=4)

In [None]:
# Get the model
model = unet_model(IMG_HEIGHT=512, IMG_WIDTH=512, IMG_CHANNELS=CHANNELS)
# model.summary()

In [None]:
EPOCHS = 100
AUGMENTATION_FACTOR = 8
history = model.fit(
    train_generator,
    # How many steps (batches of samples) to draw from generator before declaring one epoch finished and starting the next epoch
    steps_per_epoch=(NUM_IMAGES // BATCH_SIZE) * AUGMENTATION_FACTOR,
    epochs=EPOCHS,
    validation_data=validation_generator,
    # How many steps (batches) to yield from validation generator at the end of every epoch
    validation_steps=(NUM_IMAGES // BATCH_SIZE) * AUGMENTATION_FACTOR,
    verbose=1,
)

In [None]:
loss = history.history["loss"]
val_loss = history.history["val_loss"]
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, "y", label="Training loss")
plt.plot(epochs, val_loss, "r", label="Valdation loss")
plt.title("Training and validation loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()


print(history.history.keys())
# acc = history.history["mean_io_u"]
# val_acc = history.history["val_mean_io_u"]

acc = history.history["accuracy"]
val_acc = history.history["val_accuracy"]

plt.plot(epochs, acc, "y", label="Training acc")
plt.plot(epochs, val_acc, "r", label="Validation acc")
plt.title("Training and validation accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

In [None]:
# Save the model
model.save(MODEL_SAVE_DIR / f"haribonet_multiclass_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}.h5")

In [None]:
# Show the results of the model on the testing set using the validation generator
# Get the next batch from the generator
(batch_x, batch_y) = next(validation_generator)
# Predict the masks
predicted_masks = model.predict(batch_x)

print(f"predicted masks shape: {predicted_masks.shape}")

# Threshold the masks
# predicted_masks = (predicted_masks > 0.01).astype(np.uint8)
# # Show the results
# for image, mask, predicted_mask in zip(batch_x, batch_y, predicted_masks):
#     plt.imshow(image)
#     plt.show()
#     plt.imshow(mask)
#     plt.show()
#     plt.imshow(predicted_mask)
#     plt.show()

threshold = 0.5

for image, mask, predicted_mask in zip(batch_x, batch_y, predicted_masks):
    # Gem predicted mask
    predicted_gem_mask = predicted_mask[:, :, 2] > threshold
    # Ring predicted mask
    predicted_ring_mask = predicted_mask[:, :, 1] > threshold
    # Background predicted mask
    predicted_background_mask = predicted_mask[:, :, 0] > threshold

    combined_predicted_mask = np.zeros_like(predicted_gem_mask).astype(int)
    combined_predicted_mask[predicted_gem_mask] = 2
    combined_predicted_mask[predicted_ring_mask] = 1

    # plt.imshow(image)
    # plt.title("Image")
    # plt.show()
    # plt.imshow(predicted_gem_mask)
    # plt.title("Predicted Gem Mask")
    # plt.show()
    # plt.imshow(predicted_ring_mask)
    # plt.title("Predicted Ring Mask")
    # plt.show()
    # plt.imshow(predicted_background_mask)
    # plt.title("Predicted Background Mask")
    # plt.show()
    # plt.imshow(predicted_mask)
    # plt.title("Predicted Mask")
    # plt.show()
    # plt.title("Ground Truth")
    # plt.imshow(mask)
    # plt.show()

    # Combine these plots into one figure
    fig, axs = plt.subplots(3, 3, figsize=(12, 12))
    axs[0, 0].imshow(image)
    axs[0, 0].set_title("Image")
    axs[0, 1].imshow(predicted_gem_mask)
    axs[0, 1].set_title("Predicted Gem Mask")
    axs[1, 0].imshow(predicted_ring_mask)
    axs[1, 0].set_title("Predicted Ring Mask")
    axs[1, 1].imshow(predicted_background_mask)
    axs[1, 1].set_title("Predicted Background Mask")
    axs[2, 0].imshow(predicted_mask)
    axs[2, 0].set_title("Predicted Mask")
    axs[2, 1].imshow(mask)
    axs[2, 1].set_title("Ground Truth")
    axs[0, 2].imshow(combined_predicted_mask)
    axs[0, 2].set_title("Combined Predicted Mask")
    plt.show()