# Threshold study
Evaluates the performances of the multi-class model (per annotation) for different threshold values

In [None]:
import torch, torchvision
import os
import random
import datasets
import metrics
import time

import constants as cst
import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F

from unet import UNET
import utils

In [None]:
def predict_img(model, image, device, transform, out_threshold=0.5):
    with torch.no_grad():
        x = image
        logits = model(x.to(device))
        logits = transform(logits)
        y_pred = nn.Softmax(dim=1)(logits)
        y_size = y_pred.shape 
        proba = y_pred.detach().cpu().squeeze(0)
        
        all_out = []
        for i in range(y_size[1]):
            m = np.zeros((y_size[2], y_size[3]))
            m = torch.where(proba[i,:,:]>out_threshold, 1,0)
            all_out.append(m)
        
        all_out = tuple(all_out)
        all_out = torch.stack(all_out, 0)
    return all_out

# Evaluate the performances for on threshold value
def evaluate(eval_model, testing_loader, threshold):
    tps = 0
    precisions_br1 = []
    recalls_br1 = []
    F1s_br1 = []
    IOUs_br1 = []
    
    precisions_br2 = []
    recalls_br2 = []
    F1s_br2 = []
    IOUs_br2 = []
    
    precisions_cb = []
    recalls_cb = []
    F1s_cb = []
    IOUs_cb = []
    
    precisions_ch = []
    recalls_ch = []
    F1s_ch = []
    IOUs_ch = []
    
    precisions_cl = []
    recalls_cl = []
    F1s_cl = []
    IOUs_cl = []
    
    precisions_d = []
    recalls_d = []
    F1s_d = []
    IOUs_d = []
    
    precisions_en = []
    recalls_en = []
    F1s_en = []
    IOUs_en = []
    
    precisions_hm = []
    recalls_hm = []
    F1s_hm = []
    IOUs_hm = []
    
    precisions_m = []
    recalls_m = []
    F1s_m = []
    IOUs_m = []
    
    precisions_n = []
    recalls_n = []
    F1s_n = []
    IOUs_n = []
    
    precisions_oc = []
    recalls_oc = []
    F1s_oc = []
    IOUs_oc = []
    
    precisions_op = []
    recalls_op = []
    F1s_op = []
    IOUs_op = []
    
    precisions_p = []
    recalls_p = []
    F1s_p = []
    IOUs_p = []
    
    precisions_vc = []
    recalls_vc = []
    F1s_vc = []
    IOUs_vc = []

    eval_model.eval()
    for image, mask, name in testing_loader:
        prediction = predict_img(eval_model, transform(image), DEVICE, untransform, threshold)
        pred_masks = prediction

        # One computation per mask, the threshold is studied per annotation and not globally
        if mask[:,0,:,:].sum() > 0:
            precisions_br1.append(metrics.precision(pred_masks[0,:,:], mask[:,0,:,:]))
            recalls_br1.append(metrics.recall(pred_masks[0,:,:], mask[:,0,:,:]))
            F1s_br1.append(metrics.F1Score(pred_masks[0,:,:], mask[:,0,:,:]))
            IOUs_br1.append(metrics.IOUScore(pred_masks[0,:,:], mask[:,0,:,:]))

        if mask[:,1,:,:].sum() > 0:
            precisions_br2.append(metrics.precision(pred_masks[1,:,:], mask[:,1,:,:]))
            recalls_br2.append(metrics.recall(pred_masks[1,:,:], mask[:,1,:,:]))
            F1s_br2.append(metrics.F1Score(pred_masks[1,:,:], mask[:,1,:,:]))
            IOUs_br2.append(metrics.IOUScore(pred_masks[1,:,:], mask[:,1,:,:]))

        if mask[:,2,:,:].sum() > 0:
            precisions_cb.append(metrics.precision(pred_masks[2,:,:], mask[:,2,:,:]))
            recalls_cb.append(metrics.recall(pred_masks[2,:,:], mask[:,2,:,:]))
            F1s_cb.append(metrics.F1Score(pred_masks[2,:,:], mask[:,2,:,:]))
            IOUs_cb.append(metrics.IOUScore(pred_masks[2,:,:], mask[:,2,:,:]))

        if mask[:,3,:,:].sum() > 0:
            precisions_ch.append(metrics.precision(pred_masks[3,:,:], mask[:,3,:,:]))
            recalls_ch.append(metrics.recall(pred_masks[3,:,:], mask[:,3,:,:]))
            F1s_ch.append(metrics.F1Score(pred_masks[3,:,:], mask[:,3,:,:]))
            IOUs_ch.append(metrics.IOUScore(pred_masks[3,:,:], mask[:,3,:,:]))

        if mask[:,4,:,:].sum() > 0:
            precisions_cl.append(metrics.precision(pred_masks[4,:,:], mask[:,4,:,:]))
            recalls_cl.append(metrics.recall(pred_masks[4,:,:], mask[:,4,:,:]))
            F1s_cl.append(metrics.F1Score(pred_masks[4,:,:], mask[:,4,:,:]))
            IOUs_cl.append(metrics.IOUScore(pred_masks[4,:,:], mask[:,4,:,:]))

        if mask[:,5,:,:].sum() > 0:
            precisions_d.append(metrics.precision(pred_masks[5,:,:], mask[:,5,:,:]))
            recalls_d.append(metrics.recall(pred_masks[5,:,:], mask[:,5,:,:]))
            F1s_d.append(metrics.F1Score(pred_masks[5,:,:], mask[:,5,:,:]))
            IOUs_d.append(metrics.IOUScore(pred_masks[5,:,:], mask[:,5,:,:]))

        if mask[:,6,:,:].sum() > 0:
            precisions_en.append(metrics.precision(pred_masks[6,:,:], mask[:,6,:,:]))
            recalls_en.append(metrics.recall(pred_masks[6,:,:], mask[:,6,:,:]))
            F1s_en.append(metrics.F1Score(pred_masks[6,:,:], mask[:,6,:,:]))
            IOUs_en.append(metrics.IOUScore(pred_masks[6,:,:], mask[:,6,:,:]))

        if mask[:,7,:,:].sum() > 0:
            precisions_hm.append(metrics.precision(pred_masks[7,:,:], mask[:,7,:,:]))
            recalls_hm.append(metrics.recall(pred_masks[7,:,:], mask[:,7,:,:]))
            F1s_hm.append(metrics.F1Score(pred_masks[7,:,:], mask[:,7,:,:]))
            IOUs_hm.append(metrics.IOUScore(pred_masks[7,:,:], mask[:,7,:,:]))

        if mask[:,8,:,:].sum() > 0:
            precisions_m.append(metrics.precision(pred_masks[8,:,:], mask[:,8,:,:]))
            recalls_m.append(metrics.recall(pred_masks[8,:,:], mask[:,8,:,:]))
            F1s_m.append(metrics.F1Score(pred_masks[8,:,:], mask[:,8,:,:]))
            IOUs_m.append(metrics.IOUScore(pred_masks[8,:,:], mask[:,8,:,:]))

        if mask[:,9,:,:].sum() > 0:
            precisions_n.append(metrics.precision(pred_masks[9,:,:], mask[:,9,:,:]))
            recalls_n.append(metrics.recall(pred_masks[9,:,:], mask[:,9,:,:]))
            F1s_n.append(metrics.F1Score(pred_masks[9,:,:], mask[:,9,:,:]))
            IOUs_n.append(metrics.IOUScore(pred_masks[9,:,:], mask[:,9,:,:]))

        if mask[:,10,:,:].sum() > 0:
            precisions_oc.append(metrics.precision(pred_masks[10,:,:], mask[:,10,:,:]))
            recalls_oc.append(metrics.recall(pred_masks[10,:,:], mask[:,10,:,:]))
            F1s_oc.append(metrics.F1Score(pred_masks[10,:,:], mask[:,10,:,:]))
            IOUs_oc.append(metrics.IOUScore(pred_masks[10,:,:], mask[:,10,:,:]))

        if mask[:,11,:,:].sum() > 0:
            precisions_op.append(metrics.precision(pred_masks[11,:,:], mask[:,11,:,:]))
            recalls_op.append(metrics.recall(pred_masks[11,:,:], mask[:,11,:,:]))
            F1s_op.append(metrics.F1Score(pred_masks[11,:,:], mask[:,11,:,:]))
            IOUs_op.append(metrics.IOUScore(pred_masks[11,:,:], mask[:,11,:,:]))

        if mask[:,12,:,:].sum() > 0:
            precisions_p.append(metrics.precision(pred_masks[12,:,:], mask[:,12,:,:]))
            recalls_p.append(metrics.recall(pred_masks[12,:,:], mask[:,12,:,:]))
            F1s_p.append(metrics.F1Score(pred_masks[12,:,:], mask[:,12,:,:]))
            IOUs_p.append(metrics.IOUScore(pred_masks[12,:,:], mask[:,12,:,:]))

        if mask[:,13,:,:].sum() > 0:
            precisions_vc.append(metrics.precision(pred_masks[13,:,:], mask[:,13,:,:]))
            recalls_vc.append(metrics.recall(pred_masks[13,:,:], mask[:,13,:,:]))
            F1s_vc.append(metrics.F1Score(pred_masks[13,:,:], mask[:,13,:,:]))
            IOUs_vc.append(metrics.IOUScore(pred_masks[13,:,:], mask[:,13,:,:]))
            
        precisions = [precisions_br1, precisions_br2, precisions_cb, precisions_ch, precisions_cl,
                      precisions_d, precisions_en, precisions_hm, precisions_m, precisions_n,
                      precisions_oc, precisions_op, precisions_p, precisions_vc]
        recalls = [recalls_br1, recalls_br2, recalls_cb, recalls_ch, recalls_cl,
                   recalls_d, recalls_en, recalls_hm, recalls_m, recalls_n,
                   recalls_oc, recalls_op, recalls_p, recalls_vc]
        F1s = [F1s_br1, F1s_br2, F1s_cb, F1s_ch, F1s_cl, F1s_d, F1s_en, F1s_hm, 
               F1s_m, F1s_n, F1s_oc, F1s_op, F1s_p, F1s_vc]
        IOUs = [IOUs_br1, IOUs_br2, IOUs_cb, IOUs_ch, IOUs_cl, IOUs_d, IOUs_en,
                IOUs_hm, IOUs_m, IOUs_n, IOUs_oc, IOUs_op, IOUs_p, IOUs_vc]
    return precisions, recalls, F1s, IOUs,

In [None]:
# All fold models belonging to the same training run
model_names = ["alles_different_masks_CE_Fold_0_Epoch_59_MaxEpochs_250_Adam_LR_0.0001.pth",
          "alles_different_masks_CE_Fold_1_Epoch_66_MaxEpochs_250_Adam_LR_0.0001.pth",
          "alles_different_masks_CE_Fold_2_Epoch_56_MaxEpochs_250_Adam_LR_0.0001.pth",
          "alles_different_masks_CE_Fold_3_Epoch_55_MaxEpochs_250_Adam_LR_0.0001.pth",
          "alles_different_masks_CE_Fold_4_Epoch_96_MaxEpochs_250_Adam_LR_0.0001.pth"]

testing_set = datasets.ZebrafishDataset_multi(0,
                                              dataset="test",
                                              folds=cst.FOLDS)

testing_loader = torch.utils.data.DataLoader(testing_set,
                                             batch_size=1,
                                             shuffle=True,
                                             num_workers=cst.WORKERS)

SIZE = (384, 512)

# Transforms for the images
transform = transforms.Compose([transforms.Resize(SIZE),
                                transforms.Pad((0, 64, 0, 64))])
untransform = transforms.Compose([transforms.CenterCrop(SIZE),
                                 transforms.Resize((1932, 2576))])

n_th = 51
loss_precision = [[0 for _ in range(n_th)] for _ in range(14)] # 14 is the number of annotations, specific to this project
loss_recall = [[0 for _ in range(n_th)] for _ in range(14)]
loss_f1 = [[0 for _ in range(n_th)] for _ in range(14)]
loss_IOU = [[0 for _ in range(n_th)] for _ in range(14)]

thresholds = np.linspace(0, 1, num=n_th)

fold_validation = []
fold_precision = []
fold_recall = []
fold_f1 = []
fold_IOU = []

DEVICE_NAME = "cpu"
if torch.cuda.is_available():
    DEVICE_NAME = 'cuda:0'
DEVICE = torch.device(DEVICE_NAME)

TERM = "multi_all"
dir_name = os.path.join(cst.DIR,"all_different_masks")
loss_name = "CE"

for model_name in model_names:
    model_path = os.path.join(cst.DIR, "alles_different_masks_normal", model_name)
    model = utils.load_model_all(model_path)
    model.to(DEVICE)
    
    for th in range(n_th):
        precisions, recalls, F1s, IOUs = evaluate(model, testing_loader, (0.1 + th*2)/100)

        for i in range(14):
            mean_precision = np.mean(precisions[i])
            mean_recall = np.mean(recalls[i])
            mean_f1 = np.mean(F1s[i])
            mean_IOU = np.mean(IOUs[i])

            loss_precision[i][th] += mean_precision/5
            loss_recall[i][th] += mean_recall/5
            loss_f1[i][th] += mean_f1/5
            loss_IOU[i][th] += mean_IOU/5
            
        mean_precision = np.mean(precisions)
        mean_recall = np.mean(recalls)
        mean_f1 = np.mean(F1s)
        mean_IOU = np.mean(IOUs)

        if th == (n_th-1)/2:
            fold_precision.append(mean_precision)
            fold_recall.append(mean_recall)
            fold_f1.append(mean_f1)
            fold_IOU.append(mean_IOU)
            
all_f_prec = np.mean(fold_precision)
all_f_rec = np.mean(fold_recall)
all_f_f1 = np.mean(fold_f1)
all_f_IOU = np.mean(fold_IOU)

for i in range(14):
    term = cst.COMBINED_TERM[i]
    max_IOU = np.argmax(loss_IOU[i])
    print("Max IOU:", loss_IOU[i][max_IOU])
    print("Threshold:", thresholds[max_IOU])

    print("Threshold maximising the IOU score: " + str(thresholds[max_IOU]) )
    print("Precision: " + str(loss_precision[i][max_IOU]))
    print("Recall: " + str(loss_recall[i][max_IOU]) )
    print("F1-Dice: " + str(loss_f1[i][max_IOU]))
    print("IOU: " + str(loss_IOU[i][max_IOU]) )
    
    plt.plot(thresholds, loss_precision[i] , label="precision", color="tab:orange")
    plt.plot(thresholds, loss_recall[i], label="recall", color="tab:green")
    plt.plot(thresholds, loss_f1[i], label="F1", color="tab:red")
    plt.plot(thresholds, loss_IOU[i], label="IOU", color="tab:purple")
    plt.ylabel("Metrics")
    plt.xlabel("Threshold")
    plt.title("Term: " + term + ", Loss: " + loss_name)
    plt.legend()
    plt.savefig(os.path.join(dir_name, term + "_" + loss_name + "_Metric_curves.jpg"))
    plt.show()
