## Imports and settings

In [None]:
# Fix randomness and hide warnings
seed = 1234

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ['MPLCONFIGDIR'] = os.getcwd()+'/configs/'

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)

import numpy as np
np.random.seed(seed)

import logging

import random
random.seed(seed)

In [None]:
# Import tensorflow
import tensorflow as tf
from tensorflow import keras
from keras import backend
tf.autograph.set_verbosity(0)
tf.get_logger().setLevel(logging.ERROR)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
tf.random.set_seed(seed)
tf.compat.v1.set_random_seed(seed)
print(tf.__version__)

In [None]:
# Check if tensorflow is using GPU
print(tf.config.list_physical_devices('GPU'))

In [None]:
import utils
import albumentations as A

## Load the data

In [None]:
T1_samples_loaded, T2_samples_loaded, T1_labels_loaded, T2_labels_loaded = utils.load_preprocessed_data_train()

## Prepare the data for the network

In [None]:
T1_data_samples, T2_data_samples, T1_data_labels, T2_data_labels = utils.prepare_data_for_training(T1_samples_loaded, T2_samples_loaded, T1_labels_loaded, T2_labels_loaded)

In [None]:
# Vertically stack the T1 and T2 data (samples and labels)
data_samples = np.vstack((T1_data_samples, T2_data_samples))
data_labels = np.vstack((T1_data_labels, T2_data_labels))

# Print the shape of the data
print(data_samples.shape)
print(data_labels.shape)

In [None]:
# Split the data into training, validation and test sets
X_train, X_val, X_test, y_train, y_val, y_test = utils.split_data_for_training(data_samples, data_labels)

In [None]:
input_shape = X_train[0].shape
num_classes = len(np.unique(y_train))
print(f"Input shape: {input_shape}")
print(f"Number of classes: {num_classes}")

## Create the model and a custom IOU metric

In [None]:
# Create the UNet model
unet_model = utils.get_unet_model(input_shape=T1_data_samples[0].shape, num_classes=len(np.unique(T1_data_labels)))
unet_model.summary()

In [None]:
# Define a custom metric class for mean intersection over union (IoU)
class UpdatedMeanIoU(tf.keras.metrics.MeanIoU):
    def __init__(self, num_classes=None, name="mean_iou", dtype=None):
        super(UpdatedMeanIoU, self).__init__(num_classes=num_classes, name=name, dtype=dtype)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.math.argmax(y_pred, axis=-1)
        return super().update_state(y_true, y_pred, sample_weight)

## Train the model

In [None]:
# Free up RAM
del T1_samples_loaded, T2_samples_loaded, T1_labels_loaded, T2_labels_loaded, T1_data_samples, T2_data_samples, T1_data_labels, T2_data_labels, data_samples, data_labels

In [None]:
# Define input parameters
input_shape = X_train[0].shape
num_classes = len(np.unique(y_train))

# Define hyperparameters
learning_rate = 1e-3
batch_size = 16
epochs = 1000

# Define compile parameters
loss = keras.losses.SparseCategoricalCrossentropy()
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
metrics = ['accuracy', UpdatedMeanIoU(num_classes=num_classes)]

# Define callbacks
patience = 30
early_stopping = keras.callbacks.EarlyStopping(monitor='val_mean_iou', mode='max', patience=patience, restore_best_weights=True)
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor='val_mean_iou', mode='max', factor=0.1, patience=patience-5, min_lr=1e-6)

In [None]:
# Compile the model
unet_model.compile(loss=loss, optimizer=optimizer, metrics=metrics)

In [None]:
# Train the model
history = unet_model.fit(
    X_train,
    y_train,
    batch_size=batch_size,
    epochs=epochs,
    validation_data=(X_val, y_val),
    callbacks=[early_stopping, reduce_lr]
)

## Create a data augmentation pipeline

In [None]:
augmentation_pipeline = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5)
])

In [None]:
# Test the augmentation pipeline
sample = X_train[10]
label = y_train[10]
augmented = augmentation_pipeline(image=sample, mask=label)
augmented_image = augmented['image']
augmented_label = augmented['mask']

# Plot the original image and the augmented image, together with their labels
utils.plot_sample(sample, label, plot_separately=True)
utils.plot_sample(augmented_image, augmented_label, plot_separately=True)

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
# Create a function to apply the augmentation pipeline to a dataset
def augment_function(image, label):
    augmented = augmentation_pipeline(image=sample, mask=label)
    augmented_image = augmented['image']
    augmented_label = augmented['mask']
    augmented_image = tf.cast(augmented_image, tf.float32)
    augmented_label = tf.cast(augmented_label, tf.int32)
    return augmented_image, augmented_label

# Create a function to apply the augmentation pipeline to a dataset
def process_data(image, label):
    augmented_image, augmented_label = tf.numpy_function(func=augment_function, inp=[image, label], Tout=[tf.float32, tf.int32])
    return augmented_image, augmented_label

# Create the augmented dataset
augmented_dataset = X_train
augmented_dataset