# Assingment 2 - Contrast transfer
### Course: Convolutional Neural Networks with Applications in Medical Image Analysis

For the second assignment we will use the same dataset as before! Previously you have classified the available contrasts of the same anatomy, and for this assignment you will train an image to image model to generate one contrast from another. The task is to take T1-weighted images as inputs, and generate the corresponding T2-weighted images.

Your tasks, to include in the Jupyter notebook you hand in, are:
- Reach a validation MSE below 0.015 on the validation set, and describe what parameter combinations you have gone through to reach those results.
- Describe the effect of each hyper-parameter you have changed, and the way you have experimented with them. What problems did you face? What happened when the training failed? Try describing everything that you have learnt.
- Answer the questions set in notes

Upload the updated notebook to canvas, that also contains your answers to the questions above. The deadline for the assignment is March $30^{th}$, 15:00.

Good luck and have fun!

In [None]:
import os
import numpy as np
np.random.seed(2023)  # Set seed for reproducibility
import tensorflow as tf
import tensorflow.keras as keras
tf.random.set_seed(2026) # Note: Different to test different initializations.
!pip install tqdm # Adding tqdm to use progress bars. Unbarable waiting for each epoch to finish without feedback.
from tqdm import tqdm

## GPU verification

In [None]:
from keras.utils import img_to_array
from keras.utils import load_img
from keras.utils import to_categorical

# Using ImageGrid to plot the encodings.
from mpl_toolkits.axes_grid1 import ImageGrid
from typing import List
import matplotlib.pyplot as plt

gpus = tf.config.experimental.list_physical_devices('GPU')
available = len(gpus) > 0
if available:
    tf.config.experimental.set_memory_growth(gpus[0], True)
    print(f"GPU(s) available (using '{gpus[0].name}'). Training will be lightning fast!")
    # Run some dummy code to initialize the GPU.
    rand = tf.random.uniform((100, 100))
    _ = tf.matmul(rand, rand)
else:
    print("No GPU(s) available. Training will be suuuuper slow!")

# NOTE: These are the packages you will need for the assignment.
# NOTE: You are encouraged to use the course virtual environment, which already has GPU support.

## Data Generator class

In [None]:
class DataGenerator(keras.utils.Sequence):
    def __init__(self,
                 data_path,
                 arrays,
                 batch_size=32,
                 ):

        self.data_path = data_path
        self.arrays = arrays
        self.batch_size = batch_size

        if data_path is None:
            raise ValueError('The data path is not defined.')

        if not os.path.isdir(data_path):
            raise ValueError('The data path is incorrectly defined.')

        self.file_idx = 0
        self.file_list = [self.data_path + '/' + s for s in
                          os.listdir(self.data_path)]
        
        self.on_epoch_end()
        with np.load(self.file_list[0]) as npzfile:
            self.in_dims = []
            self.n_channels = 1
            for i in range(len(self.arrays)):
                im = npzfile[self.arrays[i]]
                self.in_dims.append((self.batch_size,
                                    *np.shape(im),
                                    self.n_channels))
        # Empty initialization array
        arrays = []
        for i in range(len(self.arrays)):
            arrays.append(np.empty(self.in_dims[i]).astype(np.single))
        self.init_arrays = arrays

    def __len__(self):
        """Get the number of batches per epoch."""
        return int(np.floor((len(self.file_list)) / self.batch_size))

    def __getitem__(self, index):
        """Generate one batch of data."""
        # Generate indexes of the batch
        indexes = self.indexes[index * self.batch_size:(index + 1) *
                               self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.file_list[k] for k in indexes]

        # Generate data
        a = self.__data_generation(list_IDs_temp)
        return a

    def on_epoch_end(self):
        """Update indexes after each epoch."""
        self.indexes = np.arange(len(self.file_list))
        np.random.shuffle(self.indexes)
    
    #@threadsafe_generator
    def __data_generation(self, temp_list):
        """Generate data containing batch_size samples."""
        # X : (n_samples, *dim, n_channels)
        # Initialization
        arrays = self.init_arrays

        # for i in range(len(self.arrays)):
        #     arrays.append(np.empty(self.in_dims[i]).astype(np.single))

        for i, ID in enumerate(temp_list):
            with np.load(ID) as npzfile:
                for idx in range(len(self.arrays)):
                    x = npzfile[self.arrays[idx]] \
                        .astype(np.single)
                    x = np.expand_dims(x, axis=2)
                    # Check if any nan
                    x_max = np.max(x)
                    if x_max == 0:
                        arrays[idx][i, ] = x
                    else:
                        arrays[idx][i, ] = x / x_max

        return arrays

# NOTE: Don't change the data generator!

### Data Generators creation

In [None]:
gen_dir = './data_zip/' # Change if you have copied the data locally on your machine 
array_labels = ['t1', 't1ce', 't2', 'flair','mask']  # Available arrays are: 't1', 't1ce', 't2', 'flair', 'mask'.
per_batch = len(array_labels)-1 # We don't want to use the mask as an input.
batch_size = 16

batch_size = batch_size//per_batch 

gen_train = DataGenerator(data_path=gen_dir + 'training',
                          arrays=array_labels,
                          batch_size=batch_size)

gen_val = DataGenerator(data_path=gen_dir + 'validating',
                        arrays=array_labels,
                        batch_size=batch_size)

gen_test = DataGenerator(data_path=gen_dir + 'testing',
                         arrays=array_labels,
                         batch_size=batch_size)

# NOTE: What arrays are you using? You can use multiple contrasts as inputs, if you'd like.
# NOTE: What batch size are you using? Should you use more? Or less?
# NOTE: Are you using the correct generators for the correct task? Training for training and validating for validating?

In [None]:
# Check for nan in masks:
nan_count = 0
for batch in tqdm(gen_train):
    mask = batch[-1]
    for i in range(mask.shape[0]):
        if np.isnan(mask[i]).any():
            # Count how many nan values are in the mask.
            nan_count += 1
print(f"Found {nan_count}/{len(gen_train)*batch_size} batches with nan in the mask.")


## Plot examples of the dataset

In [None]:
# A quick summary of the data:
print(f"Number of training images : {len(gen_train.file_list)}")
print(f"Training batch size       : {gen_train.in_dims}")

## Keras Imports

In [None]:
# Import packages important for building and training your model.
# import keras
from keras import backend as K
from keras import mixed_precision
from keras import optimizers
from keras.models import Model
from keras.layers import Dense, Conv2D
from keras.layers import Flatten, Input
from keras.layers import MaxPooling2D, GlobalAveragePooling2D
from keras.layers import Activation, concatenate
from keras.layers import BatchNormalization
from keras.layers import Dropout, UpSampling2D
from keras.models import Sequential
from keras.optimizers import Adam, RMSprop, Nadam
# Dice coefficient loss function
from keras.losses import binary_focal_crossentropy, BinaryFocalCrossentropy

## Provided model

In [None]:
# Model provided by the supervisors
from tensorflow import Tensor
from keras.layers import Input, Conv2D, ReLU, BatchNormalization, \
                        Add, AveragePooling2D, Flatten, Dense, UpSampling2D
from keras.models import Model

def build_model():
    filt_size = 8
    # input1 = Input(shape=(128, 128, 1))
    input1 = Input(shape=(256, 256, 1))

    conv1 = Conv2D(filt_size, 3, activation='relu', padding='same', kernel_initializer='he_normal')(input1)
    conv1 = Conv2D(filt_size, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = Conv2D(filt_size * 2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
    conv2 = Conv2D(filt_size * 2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = Conv2D(filt_size * 4, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
    conv3 = Conv2D(filt_size * 4, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    conv4 = Conv2D(filt_size * 8, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
    conv4 = Conv2D(filt_size * 8, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)

    conv5 = Conv2D(filt_size * 16, 3, activation='relu', padding='same', kernel_initializer='he_normal',name="bottle_neck1")(pool4)
    conv5 = Conv2D(filt_size * 16, 3, activation='relu', padding='same', kernel_initializer='he_normal',name="bottle_neck2")(conv5)
    drop5 = Dropout(0.5)(conv5)
    skip46 = Conv2D(filt_size * 8, 1, activation='relu', padding='same', kernel_initializer='he_normal',name="skip4-6")(conv4)
    up6 = Conv2D(filt_size * 8, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2, 2))(drop5))
    merge6 = concatenate([drop4, up6], axis=3)
    conv6 = Conv2D(filt_size * 8, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6) + skip46
    conv6 = Conv2D(filt_size * 8, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)

    skip37 = Conv2D(filt_size * 4, 1, activation='relu', padding='same', kernel_initializer='he_normal',name="skip3-7")(conv3)
    up7 = Conv2D(filt_size * 4, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2, 2))(conv6))
    merge7 = concatenate([conv3, up7], axis=3)
    conv7 = Conv2D(filt_size * 4, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7) + skip37
    conv7 = Conv2D(filt_size * 4, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)

    skip28 = Conv2D(filt_size * 2, 1, activation='relu', padding='same', kernel_initializer='he_normal',name="skip2-8")(conv2)
    up8 = Conv2D(filt_size * 2, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2, 2))(conv7))
    merge8 = concatenate([conv2, up8], axis=3)
    conv8 = Conv2D(filt_size * 2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8) + skip28
    conv8 = Conv2D(filt_size * 2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8)

    skip19 = Conv2D(filt_size, 1, activation='relu', padding='same', kernel_initializer='he_normal',name="skip1-9")(conv1)
    up9 = Conv2D(filt_size, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2, 2))(conv8))
    merge9 = concatenate([conv1, up9], axis=3)
    conv9 = Conv2D(filt_size, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9) + skip19
    conv9 = Conv2D(filt_size, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
    conv9 = Conv2D(2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
    conv10 = Conv2D(1, 1, activation="sigmoid",name="prediction")(conv9)

    return Model(inputs=input1, outputs=conv10)

In [None]:
# Model parameters
from archs.segmentation.unet import build_unet
W,H,C = 256,256,1
input_shape = (W,H,C)
num_classes = C # Number of classes is equal to the number of channels in the output
filters = [8,16,32,64,128,256,512,1024]
kernel_size = [3,3,3,3,3,3,3,1]
strides = 1
padding = "same"
activation = "selu"
# drop_rate_encoder = [0.01,0.02,0.02,0.1]
# drop_rate_decoder = [0.01,0.1,0.1,0]
drop_rate_encoder = [0.0,0.02,0.02,0.05]
drop_rate_decoder = [0.0]
depth_encoder = [2,2,3,4,4,5,6,8]
depth_decoder = [1,1,1,1,1,1,1,1]
output_depth = 10
output_activation = "sigmoid"

model = build_unet(
    input_shape=input_shape,
    num_classes=num_classes,
    filters=filters,
    kernel_size=kernel_size,
    strides=strides,
    padding=padding,
    activation=activation,
    depth_encoder=depth_encoder,
    decoder_type="add",
    upsample_type="bilinear",
    depth_decoder=depth_decoder,
    drop_rate_encoder=drop_rate_encoder,
    drop_rate_decoder=drop_rate_decoder,
    output_depth=output_depth,
    output_activation=output_activation,
)
model.summary()

### Important questions to answer:

In [None]:
# Build your model.
# model = build_model()
# model.summary()

# NOTE: Are the input sizes correct?
# NOTE: Do you have the correct number of input images?
# NOTE: Are the output sizes correct?
# NOTE: Do you have the correct number of output images?
# NOTE: What's the range of the output? Can you use an activation as a regularizer?
# NOTE: Try to imagine the model layer-by-layer and think it through. Is it doing something reasonable?
# NOTE: Are your parameters split evenly inside the model? Try making "too large" layers smaller
# NOTE: Will the model fit into memory? Is the model too small? Is the model too large?

## Augmentation class

In [None]:
# Augmentation class for used in the training pipeline.
# float wrapper for probability
def prob(p: float) -> bool:
    return np.random.random() < p
class Augmentation:
    verbose: bool = False
    # Parent class for all augmentations
    def __init__(self, p: float):
        self.p = p
    @property
    def name(self) -> str:
        return self.__class__.__name__
    def __call__(self, x: np.ndarray,y: np.ndarray) -> np.ndarray:
        if prob(self.p):
            if self.verbose:
                print(f"Augmenting: Applying {self.name}")
            return self.augment(x,y)
        else:
            return x,y
    def augment(self, x: np.ndarray,y:np.ndarray) -> np.ndarray:
        raise NotImplementedError
    

class Flip(Augmentation):
    def __init__(self, p: float = 0.5, axis: int = 0):
        super().__init__(p)
        self.axis = axis # 1 for horizontal, 0 for vertical
    def augment(self, x: np.ndarray,y:np.ndarray) -> np.ndarray:
        return np.flip(x, axis=self.axis),np.flip(y, axis=self.axis)

class Rotate(Augmentation):
    def __init__(self, p: float = 0.5, angle: float = np.pi/4):
        super().__init__(p)
        self.angle = angle
    def augment(self, x: np.ndarray,y:np.ndarray) -> np.ndarray:
        # Rotate image and fill with zeros
        # Random angle between -angle and angle
        angle = np.random.uniform(-self.angle, self.angle)
        # return rotate(x, angle, resize=False, mode="constant", cval=0),rotate(y, angle, resize=False, mode="constant", cval=0)
        # Batched rotate
        return np.array([rotate(img, angle, resize=False, mode="constant", cval=0) for img in x]),np.array([rotate(img, angle, resize=False, mode="constant", cval=0) for img in y])
class Noise(Augmentation):
    def __init__(self, p: float = 0.5, mean: float = 0.0, std: float = 0.1):
        super().__init__(p)
        self.mean = mean
        self.std = std
    def augment(self, x: np.ndarray,y:np.ndarray) -> np.ndarray:
        noise = np.random.normal(self.mean, self.std, x.shape)
        return (x + noise, y)

class Mask(Augmentation):
    def __init__(self, p: float = 0.5, max_n_masks: int = 10, mask_size: float = 0.5):
        super().__init__(p)
        self.max_n_masks = max_n_masks
        self.mask_size = mask_size
    def augment(self, x: np.ndarray, y:np.ndarray) -> np.ndarray:
        # Random number of masks
        n_masks = np.random.randint(1, self.max_n_masks)
        w = x.shape[-3]
        h = x.shape[-2]
        for _ in range(n_masks):
            # Random mask size
            mask_size = np.random.uniform(low=self.mask_size/2, high=self.mask_size)
            # Random mask position
            x1 = np.random.randint(0, w)
            y1 = np.random.randint(0, h)
            x2 = int(x1 + w * mask_size)
            y2 = int(y1 + h * mask_size)
            x[:,..., x1:x2, y1:y2,:] = 0
        return (x, y)

class Translate(Augmentation):
    def __init__(self, p: float = 0.5, factor: float = 0.5):
        super().__init__(p)
        self.factor = factor
    def augment(self, x: np.ndarray,y:np.ndarray) -> np.ndarray:
        # Random translation factor
        tx = np.random.uniform(-self.factor, self.factor) * x.shape[0]
        ty = np.random.uniform(-self.factor, self.factor) * x.shape[1]
        # Affine transform, grayscale image so no need to transform channels
        tform = AffineTransform(translation=(tx, ty))
        # Apply transform, Image will be filled with zeros
        # x = warp(x, tform.inverse, mode="constant", cval=0)
        # y = warp(y, tform.inverse, mode="constant", cval=0)
        # Batched warp
        x = np.array([warp(img, tform.inverse, mode="constant", cval=0) for img in x])
        y = np.array([warp(img, tform.inverse, mode="constant", cval=0) for img in y])
        return x,y

class Shear(Augmentation):
    def __init__(self, p: float = 0.5, factor: float = 0.5):
        super().__init__(p)
        self.factor = factor
    def augment(self, x: np.ndarray,y:np.ndarray) -> np.ndarray:
        # Random shear factor
        shear_factor = np.random.uniform(-self.factor, self.factor)
        # Create affine transform
        tform = AffineTransform(shear=shear_factor)
        # Use warp to apply transform
        # x = warp(x, tform.inverse, mode="constant", cval=0)
        # y = warp(y, tform.inverse, mode="constant", cval=0)
        # Batched warp
        x = np.array([warp(img, tform.inverse, mode="constant", cval=0) for img in x])
        y = np.array([warp(img, tform.inverse, mode="constant", cval=0) for img in y])
        return x,y
class Augmentor:
    """
    Augmentations:
    flip_x: float
        Probability of flipping the image horizontally
    flip_y: float
        Probability of flipping the image vertically
    rotate: float
        Probability of rotating the image
    radians: float
        Maximum rotation angle in radians
    translate: float
        Probability of translating the image
    noise: float
        Probability of adding noise to the image
    noise_std: float
        Standard deviation of the noise
    noise_mean: float
        Mean of the noise
    mask: float
        Probability of masking the image
    max_n_masks: int
        Maximum number of masks
    mask_size: float
        Maximum size of the mask as a fraction of the image size
    shear: float
        Probability of shearing the image
    shear_factor: float
        Maximum shear factor
    """
    def __init__(self,
                flip_x:float=0.25,
                flip_y:float=0.25,
                rotate:float=0.5,
                radians:float=np.pi/6,
                translate:float=0.2,
                noise:float=0.25,
                noise_std:float=0.1,
                noise_mean:float=0.1,
                mask:float=0.8,
                max_n_masks:int=10,
                mask_size:float=0.25,
                shear:float=0.1,
                shear_factor:float=0.4,
                verbose:bool=False
                ):
        self.verbose = verbose
        self._active = True
        Augmentation.verbose = self.verbose
        self.augmentations = {}
        if noise > 0:
            self.augmentations["noise"] = Noise(p=noise, std=noise_std, mean=noise_mean)
        if flip_x > 0:
            self.augmentations["flip_x"] = Flip(p=flip_x, axis=1)
            # self.augmentations.append(Flip(p=flip_x, axis=1))
        if flip_y > 0:
            self.augmentations["flip_y"] = Flip(p=flip_y, axis=0)
            # self.augmentations.append(Flip(p=flip_y, axis=0))
        if rotate > 0:
            self.augmentations["rotate"] = Rotate(p=rotate, angle=radians)
            # self.augmentations.append(Rotate(p=rotate, angle=radians))
        if translate > 0:
            self.augmentations["translate"] = Translate(p=translate, factor=translate)
            # self.augmentations.append(Translate(p=translate, factor=translate))
        if mask > 0:
            self.augmentations["mask"] = Mask(p=mask, max_n_masks=max_n_masks, mask_size=mask_size)
            # self.augmentations.append(Mask(p=mask, max_n_masks=max_n_masks, mask_size=mask_size))
        if shear > 0:
            self.augmentations.append(Shear(p=shear, factor=shear_factor))
    def __call__(self, x: np.ndarray,y:np.ndarray) -> np.ndarray:
        if self._active:
            if x.shape != y.shape:
                raise Exception("x and y must have the same shape")
            if len(x.shape) < 4:
                x = x[np.newaxis,...]
                y = y[np.newaxis,...]
            for aug in self.augmentations.values():
                x,y = aug(x,y)
        return x, y
    @property
    def keys(self):
        return list(self.augmentations.keys())
    @property
    def active(self):
        return self._active
    def scale_probability(self, key:str, factor:float):
        if self.verbose:
            print(f"Scaling probability of {key} by {factor:3.3e}: {self.augmentations[key].p:3.3e} -> {self.augmentations[key].p * factor:3.3e}")
        self.augmentations[key].p *= factor
    def set_active(self, active:bool):
        self._active = active
    def __repr__(self):
        return f"Augmentor({', '.join([f'{k}: {v.p:3.3e}' for k,v in self.augmentations.items()])})"
    

In [None]:
# Load model
def load_model(path,custom_objects=None,compile=True):
    m = keras.models.load_model(path,custom_objects=custom_objects,compile=compile)
    print(m.summary())
    return m

## Dice and Focal Loss - Custom FocalDiceLoss

In [None]:
# Custom dice loss for keras
def dice_coef(y_true,y_pred, smooth=100):        
    
    intersection = K.sum(y_true * y_pred)
    dice = (2. * intersection + smooth) / (K.sum(y_true) + K.sum(y_pred) + smooth)
    return dice

def dice_coef_loss(y_true, y_pred):
    return -K.log(dice_coef(y_true, y_pred))

# Custom loss function for keras containing dice loss and binary focal loss
class FocalDiceLoss(keras.losses.Loss):
    def __init__(self, w_focal,w_dice,gamma=2.0, alpha=0.25, smooth=100, **kwargs):
        super().__init__(**kwargs)
        self.gamma = gamma
        self.alpha = alpha
        self.smooth = smooth
        self.focal_loss = BinaryFocalCrossentropy(gamma=gamma, alpha=alpha)
        self.w_focal = w_focal
        self.w_dice = w_dice
    def call(self, y_true, y_pred):
        # Compute focal loss
        y_true = K.flatten(y_true)
        y_pred = K.flatten(y_pred)
        dice_loss = dice_coef_loss(y_true, y_pred)
        focal_loss = self.focal_loss(y_true, y_pred)
        return self.w_dice*dice_loss + self.w_focal*focal_loss
    
# Test loss function
def test_loss():
    y_true = np.random.randint(0,2,(100,100)).astype(np.float32)
    y_pred = np.random.rand(100,100).astype(np.float32)
    loss = FocalDiceLoss(0.5,0.5)(y_true,y_pred)
    print(loss)

## Pre processing class
Used to reshape the list of images to the correct format for the model

In [None]:
class PreProcessor:
    def __init__(self, W:int,H:int,C:int,batch_size:int, per_batch:int, augmentor:Augmentor):
        self.W = W
        self.H = H
        self.C = C
        self.per_batch = per_batch
        self.augmentor = augmentor
        self.index = np.tile(np.arange(batch_size),(1,per_batch)).reshape(-1)
    def __call__(self, x:List[np.ndarray],augment:bool=True):
        y = x[-1]
        x = x[:-1]
        x = np.array(x)
        # Shape: (batch_size, per_batch, height, width, channels)
        # Reshape to: (batch_size*per_batch, height, width, channels)
        x = x.reshape((-1,self.W,self.H,self.C))
        # Repeat y per_batch times to match the shape of x.
        # y = np.repeat(y[:,np.newaxis],per_batch, axis=1)
        # Reshape to: (batch_size*per_batch, height, width, channels) with repeatition of y
        y = y[self.index]
        # Augment the data.
        if augment:
            x, y = self.augmentor(x,y)
        return x, y


## Training Setup

In [None]:
## Original code.
# learning_rate = 0.01
# optim = optimizers.Adam(lr=learning_rate)
# model.compile(loss="mse",
#               optimizer=optim)
##
W,H,C = 256,256,1

custom_lr = 0.0001 #0.00005, Original. NOTE: I used 0.0005 for the first 50 Epochs.
weight_decay = 0.0 # Weight decay for regularization.
clipvalue = 2 # Clipvalue for regularization.
augmentation_warmup = 0 # Warmup augmentations. How many epochs to train without augmentations.
fp16 = False # Mixed precision training, i.e. use float16 instead of float32: Faster training, but less accurate.
# NOTE: Might need to replace the keyword "learning_rate" with "lr" since i used an newer version of Keras, see code below.
optim = Adam(learning_rate=custom_lr,decay=weight_decay,clipvalue=clipvalue)
# custom_optimizer = Adam(lr=custom_lr,decay=weight_decay) # Replaced RMSprop for Adam.
custom_loss = FocalDiceLoss(0.5,0.5) # Custom loss function.
custom_metric = dice_coef # Custom metric function.
augmentor = Augmentor(translate=0, # No translation. Due to lack of speed.
                      shear=0, # No shear. Due to lack of speed.
                      rotate=0, # No rotation. Due to lack of speed.
                      mask=0.8, # Probability of masking the image.
                      mask_size=0.2, # Maximum size of the mask as a fraction of the image size.
                      max_n_masks=8, # Maximum number of masks to apply.
                      noise=0.4, # Probability of adding Gaussian noise to the image.
                      noise_mean=0.05, # Mean of the noise.
                      noise_std=0.1, # Standard deviation of the noise.
                      ) # Augmentation of the data.
pre_processor = PreProcessor(W,H,C,batch_size,per_batch,augmentor)

if augmentation_warmup > 0:
    augmentor.set_active(False)

model.compile(loss=custom_loss,
              optimizer=optim,
              metrics=[custom_metric])
name = "unet_segmentation"
# Create model directory.
if not os.path.exists("./models"):
    os.makedirs("./models")
model_dir = os.path.join("./models", name)
n_epochs = 50

best = np.inf
batches = len(gen_train)
total_batches = batches * n_epochs # Every batch is a training step.
# Define the history arrays for speed.
train_epoch_history = np.zeros(n_epochs)
valid_epoch_history = np.zeros(n_epochs)
train_batch_history = np.zeros(total_batches)
counter = 0
h = 0 # Initial loss for progress bar.

# NOTE: Are you satisfied with the optimizer and its parameters?

In [None]:
def prediction_threshold(y_pred:np.ndarray,threshold:float=0.5):
    y_pred = np.where(y_pred > threshold, 1, 0)
    return y_pred

In [None]:
def plot_sample(generator: DataGenerator,
                pre_processor:PreProcessor,
                array_labels:List[str],
                model:Model=None,
                batch_idx:int=None,
                augment:bool=False,
                threshold:float=0.5,
                save:bool=False,
                save_path:str="./figures/",
                figure_name:str="sample.png",
                title:str="Samples",):
    # Get a batch of data.
    idx = batch_idx if batch_idx is not None else np.random.randint(0,len(generator))
    print(f"Plotting sample {idx}")
    data = generator[idx]
    # Preprocess the data.
    x, y = pre_processor(data,augment=augment)
    # Grid with one of each array_label.
    n_formats = len(array_labels)-1

    
    rows = 2 if not model else 3
    cols = n_formats
    preds = model.predict(x) if model else None
    if threshold:
        preds = prediction_threshold(preds,threshold) if model else None
    # Create a figure with the correct number of subplots.
    fig = plt.figure(figsize=(16, 10))
    grid = ImageGrid(fig, 111,
                    nrows_ncols=(rows, cols), 
                    axes_pad=[0.0, 0.35],
                    )
    for i, ax in enumerate(grid):
        # Get the row and column index.
        # Create masked images.
        inp = np.ma.masked_array(x[i], mask=y[i])
        row = i // cols
        col = i % cols
        # Plot the data.
        if row == 0:
            # ax.imshow(x[i+col*n_formats], cmap='gray')
            ax.imshow(inp, cmap='gray')
            ax.set_title(f"{array_labels[col]}", fontsize=20)
        elif row == 1:
            ax.imshow(y[i], cmap='Reds')
            ax.imshow(x[i], cmap='gray', alpha=0.5)
            # Underlay the input image.
            ax.set_title(f"{array_labels[-1]}", fontsize=20)
        else:
            # ax.imshow(preds[i], cmap='Blues')
            ax.imshow(preds, cmap='Blues')
            # Underlay the input image.
            ax.imshow(x[i], cmap='gray', alpha=0.5)
            ax.set_title(f"Prediction", fontsize=20)

        ax.axis('off')
    fig.suptitle(title, fontsize=30)
    # Tight 
    plt.tight_layout()
    if save:
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        fig.savefig(os.path.join(save_path,figure_name))
        plt.close(fig)
    else:
        plt.show()
%matplotlib inline
for i in range(1):
    plot_sample(gen_val, pre_processor, array_labels, model=best_model,threshold=0.5)
print("Done!")

## Training Cell

In [None]:
if fp16:
    policy = mixed_precision.Policy('mixed_float16')
    mixed_precision.set_global_policy(policy)
    print('Compute dtype: %s' % policy.compute_dtype)
    print('Variable dtype: %s' % policy.variable_dtype)
augmentation_start = batches*augmentation_warmup
print("Start training... This may take a while.")
for epoch in range(n_epochs):
    validating_loss = []
    pbar = tqdm(enumerate(gen_train)) # Progess bar to make it less boring, and trackable.
    for idx, xy in pbar:
        x,y = pre_processor(xy)
        h = model.train_on_batch(x, y)[0]
        train_batch_history[counter] = h
        counter += 1
        if counter > augmentation_start and not augmentor.active:
            augmentor.set_active(True)
        if (idx+1)%10==0 or idx==0:
            pbar.set_description(f"Training Epoch {epoch+1}/{n_epochs}. {idx+1}/{batches} Batch. Training Loss MSE: {h:.3e}")
    # for idx, (t1, t2) in enumerate(gen_val):
    #     validating_loss.append(model.test_on_batch(t1, t2)[0])
    for idx, xy in enumerate(gen_val):
        x,y = pre_processor(xy)
        h = model.test_on_batch(x, y)[0]
        validating_loss.append(h)
    train_epoch_history[epoch] = np.mean(train_batch_history[epoch*batches:(epoch+1)*batches])
    valid_epoch_history[epoch] = np.mean(validating_loss)
    if valid_epoch_history[epoch] < best:
        best = valid_epoch_history[epoch]
        model.save(model_dir)
        plot_sample(gen_train,
                    pre_processor,
                    array_labels,
                    model,
                    save=True,
                    save_path=os.path.join(model_dir,"figures"),
                    figure_name=f"train_{epoch+1}.png",
                    title=f"Epoch {epoch+1} Training Sample",
                    batch_idx=0
                )
        plot_sample(gen_val,
                    pre_processor,
                    array_labels,
                    model,
                    save=True,
                    save_path=os.path.join(model_dir,"figures"),
                    figure_name=f"validation_{epoch+1}.png",
                    title=f"Epoch {epoch+1} Validation Sample",
                )
    print(f"Epoch: {epoch + 1:2d}. Average loss - Training: {train_epoch_history[epoch]:.3e}, Validation: {valid_epoch_history[epoch]:.3e}")
# NOTE: Plotting the losses helps a lot.
# NOTE: What does plotting the training data tell you? Should you plot something else?
# NOTE: What should one do with the validation data? The data generator has a 'validation_data' argument as well.
# NOTE: When should one stop? Did you overtrain? Did you train for long enough?
# NOTE: Think abouct implementing Early Stopping?

In [None]:
# Plotting the losses.
# Plot in separate window.
epoch_to_batch = np.arange(0, total_batches, batches)
%matplotlib qt
plt.subplot(2, 2, 3)
plt.plot(epoch_to_batch, train_epoch_history, label="Training Loss",color="C0")
plt.plot(train_batch_history, label="Training Loss per Batch",alpha=0.5, color="C0")
# Center the plot around the epoch_to_batch.
# plt.xlim(epoch_to_batch[0], epoch_to_batch[-1])
plt.ylim(np.min(train_batch_history), np.max(train_epoch_history))
plt.grid()
plt.legend(loc='upper right')
plt.xlabel("Batches")
plt.ylabel("Loss - MSE")
plt.subplot(2, 2, 4)
plt.plot(epoch_to_batch, valid_epoch_history, label="Validation Loss",color="C1")
plt.grid()
plt.legend(loc='upper right')
plt.xlabel("Batches")
plt.ylabel("Loss - MSE")
# Increase spacing between subplots.
plt.subplots_adjust(wspace=0.5)
plt.show()
# Bottom two plots as one plot.
plt.subplot(2, 1, 1)
plt.plot(epoch_to_batch, train_epoch_history, label="Training Loss",color="C0")
plt.plot(epoch_to_batch, valid_epoch_history, label="Validation Loss",color="C1")
plt.grid()
plt.legend(loc='upper right')
plt.xlabel("Batches")
plt.ylabel("Loss - MSE")
plt.show()



In [None]:
import keras.losses
keras.losses.custom_loss = FocalDiceLoss
best_model = load_model(model_dir,compile=False)

In [None]:
plot_sample(gen_test, pre_processor, array_labels, model=best_model,threshold=0.8)


In [None]:
def test_model(model: Model, gen_data: DataGenerator,n: int = batch_size, pre_processer: PreProcessor = None):
    xy = gen_data[np.random.randint(0, len(gen_data))]
    x,y = pre_processer(x,y)
    prediction = model.predict(x)
    cols = 3
    plt.figure(figsize=(16, 10* n))
    for idx in range(n):
        plt.subplot(n, 3, idx * cols + 1)
        plt.imshow(t1[idx, :, :], cmap='gray')
        plt.colorbar()
        plt.title('INPUT')
        # No axis labels.
        plt.xticks([])
        plt.subplot(n, 3, idx * cols + 2)
        plt.imshow(t2[idx, :, :], cmap='gray')
        plt.colorbar()
        plt.title('GT')
        plt.xticks([])

        plt.subplot(n, 3, idx * cols + 3)
        plt.imshow(prediction[idx, :, :], cmap='gray')
        plt.colorbar()
        plt.title('PRED')
        plt.xticks([])
        # Plot difference
        # Print difference
        print(f"MSE: {np.mean((t2[idx, :, :] - prediction[idx, :, :])**2):.3e}")
    plt.show()
        
    print(f"Average:{np.mean((t2 - prediction)**2):.3e}")
    # NOTE: What do the predictions mean? What values do they take on?
%matplotlib qt
test_model(best_model, gen_train,2)#,augmentor)

In [None]:
# Extract the encoder part of the model.
# encoder = Model(model.input, model.get_layer('input_2').output)
encoder = Model(model.get_layer('Encoder').input,model.get_layer('Encoder').output)
decoder = Model(model.get_layer('Decoder').input,model.get_layer('Decoder').output)

In [None]:

def plot_encodings_grid(encodings: List[np.ndarray]):
    n_encodings = len(encodings)
    print(f"Number of encodings: {n_encodings}")
    # Create a grid of images.
    for encoding_layer in encodings:
        # n_encodings different encodings.
        # Each encoding has n_channels different channels. The figure will have sqrt(n_channels) rows and columns.
        n_channels = encoding_layer.shape[-1]
        rows = int(np.ceil((np.sqrt(n_channels))))
        cols = int(np.ceil((np.sqrt(n_channels))))
        print(f"Number of channels: {n_channels}. Rows: {rows}. Cols: {cols}.")
        # Create a figure with the correct number of subplots.
        fig = plt.figure(figsize=(16, 10))
        grid = ImageGrid(fig, 111,
                         nrows_ncols=(rows, cols), 
                         axes_pad=0.0,
                         )
        # Plot each channel.
        for idx, ax in zip(range(n_channels),grid):
            # Plot the channel.
            ax.imshow(encoding_layer[:, :, idx], cmap='gray')
            # No axis labels.
            ax.set_xticks([])
            ax.set_yticks([])
            # Set the title.
        plt.subplots_adjust(wspace=0, hspace=0)
        plt.show()

In [None]:
x_val, y_val = gen_val[0]
encodings_batched = encoder.predict(x_val)
decoded = decoder.predict(encodings_batched[::-1])
encodings = [enc[0] for enc in encodings_batched]
plot_encodings_grid([encodings[-1]])