# Implementing Vanilla U-Net
## Author:  JA Engelbrecht
## Supervisor:  Prof Martin Nieuwoudt
## Co-supervisor: Dr ST Malherbe

# Import Libraries

In [3]:
from matplotlib import rc
from jupyterthemes import jtplot
from skimage.util import montage as montage2d
import UNets.Vanilla.UNet_Vanilla as UNet

from MyFunctions.LoadImages import LoadImages
from MyFunctions.CreatePaths import CreatePaths
from CLR.clr_callback import *

import SimpleITK as sitk
import tensorflow as tf
import pandas as pd
import numpy as np
import os

from tensorflow.keras.optimizers import SGD

from sklearn.model_selection import train_test_split

############ Plot Images/Graphs Functions ############

from matplotlib.colors import LinearSegmentedColormap
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import matplotlib as mpl

cmap = LinearSegmentedColormap.from_list('mycmap', ['black', 'orange', 'red'])


rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
rc('text', usetex=True)
mpl.rcParams.update({'font.size': 12})


def set_size(width='thesis', fraction=1, subplots=(1, 1)):
    """Set figure dimensions to avoid scaling in LaTeX.

    Parameters
    ----------
    width: float or string
            Document width in points, or string of predined document type
    fraction: float, optional
            Fraction of the width which you wish the figure to occupy
    subplots: array-like, optional
            The number of rows and columns of subplots.
    Returns
    -------
    fig_dim: tuple
            Dimensions of figure in inches
    """
    if width == 'thesis':
        width_pt = 398
    else:
        width_pt = width

    # Width of figure (in pts)
    fig_width_pt = width_pt * fraction
    # Convert from pt to inches
    inches_per_pt = 1 / 72.27

    # Golden ratio to set aesthetic figure height
    # https://disq.us/p/2940ij3
    golden_ratio = (5**.5 - 1) / 2

    # Figure width in inches
    fig_width_in = fig_width_pt * inches_per_pt
    # Figure height in inches
    fig_height_in = fig_width_in * golden_ratio * (subplots[0] / subplots[1])

    return (fig_width_in, fig_height_in)


def showCTImage(IMG, SIZE):
    plt.figure(figsize=(SIZE, SIZE))
    plt.imshow(IMG, alpha=1, cmap='gray')
    plt.axis('off')
    plt.show()


def showCTMontage(IMG, SIZE):
    plt.figure(figsize=(SIZE, SIZE))
    plt.imshow(montage2d(IMG), alpha=1, cmap='gray')
    plt.axis('off')
    plt.show()


def showCTMontageOverlay(IMG1, IMG2, SIZE=15, SaveFig=False, save_fig_name=""):
    fig, ax = plt.subplots(figsize=(SIZE, SIZE))
    try:
        ax.imshow(montage2d(IMG1), alpha=1, cmap='gray')
    except:
        print("Error: Img 1")
    try:
        ax.imshow(montage2d(IMG2, fill=0), alpha=0.5,
                  cmap=cmap, interpolation='none')
    except:
        print("Error: Img 2")
    plt.axis('off')

    if SaveFig:
        save_fig_path = os.path.join(os.curdir, "SavedFigures")
        plt.savefig(os.path.join(save_fig_path,
                                 save_fig_name+".pdf"), bbox_inches='tight')
    plt.show()
######################################################

# Set Variables

In [5]:
# Change Class Variable below for different paths e.g. CT/PET
path = CreatePaths(DeviceFlag="PC", ScanTypeFlag="CT", TrainTestFlag="Train")

#DATA_PATH = "D://Masters_Repo//TrainingData//CT_v1"
#IMGS_PATH = path.imgPath()
#MSKS_PATH = path.mskPath()
#OUTPUT_PATH = path.outputPath()

DATA_PATH = "F://MyMasters//Data//TrainingData"
IMGS_PATH = "F://MyMasters//Data//TrainingData//imgs"
MSKS_PATH = "F://MyMasters//Data//TrainingData//masks"
OUTPUT_PATH = "F://MyMasters//Output"

ORIENTATION_ENSEMBLE = ["Axial", "Sagittal", "Coronal"]

print("Image Path: "+"\t"+IMGS_PATH+"\n"+"Mask Path: " +
      "\t"+MSKS_PATH+"\n"+"Output Path: "+"\t"+OUTPUT_PATH)

ScanType = "CT"
n_Scans = 72
Orientation = "Axial"

Image Path: 	F://MyMasters//Data//TrainingData//imgs
Mask Path: 	F://MyMasters//Data//TrainingData//masks
Output Path: 	F://MyMasters//Output


# Import and Process Scans

In [6]:
ScanType = "CT"
n_Scans = 72
Orientation = "Axial"

CT_Images = LoadImages(ScanType=ScanType, ScanClass="Image",
                       ImgPath=IMGS_PATH, n_Scans=n_Scans, Orientation=Orientation).LoadScans()
CT_Masks = LoadImages(ScanType=ScanType, ScanClass="Mask",
                      MskPath=MSKS_PATH, n_Scans=n_Scans, Orientation=Orientation).LoadScans()

########################## Split Into Train and Test Set ##########################
X, X_Val, y, y_Val = train_test_split(
    CT_Images, CT_Masks, test_size=0.15, random_state=42)

del CT_Images, CT_Masks

y = tf.cast(y, dtype='float32')
y_Val = tf.cast(y_Val, dtype='float32')

Reading the following CT Images:
CB_001_CT_M0.nii.gz
CB_003_CT_M0.nii.gz
CB_004_CT_M0.nii.gz
CB_005_CT_M0.nii.gz
CB_007_CT_M0.nii.gz
CB_009_CT_M0.nii.gz
CB_013_CT_M0.nii.gz
CB_020_CT_M1.nii.gz
CB_021_CT_M6.nii.gz
CB_022_CT_M0.nii.gz
CB_027_CT_M0.nii.gz
CB_029_CT_M0.nii.gz
CB_030_CT_M0.nii.gz
CB_034_CT_M0.nii.gz
CB_034_CT_M1.nii.gz
CB_034_CT_M6.nii.gz
CB_035_CT_M0.nii.gz
CB_035_CT_M1.nii.gz
CB_035_CT_M6.nii.gz
CB_041_CT_M0.nii.gz
CB_041_CT_M1.nii.gz
CB_041_CT_M3.nii.gz
CB_042_CT_M0.nii.gz
CB_042_CT_M1.nii.gz
CB_042_CT_M6.nii.gz
CB_043_CT_M0.nii.gz
CB_043_CT_M1.nii.gz
CB_043_CT_M6.nii.gz
CB_077_CT_M1.nii.gz
CB_087_CT_M0.nii.gz
CB_087_CT_M1.nii.gz
CB_087_CT_M6.nii.gz
CB_088_CT_M0.nii.gz
CB_088_CT_M1.nii.gz
CB_088_CT_M6.nii.gz
CB_089_CT_M0.nii.gz
CB_089_CT_M1.nii.gz
CB_089_CT_M6.nii.gz
CB_090_CT_M0.nii.gz
CB_090_CT_M1.nii.gz
CB_090_CT_M6.nii.gz
CB_092_CT_M0.nii.gz
CB_092_CT_M1.nii.gz
CB_092_CT_M18.nii.gz
CB_092_CT_M6.nii.gz
CB_093_CT_M0.nii.gz
CB_095_CT_M0.nii.gz
CB_095_CT_M1.nii.gz
CB_095

## View Imported Scans Overlayed with Masks

In [None]:
y_mask = np.copy(y)
y_mask = np.ma.masked_where(y == 0.5, y, copy=False)
showCTMontageOverlay(IMG1=X[500:564, :, :], IMG2=y[500:564, :, :],
                     SIZE=25, SaveFig=True, save_fig_name="Masks on Images")

## Expand Arrays with a 4'th Singular Dimension (Grayscale Images)

In [7]:
X = np.expand_dims(X, axis=3)
y = np.expand_dims(y, axis=3)
X_Val = np.expand_dims(X_Val, axis=3)
y_Val = np.expand_dims(y_Val, axis=3)

# U-Net

## Preparing to Create U-Net

In [43]:
############## Functions to Log Training of U-Net ##############
def get_run_logdir(root_logdir, input_string):
    import time
    if not input_string:
        run_id = time.strftime("run_%Y_%m_%d-%H_%M_%S")
    else:
        run_id = os.path.join(
            input_string, time.strftime("run_%Y_%m_%d-%H_%M_%S"))
    return os.path.join(root_logdir, run_id)


def create_logdir(modelName):
    root_logdir = os.path.join(os.curdir, "My_logs")
    run_logdir = get_run_logdir(root_logdir, modelName)
    return run_logdir
################################################################

########## Custom Loss Function for Dice Coeffiecient ##########


def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + smooth)


def dice_coef_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

def tversky(y_true, y_pred, smooth=1, alpha=0.7):
    y_true_pos = tf.keras.backend.flatten(y_true)
    y_pred_pos = tf.keras.backend.flatten(y_pred)
    true_pos = tf.keras.backend.sum(y_true_pos * y_pred_pos)
    false_neg = tf.keras.backend.sum(y_true_pos * (1 - y_pred_pos))
    false_pos = tf.keras.backend.sum((1 - y_true_pos) * y_pred_pos)
    return (true_pos + smooth) / (true_pos + alpha * false_neg + (1 - alpha) * false_pos + smooth)

def tversky_loss(y_true, y_pred):
    return 1 - tversky(y_true, y_pred)


def focal_tversky_loss(y_true, y_pred, gamma=4/3):
    tv = tversky(y_true, y_pred)
    return tf.keras.backend.pow((1 - tv), gamma)
################################################################

## Callbacks for Training

In [44]:
MyModelName = 'U-Net_V_' + Orientation
MyLogdir = create_logdir(MyModelName)
MyModelSaveRoot = os.path.join(os.curdir, "TrainedModels")
MyModelSavePath = os.path.join(MyModelSaveRoot, MyModelName+".h5")

print(MyLogdir)
print(MyModelSavePath)
print(MyModelName)

.\My_logs\U-Net_V_Axial\run_2020_10_05-14_29_57
.\TrainedModels\U-Net_V_Axial.h5
U-Net_V_Axial


In [45]:
csv_logger_cb = tf.keras.callbacks.CSVLogger(
    os.path.join(MyModelSaveRoot, MyModelName+".csv"), append=True)
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(MyModelSavePath,
                                                   monitor='val_loss', verbose=1, save_best_only=True)
early_stopping_cb = tf.keras.callbacks.EarlyStopping(
    patience=15, restore_best_weights=True, monitor='val_loss')
clr_triangular_cb = CyclicLR(
    base_lr=1e-3, max_lr=1e-2, mode='triangular2', step_size=5*X.shape[0])
tensorboard_cb = tf.keras.callbacks.TensorBoard(MyLogdir)

## Create U-Net

In [46]:
MyModel = UNet.UNet_Vanilla(input_shape=(256, 256, 1)).CreateUnet()
MyModel.compile(optimizer=SGD(momentum=0.9, nesterov=True),
                loss=focal_tversky_loss, metrics=[tf.keras.metrics.MeanIoU(num_classes=2)])

Model: "model_7"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_8 (InputLayer)            [(None, 256, 256, 1) 0                                            
__________________________________________________________________________________________________
conv2d_161 (Conv2D)             (None, 256, 256, 64) 640         input_8[0][0]                    
__________________________________________________________________________________________________
activation_154 (Activation)     (None, 256, 256, 64) 0           conv2d_161[0][0]                 
__________________________________________________________________________________________________
conv2d_162 (Conv2D)             (None, 256, 256, 64) 36928       activation_154[0][0]             
____________________________________________________________________________________________

NameError: name 'tversky' is not defined

### Document Compile Parameters

In [47]:
Optimizer = "SGD, Momentum = 0.9, Nestrov = True"
loss = "dice_coef_loss"

### Set Training Patameters

In [48]:
batch_size = 8
epochs = 20

## Train U-Net

In [49]:
MyModel.fit(X, y, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(X_Val, y_Val),
            callbacks=[checkpoint_cb, early_stopping_cb, clr_triangular_cb, tensorboard_cb, csv_logger_cb])

Train on 15667 samples, validate on 2765 samples
Epoch 1/20


NameError: name 'tversky' is not defined

In [None]:
clr_triangular_cb._reset()

In [None]:
MyModel.fit(X, y, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(X_Val, y_Val),
            callbacks=[checkpoint_cb, early_stopping_cb, clr_triangular_cb, tensorboard_cb, csv_logger_cb])

## Load U-Net

In [None]:
MyModel.load_weights(MyModelSavePath)

# Write Model Parameters to Text File

In [None]:
MyModelParameters_Strings = ["ScanType", "n_Scans",
                             "Orientation", "Optimizer", "Loss", "batch_size", "epochs"]
MyModelParameters_values = [ScanType, n_Scans,
                            Orientation, Optimizer, loss, batch_size, epochs*2]

TextFileName = MyModelName+".txt"
TextFilePath = os.path.join(os.curdir, "TrainedModels", TextFileName)

with open(TextFilePath, "w") as file:
    file.write("Parameters for " + MyModelName + ":\n\n")
    for parameter in enumerate(MyModelParameters_Strings):
        file.write(parameter[1] + ": " +
                   str(MyModelParameters_values[parameter[0]])+"\n")
    file.close()

# Check Performance on Test Set
## View Predicted Images Over Masks

In [None]:
try:
    y_predict = MyModel.predict(X_Val, batch_size=10, verbose=1)
except:
    X_Val = np.squeeze(X_Val)
    print("Error: Input to Model has to be 4D (x, y, x, 1)")
    print("Reshaping..")
    X_Val = np.expand_dims(X_Val, axis=3)
    y_predict = MyModel.predict(X_Val, batch_size=10, verbose=1)

In [None]:
y_predict = np.squeeze(y_predict)
X_Val = np.squeeze(X_Val)

try:
    X_Val = np.squeeze(X_Val)
except:
    pass
try:
    y_threshold = np.squeeze(y_threshold)
except:
    pass
try:
    y_Val = np.squeeze(y_Val)
except:
    pass

y_new = np.ma.masked_where(y_predict > 0, y_predict, copy=False)

showCTMontageOverlay(IMG1=y_Val[300:336, :, :],
                     IMG2=y_new[300:336, :, :], SIZE=25, SaveFig=True, save_fig_name="Predicted Masks on Actual Masks")

## Plot Training History

In [None]:
history = pd.read_csv(os.path.join(MyModelSaveRoot, MyModelName + ".csv"))

In [None]:
fig, axs = plt.subplots(1, 2, figsize=set_size(subplots=(1, 2)))

for ax in axs.flat:
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(5))

axs[0].set_ylim((0.99, 1))
axs[0].set_xlabel('Epoch', labelpad=10)
axs[0].set_ylabel('Accuracy', labelpad=10)
axs[0].plot(history.accuracy)

axs[1].set_ylim((0.25, 0.5))
axs[1].set_xlabel('Epoch')
axs[1].set_ylabel('Loss')
axs[1].plot(history.loss)

fig.tight_layout()

save_path = os.path.join(os.curdir, "SavedFigures")
save_name = "UNet_Vanilla_Accuracy_Loss"
plt.savefig(os.path.join(save_path, save_name+".pdf"), bbox_inches='tight')