In [None]:
from keras.utils import normalize
import os
import cv2
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
from unet import unet_model
import random
from pathlib import Path
import logging
import itertools

LOGGER = logging.getLogger()
from datetime import datetime

from skimage.morphology import binary_erosion
from skimage.morphology import skeletonize
from skimage.filters import hessian
from skimage.feature import hessian_matrix, hessian_matrix_eigvals
from skimage.morphology import label
from skimage.measure import regionprops
from skimage.color import label2rgb

import tensorflow as tf
from sklearn.model_selection import train_test_split

# for key, value in os.environ.items():
#     print(f"{key} : {value}")

In [None]:
# Ensure that your GPU is working
tf.test.gpu_device_name()

In [None]:
# Set the random seeds
SEED = 5
np.random.seed(SEED)
tf.random.set_seed(SEED)
# Set the normalisation bounds for the molecules in nm

In [None]:
# MacOS
# ORIGINAL_IMAGE_DIR = Path(
#     "/Users/sylvi/topo_data/hariborings/training_data/cropped/dna_only/images_256/"
# )
# MASK_DIR = Path("/Users/sylvi/topo_data/hariborings/training_data/cropped/dna_only/masks_256/")
# ORIGINAL_IMAGE_DIR = Path("/Users/sylvi/topo_data/hariborings/training_data/cropped/dna_cas9/images_256/")
# MASK_DIR = Path("/Users/sylvi/topo_data/hariborings/training_data/cropped/dna_cas9/masks_256_ring/")

# DNA ONLY SHARPER
ORIGINAL_IMAGE_DIR = Path("/Users/sylvi/topo_data/hariborings/training_data/cropped/dna_only/images_extra_doritos_256/")
MASK_DIR = Path("/Users/sylvi/topo_data/hariborings/training_data/cropped/dna_only/masks_extra_doritos_256/")

# Upper and lower bounds for normalisation
NORM_UPPER_BOUND = 5
NORM_LOWER_BOUND = -1

# MODEL_SAVE_DIR = Path("./saved_models")
# MacOS
# MODEL_SAVE_DIR = Path("/Users/sylvi/topo_data/hariborings/training_data/cropped/saved_models/dna_cas9_ring")
MODEL_SAVE_DIR = Path("/Users/sylvi/topo_data/hariborings/saved_models/dna_only_extra_doritos/")
MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)

# Get the number of .png images
NUM_IMAGES = len(list(ORIGINAL_IMAGE_DIR.glob("*.npy")))
NUM_MASKS = len(list(MASK_DIR.glob("*.npy")))
print(f"Number of images: {NUM_IMAGES}, number of masks: {NUM_MASKS}")

In [None]:
def zoom_and_shift(image: np.ndarray, ground_truth: np.ndarray, max_zoom_percentage: float = 0.1) -> np.ndarray:
    """Zooms in on the image by a random amount between 0 and max_zoom_percentage,
    then shifts the image by a random amount up to the number of zoomed pixels.
    """

    # Choose a zoom percentage and caluculate the number of pixels to zoom in
    zoom = np.random.uniform(0, 0.1)
    zoom_pixels = int(image.shape[0] * zoom)

    # If there is zoom, choose a random shift
    if int(zoom_pixels) > 0:
        shift_x = np.random.randint(int(-zoom_pixels), int(zoom_pixels))
        shift_y = np.random.randint(int(-zoom_pixels), int(zoom_pixels))

        # Zoom and shift the image
        zoomed_and_shifted_image = image[
            zoom_pixels + shift_x : -zoom_pixels + shift_x,
            zoom_pixels + shift_y : -zoom_pixels + shift_y,
        ]
        zoomed_and_shifted_ground_truth = ground_truth[
            zoom_pixels + shift_x : -zoom_pixels + shift_x,
            zoom_pixels + shift_y : -zoom_pixels + shift_y,
        ]
    else:
        # Do nothing
        shift_x = 0
        shift_y = 0

        zoomed_and_shifted_image = image
        zoomed_and_shifted_ground_truth = ground_truth

    return zoomed_and_shifted_image, zoomed_and_shifted_ground_truth


# An image generator that loads images as they are needed
def image_generator(image_indexes, batch_size=4):
    while True:
        # Select files (paths/indices) for the batch
        batch_image_indexes = np.random.choice(a=image_indexes, size=batch_size)
        batch_input = []
        batch_output = []

        # Load the image and ground truth
        for index in batch_image_indexes:
            # Get the image
            image = np.load(ORIGINAL_IMAGE_DIR / f"image_{index}.npy")
            image = np.array(image)

            ground_truth = np.load(MASK_DIR / f"mask_{index}.npy")
            ground_truth = ground_truth.astype(bool)

            # Randomly zoom and shift the image
            image, ground_truth = zoom_and_shift(image=image, ground_truth=ground_truth, max_zoom_percentage=0.2)

            # Resize to 256x256
            image = Image.fromarray(image)
            image = image.resize((256, 256), resample=Image.BILINEAR)
            image = np.array(image)
            ground_truth = Image.fromarray(ground_truth)
            ground_truth = ground_truth.resize((256, 256), resample=Image.NEAREST)
            ground_truth = np.array(ground_truth)

            # Normalise the image
            # image = image - np.min(image)
            # image = image / np.max(image)
            image = np.clip(image, NORM_LOWER_BOUND, NORM_UPPER_BOUND)
            image = image - NORM_LOWER_BOUND
            image = image / (NORM_UPPER_BOUND - NORM_LOWER_BOUND)

            # Get the ground truth

            # ground_truth = Image.fromarray(ground_truth)
            # No interpolation
            # ground_truth = ground_truth.resize((512, 512), resample=Image.NEAREST)
            # ground_truth = np.array(ground_truth)

            # Augment the images
            # Flip the images 50% of the time
            if random.choice([0, 1]) == 1:
                image = np.flip(image, axis=1)
                ground_truth = np.flip(ground_truth, axis=1)
            # Rotate the images by either 0, 90, 180, or 270 degrees
            rotation = random.choice([0, 1, 2, 3])
            image = np.rot90(image, rotation)
            ground_truth = np.rot90(ground_truth, rotation)

            batch_input.append(image)
            batch_output.append(ground_truth)

        batch_x = np.array(batch_input).astype(np.float32)
        batch_y = np.array(batch_output).astype(np.float32)

        yield (batch_x, batch_y)

In [None]:
# Check that the generator is doing the right thing
batch_generator = image_generator([0, 1, 2, 3, 4], batch_size=4)
(batch_x, batch_y) = next(batch_generator)
for image, mask in zip(batch_x, batch_y):
    print(f"image shape: {image.shape}")
    print(f"image max: {np.max(image)}")
    print(f"image min: {np.min(image)}")
    print(f"mask shape: {mask.shape}")
    print(f"mask unique: {np.unique(mask)}")
    print(f"mask dtype: {mask.dtype}")

    fig, ax = plt.subplots(1, 3, figsize=(10, 30))
    ax[0].imshow(image)
    ax[0].set_title("image")
    ax[1].imshow(mask)
    ax[1].set_title("mask")
    ax[2].imshow(image)
    ax[2].imshow(mask, alpha=0.2)
    ax[2].set_title("image with mask")
    plt.show()

In [None]:
IMG_WIDTH, IMG_HEIGHT = (256, 256)
CHANNELS = 1
BATCH_SIZE = 25
EPOCHS = 45
AUGMENTATION_FACTOR = 8
LEARNING_RATE = 0.001
TRAIN_TEST_SPLIT = 0.1
LOSS_FUNCTION = "binary_crossentropy"

In [None]:
# Split what images are used for training and validation
train_image_indexes, validation_image_indexes = train_test_split(
    range(0, NUM_IMAGES), test_size=TRAIN_TEST_SPLIT, random_state=SEED
)

print(f"Number of training images: {len(train_image_indexes)}")
print(f"Number of validation images: {len(validation_image_indexes)}")

print(f"Training image indexes: {train_image_indexes}")
print(f"Validation image indexes: {validation_image_indexes}")

# Create the generators
train_generator = image_generator(train_image_indexes, batch_size=BATCH_SIZE)
validation_generator = image_generator(validation_image_indexes, batch_size=BATCH_SIZE)

In [None]:
# Get the model
model = unet_model(
    IMG_HEIGHT=IMG_HEIGHT,
    IMG_WIDTH=IMG_WIDTH,
    IMG_CHANNELS=CHANNELS,
    learning_rate=LEARNING_RATE,
    loss_function=LOSS_FUNCTION,
)
model.summary()

In [None]:
history = model.fit(
    train_generator,
    # How many steps (batches of samples) to draw from generator before declaring one epoch finished and starting the next epoch
    steps_per_epoch=(NUM_IMAGES // BATCH_SIZE) * AUGMENTATION_FACTOR,
    epochs=EPOCHS,
    validation_data=validation_generator,
    # How many steps (batches) to yield from validation generator at the end of every epoch
    validation_steps=(NUM_IMAGES // BATCH_SIZE) * AUGMENTATION_FACTOR,
    verbose=1,
)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(30, 8))

loss = history.history["loss"]
val_loss = history.history["val_loss"]
epochs = range(1, len(loss) + 1)
ax[0].plot(epochs, loss, "y", label="Training loss")
ax[0].plot(epochs, val_loss, "r", label="Valdation loss")
ax[0].set_title("Training and validation loss")
ax[0].set_xlabel("Epochs")
ax[0].set_ylabel("Loss")
ax[0].legend()


print(history.history.keys())
# acc = history.history["mean_io_u"]
# val_acc = history.history["val_mean_io_u"]

acc = history.history["accuracy"]
val_acc = history.history["val_accuracy"]

ax[1].plot(epochs, acc, "y", label="Training acc")
ax[1].plot(epochs, val_acc, "r", label="Validation acc")
ax[1].set_title("Training and validation accuracy")
ax[1].set_xlabel("Epochs")
ax[1].set_ylabel("Accuracy")

if "iou" in history.history:
    ax[1].plot(epochs, history.history["iou"], "b", label="Training iou")
    ax[1].plot(epochs, history.history["val_iou"], "g", label="Validation iou")

ax[1].legend()

# acc = history.history["mean_io_u_1"]
# val_acc = history.history["val_mean_io_u_1"]

# ax[2].plot(epochs, acc, "y", label="Training acc")
# ax[2].plot(epochs, val_acc, "r", label="Validation acc")
# ax[2].set_title("Training and validation accuracy")
# ax[2].set_xlabel("Epochs")
# ax[2].set_ylabel("Accuracy")
# ax[2].legend()

In [None]:
# Show the results of the model on the testing set using the validation generator
# Get the next batch from the generator
(batch_x, batch_y) = next(validation_generator)
# Predict the masks
predicted_masks = model.predict(batch_x)
# Threshold the masks
threshold = 0.5
predicted_masks = (predicted_masks > threshold).astype(np.uint8)
# Show the results
fig, ax = plt.subplots(len(batch_x), 4, figsize=(20, 10 * len(batch_x)))
for index, (image, mask, predicted_mask) in enumerate(zip(batch_x, batch_y, predicted_masks)):
    # Plot the images in a figure

    ax[index, 0].imshow(image)
    ax[index, 0].set_title("Image")
    ax[index, 1].imshow(mask)
    ax[index, 1].set_title("True Mask")
    ax[index, 2].imshow(predicted_mask)
    ax[index, 2].set_title(f"Predicted Mask (thresholded at {threshold})")
    ax[index, 3].imshow(image)
    ax[index, 3].imshow(predicted_mask, alpha=0.2)

fig.tight_layout()
plt.show()

In [None]:
# Save the model with date formatted as YYYY-MM-DD_HH-MM-SS
# Also record the number of epochs, batch size, learning rate, and augmentation factor
file_name = f"haribonet_dna_only_single_class_extra_doritos_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_image-size-{IMG_HEIGHT}x{IMG_WIDTH}_epochs-{EPOCHS}_batch-size-{BATCH_SIZE}_learning-rate-{LEARNING_RATE}.h5"

model.save(MODEL_SAVE_DIR / file_name)

print(f"Model saved as {file_name} to {MODEL_SAVE_DIR}")

In [None]:
# Load model

# model = tf.keras.models.load_model(
#     MODEL_SAVE_DIR
#     / "haribonet_single_class_2023-12-19_13-13-45_image-size-256x256_epochs-30_batch-size-32_learning-rate-0.001.h5"
# )