In [None]:
import time
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import pandas as pd

In [None]:
os.chdir('Segmentation')

os.mkdir('Models')
os.mkdir('Learning_curves')
os.mkdir('Predictions')
os.mkdir('Examples')

# Load data

In [None]:
# Images and Masks are located in folders called 'Images' and 'Masks', respectively, and have the same names.
from PIL import Image
images = np.asarray([np.array(Image.open(os.path.join('Images',f))) for f in os.listdir('Images')])
masks = np.asarray([np.array(Image.open(os.path.join('Masks',f))) for f in os.listdir('Masks')])

In [None]:
_,w,h = images.shape

In [None]:
## Normalize images
M = images.max()
m = images.min()

images = (images - m)/(M - m)

In [None]:
# Set a seed for NumPy and for TensorFlow for random operations
seed = 15
np.random.seed(seed)
tf.random.set_seed(seed)

# Randomly shuffle data
p = np.random.permutation(len(images))
images = images[p]
masks = masks[p]

# 80-20 train-test split
p = round(0.8*len(images))
images_train = images[0:p]
images_test = images[p::]
masks_train = masks[0:p]
masks_test = masks[p::]

In [None]:
## Convert into 3-channel images
images_train_3ch = np.repeat(np.expand_dims(images_train,axis=-1),3,axis=-1)
images_test_3ch = np.repeat(np.expand_dims(images_test,axis=-1),3,axis=-1)

In [None]:
def display(display_list,title_list):
    N = len(display_list)
    L = len(display_list[0])-1
    plt.figure(figsize=(15, 5*N))
    
    for i in range(N):
        img = display_list[i][0]
        for j in range(L):
            mask = display_list[i][j+1]
            plt.subplot(N, L, (i*L+1)+j)
            plt.imshow(img,'gray'), plt.imshow(mask,alpha=0.5)            
            plt.axis('off')
            if i == 0: plt.title(title_list[j])

In [None]:
idx = [np.random.randint(0,len(images_test)) for i in range(10)]
display_list = [[images_test[i],masks_test[i]] for i in idx]
title_list = ['Ground truth']
display(display_list,['Ground Truth'])

img_list = [images_test[i] for i in idx]
gt_list = [masks_test[i] for i in idx]

# Model

In [None]:
from tensorflow.keras.applications import densenet, inception_v3, mobilenet, mobilenet_v2, vgg16, vgg19, resnet
from tensorflow.keras.layers import Conv2D, BatchNormalization, Conv2DTranspose, MaxPooling2D, AveragePooling2D, DepthwiseConv2D, Cropping2D, ZeroPadding2D, Concatenate # UpSampling2D
from tensorflow.keras import Model
from tensorflow.keras.optimizers import SGD, Adam, RMSprop, schedules
from tensorflow.keras import Input

In [None]:
CNN_dict = {'densenet121':'densenet.DenseNet121','densenet169':'densenet.DenseNet169','densenet201':'densenet.DenseNet201',
            'mobilenet':'mobilenet.MobileNet','mobilenet_v2':'mobilenet_v2.MobileNetV2',
            'vgg16':'vgg16.VGG16','vgg19':'vgg19.VGG19',
            'resnet50':'resnet.ResNet50','resnet101':'resnet.ResNet101','resnet152':'resnet.ResNet152',
            'inception_v3':'inception_v3.InceptionV3'}

In [None]:
def detect_downsampling_layers(encoder):
    downsampling_layers = []
    last_block = 'conv-_block-'
    for layer in encoder.layers:
        if isinstance(layer, tf.keras.layers.MaxPooling2D) \
           or (isinstance(layer, tf.keras.layers.AveragePooling2D) and layer.strides[0] > 1) \
           or (isinstance(layer, tf.keras.layers.Conv2D) and layer.strides[0] > 1) \
           or (isinstance(layer, tf.keras.layers.DepthwiseConv2D) and layer.strides[0] > 1):
            
            if '_block' not in layer.name:
                # Densenet, VGG and MobileNet models
                downsampling_layers.append(layer.input)
                
            else:
                if last_block[4] != layer.name[4]:
                    # ResNet models
                    downsampling_layers.append(layer.input)
                last_block = layer.name
    return downsampling_layers

def detect_downsampling_layers_inception(encoder):
    downsampling_layers = []
    for layer in encoder.layers:
        if (isinstance(layer, tf.keras.layers.Conv2D) and layer.strides[0] > 1 and layer.input.shape[1] == encoder.input.shape[1]) \
           or isinstance(layer, tf.keras.layers.MaxPooling2D):
            downsampling_layers.append(layer.input)                
    return downsampling_layers

In [None]:
def reshape_var(var,out_shape):
    # Width
    dw = var.shape[1] - out_shape[0]
    if dw > 0:
        dpad_w = (0,0)
        
        if dw % 2 == 0:
            dcrop_w = (dw//2,dw//2)
        else:
            dcrop_w = (dw//2+1,dw//2)
        
    elif dw < 0:
        dcrop_w = (0,0)
        
        if dw % 2 == 0:
            dpad_w = (-dw//2, -dw//2) 
        else:
            dpad_w = (-dw//2+1,-dw//2)
    else:
        dpad_w = (0,0)
        dcrop_w = (0,0)
    
    # Height
    dh = var.shape[2] - out_shape[1]
    if dh > 0:
        dpad_h = (0,0)
        
        if dh % 2 == 0:
            dcrop_h = (dh//2,dh//2)
        else:
            dcrop_h = (dh//2+1,dh//2)
        
    elif dh < 0:
        dcrop_h = (0,0)
        
        if dh % 2 == 0:
            dpad_h = (-dh//2, -dh//2) 
        else:
            dpad_h = (-dh//2+1,-dh//2)
    else:
        dpad_h = (0,0)
        dcrop_h = (0,0)
        
    var = Cropping2D(cropping=(dcrop_w,dcrop_h))(var)
    var = ZeroPadding2D(padding=(dpad_w,dpad_h))(var)
        
    return var

In [None]:
w_,h_ = w,h
skip_shapes = []
while w_ >= 7:
    skip_shapes.append((w_,h_))
    w_ = w_//2
    h_ = h_//2
skip_shapes.reverse()

In [None]:
def UNet_CNN(model_handle, img_width, img_height, skip_shapes, num_classes):    
    # Load a pretrained model as the encoder (e.g., ResNet50)
    base_model = eval(model_handle + '(include_top=False, input_shape=(img_width,img_height,3))')

    # Iterate through the encoder's layers to identify downsampling layers
    if 'inception' not in model_handle:
        skip_connections = detect_downsampling_layers(base_model)
    else:
        skip_connections = detect_downsampling_layers_inception(base_model)
    skip_connections.append(base_model.layers[-1].output)
    encoder = Model(base_model.input, skip_connections, name='Encoder')
    
    # Create the decoder part of the U-Net
    img_in = Input(shape=(img_width, img_height, 3))
    skip_connections = encoder(img_in)
    decoder = skip_connections[-1]
    
    for i,skip in enumerate(reversed(skip_connections[:-1])):
        # upsample
        num_filters = skip.shape[-1]
        decoder = Conv2DTranspose(num_filters,(2, 2),activation='relu',padding='same',strides=(2, 2))(decoder)
        
        # Adjust skip and decoder size
        skip = reshape_var(skip,skip_shapes[i+1])
        decoder = reshape_var(decoder,skip_shapes[i+1])
        
        # concatenate
        decoder = Concatenate()([decoder, skip])

        # convolution + batch normalization
        decoder = Conv2D(num_filters,activation='relu',kernel_size=3,strides=1,padding='same',use_bias=True)(decoder)
        decoder = BatchNormalization()(decoder)
        
        # convolution + batch normalization
        if i < len(skip_connections[:-1])-1:
            decoder = Conv2D(num_filters,activation='relu',kernel_size=3,strides=1,padding='same',use_bias=True)(decoder)
            decoder = BatchNormalization()(decoder)
        else:
            # Final segmentation (output) layer
            decoder = Conv2D(num_classes,activation='sigmoid',kernel_size=3,strides=1,padding='same',use_bias=True)(decoder)
                
    # Create the U-Net model
    unet = Model(img_in, decoder)
    
    return unet

# Metrics

In [None]:
## Accuracy
def Accuracy(labels,preds,threshold=0.5,smooth=1e-6):
    preds = tf.where(preds >= threshold, 1.0, 0.0)
    correct_predictions = tf.reduce_sum(tf.cast(tf.equal(preds, labels), tf.float32))
    total_pixels = tf.cast(tf.reduce_prod(tf.shape(labels)), tf.float32)
    acc = correct_predictions / total_pixels        
    
    return acc


## Jaccard index (Intersection over Union, IoU)
def jaccard(labels,preds,smooth=1e-6):
    # Compute the intersection and union of the predicted and ground truth masks
    intersection = tf.reduce_sum(labels * preds)
    union = tf.reduce_sum(labels) + tf.reduce_sum(preds) - intersection
    
    # Calculate the Jaccard index (IoU)
    jaccard_index = (intersection + smooth) / (union + smooth)
    
    return jaccard_index

def Jaccard(labels,preds,threshold=0.5,smooth=1e-6):
    preds = tf.where(preds >= threshold, 1.0, 0.0)
    jaccard_index = jaccard(labels,preds,smooth)  
    return jaccard_index


## Dice coefficient
def diceCoeff(labels,preds,smooth=1e-6):
    # Compute the intersection and union of the predicted and ground truth masks
    intersection = tf.reduce_sum(labels * preds)
    union = tf.reduce_sum(labels) + tf.reduce_sum(preds)
    
    # Calculate the Jaccard index (IoU)
    dice = (2.*intersection + smooth) / (union + smooth)
    
    return dice

def DiceCoeff(labels,preds,threshold=0.5,smooth=1e-6):
    preds = tf.where(preds >= threshold, 1.0, 0.0)
    dice = diceCoeff(labels,preds,smooth)
    return dice

# Loss function

In [None]:
def DiceLoss(labels,preds,smooth=1e-6):
    dice = diceCoeff(labels,preds,smooth)
    return 1 - dice

# Fine-tuning loop

In [None]:
def get_optimizer(optimizer_type,lr):
    if optimizer_type == 'Adam':
        optimizer = Adam(learning_rate=lr)
    elif optimizer_type == 'SGD':
        optimizer = SGD(learning_rate=lr)
    elif optimizer_type == 'RMSprop':
        optimizer = RMSprop(learning_rate=lr)
    else:
        print('Error: optimizer name')
        return
    return optimizer

In [None]:
def finetune(model,lr_freeze,lr_unfreeze,optim,criterion,train_images,train_masks,val_images,val_masks,
             num_epochs=30,num_epochs_freeze=15,batch_size=8,perc_unfreeze=0.2):
    ## Step 1: train randomly initialized weights
    #num_epochs_freeze = num_epochs//2
    print('Freezing base model...')
    
    # Optimizer
    optimizer = get_optimizer(optim,lr_freeze)
    
    # Freeze the base pretrained CNN
    for layer in model.layers:
        if layer.name == 'Encoder':
            for layer2 in layer.layers:
                layer2.trainable = False
        else:
            layer.trainable = True

    # Compile the model and train it on your dataset
    model.compile(optimizer=optimizer, loss=criterion, metrics=[Accuracy,Jaccard,DiceCoeff])

    # Train the model
    history = model.fit(train_images,np.expand_dims(train_masks,axis=-1).astype(np.float32),
                        validation_data=[val_images,np.expand_dims(val_masks,axis=-1).astype(np.float32)],
                        batch_size=batch_size,epochs=num_epochs_freeze, verbose=0)

    train_loss = history.history['loss']
    train_acc = history.history['accuracy']
    train_iou = history.history['jaccard']
    train_dice = history.history['dice_coeff']
    val_loss = history.history['val_loss']
    val_acc = history.history['val_accuracy']
    val_iou = history.history['val_jaccard']
    val_dice = history.history['val_dice_coeff']
    
    print('Step 1 completed.')
            
            
    ## Step 2: fine tune all the parameters in the model
    num_epochs_unfreeze = num_epochs - num_epochs_freeze
    print('Unfreezing base model...')
    
    # Optimizer
    optimizer = get_optimizer(optim,lr_unfreeze)
    
    # Unfreeze the top perc_unfreeze% layers in the encoder
    for layer in model.layers:
        if layer.name == 'Encoder':
            L = len(layer.layers)
            L_unfreeze = round(L*perc_unfreeze)
            for i,layer2 in enumerate(reversed(layer.layers)):
                if i <= L_unfreeze:
                    layer2.trainable = True
                else:
                    layer2.trainable = False
        else:
            layer.trainable = True

    # Compile the model and train it on your dataset
    model.compile(optimizer=optimizer, loss=criterion, metrics=[Accuracy,Jaccard,DiceCoeff])

    # Train the model
    history = model.fit(train_images,np.expand_dims(train_masks,axis=-1).astype(np.float32),
                        validation_data=[val_images,np.expand_dims(val_masks,axis=-1).astype(np.float32)],
                        batch_size=batch_size,epochs=num_epochs_unfreeze,verbose=0)

    train_loss = train_loss + history.history['loss']
    train_acc = train_acc + history.history['accuracy']
    train_iou = train_iou + history.history['jaccard']
    train_dice = train_dice + history.history['dice_coeff']
    val_loss = val_loss + history.history['val_loss']
    val_acc = val_acc + history.history['val_accuracy']
    val_iou = val_iou + history.history['val_jaccard']
    val_dice = val_dice + history.history['val_dice_coeff']
    
    print('Step 2 completed.')
    
    train_history = [train_loss,train_acc,train_iou,train_dice]
    val_history = [val_loss,val_acc,val_iou,val_dice]
    
    return model, train_history, val_history

In [None]:
def plot_training(train_history,val_history,num_epochs):
    font_title = 12
    font_legend = 10
    
    # Visualize the training results
    epochs_range = range(num_epochs)
    plt.figure(figsize=(10, 10))
    
    train_loss,train_acc,train_iou,train_dice = train_history
    val_loss,val_acc,val_iou,val_dice = val_history
    
    plt.subplot(2, 2, 1)
    plt.plot(epochs_range, train_loss, label='Training Loss')
    plt.plot(epochs_range, val_loss, label='Test Loss')
    plt.legend(loc='upper right', fontsize=font_legend)
    plt.title('Train and Test Loss', fontsize=font_title)
    
    plt.subplot(2, 2, 2)
    plt.plot(epochs_range, train_acc, label='Training Accuracy')
    plt.plot(epochs_range, val_acc, label='Test Accuracy')
    plt.legend(loc='lower right', fontsize=font_legend)
    plt.title('Train and Test Accuracy', fontsize=font_title)
    
    plt.subplot(2, 2, 3)
    plt.plot(epochs_range, train_iou, label='Training IoU')
    plt.plot(epochs_range, val_iou, label='Test IoU')
    plt.legend(loc='lower right', fontsize=font_legend)
    plt.title('Train and Test IoU', fontsize=font_title)

    plt.subplot(2, 2, 4)
    plt.plot(epochs_range, train_dice, label='Training Dice')
    plt.plot(epochs_range, val_dice, label='Test Dice')
    plt.legend(loc='lower right', fontsize=font_legend)
    plt.title('Train and Test Dice', fontsize=font_title)
    #plt.show()
    
    return

In [None]:
def plot_examples(model,img_list,gt_list,threshold=0.5):
    # Plot a few examples
    N = len(img_list)
    plt.figure(figsize=(12, 5*N))
    
    for i,data in enumerate(zip(img_list,gt_list)):
        img,gt = data
        
        # Display ground truth
        plt.subplot(N, 2, 2*i+1)
        plt.imshow(img,'gray'), plt.imshow(gt,alpha=0.5)            
        plt.axis('off')
        if i == 0: plt.title('Ground truth', fontsize=20)
            
        # Predict mask
        img_ = np.repeat(np.expand_dims(img,axis=-1),3,axis=-1)
        pred = np.round(model.predict(np.expand_dims(img_,0)))
        
        # Display predicted mask
        plt.subplot(N, 2, 2*i+2)
        plt.imshow(img,'gray'), plt.imshow(pred[0],alpha=0.5)            
        plt.axis('off')
        if i == 0: plt.title('Predicted', fontsize=20)

In [None]:
## Hyperparameters
# Training
optim = 'Adam'   # 'Adam','SGD','RMSprop'
batch_size = 8
num_epochs = 30

# Fine-tuning
num_epochs_freeze = num_epochs//2
perc_unfreeze = 0.2

# Number of classes (output)
num_classes = 1
    
# Classification threshold (probability > threshold: 1; 0 otherwise)
threshold = 0.5

# Loss function
criterion = DiceLoss

# Hyperparameter tuning

In [None]:
# Hyperparameter optimization
n_trials = 30

# k for k-fold cross-validation
k = 4

In [None]:
import optuna

In [None]:
def create_crossval_subsets(images,masks,k):
    # Create train+val subsets for k-fold cross-validation
    fold_size = len(images) // k
    
    train_val_subsets = []
    for i in range(k):
        val_images = images[i*fold_size:(i+1)*fold_size]
        train_images = np.concatenate((images[0:i*fold_size],images[(i+1)*fold_size::]))
        
        val_masks = masks[i*fold_size:(i+1)*fold_size]
        train_masks = np.concatenate((masks[0:i*fold_size],masks[(i+1)*fold_size::]))
        
        train_val_subsets.append([train_images,train_masks,val_images,val_masks])
    
    return train_val_subsets

def objective(trial, model_handle, images, masks, num_classes, k, optim='Adam', criterion=DiceLoss, num_epochs=100,
              num_epochs_freeze=50, batch_size=8, perc_unfreeze=0.2):  
    # Initialize model
    _,w,h,_ = images.shape
    unet = UNet_CNN(model_handle, w, h, skip_shapes, num_classes)
    init_weights = unet.get_weights() 
    
    ## Hyperparameters
    # Learning rate
    lr_freeze = trial.suggest_float('lr_freeze', 1e-5, 1e-2, log=True)
    lr_unfreeze = trial.suggest_float('lr_unfreeze', 1e-6, 1e-3, log=True)
        
    ## Cross-validation
    train_val_subsets = create_crossval_subsets(images,masks,k)
    fold_loss_train,fold_acc_train,fold_iou_train,fold_dice_train = [],[],[],[]
    fold_loss,fold_acc,fold_iou,fold_dice = [],[],[],[]
    train_time = []
    
    for j,data in enumerate(train_val_subsets):
        print('Fold {}/{}'.format(j+1, k))
        
        train_images,train_masks,val_images,val_masks = data
        
        unet.set_weights(init_weights)
        
        # Train the model
        unet, train_history, val_history = finetune(unet,lr_freeze,lr_unfreeze,optim,criterion,train_images,train_masks,val_images,val_masks,
                                                     num_epochs=num_epochs,num_epochs_freeze=num_epochs_freeze,batch_size=batch_size,perc_unfreeze=perc_unfreeze)
    
        #fold_loss_train.append(train_history[0][-1])
        #fold_acc_train.append(train_history[1][-1])
        #fold_iou_train.append(train_history[2][-1])
        #fold_dice_train.append(train_history[3][-1])
        fold_loss.append(val_history[0][-1])
        #fold_acc.append(val_history[1][-1])
        #fold_iou.append(val_history[2][-1])
        fold_dice.append(val_history[3][-1])
        
    print('Fold Dice coeffs:'), print(fold_dice)
    
    return np.asarray(fold_loss).mean()  # Objective value linked with the Trial object. 

In [None]:
# Wrap the objective inside a lambda and call objective inside it
func = lambda trial: objective(trial, model_handle, images_train_3ch, masks_train, num_classes, k, optim=optim, criterion=criterion, num_epochs=num_epochs,
                               num_epochs_freeze=num_epochs_freeze, batch_size=batch_size, perc_unfreeze=perc_unfreeze)

In [None]:
columns = ['pretrainedCNN','nTrials','k-fold','optimizer','lossFunc','numEpochsFreeze','numEpochsUnfreeze','batchSize','percUnfreezeLayers','trainTime','lrFreeze','lrUnfreeze',
          'bestMeanLoss']
df_hyperparameter_tuning = pd.DataFrame(columns=columns)

In [None]:
def tune_hyperparameters(model_handle,images,masks,num_classes,n_trials=50,k=4,optim='Adam',criterion=DiceLoss,num_epochs=100,num_epochs_freeze=50,
                         batch_size=8,perc_unfreeze=0.2):
                  
    # Tune hyperparameters (lr_freeze and lr_unfreeze)
    study = optuna.create_study()  # Create a new study.
    study.optimize(func, n_trials=n_trials)  # Invoke optimization of the objective function.

    # Get best hyperparameter combination
    lr_freeze = study.best_params['lr_freeze']
    lr_unfreeze = study.best_params['lr_unfreeze']
    optuna.visualization.plot_parallel_coordinate(study, params=['lr_freeze','lr_unfreeze'])
    
    hyperparameters = [lr_freeze,lr_unfreeze]
    
    return hyperparameters,study.best_value,study.trials_dataframe()

In [None]:
tuned_lr = {}
for baseModel_name in CNN_dict.keys():
    print(baseModel_name), print()
    model_handle = CNN_dict[baseModel_name]

    # Tune hyperparameters (lr_freeze and lr_unfreeze)
    tic = time.time()
    hyperparameters,best_value,trials_df = tune_hyperparameters(model_handle,images_train_3ch,masks_train,num_classes,n_trials,4,optim,criterion,num_epochs,
                                                                num_epochs_freeze,batch_size,perc_unfreeze)
    elapsedTime = time.time() - tic
    print('Hyperparameter tuning took ' + str(elapsedTime//60) + ' minutes and ' + str(elapsedTime%60) + ' seconds')
    
    lr_freeze,lr_unfreeze = hyperparameters
    tuned_lr[baseModel_name] = {'lr_freeze' : lr_freeze, 'lr_unfreeze' : lr_unfreeze}
    
    df_hyperparameter_tuning.loc[len(df_hyperparameter_tuning.index)] = [baseModel_name,n_trials,k,optim,'DiceLoss',num_epochs_freeze,num_epochs-num_epochs_freeze,
                                                                        batch_size,100*perc_unfreeze,elapsedTime,lr_freeze,lr_unfreeze,best_value]
    df_hyperparameter_tuning.to_csv('HyperparamTuning_Results.csv',index=False)
    print(df_hyperparameter_tuning.iloc[len(df_hyperparameter_tuning.index)-1])
    print()

In [None]:
pd.DataFrame(tuned_lr)

# Training with tuned hyperparameters

In [None]:
df_hyperparameter_tuning = pd.read_csv(HyperparamTuning_Results.csv')

In [None]:
def train_model(baseModel_name,images_train,masks_train,images_test,masks_test,num_classes,CNN_dict,lr_freeze=0.005,lr_unfreeze=0.0005,optim='Adam',criterion=DiceLoss,
                num_epochs=100,num_epochs_freeze=50,batch_size=8,perc_unfreeze=0.2):   

    model_handle = CNN_dict[baseModel_name]
    
    # Initialize model
    _,w,h,_ = images_train.shape
    unet = UNet_CNN(model_handle, w, h, skip_shapes, num_classes)
        
    # Train the model
    tic = time.time()
    unet, train_history, test_history = finetune(unet,lr_freeze,lr_unfreeze,optim,criterion,images_train,masks_train,images_test,masks_test,
                                                 num_epochs=num_epochs,num_epochs_freeze=num_epochs_freeze,batch_size=batch_size,perc_unfreeze=perc_unfreeze)
    train_time = time.time() - tic
    
    results = [train_history[0][-1],train_history[1][-1],train_history[2][-1],train_history[3][-1],
               test_history[0][-1],test_history[1][-1],test_history[2][-1],test_history[3][-1]]
    
    ## Save trained model
    unet.save('Models/' + baseModel_name + '.h5')
    
    # Plot training process
    plot_training(train_history,test_history,num_epochs)
    plt.savefig('Learning_curves/' + baseModel_name + '.png')
    plt.show()
    
    # Save predicted masks
    preds = unet.predict(np.concatenate((images_train,images_test)))
    if preds.ndim == 3:
        preds = (preds > threshold).astype(int)
    else:
        preds = np.argmax(preds,axis=-1)
    np.save('Predictions/'+baseModel_name+'.npy',preds)
    
    return unet,results,train_time

In [None]:
columns = ['pretrainedCNN','optimizer','lossFunc','numEpochsFreeze','numEpochsUnfreeze','batchSize','percUnfreezeLayers','lrFreeze','lrUnfreeze','trainTime','inferenceTime',
          'trainLoss','trainAcc','trainIOU','trainDice','testLoss','testAcc','testIOU','testDice']
df_results = pd.DataFrame(columns=columns)

In [None]:
for baseModel_name in CNN_dict.keys():
    print(baseModel_name), print()
    
    # Train the model
    lr_freeze = df_hyperparameter_tuning[df_hyperparameter_tuning['pretrainedCNN']==baseModel_name]['lrFreeze'].values[0]
    lr_unfreeze = df_hyperparameter_tuning[df_hyperparameter_tuning['pretrainedCNN']==baseModel_name]['lrUnfreeze'].values[0]
    model,results,train_time = train_model(baseModel_name,images_train_3ch,masks_train,images_test_3ch,masks_test,num_classes,CNN_dict,lr_freeze,lr_unfreeze,
                                           optim,criterion,num_epochs,num_epochs_freeze,batch_size,perc_unfreeze)
    
    # Inference time
    inference_time = []
    for img in images_test_3ch:
        tic = time.time()
        pred = model.predict(img[np.newaxis,:])
        inference_time.append(time.time() - tic)
    inference_time = np.asarray(inference_time).mean()
    
    # Store results
    train_loss,train_acc,train_iou,train_dice,test_loss,test_acc,test_iou,test_dice = results
    train_time = train_time//60 + train_time%60
    df_results.loc[len(df_results.index)] = [baseModel_name,optim,'DiceLoss',num_epochs_freeze,num_epochs-num_epochs_freeze,batch_size,100*perc_unfreeze,lr_freeze,
                                            lr_unfreeze,train_time,inference_time,train_loss,train_acc,train_iou,train_dice,test_loss,test_acc,test_iou,test_dice]
    df_results.to_csv('Results.csv',index=False)
    print(df_results.iloc[len(df_results.index)-1])
    
    # Plot some examples
    plot_examples(model,img_list,gt_list, threshold)
    plt.savefig('Examples/' + baseModel_name + '.png')
    plt.show()
    
    print()