In [None]:
from keras.utils import normalize, to_categorical
import os
import cv2
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
from unet import unet_model
import random
from pathlib import Path
import logging
import itertools

LOGGER = logging.getLogger()
from datetime import datetime

from skimage.filters import gaussian

import tensorflow as tf
from sklearn.model_selection import train_test_split

for key, value in os.environ.items():
    print(f"{key} : {value}")

In [None]:
# Ensure that your GPU is working
tf.test.gpu_device_name()

In [None]:
# Set the random seeds
SEED = 0
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [None]:
# MacOS
ORIGINAL_IMAGE_DIR = Path("/Users/sylvi/topo_data/cats/training_data/cropped/cropped_images")
MASK_DIR = Path("/Users/sylvi/topo_data/cats/training_data/cropped/cropped_labels")

# Linux
MODEL_SAVE_DIR = Path("/Users/sylvi/topo_data/cats/saved_models/")
CHANNELS = 1

BATCH_SIZE = 5
NORM_LOWER_BOUND = -1
NORM_UPPER_BOUND = 7
MAX_ZOOM_PERCENTAGE = 0.3
MAX_GAUSSIAN_BLUR = 3

# Get the number of .png images
NUM_IMAGES = len(list(ORIGINAL_IMAGE_DIR.glob("image_*.npy")))
NUM_MASKS = len(list(MASK_DIR.glob("mask_*.npy")))
print(f"Number of images: {NUM_IMAGES}, Number of masks: {NUM_MASKS}")

In [None]:
sanity_check_image_index = 74

img = np.load(ORIGINAL_IMAGE_DIR / f"image_{sanity_check_image_index}.npy")
plt.imshow(img, cmap="gray")
plt.show()
print(img.shape)

msk = np.load(MASK_DIR / f"mask_{sanity_check_image_index}.npy")
plt.imshow(msk, cmap="gray")
print(msk.shape)

In [None]:
def zoom_and_shift(image: np.ndarray, ground_truth: np.ndarray, max_zoom_percentage: float = 0.1) -> np.ndarray:
    """Zooms in on the image by a random amount between 0 and max_zoom_percentage,
    then shifts the image by a random amount up to the number of zoomed pixels.
    """

    # Choose a zoom percentage and caluculate the number of pixels to zoom in
    zoom = np.random.uniform(0, max_zoom_percentage)
    zoom_pixels = int(image.shape[0] * zoom)

    # If there is zoom, choose a random shift
    if int(zoom_pixels) > 0:
        shift_x = np.random.randint(int(-zoom_pixels), int(zoom_pixels))
        shift_y = np.random.randint(int(-zoom_pixels), int(zoom_pixels))

        # Zoom and shift the image
        zoomed_and_shifted_image = image[
            zoom_pixels + shift_x : -zoom_pixels + shift_x,
            zoom_pixels + shift_y : -zoom_pixels + shift_y,
        ]
        zoomed_and_shifted_ground_truth = ground_truth[
            zoom_pixels + shift_x : -zoom_pixels + shift_x,
            zoom_pixels + shift_y : -zoom_pixels + shift_y,
        ]
    else:
        # Do nothing
        shift_x = 0
        shift_y = 0

        zoomed_and_shifted_image = image
        zoomed_and_shifted_ground_truth = ground_truth

    return zoomed_and_shifted_image, zoomed_and_shifted_ground_truth


# An image generator that loads images as they are needed
def image_generator(image_indexes, batch_size=4):
    while True:
        # Select files (paths/indices) for the batch
        batch_image_indexes = np.random.choice(a=image_indexes, size=batch_size)
        batch_input = []
        batch_output = []

        # Load the image and ground truth
        for index in batch_image_indexes:
            # Load the image and ground truth
            image = np.load(ORIGINAL_IMAGE_DIR / f"image_{index}.npy")
            ground_truth = np.load(MASK_DIR / f"mask_{index}.npy")
            ground_truth = ground_truth.astype(bool)

            # Zoom and shift the image
            image, ground_truth = zoom_and_shift(image, ground_truth, max_zoom_percentage=MAX_ZOOM_PERCENTAGE)

            # Resize without interpolation
            image = Image.fromarray(image)
            image = image.resize((512, 512))
            image = np.array(image)

            # Apply a gaussian filter to the image
            # gaussian_size = np.random.random() * MAX_GAUSSIAN_BLUR
            # image = gaussian(image, sigma=gaussian_size)

            # Normalise the image
            image = np.clip(image, NORM_LOWER_BOUND, NORM_UPPER_BOUND)
            image = image - NORM_LOWER_BOUND
            image = image / (NORM_UPPER_BOUND - NORM_LOWER_BOUND)

            # Resize without interpolation
            ground_truth = Image.fromarray(ground_truth.astype(np.uint8))
            ground_truth = ground_truth.resize((512, 512), resample=Image.NEAREST)
            ground_truth = np.array(ground_truth).astype(int)

            # Flip the images 50% of the time
            if random.choice([0, 1]) == 1:
                image = np.flip(image, axis=1)
                ground_truth = np.flip(ground_truth, axis=1)

            # Rotate the images by either 0, 90, 180, or 270 degrees
            rotation = random.choice([0, 1, 2, 3])
            image = np.rot90(image, rotation)
            ground_truth = np.rot90(ground_truth, rotation)

            batch_input.append(image)
            batch_output.append(ground_truth)

        batch_x = np.array(batch_input).astype(np.float32)
        batch_y = np.array(batch_output).astype(np.float32)

        yield (batch_x, batch_y)

In [None]:
# Check that the generator is doing the right thing
batch_generator = image_generator([1, 2, 3, 4], batch_size=4)
(batch_x, batch_y) = next(batch_generator)
for image, mask in zip(batch_x, batch_y):
    plt.imshow(image)
    print(f"image shape: {image.shape}")
    print(f"image max: {np.max(image)}")
    print(f"image min: {np.min(image)}")
    plt.show()
    print(f"mask shape: {mask.shape}")
    plt.imshow(mask)
    print(f"mask shape: {mask.shape}")
    print(f"mask unique: {np.unique(mask)}")
    plt.show()

    print(f"mask dtype: {mask.dtype}")

In [None]:
# Split what images are used for training and validation
train_image_indexes, validation_image_indexes = train_test_split(range(NUM_IMAGES), test_size=0.2, random_state=SEED)

print(f"Number of training images: {len(train_image_indexes)}")
print(f"Number of validation images: {len(validation_image_indexes)}")

print(f"Training image indexes: {train_image_indexes}")
print(f"Validation image indexes: {validation_image_indexes}")

# Create the generators
train_generator = image_generator(train_image_indexes, batch_size=4)
validation_generator = image_generator(validation_image_indexes, batch_size=4)

In [None]:
# Get the model
model = unet_model(IMG_HEIGHT=512, IMG_WIDTH=512, IMG_CHANNELS=CHANNELS)
# model.summary()

In [None]:
EPOCHS = 25
AUGMENTATION_FACTOR = 8
history = model.fit(
    train_generator,
    # How many steps (batches of samples) to draw from generator before declaring one epoch finished and starting the next epoch
    steps_per_epoch=(NUM_IMAGES // BATCH_SIZE) * AUGMENTATION_FACTOR,
    epochs=EPOCHS,
    validation_data=validation_generator,
    # How many steps (batches) to yield from validation generator at the end of every epoch
    validation_steps=(NUM_IMAGES // BATCH_SIZE) * AUGMENTATION_FACTOR,
    verbose=1,
)

In [None]:
loss = history.history["loss"]
val_loss = history.history["val_loss"]
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, "y", label="Training loss")
plt.plot(epochs, val_loss, "r", label="Valdation loss")
plt.title("Training and validation loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()


print(history.history.keys())
# acc = history.history["mean_io_u"]
# val_acc = history.history["val_mean_io_u"]

acc = history.history["accuracy"]
val_acc = history.history["val_accuracy"]

plt.plot(epochs, acc, "y", label="Training acc")
plt.plot(epochs, val_acc, "r", label="Validation acc")
plt.title("Training and validation accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.show()


# For each key, value in the history, plot
for key, value in history.history.items():
    plt.plot(epochs, value, label=key)
plt.legend()
plt.show()

In [None]:
# Save the model
model.save(MODEL_SAVE_DIR / f"catsnet_no_gauss{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.h5")

# Save the history
np.save(
    MODEL_SAVE_DIR / f"catsnet_no_gauss{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_history.npy",
    history.history,
)

In [None]:
# Show the results of the model on the testing set using the validation generator
# Get the next batch from the generator
(batch_x, batch_y) = next(train_generator)
# Predict the masks
predicted_masks = model.predict(batch_x)

threshold = 0.5

for image, mask, predicted_mask in zip(batch_x, batch_y, predicted_masks):
    # Threshold the predicted mask
    predicted_mask = predicted_mask > threshold

    # Plot the result
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(image)
    ax[0].set_title("Image")
    ax[1].imshow(mask)
    ax[1].set_title("Ground truth mask")
    ax[2].imshow(predicted_mask)
    ax[2].set_title("Predicted mask")
    plt.show()

In [None]:
# Load a model

from topostats.grain_finding_cats_unet import dice_loss, iou_loss

model = tf.keras.models.load_model(
    MODEL_SAVE_DIR / "catsnet_REDUCED_EQUAL_TYPES_2024-04-19_23-39-12.h5",
    custom_objects={"dice_loss": dice_loss, "iou_loss": iou_loss},
)
model.summary()

In [None]:
# Run the model on a single image
# image_number = 76
# image = np.load(ORIGINAL_IMAGE_DIR / f"image_{image_number}.npy")
# mask = np.load(MASK_DIR / f"mask_{image_number}.npy")

# Load libby's data
image = np.load("/Users/sylvi/topo_data/cats/libby_shelterin_examples/image_1.npy")
mask = np.zeros_like(image)

# Normalise the image
# image = image - np.min(image)
# image = image / np.max(image)
# cap image
image = np.clip(image, NORM_LOWER_BOUND, NORM_UPPER_BOUND)
image = image - NORM_LOWER_BOUND
image = image / (NORM_UPPER_BOUND - NORM_LOWER_BOUND)

# crop the image

# image = image[300:-300, 300:-300]
# mask = mask[300:-300, 300:-300]

# Resize the image to 512x512
image = Image.fromarray(image)
image = image.resize((512, 512))
image = np.array(image)

# Predict the mask
predicted_mask = model.predict(np.expand_dims(image, axis=0))

# Threshold the predicted mask
threshold = 0.5
predicted_mask = predicted_mask > threshold

# Plot the result
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].imshow(image)
ax[0].set_title("Image")
ax[1].imshow(predicted_mask[0, :, :, 0])
ax[1].set_title("Predicted mask")
ax[2].imshow(mask)
ax[2].set_title("Ground truth mask")
plt.show()


# Plot kde of image data

import seaborn as sns

sns.kdeplot(image.flatten())
plt.xlabel("Pixel heights")
plt.title("Height distributions")
plt.show()