In [None]:
from keras.utils import normalize
from keras.utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator
import os
import cv2
from PIL import Image
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
import seaborn as sns
from unet import unet_model
from unet_any_size_single_output import unet_model
import random
from pathlib import Path
import logging
import itertools
import re

LOGGER = logging.getLogger()
from datetime import datetime

from skimage.morphology import binary_erosion
from skimage.morphology import skeletonize
from skimage.filters import hessian
from skimage.feature import hessian_matrix, hessian_matrix_eigvals
from skimage.morphology import label
from skimage.measure import regionprops
from skimage.color import label2rgb

# TopoStats needs to be >= version 2.1.0
# from topostats import io
# from topostats import grain_finding_cats_unet

import tensorflow as tf
from sklearn.model_selection import train_test_split

for key, value in os.environ.items():
    print(f"{key} : {value}")

In [None]:
# Ensure that your GPU is working
tf.test.gpu_device_name()

In [None]:
# Set the random seeds
SEED = 0
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [None]:
ORIGINAL_IMAGE_DIR = Path("/Users/sylvi/topo_data/cats/training_data/images_flattened_all/")
MASK_DIR = Path("/Users/sylvi/topo_data/cats/training_data/images_edge_detection_upper_labels_multiclass_sigma_4")

MODEL_SAVE_DIR = Path("./saved_models")
CHANNELS = 1

BATCH_SIZE = 4

# Get the number of .png images
NUM_IMAGES = len(list(ORIGINAL_IMAGE_DIR.glob("*.png")))
print(f"Number of images: {NUM_IMAGES}")

In [None]:
# An image generator that loads images as they are needed
def image_generator(image_indexes, batch_size=4):
    while True:
        # Select files (paths/indices) for the batch
        batch_image_indexes = np.random.choice(a=image_indexes, size=batch_size)
        batch_input = []
        batch_output = []

        # Load the image and ground truth
        for index in batch_image_indexes:
            # Find the index as the only number in the filename
            # index = re.search(r"\d+", image_path)
            # print(index)
            image = cv2.imread(str(ORIGINAL_IMAGE_DIR / f"training_image_{index}.png"), 0)
            image = Image.fromarray(image)
            image = image.resize((512, 512))
            image = np.array(image)
            # Normalise the image
            image = image - np.min(image)
            image = image / np.max(image)

            ground_truth = np.load(MASK_DIR / f"mask_array_{index}.npy")
            ground_truth = ground_truth.astype(bool)
            ground_truth = Image.fromarray(ground_truth.astype(np.uint8))
            ground_truth = ground_truth.resize((512, 512))
            ground_truth = np.array(ground_truth).astype(int)

            # Augment the images
            # Flip the images 50% of the time
            if random.choice([0, 1]) == 1:
                image = np.flip(image, axis=1)
                ground_truth = np.flip(ground_truth, axis=1)
            # Rotate the images by either 0, 90, 180, or 270 degrees
            rotation = random.choice([0, 1, 2, 3])
            image = np.rot90(image, rotation)
            ground_truth = np.rot90(ground_truth, rotation)

            batch_input.append(image)
            batch_output.append(ground_truth)

        batch_x = np.array(batch_input).astype(np.float32)
        batch_y = np.array(batch_output).astype(np.float32)

        yield (batch_x, batch_y)

In [None]:
# Check that the generator is doing the right thing
batch_generator = image_generator([0, 1, 2, 3, 4], batch_size=4)
(batch_x, batch_y) = next(batch_generator)
for image, mask in zip(batch_x, batch_y):
    plt.imshow(image)
    print(f"image shape: {image.shape}")
    print(f"image max: {np.max(image)}")
    print(f"image min: {np.min(image)}")
    plt.show()
    plt.imshow(mask)
    print(f"mask shape: {mask.shape}")
    print(f"mask unique: {np.unique(mask)}")
    plt.show()

    print(f"mask dtype: {mask.dtype}")

In [None]:
# Split what images are used for training and validation
train_image_indexes, validation_image_indexes = train_test_split(range(0, NUM_IMAGES), test_size=0.2, random_state=SEED)

print(f"Number of training images: {len(train_image_indexes)}")
print(f"Number of validation images: {len(validation_image_indexes)}")

print(f"Training image indexes: {train_image_indexes}")
print(f"Validation image indexes: {validation_image_indexes}")

# Create the generators
train_generator = image_generator(train_image_indexes, batch_size=4)
validation_generator = image_generator(validation_image_indexes, batch_size=4)

In [None]:
# Get the model
model = unet_model(IMG_HEIGHT=512, IMG_WIDTH=512, IMG_CHANNELS=CHANNELS)
# model.summary()

In [None]:
EPOCHS = 50
history = model.fit_generator(
    train_generator,
    # How many steps (batches of samples) to draw from generator before declaring one epoch finished and starting the next epoch
    steps_per_epoch=NUM_IMAGES // BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=validation_generator,
    # How many steps (batches) to yield from validation generator at the end of every epoch
    validation_steps=NUM_IMAGES // BATCH_SIZE,
    verbose=1,
)