# Train - Validate - Evaluate - Threshold study
This notebook focuses on the binary model. At the end of training the model performances are computed for multiple threshold values.

In [None]:
import torch, torchvision
import os
import random
import datasets
import metrics
import time

import constants as cst
import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F

import torchgeometry.losses as loss_fn
from unet import UNET
import utils

In [None]:
def predict_img(model, image, device, transform, out_threshold=0.5):
    with torch.no_grad():
        x = image
        logits = model(x.to(device))
        logits = transform(logits)
        y_pred = nn.Softmax(dim=1)(logits)
        proba = y_pred.detach().cpu().squeeze(0).numpy()[1, :, :]
        return proba > out_threshold

# One validation step
def validate(model, validation_loader, transform, DEVICE):
    model.eval()
    with torch.no_grad():
        val_loss = []
        for images, masks, names in validation_loader:
            images = transform(images)
            outputs = model(images.to(DEVICE))

            masks = masks.type(torch.LongTensor)
            masks = transform(masks)
            masks = torch.squeeze(masks, 1)

            vloss = criterion(outputs, masks.to(DEVICE))
            loss = vloss.detach().item()
            val_loss.append(loss)

        loss = np.mean(val_loss)
    return loss

# One training step
def train(model, training_loader, transforms, DEVICE, criterion, optimiser):
    model.train()
    train_loss = []
    for images, masks, names in training_loader:
        images = transform(images)
        outputs = model(images.to(DEVICE))

        masks = masks.type(torch.LongTensor)
        masks = transform(masks)
        masks = torch.squeeze(masks, 1)

        tloss = criterion(outputs, masks.to(DEVICE))
        loss = tloss.detach().item()
        train_loss.append(loss)

        optimiser.zero_grad()
        tloss.backward()
        optimiser.step()

    loss = np.mean(train_loss)
    return loss

# Evaluating the performances of the model for a certain threshold value
def evaluate(eval_model, testing_loader, threshold):
    tps = 0
    precisions = []
    recalls = []
    F1s = []
    IOUs = []

    eval_model.eval()
    for image, mask, name in testing_loader:
        prediction = predict_img(eval_model, transform(image), DEVICE, untransform, out_threshold=threshold)
        pred = torch.from_numpy(prediction)

        precisions.append(metrics.precision(pred, mask))
        recalls.append(metrics.recall(pred, mask))
        F1s.append(metrics.F1Score(pred, mask))
        IOUs.append(metrics.IOUScore(pred, mask))
    return precisions, recalls, F1s, IOUs

# Writes the performances of the model with a threshold of 0.5 to a text file 
# Writes it in a latex table format
def write_latex_half(dir_name, loss_name, val, prec, rec, f1, IOU):
    f = open(os.path.join(dir_name,"latex_half.txt"),"a+")
    if loss_name == "CE":
        f.write("\\begin{table}[!h]\n")
        f.write("\\centering\n")
        f.write("\\begin{tabular}{|c|c|c|c|c|c|c|}\n")
        f.write("\\hline\n")
        f.write("Loss & Validation & Precision & Recall & F1 & IOU & Avg\\\\ \n")
        f.write("\\hline\n")
        f.write("\\hline\n")
    f.write(loss_name + " & " + str(val) + " & " + str(prec) + " & " + str(rec)
            + " & " + str(f1) + " & " + str(IOU) + " & " + str((prec+rec+f1+IOU)/4) + "\\\\ \n")
    f.write("\\hline\n")
    if loss_name == "Focal":
        f.write("\end{tabular}\n")
        f.write("\\caption{}\n")
        f.write("\\label{}\n")
        f.write("\\end{table}\n")
    f.close()
    return

# Writes the performances of the model with a threshold maximizing the IOU to a text file 
# Writes it in a latex table format
def write_latex_IOU(dir_name, loss_name, prec, rec, f1, IOU, th):
    f = open(os.path.join(dir_name,"latex_max_IOU.txt"),"a+")
    if loss_name == "CE":
        f.write("\\begin{table}[!h]\n")
        f.write("\\centering\n")
        f.write("\\begin{tabular}{|c|c|c|c|c|c|c|}\n")
        f.write("\\hline\n")
        f.write("Loss & Threshold & Precision & Recall & F1 & IOU & Avg\\\\ \n")
        f.write("\\hline\n")
        f.write("\\hline\n")
    f.write(loss_name + " & " + str(th) + " & " + str(prec) + " & " + str(rec)
            + " & " + str(f1) + " & " + str(IOU) + " & " + str((prec+rec+f1+IOU)/4) + "\\\\ \n")
    f.write("\\hline\n")
    if loss_name == "Focal":
        f.write("\\end{tabular}\n")
        f.write("\\caption{}\n")
        f.write("\\label{}\n")
        f.write("\\end{table}\n")
    f.close()
    return

In [None]:
random.seed(cst.SEED)
torch.manual_seed(cst.SEED)
np.random.seed(cst.SEED)

TERM = "br2"  # Term to segment
SIZE = (384, 512)  # Related to this project
FOLDS = [0,1,2,3,4]  # Allows to run all or some folds (computational resources limits)
LOSSES = ["CE", "Dice", "Tversky", "Focal"] # Can study all or some loss functions (computational resources limits)
n_th = 51
thresholds = np.linspace(0, 1, num=n_th)  # All threshold values used to study the performances

run_name = TERM + "_normal"
dir_name = os.path.join(cst.DIR, run_name)
os.makedirs(dir_name, exist_ok = True)

# Text file log
if "train.txt" not in os.listdir(dir_name):
    f = open(os.path.join(dir_name,"train.txt"),"w+")
    f.write("--------------------------------------------------\n")
    f.write("Term studied: " + TERM + "\n\n")
    f.write("--------------------------------------------------\n")
    f.close()

DEVICE_NAME = "cpu"
if torch.cuda.is_available():
    DEVICE_NAME = 'cuda:0'
DEVICE = torch.device(DEVICE_NAME)

In [None]:
for loss_name in LOSSES:
    print("Starting loss:", loss_name)
    # Used for studying the performances using different threshold values
    loss_precision = [0] * n_th
    loss_recall = [0] * n_th
    loss_f1 = [0] * n_th
    loss_IOU = [0] * n_th
    
    print("Starting term: " + TERM)
    start_term = time.time()

    image_folder = os.path.join(cst.DIR, "images")
    mask_folder = os.path.join(cst.DIR, TERM)

    # Transforms for the images
    transform = transforms.Compose([transforms.Resize(SIZE),
                                    transforms.Pad((0, 64, 0, 64))])
    untransform = transforms.Compose([transforms.CenterCrop(SIZE),
                                     transforms.Resize((1932, 2576))])

    fold_validation = []
    fold_precision = []
    fold_recall = []
    fold_f1 = []
    fold_IOU = []

    for fold in FOLDS:
        print("Starting fold: {}".format(fold))
        start_fold = time.time()
        """Datasets and loaders"""
        training_set = datasets.ZebrafishDataset_KFold(image_folder,
                                                      mask_folder,
                                                      actual_fold=fold,
                                                      dataset="train",
                                                      folds=cst.FOLDS)
        validation_set = datasets.ZebrafishDataset_KFold(image_folder,
                                                        mask_folder,
                                                        actual_fold=fold,
                                                        dataset="validate",
                                                        folds=cst.FOLDS)
        testing_set = datasets.ZebrafishDataset_KFold(image_folder,
                                                     mask_folder,
                                                     actual_fold=fold,
                                                     dataset="test",
                                                     folds=cst.FOLDS)

        training_loader = torch.utils.data.DataLoader(training_set,
                                                      batch_size=cst.BATCH_SIZE,
                                                      shuffle=True,
                                                      num_workers=cst.WORKERS)

        validation_loader = torch.utils.data.DataLoader(validation_set,
                                                        batch_size=cst.BATCH_SIZE,
                                                        shuffle=True,
                                                        num_workers=cst.WORKERS)

        testing_loader = torch.utils.data.DataLoader(testing_set,
                                                     batch_size=1,
                                                     shuffle=True,
                                                     num_workers=cst.WORKERS)

        model = UNET(3, 2)
        model.to(DEVICE)
        best_model = UNET(3, 2)
        best_model = model
       
        # Text file log - writing loss and parameters
        f = open(os.path.join(dir_name,"train.txt"),"a+")
        if fold==0:
            f.write("Loss used: " + loss_name + "\n\n")
                
        criterion = nn.CrossEntropyLoss()
        criterion_string = loss_name

        if loss_name == "Dice":
            criterion = loss_fn.DiceLoss()
        if loss_name == "Loss used: Tversky":
            A = 0.7
            B = 0.3
            criterion = loss_fn.TverskyLoss(alpha=A, beta=B)
            if fold==0:
                f.write("Alpha: " + str(A) + "\n")
                f.write("Beta: " + str(B) + "\n\n")
        if loss_name == "Focal":
            A = 0.8
            G = 2
            criterion = loss_fn.FocalLoss(alpha=A, gamma=G, reduction="mean")
            if fold==0:
                f.write("Alpha: " + str(A) + "\n")
                f.write("Gamma: " + str(G) + "\n\n")
            
        if fold==0:
            f.write("Learning rate: " + str(cst.LEARNING_RATE) + "\n")
            f.write("Weight decay: " + str(cst.WEIGHT_DECAY) + "\n")
            f.write("Max epochs: " + str(cst.EPOCHS) + "\n")
            f.write("Batch size: " + str(cst.BATCH_SIZE) + "\n")
            f.write("Workers: " + str(cst.WORKERS) + "\n\n")
            f.write("--------------------------------------------------\n")
        f.write("Current fold: " + str(fold) + "\n\n")
        f.close()

        optimiser = torch.optim.Adam(model.parameters(), lr=cst.LEARNING_RATE, weight_decay=cst.WEIGHT_DECAY)

        """Computing validation loss before training"""
        loss = validate(model, validation_loader, transform, DEVICE)

        best_val = loss
        best_epoch = 0
        last_epoch = 0

        epochs_train_losses = []
        epochs_val_losses = []
        for epoch in range(cst.EPOCHS):
            """Training"""
            loss = train(model, training_loader, transforms, DEVICE, criterion, optimiser)
            epochs_train_losses.append(loss)

            """Validation"""
            loss = validate(model, validation_loader, transform, DEVICE)
            epochs_val_losses.append(loss)

            """Updating best model"""
            if loss < best_val:
                best_val = loss
                best_model = model
                best_epoch = epoch+1

            if (epoch+1)%50 == 0:
                print("Epoch: " + str(epoch+1))
                print("Validation: {}.".format(loss))
                print("Best validation: {}.".format(best_val))

            """Train and validate loops over"""
            curr = time.time()
            curr = curr - start_term
            secondes = curr % 60
            minutes = (curr-secondes)/60

            last_epoch = epoch

            # Notebooks shutdown after 6 hours. Stop the code and save the results.
            if minutes >= 345:
                f = open(os.path.join(dir_name,"train.txt"),"a+")
                f.write("Learning stopped due to timeout. \n")
                f.close()
                break
            if (epoch - best_epoch) >= 50:
                break

        """All epochs are over"""
        fold_validation.append(best_val)

        model_name = TERM + '_' + loss_name + "_Fold_" + str(fold) + "_Epoch_" + str(best_epoch) + "_MaxEpochs_" 
        model_name += str(cst.EPOCHS) + '_' + cst.OPTIMIZER + "_LR_" + str(cst.LEARNING_RATE) + ".pth"

        model_filepath = os.path.join(dir_name, model_name)
        torch.save(best_model.state_dict(), model_filepath) # save better?
        
        curr = time.time()
        curr = curr - start_fold
        secondes = curr % 60
        minutes = (curr-secondes)/60

        # Plot losses
        index = [i+1 for i in range(last_epoch+1)]
        plt.plot(index[1:], epochs_train_losses[1:], label="Training")
        plt.plot(index[1:], epochs_val_losses[1:], label="Validation")
        plt.title("Term: " + TERM + ", Loss: " + loss_name + ", Fold: " + str(fold)) 
        plt.ylabel("Loss")
        plt.xlabel("Epochs")
        plt.legend()
        plt.savefig(os.path.join(dir_name, TERM + "_" + loss_name +"_Fold_" + str(fold) + "_Loss_curves.jpg"))
        plt.show()
        
        # Text file log
        f = open(os.path.join(dir_name,"train.txt"),"a+")
        f.write("Best epoch: " + str(best_epoch) + "\n")
        f.write("Best validation loss: " + str(best_val) + "\n")
        f.write("Ellapsed time: " + str(minutes) + " minutes " + str(secondes) + " seconds\n\n")
        f.write("Name of the model saved:\n")
        f.write(model_name + "\n\n")
        f.close()

        """Evaluating performances for each threshold"""
        for th in range(n_th):
            precisions, recalls, F1s, IOUs = evaluate(best_model, testing_loader, (th*2)/100)

            mean_precision = np.mean(precisions)
            mean_recall = np.mean(recalls)
            mean_f1 = np.mean(F1s)
            mean_IOU = np.mean(IOUs)

            loss_precision[th] += mean_precision
            loss_recall[th] += mean_recall
            loss_f1[th] += mean_f1
            loss_IOU[th] += mean_IOU

            # Text file log - perfomances with a threshold of 0.5
            if th == (n_th-1)/2:
                f = open(os.path.join(dir_name,"train.txt"),"a+")
                f.write("Performance of this model for a threshold of 0.5: \n")
                f.write("Precision: " + str(mean_precision) + "\n")
                f.write("Recall: " + str(mean_recall) + "\n")
                f.write("F1-Dice: " + str(mean_f1) + "\n")
                f.write("IOU: " + str(mean_IOU) + "\n")
                f.write("Average: " + str((mean_precision+mean_recall+mean_f1+mean_IOU)/4)+ "\n")
                f.close()
                fold_precision.append(mean_precision)
                fold_recall.append(mean_recall)
                fold_f1.append(mean_f1)
                fold_IOU.append(mean_IOU)

        confidence = 0.9
        
        curr = time.time()
        curr = curr - start_fold
        secondes = curr % 60
        minutes = (curr-secondes)/60
        f = open(os.path.join(dir_name,"train.txt"),"a+")
        f.write("Total fold time: " + str(minutes) + " minutes " + str(secondes) + " seconds\n\n")
        f.write("--------------------------------------------------\n")

        print("Last epoch: {}".format(last_epoch))
        print("Term: " + TERM)
        print("Fold: {}".format(fold))
        print("Fold took: " + str(minutes) + " minutes " + str(secondes) + " seconds to train")
        print("Last val: {}".format(loss))
        print("Best val: {}".format(best_val))
        print()
        print("Precision: {}".format(fold_precision[fold]))
        print("Recall: {}".format(fold_recall[fold]))
        print("F1/Dice score: {}".format(fold_f1[fold]))
        print("IoU: {}".format(fold_IOU[fold]))
        avg = (fold_precision[fold]+ fold_recall[fold]+ fold_f1[fold]+fold_IOU[fold])/4
        print("Avg:", avg)
        print("--------------------")
    """Fold loop end"""
    
    # Mean performances of all folds
    all_f_prec = np.mean(fold_precision)
    all_f_rec = np.mean(fold_recall)
    all_f_f1 = np.mean(fold_f1)
    all_f_IOU = np.mean(fold_IOU)
    all_f_val = str(np.mean(fold_validation))
    
    # Text file log - mean performmances
    f = open(os.path.join(dir_name,"train.txt"),"a+")
    f.write("Average performance of all models for a threshold of 0.5: \n")
    f.write("Mean validation: " + str(all_f_val) + "\n")
    f.write("Precision: " + str(all_f_prec) + "\n")
    f.write("Recall: " + str(all_f_rec) + "\n")
    f.write("F1-Dice: " + str(all_f_f1) + "\n")
    f.write("IOU: " + str(all_f_IOU) + "\n")
    f.write("Average: " + str((all_f_prec+all_f_rec+all_f_f1+all_f_IOU)/4) + "\n\n")
    f.close()
    
    write_latex_half(dir_name, loss_name, all_f_val, all_f_prec, all_f_rec, all_f_f1, all_f_IOU)
    
    # Average of all performances for all threshold values
    loss_avg = [0] * n_th
    for i in range(len(loss_precision)):
        loss_precision[i] = loss_precision[i]/5
        loss_recall[i] = loss_recall[i]/5
        loss_f1[i] = loss_f1[i]/5
        loss_IOU[i] = loss_IOU[i]/5
        loss_avg[i] = (loss_precision[i]+loss_recall[i]+loss_f1[i]+loss_IOU[i])/4
    
    # Studying threshold maximizing each metric
    max_IOU = np.argmax(loss_IOU)
    max_prec = np.argmax(loss_precision)
    max_rec = np.argmax(loss_recall)
    max_f1 = np.argmax(loss_f1)
    max_avg = np.argmax(loss_avg)
    
    # Text file log - maximum precision
    f = open(os.path.join(dir_name,"train.txt"),"a+")
    f.write("Threshold maximising the precision: " + str(thresholds[max_prec]) + "\n")
    f.write("Precision: " + str(loss_precision[max_prec]) + "\n")
    f.write("Recall: " + str(loss_recall[max_prec]) + "\n")
    f.write("F1-Dice: " + str(loss_f1[max_prec]) + "\n")
    f.write("IOU: " + str(loss_IOU[max_prec]) + "\n")
    f.write("Average: " + str(loss_avg[max_prec]) + "\n\n")
    
    # Text file log - maximum recall
    f.write("Threshold maximising the recall: " + str(thresholds[max_rec]) + "\n")
    f.write("Precision: " + str(loss_precision[max_rec]) + "\n")
    f.write("Recall: " + str(loss_recall[max_rec]) + "\n")
    f.write("F1-Dice: " + str(loss_f1[max_rec]) + "\n")
    f.write("IOU: " + str(loss_IOU[max_rec]) + "\n")
    f.write("Average: " + str(loss_avg[max_rec]) + "\n\n")
    
    # Text file log - maximum Dice score
    f.write("Threshold maximising the F1-Dice score: " + str(thresholds[max_f1]) + "\n")
    f.write("Precision: " + str(loss_precision[max_f1]) + "\n")
    f.write("Recall: " + str(loss_recall[max_f1]) + "\n")
    f.write("F1-Dice: " + str(loss_f1[max_f1]) + "\n")
    f.write("IOU: " + str(loss_IOU[max_f1]) + "\n")
    f.write("Average: " + str(loss_avg[max_f1]) + "\n\n")
    
    # Text file log - maximum IOU
    f.write("Threshold maximising the IOU score: " + str(thresholds[max_IOU]) + "\n")
    f.write("Precision: " + str(loss_precision[max_IOU]) + "\n")
    f.write("Recall: " + str(loss_recall[max_IOU]) + "\n")
    f.write("F1-Dice: " + str(loss_f1[max_IOU]) + "\n")
    f.write("IOU: " + str(loss_IOU[max_IOU]) + "\n")
    f.write("Average: " + str(loss_avg[max_IOU]) + "\n\n")
    
    # Text file log - maximum average
    f.write("Threshold maximising the average performances: " + str(thresholds[max_avg]) + "\n")
    f.write("Precision: " + str(loss_precision[max_avg]) + "\n")
    f.write("Recall: " + str(loss_recall[max_avg]) + "\n")
    f.write("F1-Dice: " + str(loss_f1[max_avg]) + "\n")
    f.write("IOU: " + str(loss_IOU[max_avg]) + "\n")
    f.write("Average: " + str(loss_avg[max_avg]) + "\n\n")
    
    f.write("--------------------------------------------------\n")
    f.close()
    
    write_latex_IOU(dir_name,
                    loss_name, 
                    loss_precision[max_IOU], 
                    loss_recall[max_IOU], 
                    loss_f1[max_IOU], 
                    loss_IOU[max_IOU], 
                    thresholds[max_IOU])
    
    # Plot - all metric curves
    plt.plot(thresholds, loss_avg , label="average", color="tab:blue")
    plt.plot(thresholds, loss_precision , label="precision", color="tab:orange")
    plt.plot(thresholds, loss_recall, label="recall", color="tab:green")
    plt.plot(thresholds, loss_f1, label="F1", color="tab:red")
    plt.plot(thresholds, loss_IOU, label="IOU", color="tab:purple")
    plt.ylabel("Metrics")
    plt.xlabel("Threshold")
    plt.title("Term: " + TERM + ", Loss: " + loss_name)
    plt.legend()
    plt.savefig(os.path.join(dir_name, TERM + "_" + loss_name + "_Metric_curves.jpg"))
    plt.show()
    
    # Plot - average curve
    plt.plot(thresholds, loss_avg , label="average", color="tab:blue")
    plt.vlines(thresholds[max_avg], 0, loss_avg[max_avg], colors="black",linestyles="dashed")
    plt.hlines(loss_avg[max_avg], 0, thresholds[max_avg], colors="black",linestyles="dashed")
    plt.ylabel("Average")
    plt.xlabel("Threshold")
    plt.title("Term: " + TERM + ", Loss: " + loss_name)
    plt.savefig(os.path.join(dir_name, TERM + "_" + loss_name + "_Average_curve.jpg"))
    plt.show()
    
    # Plot - precision curve
    plt.plot(thresholds, loss_precision , label="precision", color="tab:orange")
    plt.vlines(thresholds[max_prec], 0, loss_precision[max_prec], colors="black",linestyles="dashed")
    plt.hlines(loss_precision[max_prec], 0, thresholds[max_prec], colors="black",linestyles="dashed")
    plt.ylabel("Precision")
    plt.xlabel("Threshold")
    plt.title("Term: " + TERM + ", Loss: " + loss_name)
    plt.savefig(os.path.join(dir_name, TERM + "_" + loss_name + "_Precision_curve.jpg"))
    plt.show()
    
    # Plot - recall curve
    plt.plot(thresholds, loss_recall , label="recall", color="tab:green")
    plt.vlines(thresholds[max_rec], 0, loss_recall[max_rec], colors="black",linestyles="dashed")
    plt.hlines(loss_recall[max_rec], 0, thresholds[max_rec], colors="black",linestyles="dashed")
    plt.ylabel("Recall")
    plt.xlabel("Threshold")
    plt.title("Term: " + TERM + ", Loss: " + loss_name)
    plt.savefig(os.path.join(dir_name, TERM + "_" + loss_name + "_Recall_curve.jpg"))
    plt.show()
    
    # Plot - F1 score curve
    plt.plot(thresholds, loss_f1 , label="f1", color="tab:red")
    plt.vlines(thresholds[max_f1], 0, loss_f1[max_f1], colors="black",linestyles="dashed")
    plt.hlines(loss_f1[max_f1], 0, thresholds[max_f1], colors="black",linestyles="dashed")
    plt.ylabel("F1 score")
    plt.xlabel("Threshold")
    plt.title("Term: " + TERM + ", Loss: " + loss_name)
    plt.savefig(os.path.join(dir_name, TERM + "_" + loss_name + "_F1_curve.jpg"))
    plt.show()
    
    # Plot - IOU curve
    plt.plot(thresholds, loss_IOU , label="IOU", color="tab:purple")
    plt.vlines(thresholds[max_IOU], 0, loss_IOU[max_IOU], colors="black",linestyles="dashed")
    plt.hlines(loss_IOU[max_IOU], 0, thresholds[max_IOU], colors="black",linestyles="dashed")
    plt.ylabel("IOU")
    plt.xlabel("Threshold")
    plt.title("Term: " + TERM + ", Loss: " + loss_name)
    plt.savefig(os.path.join(dir_name, TERM + "_" + loss_name + "_IOU_curve.jpg"))
    plt.show()
    
    print()
    print("ALL FOLDS TRAINING ENDED")
    print("Mean best validation: {}".format(np.mean(fold_validation)))
    print("Mean precision: {}".format(np.mean(fold_precision)))
    print("Mean recall: {}".format(np.mean(fold_recall)))
    print("Mean F1: {}".format(np.mean(fold_f1)))
    print("Mean IOU: {}".format(np.mean(fold_IOU)))
"""term loop end"""
print()