# Train - Validate - Evaluate
This notebook focuses on the multi-class model.

In [None]:
import torch, torchvision
import os
import random
import datasets
import metrics
import time

import constants as cst
import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F

import loss_function_multi_class as loss_fn
from unet import UNET
import utils

In [None]:
def predict_img(model, image, device, transform, out_threshold=0.5):
    with torch.no_grad():
        x = image
        logits = model(x.to(device))
        logits = transform(logits)
        y_pred = nn.Softmax(dim=1)(logits)
        y_size = y_pred.shape 
        proba = y_pred.detach().cpu().squeeze(0)
        
        all_out = []
        for i in range(y_size[1]): 
            m = np.zeros((y_size[2], y_size[3]))
            m = torch.where(proba[i,:,:]>out_threshold, 1,0)
            all_out.append(m)
        
        all_out = tuple(all_out)
        all_out = torch.stack(all_out, 0)
        return all_out

# Validation step
def validate(model, validation_loader, transform, DEVICE, loss_name):
    model.eval()
    with torch.no_grad():
        val_loss = []
        for images, masks, names in validation_loader:
            images = transform(images)
            outputs = model(images.to(DEVICE))

            masks = masks.type(torch.FloatTensor)
            masks = transform(masks)
            masks = torch.squeeze(masks, 1)
            masks = masks.to(DEVICE)
            
            total_loss = 0
            denum = 0
            outputs = F.softmax(outputs, dim=1)
            sh = images.shape
            for batch_n in range(sh[0]):
                for i in range(14):
                    if masks[batch_n, i,:,:].sum()>0:
                        curr_loss = criterion(outputs[batch_n, i,:,:], masks[batch_n, i,:,:])
                        total_loss += curr_loss
                        denum += 1
                
            vloss = total_loss/denum
            loss = vloss.detach().item()
            val_loss.append(loss)

        loss = np.mean(val_loss)
    return loss

# Training step
def train(model, training_loader, transforms, DEVICE, criterion, optimiser, loss_name):
    model.train()
    train_loss = []
    for images, masks, names in training_loader:
        images = transform(images)
        outputs = model(images.to(DEVICE))

        masks = masks.type(torch.FloatTensor)
        masks = transform(masks)
        masks = torch.squeeze(masks, 1)
        masks = masks.to(DEVICE)
    
        outputs = F.softmax(outputs, dim=1)
            
        total_loss = 0
        denum = 0
        sh = images.shape
        for batch_n in range(sh[0]):
            for i in range(14):
                #to do 
                if masks[batch_n, i,:,:].sum()>0:
                    curr_loss = criterion(outputs[batch_n,i,:,:], masks[batch_n,i,:,:])
                    total_loss += curr_loss
                    denum += 1
            
        tloss = total_loss/denum
        loss = tloss.detach().item()
        train_loss.append(loss)

        optimiser.zero_grad()
        tloss.backward()
        optimiser.step()

    loss = np.mean(train_loss)
    return loss

# Evaluating performances of the model
def evaluate(eval_model, testing_loader, threshold):
    tps = 0
    precisions = []
    recalls = []
    F1s = []
    IOUs = []

    eval_model.eval()
    for image, mask, name in testing_loader:
        prediction = predict_img(eval_model, transform(image), DEVICE, untransform, threshold)
        pred_masks = prediction
        
        p = 0 # precision
        r = 0 # recall
        f = 0 # f1
        iou = 0 # iou
        total = 0
        for i in range(14):
            if mask[:,i,:,:].sum()>0:
                p = metrics.precision(pred_masks[i,:,:], mask[:,i,:,:])
                r = metrics.recall(pred_masks[i,:,:], mask[:,i,:,:])
                f = metrics.F1Score(pred_masks[i,:,:], mask[:,i,:,:])
                iou = metrics.IOUScore(pred_masks[i,:,:], mask[:,i,:,:])
                total += 1
            

        precisions.append(p/total)
        recalls.append(r/total)
        F1s.append(f/total)
        IOUs.append(iou/total)
    return precisions, recalls, F1s, IOUs

# Create one multi-class segmentation mask using multiple masks
def stack_masks(masks, size=(512,512)):
    stacks = []
    tr = transforms.ToTensor()
    m_size = masks.shape
    for b in range(m_size[0]):
        m = np.zeros(size)
        m = tr(m)
        for c in range(m_size[1]):
            m = torch.where(m==0, masks[b,c,:,:]*(c+1),m)
        m = m.squeeze()
        stacks.append(m)
    
    stacks = tuple(stacks)
    stacks = torch.stack(stacks, 0)
    return stacks

In [None]:
random.seed(cst.SEED)
torch.manual_seed(cst.SEED)
np.random.seed(cst.SEED)

TERM = "all_different_masks"
SIZE = (384, 512)  # Related to this project
FOLDS = [0,1,2,3,4]  # Allows to run all or some folds (computational resources limits)
LOSSES = ["CE", "Dice", "Tversky", "Focal"]  # Can study all or some loss functions (computational resources limits)

run_name = TERM
dir_name = os.path.join(cst.DIR, run_name)
os.makedirs(dir_name, exist_ok = True)

# Text file log
if "train.txt" not in os.listdir(dir_name):
    f = open(os.path.join(dir_name,"train.txt"),"w+")
    f.write("--------------------------------------------------\n")
    f.write("Term studied: " + TERM + "\n\n")
    f.write("--------------------------------------------------\n")
    f.close()

DEVICE_NAME = "cpu"
if torch.cuda.is_available():
    DEVICE_NAME = 'cuda:0' 
DEVICE = torch.device(DEVICE_NAME)

In [None]:
# Multiple loss functions can be studied
for loss_name in LOSSES:
    print("Starting loss:", loss_name)
    # Used for studying the performances using different threshold values
    loss_precision = [0] * n_th
    loss_recall = [0] * n_th
    loss_f1 = [0] * n_th
    loss_IOU = [0] * n_th
    
    print("Starting term: " + TERM)
    start_term = time.time()

    image_folder = os.path.join(cst.DIR, "images")
    mask_folder = os.path.join(cst.DIR, TERM)

    # Transforms for the images
    transform = transforms.Compose([transforms.Resize(SIZE),
                                    transforms.Pad((0, 64, 0, 64))])
    untransform = transforms.Compose([transforms.CenterCrop(SIZE),
                                     transforms.Resize((1932, 2576))])

    fold_validation = []
    fold_precision = []
    fold_recall = []
    fold_f1 = []
    fold_IOU = []
    
    fold_precision_argmax = []
    fold_recall_argmax = []
    fold_f1_argmax = []
    fold_IOU_argmax = []

    for fold in FOLDS:
        print("Starting fold: {}".format(fold))
        start_fold = time.time()
        """Datasets and loaders"""
        training_set = datasets.ZebrafishDataset_multi(actual_fold=fold,
                                                      dataset="train",
                                                      folds=cst.FOLDS)
        validation_set = datasets.ZebrafishDataset_multi(actual_fold=fold,
                                                        dataset="validate",
                                                        folds=cst.FOLDS)
        testing_set = datasets.ZebrafishDataset_multi(actual_fold=fold,
                                                     dataset="test",
                                                     folds=cst.FOLDS)

        training_loader = torch.utils.data.DataLoader(training_set,
                                                      batch_size=cst.BATCH_SIZE,
                                                      shuffle=True,
                                                      num_workers=cst.WORKERS)

        validation_loader = torch.utils.data.DataLoader(validation_set,
                                                        batch_size=cst.BATCH_SIZE,
                                                        shuffle=True,
                                                        num_workers=cst.WORKERS)

        testing_loader = torch.utils.data.DataLoader(testing_set,
                                                     batch_size=1,
                                                     shuffle=True,
                                                     num_workers=cst.WORKERS)

        model = UNET(3, 14)
        model.to(DEVICE)
        best_model = UNET(3, 14)
        best_model = model
       
        # Text file log - writing loss and parameters
        f = open(os.path.join(dir_name,"train.txt"),"a+")
        if fold==0:
            f.write("Loss used: " + loss_name + "\n\n")
                
        criterion = nn.BCELoss()
        criterion_string = loss_name

        if loss_name == "Dice":
            criterion = loss_fn.DiceLoss()
        if loss_name == "Loss used: Tversky":
            A = 0.7
            B = 0.3
            criterion = loss_fn.TverskyLoss(alpha=A, beta=B)
            if fold==0:
                f.write("Alpha: " + str(A) + "\n")
                f.write("Beta: " + str(B) + "\n\n")
        if loss_name == "Focal":
            A = 0.8
            G = 2
            criterion = loss_fn.FocalLoss(alpha=A, gamma=G, reduction="mean")
            if fold==0:
                f.write("Alpha: " + str(A) + "\n")
                f.write("Gamma: " + str(G) + "\n\n")
            
        if fold==0:
            f.write("Learning rate: " + str(cst.LEARNING_RATE) + "\n")
            f.write("Weight decay: " + str(cst.WEIGHT_DECAY) + "\n")
            f.write("Max epochs: " + str(cst.EPOCHS) + "\n")
            f.write("Batch size: " + str(cst.BATCH_SIZE) + "\n")
            f.write("Workers: " + str(cst.WORKERS) + "\n\n")
            f.write("--------------------------------------------------\n")
        f.write("Current fold: " + str(fold) + "\n\n")
        f.close()

        optimiser = torch.optim.Adam(model.parameters(), lr=cst.LEARNING_RATE, weight_decay=cst.WEIGHT_DECAY)

        """Computing validation loss before training"""
        loss = validate(model, validation_loader, transform, DEVICE, loss_name)

        best_val = loss
        best_epoch = 0
        last_epoch = 0

        epochs_train_losses = []
        epochs_val_losses = []
        for epoch in range(cst.EPOCHS):
            """Training"""
            loss = train(model, training_loader, transforms, DEVICE, criterion, optimiser, loss_name)
            epochs_train_losses.append(loss)

            """Validation"""
            loss = validate(model, validation_loader, transform, DEVICE, loss_name)
            epochs_val_losses.append(loss)

            """Updating best model"""
            if loss < best_val:
                best_val = loss
                best_model = model
                best_epoch = epoch+1

            if (epoch+1)%50 == 0:
                print("Epoch: " + str(epoch+1))
                print("Validation: {}.".format(loss))
                print("Best validation: {}.".format(best_val))

            """Train and validate loops over"""
            curr = time.time()
            curr = curr - start_term
            secondes = curr % 60
            minutes = (curr-secondes)/60

            last_epoch = epoch

            # Notebooks shutdown after 6 hours. Stop the code and save the results.
            if minutes >= 350:
                f = open(os.path.join(dir_name,"train.txt"),"a+")
                f.write("Learning stopped due to timeout. \n")
                f.close()
                break
            if (epoch - best_epoch) >= 50:
                break

        """All epochs are over"""
        fold_validation.append(best_val)

        model_name = TERM + '_' + loss_name + "_Fold_" + str(fold) + "_Epoch_" + str(best_epoch) + "_MaxEpochs_" 
        model_name += str(cst.EPOCHS) + '_' + cst.OPTIMIZER + "_LR_" + str(cst.LEARNING_RATE) + ".pth"

        model_filepath = os.path.join(dir_name, model_name)
        torch.save(best_model.state_dict(), model_filepath) # save better?
        
        curr = time.time()
        curr = curr - start_fold
        secondes = curr % 60
        minutes = (curr-secondes)/60

        # Plot losses
        index = [i+1 for i in range(last_epoch+1)]
        plt.plot(index[1:], epochs_train_losses[1:], label="Training")
        plt.plot(index[1:], epochs_val_losses[1:], label="Validation")
        plt.title("Term: " + TERM + ", Loss: " + loss_name + ", Fold: " + str(fold)) 
        plt.ylabel("Loss")
        plt.xlabel("Epochs")
        plt.legend()
        plt.savefig(os.path.join(dir_name, TERM + "_" + loss_name +"_Fold_" + str(fold) + "_Loss_curves.jpg"))
        plt.show()
        
        # Text file log
        f = open(os.path.join(dir_name,"train.txt"),"a+")
        f.write("Best epoch: " + str(best_epoch) + "\n")
        f.write("Best validation loss: " + str(best_val) + "\n")
        f.write("Ellapsed time: " + str(minutes) + " minutes " + str(secondes) + " seconds\n\n")
        f.write("Name of the model saved:\n")
        f.write(model_name + "\n\n")
        f.close()

        print("Fold took: " + str(minutes) + " minutes " + str(secondes) + " seconds to train")
        
        f = open(os.path.join(dir_name,"train.txt"),"a+")
        f.write("Total fold time: " + str(minutes) + " minutes " + str(secondes) + " seconds\n\n")
        f.write("--------------------------------------------------\n")
        f.close()
        
    """Fold loop end"""
    
"""term loop end"""
print()