# Notebook to evaluate model performance
Looking at the performance of a model on the test images. 

**IMPORTANT!!** 
Remember to first download the models running the `download_extract_models.py` file. 
**IMPORTANT!!** 

In [22]:
from sklearn.metrics import confusion_matrix
from torchvision import transforms
import matplotlib.pyplot as plt
from skimage.io import imread
import numpy as np
import torch
import sys 
sys.path.append('../src/') 
from data_process import *
from u2net import *
from unet import *
from resnest import *

## Helper functions
Functions displaying the results or performing different kinds of calculations in order to evaluate the models.

In [23]:
def display_results(model, device, img=1):
    """Function displaying the predictions of a model on a given image from the test set.

    Args:
        model (nn.Module): the model being evaluated
        img (int, optional): Can take any number in range 1-10. Defaults to 1.
    """
    path_to_arr = "../data/arrays/image_{img}.npy"
    path_to_img = "../data/images/image_{img}.png"
    mask_arr = np.load(path_to_arr)[:,:,3]
    
    ref = imread(path_to_img)
    img = resize_2_256(ref)

    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                                         std=[0.229, 0.224, 0.225])])
    img = transform(img)
    img = img.to(device)
    img = img.unsqueeze(0)
    
    if model.name == "U2NET":
        pred, _,_,_,_,_,_ = model(img)
    else:
        pred = model(img) # Rest

    pred_np = pred.cpu().detach().numpy().copy()  # output_np1.shape = (1, 10, 256, 256)
    pred_np = (np.argmax(pred_np, axis=1) * 10).astype(np.uint8)
    
    plt.figure(figsize=(15, 10))
    plt.subplot(1, 3, 1)
    plt.title("Picture")
    plt.imshow(ref)
    plt.subplot(1, 3, 2)
    plt.title("Ground truth")
    plt.imshow(label_2_colour(mask_arr))
    plt.subplot(1, 3, 3)
    plt.title("Predicted - test")
    plt.imshow(label_2_colour(pred_np[0]))

In [14]:
# Make the evaluation functions
def con_matrix(ref_mask, pred_mask):
    """Computes the confusion matrix between a reference mask and a predicted mask.

    Args:
        ref_mask (ndarray): the reference mask
        pred_mask (ndarray): the predicted mask

    Returns:
        ndarray: the confusion matrix
    """
    true = ref_mask.reshape(-1)
    pred = pred_mask.reshape(-1)
    hist = confusion_matrix(true, pred)
    
    return hist

def pixel_accuracy(hist):
    """Calculates the pr pixel accuracy."""
    acc = np.diag(hist).sum() / hist.sum()
    return acc

def mean_pixel_accuracy(hist):
    """Calculates the mean pixel accuracy."""
    acc = np.diag(hist) /  hist.sum(1)
    mean_Acc = np.nanmean(acc)
    return mean_Acc

def miou(hist):
    """Calculates the mean IOU."""
    sum = np.mean(np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)))
    return sum

def evaluate(model, device, test_array):
    """Function to evaluate a model on the test set. Return the mean IOU, mean pixel accuracy and pixel accuracy.

    Args:
        model (nn.Module): a trained model (U-Net, U^2-Net, ResNeSt)
        device (torch.device): device to run the model on
        test_array (np.array): numpy array holding the test dataset
    """
    model.eval()
    sum_miou = 0
    sum_mean_acc = 0
    sum_acc = 0
    for array in test_array:
        mask = array[:,:,3]
        imgA = array[:,:,0:3]

        transform = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                            std=[0.229, 0.224, 0.225])])
        imgA = transform(imgA)
        imgA = imgA.to(device)      
        imgA = imgA.unsqueeze(0)
        
        if model.name == "U2NET":
            prediction, _,_,_,_,_,_ = model(imgA)
        else:
            prediction = model(imgA)

        pred_np = prediction.cpu().detach().numpy().copy()  
        pred_np = (np.argmax(pred_np, axis=1) * 10).astype(np.uint8)
        
        hist = con_matrix(mask, pred_np[0])
        sum_miou = miou(hist) + sum_miou
        sum_mean_acc = mean_pixel_accuracy(hist) + sum_mean_acc
        sum_acc = pixel_accuracy(hist) + sum_acc
    
    mean_miou = sum_miou / test_array.shape[0]
    mean_acc = sum_mean_acc / test_array.shape[0]
    acc = sum_acc / test_array.shape[0]
    
    return mean_miou, mean_acc, acc

# Evaluate the model

In [20]:
# Load the test dataset
test_array = np.load("../data/test.npy")

# Specify device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Specify which trained model to load by commenting out everyone except the one you want to evaluate
modelname = "unet" 
# modelname = "resnest50" 
# modelname = "resnest101" 
# modelname = "resnest200" 
# modelname = "u2net" 

# Load the model
model = torch.load(f"../trained_models/{modelname}.pt", map_location=device)

# Evaluate
iou, mean_acc, acc = evaluate(model, device, test_array)
print(f"Mean Pixel Accuracy: {round(mean_acc, 3)}")
print(f"Mean IOU: {round(iou, 3)}")
print(f"Pixel Accuracy: {round(acc, 3)}")

# Display the performance of the model on one of the test images
display_results(model, device, img=1)

  acc = np.diag(hist) /  hist.sum(1)
  acc = np.diag(hist) /  hist.sum(1)
  acc = np.diag(hist) /  hist.sum(1)
  acc = np.diag(hist) /  hist.sum(1)
  acc = np.diag(hist) /  hist.sum(1)


Mean Pixel Accuracy: 0.506
Mean IOU: 0.391
Pixel Accuracy: 0.874
