# Threshold study
Evaluates the performances of the binary or cropped for different threshold values

In [None]:
import os
import time
import torch

import utils
import datasets
import metrics

import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

import constants as cst

In [None]:
plot_name = "Performances of the model for different thresholds"

# All fold per model run
model0 = "vc_Dice_Fold_0_Epoch_246_MaxEpochs_250_Adam_LR_0.0001.pth"
model1 = "vc_Dice_Fold_1_Epoch_248_MaxEpochs_250_Adam_LR_0.0001.pth"
model2 = "vc_Dice_Fold_2_Epoch_249_MaxEpochs_250_Adam_LR_0.0001.pth"
model3 = "vc_Dice_Fold_3_Epoch_247_MaxEpochs_250_Adam_LR_0.0001.pth"
model4 = "vc_Dice_Fold_4_Epoch_249_MaxEpochs_250_Adam_LR_0.0001.pth"

models = [model0, model1, model2, model3, model4]

DEVICE_NAME = "cpu"
if torch.cuda.is_available():
    DEVICE_NAME = 'cuda:0'
DEVICE = torch.device(DEVICE_NAME)

term = "vc"
# Cropped model or binary
crop = False

n_th = 51
thresholds = np.linspace(0, 1, num=n_th)

fold = 0
image_folder = "/notebooks/images"
mask_folder = "/notebooks/" + term

if crop:
    model_name = "fish_CE_Fold_4_Epoch_95_MaxEpochs_600_Adam_LR_0.0001.pth"
    model_path = os.path.join(cst.MODEL, model_name)

    model = utils.load_model(model_path)
    model.to(DEVICE)
    
    testing_set = datasets.ZebrafishDataset_KFold_crop_head(image_folder,
                                                           mask_folder,
                                                           actual_fold=fold,
                                                           model = model,
                                                           device = DEVICE,
                                                           dataset="test",
                                                           folds = cst.FOLDS)    
else:
    testing_set = datasets.ZebrafishDataset_KFold(image_folder,
                                                     mask_folder,
                                                     actual_fold=fold,
                                                     dataset="test",
                                                     folds=cst.FOLDS)
    
testing_loader = torch.utils.data.DataLoader(testing_set,
                                             batch_size=1,
                                             shuffle=True,
                                             num_workers=cst.WORKERS)

In [None]:
def predict_img(model, image, device, transform, out_threshold=0.5):
    with torch.no_grad():
        x = image
        logits = model(x.to(device))
        logits = transform(logits)
        y_pred = nn.Softmax(dim=1)(logits)
        proba = y_pred.detach().cpu().squeeze(0).numpy()[1, :, :]
        return proba > out_threshold
    
# Evaluate the performances for the binary model
def evaluate(eval_model, testing_loader, threshold):
    tps = 0
    precisions = []
    recalls = []
    F1s = []
    IOUs = []
    
    SIZE = (384, 512)
    
    transform = transforms.Compose([transforms.Resize(SIZE),
                                    transforms.Pad((0, 64, 0, 64))])
    untransform = transforms.Compose([transforms.CenterCrop(SIZE),
                                     transforms.Resize((1932, 2576))])

    eval_model.eval()
    for image, mask, name in testing_loader:
        
        prediction = utils.predict_img(eval_model, transform(image), DEVICE, untransform, out_threshold=threshold)
        pred = torch.from_numpy(prediction)

        precisions.append(metrics.precision(pred, mask))
        recalls.append(metrics.recall(pred, mask))
        F1s.append(metrics.F1Score(pred, mask))
        IOUs.append(metrics.IOUScore(pred, mask))
    return precisions, recalls, F1s, IOUs

# Evaluate the performances for the cropped model
def evaluate_crop(eval_model, testing_loader, threshold):
    tps = 0
    precisions = []
    recalls = []
    F1s = []
    IOUs = []
    
    eval_model.eval()
    for image, mask, name, size in testing_loader:
        h_length = int(size[0])
        v_length = int(size[1])
        if h_length>v_length:
            untr = transforms.Compose([transforms.Resize((h_length, h_length)),
                                       transforms.CenterCrop((v_length, h_length))])
        elif h_length<v_length:
            untr = transforms.Compose([transforms.Resize((v_length, v_length)),
                                       transforms.CenterCrop((v_length, h_length))])
        else:
            untr = transforms.Compose([transforms.Resize((h_length, h_length))])

        image_name = name[0]
        
        prediction = utils.predict_img(eval_model, image, DEVICE, untr, out_threshold=threshold)
        pred = torch.from_numpy(prediction)
        mask = untr(mask)

        precisions.append(metrics.precision(pred, mask))
        recalls.append(metrics.recall(pred, mask))
        F1s.append(metrics.F1Score(pred, mask))
        IOUs.append(metrics.IOUScore(pred, mask))
    return precisions, recalls, F1s, IOUs

In [None]:
start = time.time()
precisions = []
recalls = []
F1s = []
IOUs = []

for threshold in thresholds:
    m_precision = []
    m_recall = []
    m_f1 = []
    m_iou = []
    
    for model_name in models:
        if crop:
            m_dir = term+"_cropped"
        else:
            m_dir = term+"_normal"
        model_path = os.path.join(cst.DIR, m_dir, model_name)
        model = utils.load_model(model_path)
        model.to(DEVICE)
        
        if crop:
            p,r,f,i = evaluate_crop(model, testing_loader, threshold)           
        else:
            p,r,f,i = evaluate(model, testing_loader, threshold)
            
        m_precision.append(np.mean(p))
        m_recall.append(np.mean(r))
        m_f1.append(np.mean(f))
        m_iou.append(np.mean(i))
    
    precisions.append(np.mean(m_precision))
    recalls.append(np.mean(m_recall))
    F1s.append(np.mean(m_f1))
    IOUs.append(np.mean(m_iou))
    
means = []
for i in range(n_th):
    m = (precisions[i]+recalls[i]+F1s[i]+IOUs[i])/4
    means.append(m)
    
    
TERM = term
loss_name = "Dice"
run_name = model_dir
dir_name = os.path.join(cst.DIR, run_name)

    
# Plot - all metric curves
plt.plot(thresholds, means , label="average", color="tab:blue")
plt.plot(thresholds, precisions , label="precision", color="tab:orange")
plt.plot(thresholds, recalls, label="recall", color="tab:green")
plt.plot(thresholds, F1s, label="F1", color="tab:red")
plt.plot(thresholds, IOUs, label="IOU", color="tab:purple")
plt.ylabel("Metrics")
plt.xlabel("Threshold")
plt.title("Term: " + TERM + ", Loss: " + loss_name)
plt.legend()
plt.savefig(os.path.join(dir_name, TERM + "_" + loss_name + "_Metric_curves.jpg"))
plt.show()

# Plot - average curve
best = np.argmax(means)
plt.plot(thresholds, means , label="average", color="tab:blue")
plt.vlines(thresholds[best], 0, means[best], colors="black",linestyles="dashed")
plt.hlines(means[best], 0, thresholds[best], colors="black",linestyles="dashed")
plt.ylabel("Average")
plt.xlabel("Threshold")
plt.title("Term: " + TERM + ", Loss: " + loss_name)
plt.savefig(os.path.join(dir_name, TERM + "_" + loss_name + "_average_curves.jpg"))
plt.show()

# Plot - precision curve
best = np.argmax(precisions)
plt.plot(thresholds, precisions , label="precision", color="tab:orange")
plt.vlines(thresholds[best], 0, precisions[best], colors="black",linestyles="dashed")
plt.hlines(precisions[best], 0, thresholds[best], colors="black",linestyles="dashed")
plt.ylabel("Precision")
plt.xlabel("Threshold")
plt.title("Term: " + TERM + ", Loss: " + loss_name)
plt.savefig(os.path.join(dir_name, TERM + "_" + loss_name + "_precision_curves.jpg"))
plt.show()

# Plot - recall curve
best = np.argmax(recalls)
plt.plot(thresholds, recalls, label="recall", color="tab:green")
plt.vlines(thresholds[best], 0, recalls[best], colors="black",linestyles="dashed")
plt.hlines(recalls[best], 0, thresholds[best], colors="black",linestyles="dashed")
plt.ylabel("Recall")
plt.xlabel("Threshold")
plt.title("Term: " + TERM + ", Loss: " + loss_name)
plt.savefig(os.path.join(dir_name, TERM + "_" + loss_name + "_recall_curves.jpg"))
plt.show()

# Plot - F1 score curve
best = np.argmax(F1s)
plt.plot(thresholds, F1s, label="F1", color="tab:red")
plt.vlines(thresholds[best], 0, F1s[best], colors="black",linestyles="dashed")
plt.hlines(F1s[best], 0, thresholds[best], colors="black",linestyles="dashed")
plt.ylabel("F1 score")
plt.xlabel("Threshold")
plt.title("Term: " + TERM + ", Loss: " + loss_name)
plt.savefig(os.path.join(dir_name, TERM + "_" + loss_name + "_F1_curves.jpg"))
plt.show()

# Plot - IOU curve
best = np.argmax(IOUs)
plt.plot(thresholds, IOUs, label="IOU", color="tab:purple")
plt.vlines(thresholds[best], 0, IOUs[best], colors="black",linestyles="dashed")
plt.hlines(IOUs[best], 0, thresholds[best], colors="black",linestyles="dashed")
plt.ylabel("IOU")
plt.xlabel("Threshold")
plt.title("Term: " + TERM + ", Loss: " + loss_name)
plt.savefig(os.path.join(dir_name, TERM + "_" + loss_name + "_IOU_curves.jpg"))
plt.show()

best = np.argmax(means)
print("The threshold providing the best average performances is: " + str(thresholds[best]))
print("Average: " + str(means[best]))
print("Precision: " + str(precisions[best]))
print("Recall: " + str(recalls[best]))
print("F1: " + str(F1s[best]))
print("IOUs: " + str(IOUs[best]))
print()

best = np.argmax(precisions)
print("The threshold providing the best precision is: " + str(thresholds[best]))
print("Average: " + str(means[best]))
print("Precision: " + str(precisions[best]))
print("Recall: " + str(recalls[best]))
print("F1: " + str(F1s[best]))
print("IOUs: " + str(IOUs[best]))
print()

best = np.argmax(recalls)
print("The threshold providing the best recall is: " + str(thresholds[best]))
print("Average: " + str(means[best]))
print("Precision: " + str(precisions[best]))
print("Recall: " + str(recalls[best]))
print("F1: " + str(F1s[best]))
print("IOUs: " + str(IOUs[best]))
print()

best = np.argmax(F1s)
print("The threshold providing the best F1 is: " + str(thresholds[best]))
print("Average: " + str(means[best]))
print("Precision: " + str(precisions[best]))
print("Recall: " + str(recalls[best]))
print("F1: " + str(F1s[best]))
print("IOUs: " + str(IOUs[best]))
print()

best = np.argmax(IOUs)
print("The threshold providing the best IOU is: " + str(thresholds[best]))
print("Average: " + str(means[best]))
print("Precision: " + str(precisions[best]))
print("Recall: " + str(recalls[best]))
print("F1: " + str(F1s[best]))
print("IOUs: " + str(IOUs[best]))
print()

curr = time.time()
curr = curr - start
secondes = curr % 60
minutes = (curr-secondes)/60
print("Computing time: " + str(minutes) + " minutes " + str(secondes))