# Training the UNET

We have defined the data generator, the architecture, the loss function and the metric we are going to use. Now, we put all the pieces together to train our U-Net. First, we import the libraries we need for the training.

In [3]:
import numpy as np
import tensorflow as tf
import json
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
os.environ["CUDA_VISIBLE_DEVICES"]="0"
from glob import glob

from tensorflow.keras import initializers, regularizers, constraints, optimizers, metrics
from tensorflow.keras import backend as K
from tensorflow.keras.activations import softmax
from tensorflow.keras.layers import Dense, Input, Conv2D, Conv2DTranspose, Dropout, Flatten, BatchNormalization, Concatenate, Lambda, Activation, Reshape, Layer, InputSpec
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.utils import Sequence

import Data_Generator.DataGenerator as DG
import Loss_Metrics.LossMetrics 
import UNet_Arch.UNET_architecture as unet

tf.keras.backend.clear_session()

We set the path and the names of the data set.

In [9]:
dim = 128  # Available dimensions: 128, 256, 512

# Allow for overriding in tests
dim = os.environ.get("UNET_SHEZEN_DIM", dim)

data_url = {128: '0e1e5d/dl/Shezen_128128.zip', 256: '3a4a44/dl/Shezen_256256.zip', 512: '80a4eb/dl/Shezen_512512.zip'}

if not os.path.exists(f"./Shezen_{dim:d}{dim:d}"):
    !  wget https://pandora.infn.it/public/{data_url[dim]} -qO tmp.zip
    ! unzip -q tmp.zip 
    ! rm -rf tmp.zip

path='./Shezen_{:d}{:d}/'.format(dim,dim) # input images path
data_path= path + '*'

We downloaded the data, now we split them into training, validation and test set. Please leave the data partition as written such that each group will test their architecture on the same test set.

In the following cell you have to set the data generator for both the training and the validation set.

In [5]:
data_list = [os.path.basename(f) for f in glob(data_path)] 

train_list = data_list[0:500]
val_list = data_list[500:590]
test_list = data_list[590:]
training_gen = DG.DataGenerator(path=path,list_X=train_list,batch_size=8,dim=(dim,dim))
validation_gen = DG.DataGenerator(path=path,list_X=val_list,batch_size=1,dim=(dim,dim))

# Callbacks for training

As you may know, we can use some built-in callbacks that help to train the network. 

ModelCheckPoint is used to save the weights at the epoch with the best validation metric. You need to create a folder named 'results'. You can read the documentation [here](https://keras.io/api/callbacks/model_checkpoint/).

ReduceLROnPlateau is used to decrease the learning rate during training. You can read the documentation [here](https://keras.io/api/callbacks/reduce_lr_on_plateau/).

We define also the optimizer.

In [6]:
#checkpoint_path="results/weights-{epoch:03d}-{loss:.2f}-{DSC:.2f}.hdf5" # intermediate weight save path
checkpoint_path_val="results/weights-val-{epoch:03d}-{loss:.2f}-{val_DSC:.2f}.hdf5"  # intermediate weigth save path

#check=ModelCheckpoint(filepath=checkpoint_path, monitor='loss', verbose=0, save_best_only=False, save_weights_only=True, mode='auto', save_freq='epoch') # Checkpoints parameters
check_val=ModelCheckpoint(filepath=checkpoint_path_val, monitor='val_DSC', verbose=0, save_best_only=True, save_weights_only=True, mode='max')

reduce_lr = ReduceLROnPlateau(monitor='val_DSC', factor=0.5, patience=8, min_lr=0.000005, verbose= True)

adamlr = optimizers.Adam(learning_rate=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, amsgrad=True)

Call the U_Net architecture, compile and train it.

In [8]:
unet_test=unet.U_net((dim,dim,1))
unet_test.compile(loss=Loss_Metrics.LossMetrics.DSC_loss,optimizer=adamlr, metrics=[Loss_Metrics.LossMetrics.DSC])
unet_test.summary()
history = unet_test.fit(training_gen, validation_data= validation_gen, epochs=50,callbacks=[reduce_lr, check_val],verbose=1)

with open('history.json', 'w') as f:
    json.dump(str(history.history), f)


Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 128, 128, 1  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_28 (Conv2D)             (None, 128, 128, 8)  16          ['input_2[0][0]']                
                                                                                                  
 batch_normalization_32 (BatchN  (None, 128, 128, 8)  32         ['conv2d_28[0][0]']              
 ormalization)                                                                                    
                                                                                            