In [None]:
from keras.utils import normalize
import os
import cv2
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
from unet import unet_model
import random
from pathlib import Path
import logging
import itertools

LOGGER = logging.getLogger()
from datetime import datetime

from skimage.morphology import binary_erosion
from skimage.morphology import skeletonize
from skimage.filters import hessian
from skimage.feature import hessian_matrix, hessian_matrix_eigvals
from skimage.morphology import label
from skimage.measure import regionprops
from skimage.color import label2rgb

import tensorflow as tf
from sklearn.model_selection import train_test_split

from topostats.plottingfuncs import Colormap

colormap = Colormap()
cmap = colormap.get_cmap()

# for key, value in os.environ.items():
#     print(f"{key} : {value}")

In [None]:
# Ensure that your GPU is working
tf.test.gpu_device_name()

In [None]:
# Set the random seeds
SEED = 4
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [None]:
IMAGE_DIR = Path("/Users/sylvi/topo_data/hariborings/molecule_of_interest_identification/training_data/images/")
MASK_DIR = Path("/Users/sylvi/topo_data/hariborings/molecule_of_interest_identification/training_data/masks/")
MODEL_SAVE_DIR = Path("/Users/sylvi/topo_data/hariborings/saved_models/")

# Check they exist
assert IMAGE_DIR.exists()
assert MASK_DIR.exists()
assert MODEL_SAVE_DIR.exists()

# Get the images and masks
IMAGES = sorted(list(IMAGE_DIR.glob("*.npy")))
MASKS = sorted(list(MASK_DIR.glob("*.npy")))

print(f"number of images: {len(IMAGES)} | number of masks: {len(MASKS)}")

In [None]:
# blank = np.zeros((512, 512))
# np.save(MASK_DIR / "task-639.npy", blank)
# plt.imsave(MASK_DIR / "task-639.png", blank, cmap="gray")

In [None]:
# Sanity check the images and masks

for i in range(len(IMAGES)):
    fig, ax = plt.subplots(1, 2, figsize=(10, 10))
    ax[0].imshow(np.array(np.load(IMAGES[i])))
    ax[1].imshow(np.array(np.load(MASKS[i])))

    plt.show()

In [None]:
IMAGES_PNG = sorted(list(IMAGE_DIR.glob("image_*.png")))
MASKS_PNG = sorted(list(MASK_DIR.glob("mask_*.png")))

print(f"number of images: {len(IMAGES_PNG)} | number of masks: {len(MASKS_PNG)}")

In [None]:
# FILE_EXT = ".png"

# # rename images to the mask name, keeping the path the same
# for index, (image, mask) in enumerate(zip(IMAGES_PNG, MASKS_PNG)):
#     print(f"index: {index} | image: {image} | mask: {mask}")

#     new_image_name = f"image_{index}{FILE_EXT}"
#     new_mask_name = f"mask_{index}{FILE_EXT}"

#     new_image_path = image.parent / new_image_name
#     new_mask_path = mask.parent / new_mask_name

#     print(f"new_image_path: {new_image_path} | new_mask_path: {new_mask_path}")

#     image.rename(new_image_path)
#     mask.rename(new_mask_path)

In [None]:
# An image generator that loads images as they are needed
def image_generator(image_indexes, batch_size=4):
    while True:
        # Select files (paths/indices) for the batch
        batch_image_indexes = np.random.choice(a=image_indexes, size=batch_size)
        batch_input = []
        batch_output = []

        # Load the image and ground truth
        for index in batch_image_indexes:
            # Get the image
            image = np.load(IMAGE_DIR / f"image_{index}.npy")
            # image = Image.fromarray(image)
            # image = image.resize((512, 512))
            image = np.array(image)
            # Normalise the image
            image = image - np.min(image)
            image = image / np.max(image)

            # Get the ground truth
            ground_truth = np.load(MASK_DIR / f"mask_{index}.npy")
            ground_truth = ground_truth.astype(bool)
            ground_truth = Image.fromarray(ground_truth)
            # No interpolation
            # ground_truth = ground_truth.resize((512, 512), resample=Image.NEAREST)
            # ground_truth = np.array(ground_truth)

            # Augment the images
            # Flip the images 50% of the time
            if random.choice([0, 1]) == 1:
                image = np.flip(image, axis=1)
                ground_truth = np.flip(ground_truth, axis=1)
            # Rotate the images by either 0, 90, 180, or 270 degrees
            rotation = random.choice([0, 1, 2, 3])
            image = np.rot90(image, rotation)
            ground_truth = np.rot90(ground_truth, rotation)

            batch_input.append(image)
            batch_output.append(ground_truth)

        batch_x = np.array(batch_input).astype(np.float32)
        batch_y = np.array(batch_output).astype(np.float32)

        yield (batch_x, batch_y)

In [None]:
# Check that the generator is doing the right thing
indexes = np.arange(0, len(IMAGES))
batch_generator = image_generator(indexes, batch_size=4)
(batch_x, batch_y) = next(batch_generator)
for image, mask in zip(batch_x, batch_y):
    print(f"image shape: {image.shape}")
    print(f"image max: {np.max(image)}")
    print(f"image min: {np.min(image)}")
    print(f"mask shape: {mask.shape}")
    print(f"mask unique: {np.unique(mask)}")
    print(f"mask dtype: {mask.dtype}")

    fig, ax = plt.subplots(1, 3, figsize=(10, 30))
    ax[0].imshow(image)
    ax[0].set_title("image")
    ax[1].imshow(mask)
    ax[1].set_title("mask")
    ax[2].imshow(image)
    ax[2].imshow(mask, alpha=0.2)
    ax[2].set_title("image with mask")
    plt.show()

In [None]:
IMG_WIDTH, IMG_HEIGHT = (512, 512)
CHANNELS = 1
BATCH_SIZE = 3
EPOCHS = 10
AUGMENTATION_FACTOR = 8
LEARNING_RATE = 0.001
TRAIN_TEST_SPLIT = 0.1

In [None]:
# Split what images are used for training and validation
train_image_indexes, validation_image_indexes = train_test_split(
    range(0, len(IMAGES)), test_size=TRAIN_TEST_SPLIT, random_state=SEED
)

print(f"Number of training images: {len(train_image_indexes)}")
print(f"Number of validation images: {len(validation_image_indexes)}")

print(f"Training image indexes: {train_image_indexes}")
print(f"Validation image indexes: {validation_image_indexes}")

# Create the generators
train_generator = image_generator(train_image_indexes, batch_size=BATCH_SIZE)
validation_generator = image_generator(validation_image_indexes, batch_size=BATCH_SIZE)

In [None]:
# Get the model
model = unet_model(IMG_HEIGHT=IMG_HEIGHT, IMG_WIDTH=IMG_WIDTH, IMG_CHANNELS=CHANNELS, learning_rate=LEARNING_RATE)
model.summary()

In [None]:
history = model.fit(
    train_generator,
    # How many steps (batches of samples) to draw from generator before declaring one epoch finished and starting the next epoch
    steps_per_epoch=(len(IMAGES) // BATCH_SIZE) * AUGMENTATION_FACTOR,
    epochs=EPOCHS,
    validation_data=validation_generator,
    # How many steps (batches) to yield from validation generator at the end of every epoch
    validation_steps=(len(IMAGES) // BATCH_SIZE) * AUGMENTATION_FACTOR,
    verbose=1,
)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(30, 8))

loss = history.history["loss"]
val_loss = history.history["val_loss"]
epochs = range(1, len(loss) + 1)
ax[0].plot(epochs, loss, "y", label="Training loss")
ax[0].plot(epochs, val_loss, "r", label="Valdation loss")
ax[0].set_title("Training and validation loss")
ax[0].set_xlabel("Epochs")
ax[0].set_ylabel("Loss")
ax[0].legend()


print(history.history.keys())
# acc = history.history["mean_io_u"]
# val_acc = history.history["val_mean_io_u"]


print(history.history.keys())

iou_score = history.history["iou"]
val_iou_score = history.history["val_iou"]

ax[1].plot(epochs, iou_score, "y", label="Training iou")
ax[1].plot(epochs, val_iou_score, "r", label="Validation iou")
ax[1].set_title("Training and validation iou")
ax[1].set_xlabel("Epochs")
ax[1].set_ylabel("iou")
ax[1].legend()


# acc = history.history["accuracy"]
# val_acc = history.history["val_accuracy"]

# ax[1].plot(epochs, acc, "y", label="Training acc")
# ax[1].plot(epochs, val_acc, "r", label="Validation acc")
# ax[1].set_title("Training and validation accuracy")
# ax[1].set_xlabel("Epochs")
# ax[1].set_ylabel("Accuracy")
# ax[1].legend()

# iou_score = history.history["mean_io_u"]
# val_iou_score = history.history["val_mean_io_u"]

In [None]:
# Show the results of the model on the testing set using the validation generator
# Get the next batch from the generator
(batch_x, batch_y) = next(validation_generator)
# Predict the masks
predicted_masks = model.predict(batch_x)
# Threshold the masks
threshold = 0.01
predicted_masks = (predicted_masks > threshold).astype(np.uint8)
# Show the results
fig, ax = plt.subplots(len(batch_x), 4, figsize=(20, 10 * len(batch_x)))
for index, (image, mask, predicted_mask) in enumerate(zip(batch_x, batch_y, predicted_masks)):
    # Plot the images in a figure

    ax[index, 0].imshow(image)
    ax[index, 0].set_title("Image")
    ax[index, 1].imshow(mask)
    ax[index, 1].set_title("True Mask")
    ax[index, 2].imshow(predicted_mask)
    ax[index, 2].set_title(f"Predicted Mask (thresholded at {threshold})")
    ax[index, 3].imshow(image)
    ax[index, 3].imshow(predicted_mask, alpha=0.2)

fig.tight_layout()
plt.show()