# Implementing Vanilla U-Net
## Author:  JA Engelbrecht
## Supervisor:  Prof Martin Nieuwoudt
## Co-supervisor: Dr ST Malherbe

# Import Libraries

In [None]:
from matplotlib import rc
from jupyterthemes import jtplot
from skimage.util import montage as montage2d
import UNets.Vanilla.UNet_Vanilla as UNet

from MyFunctions.learningRateFunction import LRFinder
from MyFunctions.LoadImages import LoadImages
from MyFunctions.CreatePaths import CreatePaths
from CLR.clr_callback import *

import SimpleITK as sitk
import tensorflow as tf
import pandas as pd
import numpy as np
import os

from tensorflow.keras.optimizers import SGD
from tensorflow.keras.optimizers import Adam

from sklearn.model_selection import train_test_split

############ Plot Images/Graphs Functions ############

from matplotlib.colors import LinearSegmentedColormap
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import matplotlib as mpl

cmap = LinearSegmentedColormap.from_list('mycmap', ['black', 'orange', 'red'])


rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
rc('text', usetex=True)
mpl.rcParams.update({'font.size': 12})


def set_size(width='thesis', fraction=1, subplots=(1, 1)):
    """Set figure dimensions to avoid scaling in LaTeX.

    Parameters
    ----------
    width: float or string
            Document width in points, or string of predined document type
    fraction: float, optional
            Fraction of the width which you wish the figure to occupy
    subplots: array-like, optional
            The number of rows and columns of subplots.
    Returns
    -------
    fig_dim: tuple
            Dimensions of figure in inches
    """
    if width == 'thesis':
        width_pt = 398
    else:
        width_pt = width

    # Width of figure (in pts)
    fig_width_pt = width_pt * fraction
    # Convert from pt to inches
    inches_per_pt = 1 / 72.27

    # Golden ratio to set aesthetic figure height
    # https://disq.us/p/2940ij3
    golden_ratio = (5**.5 - 1) / 2

    # Figure width in inches
    fig_width_in = fig_width_pt * inches_per_pt
    # Figure height in inches
    fig_height_in = fig_width_in * golden_ratio * (subplots[0] / subplots[1])

    return (fig_width_in, fig_height_in)


def showCTImage(IMG, SIZE):
    plt.figure(figsize=(SIZE, SIZE))
    plt.imshow(IMG, alpha=1, cmap='gray')
    plt.axis('off')
    plt.show()


def showCTMontage(IMG, SIZE):
    plt.figure(figsize=(SIZE, SIZE))
    plt.imshow(montage2d(IMG), alpha=1, cmap='gray')
    plt.axis('off')
    plt.show()


def showCTMontageOverlay(IMG1, IMG2, SIZE=15, SaveFig=False, save_fig_name=""):
    fig, ax = plt.subplots(figsize=(SIZE, SIZE))
    try:
        ax.imshow(montage2d(IMG1), alpha=1, cmap='gray')
    except:
        print("Error: Img 1")
    try:
        ax.imshow(montage2d(IMG2, fill=0), alpha=0.5,
                  cmap=cmap, interpolation='none')
    except:
        print("Error: Img 2")
    plt.axis('off')

    if SaveFig:
        save_fig_path = os.path.join(os.curdir, "SavedFigures")
        plt.savefig(os.path.join(save_fig_path,
                                 save_fig_name+".pdf"), bbox_inches='tight')
    plt.show()
######################################################

# Set Variables

In [None]:
# Change Class Variable below for different paths e.g. CT/PET
path = CreatePaths(DeviceFlag="PC", ScanTypeFlag="CT", TrainTestFlag="Train")

#DATA_PATH = "D://Masters_Repo//TrainingData//CT_v1"
#IMGS_PATH = path.imgPath()
#MSKS_PATH = path.mskPath()
#OUTPUT_PATH = path.outputPath()

DATA_PATH = "F://MyMasters//Data//TrainingData"
IMGS_PATH = "F://MyMasters//Data//TrainingData//imgs"
MSKS_PATH = "F://MyMasters//Data//TrainingData//masks"
OUTPUT_PATH = "F://MyMasters//Output"

ORIENTATION_ENSEMBLE = ["Axial", "Sagittal", "Coronal"]

print("Image Path: "+"\t"+IMGS_PATH+"\n"+"Mask Path: " +
      "\t"+MSKS_PATH+"\n"+"Output Path: "+"\t"+OUTPUT_PATH)

ScanType = "CT"
n_Scans = 72
Orientation = "Axial"

# Import and Process Scans

In [None]:
ScanType = "CT"
n_Scans = 72
Orientation = "Axial"

CT_Images = LoadImages(ScanType=ScanType, ScanClass="Image",
                       ImgPath=IMGS_PATH, n_Scans=n_Scans, Orientation=Orientation).LoadScans()
CT_Masks = LoadImages(ScanType=ScanType, ScanClass="Mask",
                      MskPath=MSKS_PATH, n_Scans=n_Scans, Orientation=Orientation).LoadScans()

########################## Split Into Train and Test Set ##########################
X, X_Val, y, y_Val = train_test_split(
    CT_Images, CT_Masks, test_size=0.15, random_state=42)

del CT_Images, CT_Masks

y = tf.cast(y, dtype='float32')
y_Val = tf.cast(y_Val, dtype='float32')

## View Imported Scans Overlayed with Masks

In [None]:
y_mask = np.copy(y)
y_mask = np.ma.masked_where(y == 0.5, y, copy=False)
showCTMontageOverlay(IMG1=X[500:564, :, :], IMG2=y[500:564, :, :],
                     SIZE=25, SaveFig=True, save_fig_name="Masks on Images")

## Expand Arrays with a 4'th Singular Dimension (Grayscale Images)

In [None]:
X = np.expand_dims(X, axis=3)
y = np.expand_dims(y, axis=3)
X_Val = np.expand_dims(X_Val, axis=3)
y_Val = np.expand_dims(y_Val, axis=3)

#### Augment Data

In [None]:
dataAug = dict(rotation_range=15,
               zoom_range=0.15,
               horizontal_flip=True,
               vertical_flip=True)

image_datagen = tf.keras.preprocessing.image.ImageDataGenerator(**dataAug)
mask_datagen = tf.keras.preprocessing.image.ImageDataGenerator(**dataAug)
seed = 42

image_datagen.fit(X, augment=True, seed=seed)
mask_datagen.fit(y, augment=True, seed=seed)

In [None]:
X_Aug = image_datagen.flow(X, batch_size=1, seed=seed)
y_Aug = mask_datagen.flow(y, batch_size=1, seed=seed)
viewImages = np.zeros((200, 256, 256, 1))
viewMasks = np.zeros((200, 256, 256, 1))
for i in range(199):
    viewImages[i, :, :, :] = X_Aug.next()[0]
    viewMasks[i, :, :, :] = y_Aug.next()[0]

In [None]:
showCTMontageOverlay(IMG1=viewImages[0:199, :, :, 0], IMG2=viewMasks[0:199, :, :, 0],
                     SIZE=25, SaveFig=False, save_fig_name="Masks on Images")

# U-Net

## Preparing to Create U-Net

In [None]:
############## Functions to Log Training of U-Net ##############
def get_run_logdir(root_logdir, input_string):
    import time
    if not input_string:
        run_id = time.strftime("run_%Y_%m_%d-%H_%M_%S")
    else:
        run_id = os.path.join(
            input_string, time.strftime("run_%Y_%m_%d-%H_%M_%S"))
    return os.path.join(root_logdir, run_id)


def create_logdir(modelName):
    root_logdir = os.path.join(os.curdir, "My_logs")
    run_logdir = get_run_logdir(root_logdir, modelName)
    return run_logdir
################################################################

########## Custom Loss Function for Dice Coeffiecient ##########


def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + smooth)


def dice_coef_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)


def tversky(y_true, y_pred, smooth=1, alpha=0.7):
    y_true_pos = tf.keras.backend.flatten(y_true)
    y_pred_pos = tf.keras.backend.flatten(y_pred)
    true_pos = tf.keras.backend.sum(y_true_pos * y_pred_pos)
    false_neg = tf.keras.backend.sum(y_true_pos * (1 - y_pred_pos))
    false_pos = tf.keras.backend.sum((1 - y_true_pos) * y_pred_pos)
    return (true_pos + smooth) / (true_pos + alpha * false_neg + (1 - alpha) * false_pos + smooth)


def tversky_loss(y_true, y_pred):
    return 1 - tversky(y_true, y_pred)


def focal_tversky_loss(y_true, y_pred, gamma=4/3):
    tv = tversky(y_true, y_pred)
    return tf.keras.backend.pow((1 - tv), gamma)
################################################################

## Callbacks for Training

In [None]:
MyModelName = 'U-Net_V_' + Orientation
MyLogdir = create_logdir(MyModelName)
MyModelSaveRoot = os.path.join(os.curdir, "TrainedModels")
MyModelSavePath = os.path.join(MyModelSaveRoot, MyModelName+".h5")

print(MyLogdir)
print(MyModelSavePath)
print(MyModelName)

### Create Model

In [None]:
MyModel = UNet.UNet_Vanilla(input_shape=(256, 256, 1)).CreateUnet()
MyModel.compile(optimizer=Adam(learning_rate=1e-4),
                loss=dice_coef_loss, metrics=[tf.keras.metrics.MeanIoU(num_classes=2)])

In [None]:
batch_size = 10
steps_p_epoch = np.ceil(1000/batch_size)

# Use small subset of data
image_generator = image_datagen.flow(
    X[0:1000, :, :, :], batch_size=batch_size, seed=seed)
mask_generator = mask_datagen.flow(
    y[0:1000, :, :, :], batch_size=batch_size, seed=seed)
train_generator = zip(image_generator, mask_generator)

# Use small subset of data
lr_finder = LRFinder(min_lr=1e-5, max_lr=1e-2,
                     steps_per_epoch=steps_p_epoch, epochs=3)
MyModel.fit(train_generator, steps_per_epoch=steps_p_epoch,
            epochs=3, verbose=1, callbacks=[lr_finder])
lr_finder.plot_loss()

In [None]:
csv_logger_cb = tf.keras.callbacks.CSVLogger(
    os.path.join(MyModelSaveRoot, MyModelName+".csv"), append=True)
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(MyModelSavePath,
                                                   monitor='val_loss', verbose=1, save_best_only=True)
early_stopping_cb = tf.keras.callbacks.EarlyStopping(
    patience=15, restore_best_weights=True, monitor='val_loss')
clr_triangular_cb = CyclicLR(
    base_lr=1e-4, max_lr=1.8e-4, mode='triangular2', step_size=5*X.shape[0])
tensorboard_cb = tf.keras.callbacks.TensorBoard(MyLogdir)

### Document Compile Parameters

In [None]:
Optimizer = "Adam"
loss = "dice_coef_loss"

### Set Training Patameters

In [None]:
batch_size = 10
epochs = 20
steps_p_epoch = np.ceil(X.shape[0]/batch_size)
image_generator = image_datagen.flow(X, batch_size=batch_size, seed=seed)
mask_generator = mask_datagen.flow(y, batch_size=batch_size, seed=seed)
train_generator = zip(image_generator, mask_generator)

## Train U-Net

In [None]:
MyModel.fit(train_generator, steps_per_epoch=steps_p_epoch, epochs=epochs, verbose=1, validation_data=(X_Val, y_Val),
            callbacks=[checkpoint_cb, early_stopping_cb, clr_triangular_cb, tensorboard_cb, csv_logger_cb])

In [None]:
clr_triangular_cb._reset()

In [None]:
MyModel.fit(train_generator, steps_per_epoch=steps_p_epoch, epochs=epochs, verbose=1, validation_data=(X_Val, y_Val),
            callbacks=[checkpoint_cb, early_stopping_cb, clr_triangular_cb, tensorboard_cb, csv_logger_cb])

## Load U-Net

In [None]:
MyModel.load_weights(MyModelSavePath)

# Write Model Parameters to Text File

In [None]:
MyModelParameters_Strings = ["ScanType", "n_Scans",
                             "Orientation", "Optimizer", "Loss", "batch_size", "epochs"]
MyModelParameters_values = [ScanType, n_Scans,
                            Orientation, Optimizer, loss, batch_size, epochs*2]

TextFileName = MyModelName+".txt"
TextFilePath = os.path.join(os.curdir, "TrainedModels", TextFileName)

with open(TextFilePath, "w") as file:
    file.write("Parameters for " + MyModelName + ":\n\n")
    for parameter in enumerate(MyModelParameters_Strings):
        file.write(parameter[1] + ": " +
                   str(MyModelParameters_values[parameter[0]])+"\n")
    file.close()

# Check Performance on Test Set
## View Predicted Images Over Masks

In [None]:
try:
    y_predict = MyModel.predict(X_Val, batch_size=10, verbose=1)
except:
    X_Val = np.squeeze(X_Val)
    print("Error: Input to Model has to be 4D (x, y, x, 1)")
    print("Reshaping..")
    X_Val = np.expand_dims(X_Val, axis=3)
    y_predict = MyModel.predict(X_Val, batch_size=10, verbose=1)

In [None]:
y_predict = np.squeeze(y_predict)
X_Val = np.squeeze(X_Val)

try:
    X_Val = np.squeeze(X_Val)
except:
    pass
try:
    y_threshold = np.squeeze(y_threshold)
except:
    pass
try:
    y_Val = np.squeeze(y_Val)
except:
    pass

y_new = np.ma.masked_where(y_predict > 0, y_predict, copy=False)

showCTMontageOverlay(IMG1=X_Val[300:400, :, :],
                     IMG2=y_new[300:400, :, :], SIZE=25, SaveFig=True, save_fig_name="Predicted Masks on Actual Masks")

## Plot Training History

In [None]:
history = pd.read_csv(os.path.join(MyModelSaveRoot, MyModelName + ".csv"))

In [None]:
fig, axs = plt.subplots(1, 2, figsize=set_size(subplots=(1, 2)))

for ax in axs.flat:
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(5))

axs[0].set_ylim((0.99, 1))
axs[0].set_xlabel('Epoch', labelpad=10)
axs[0].set_ylabel('Accuracy', labelpad=10)
axs[0].plot(history.accuracy)

axs[1].set_ylim((0.25, 0.5))
axs[1].set_xlabel('Epoch')
axs[1].set_ylabel('Loss')
axs[1].plot(history.loss)

fig.tight_layout()

save_path = os.path.join(os.curdir, "SavedFigures")
save_name = "UNet_Vanilla_Accuracy_Loss"
plt.savefig(os.path.join(save_path, save_name+".pdf"), bbox_inches='tight')