# Model Training Code

This Python notebook implements the training process of deep learning model. 

## Importing Packages

In [None]:
import numpy as np
import SimpleITK as sitk
from matplotlib import pyplot as plt
import os

In [None]:
import keras

%env SM_FRAMEWORK=tf.keras
from segmentation_models.losses import *
from segmentation_models.metrics import *

In [None]:
from VNet import vnet
from UNet import unet_3d

## Loading and Processing Dataset

In [None]:
# Training Set
train_img_folder = 'Data/train_img'
train_img_list = os.listdir(train_img_folder)
for i in range(len(train_img_list)):
    img = sitk.ReadImage(train_img_folder + '/' + train_img_list[i], imageIO="NrrdImageIO")
    curr_x = np.expand_dims(sitk.GetArrayFromImage(img), axis=0)
    if i == 0:
        x_train = curr_x
    else:
        x_train = np.concatenate((x_train, curr_x), axis=0)

train_msk_folder = 'Data/train_msk'
train_msk_list = os.listdir(train_msk_folder)
for i in range(len(train_msk_list)):
    seg = sitk.ReadImage(train_msk_folder + '/' + train_msk_list[i], imageIO="NrrdImageIO")
    curr_y = np.expand_dims(sitk.GetArrayFromImage(seg), axis=0)
    if i == 0:
        y_train = curr_y
    else:
        y_train = np.concatenate((y_train, curr_y), axis=0)
    
# Validation Set
val_img_folder = 'Data/val_img'
val_img_list = os.listdir(val_img_folder)
for i in range(len(val_img_list)):
    img = sitk.ReadImage(val_img_folder + '/' + val_img_list[i], imageIO="NrrdImageIO")
    curr_x = np.expand_dims(sitk.GetArrayFromImage(img), axis=0)
    if i == 0:
        x_val = curr_x
    else:
        x_val = np.concatenate((x_val, curr_x), axis=0)

val_msk_folder = 'Data/val_msk'
val_msk_list = os.listdir(val_msk_folder)
for i in range(len(val_msk_list)):
    seg = sitk.ReadImage(val_msk_folder + '/' + val_msk_list[i], imageIO="NrrdImageIO")
    curr_y = np.expand_dims(sitk.GetArrayFromImage(seg), axis=0)
    if i == 0:
        y_val = curr_y
    else:
        y_val = np.concatenate((y_val, curr_y), axis=0)

test_img_folder = 'Data/test_img'
test_img_list = os.listdir(test_img_folder)
for i in range(len(test_img_list)):
    img = sitk.ReadImage(test_img_folder + '/' + test_img_list[i], imageIO="NrrdImageIO")
    curr_x = np.expand_dims(sitk.GetArrayFromImage(img), axis=0)
    if i == 0:
        x_test = curr_x
    else:
        x_test = np.concatenate((x_test, curr_x), axis=0)

# Test Set
test_msk_folder = 'Data/test_msk'
test_msk_list = os.listdir(test_msk_folder)
for i in range(len(test_msk_list)):
    seg = sitk.ReadImage(test_msk_folder + '/' + test_msk_list[i], imageIO="NrrdImageIO")
    curr_y = np.expand_dims(sitk.GetArrayFromImage(seg), axis=0)
    if i == 0:
        y_test = curr_y
    else:
        y_test = np.concatenate((y_test, curr_y), axis=0)

In [None]:
# Make the dataset compatible with the model
x_train = np.transpose(x_train, (0, 2, 3, 1)).astype('float')
y_train = np.transpose(y_train, (0, 2, 3, 1)).astype('float')
x_val = np.transpose(x_val, (0, 2, 3, 1)).astype('float')
y_val = np.transpose(y_val, (0, 2, 3, 1)).astype('float')
x_test = np.transpose(x_test, (0, 2, 3, 1)).astype('float')
y_test = np.transpose(y_test, (0, 2, 3, 1)).astype('float')

## Model Training

In [None]:
# Setting Model
my_model = vnet(loss=dice_loss) # Using different loss function by changing this line
my_model.summary()

In [None]:
# Training
history = my_model.fit(x=x_train, y=y_train, validation_data=(x_val, y_val), batch_size=1, epochs=50, verbose=1)

In [None]:
# Saving the model after training
my_model.save('my_model.h5')

In [None]:
# Using test set to evaluate the trained model
my_model.evaluate(x=x_test, y=y_test, batch_size=1, verbose=1)

## Plotting History

In [None]:
# Plotting history of loss function
plt.plot(history.history['loss'],color='r')
plt.plot(history.history['val_loss'],color='g')
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['train_loss', 'val_loss'], loc='upper left')
plt.show()

In [None]:
# Plotting history of metric
plt.plot(history.history['iou_score'],color='b')
plt.plot(history.history['val_iou_score'],color='k')
plt.title('Model Accuracy')
plt.ylabel('IoU')
plt.xlabel('Epoch')
plt.legend(['train_acc', 'val_acc'], loc='upper left')
plt.show()