# Notebook to evaluate model performance
Looking at the performance of a model on the 30 test images. 

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from torchvision import transforms
import numpy as np
from skimage.io import imread
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import glob
import sys 
sys.path.append('../src/')
from data_process import *
from u2net import *
from unet import *
from resnest import *

## Helper functions

In [38]:
# Make the evaluation functions
def CON_matrix(ref_mask, pred_mask):
    """Computes the confusion matrix between a reference mask and a predicted mask.

    Args:
        ref_mask (ndarray): the reference mask
        pred_mask (ndarray): the predicted mask

    Returns:
        ndarray: the confusion matrix
    """
    true = ref_mask.reshape(-1)
    pred = pred_mask.reshape(-1)
    hist = confusion_matrix(true, pred)
    
    return hist

def Pixel_Accuracy(hist):
    """Calculates the pr pixel accuracy."""
    # acc = (TP + TN) / (TP + TN + FP + TN)
    acc = np.diag(hist).sum() / hist.sum()
    return acc

def Mean_Pixel_Accuracy(hist):
    """Calculates the mean pixel accuracy."""
    # acc = (TP) / TP + FP
    acc = np.diag(hist) /  hist.sum(1)
    mean_Acc = np.nanmean(acc)
    return mean_Acc

def MIOU(hist):
    """Calculates the mean IOU."""
    sum = np.mean(np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)))
    return sum

def evaluate(model, device, test_array):
    """Function to evaluate a model on the test set. Return the mean IOU, mean pixel accuracy and pixel accuracy.

    Args:
        model (nn.Module): a trained model (U-Net, U^2-Net, ResNeSt)
        device (torch.device): device to run the model on
        test_array (np.array): numpy array holding the test dataset
    """
    model.eval()
    sum_miou = 0
    sum_mean_acc = 0
    sum_acc = 0
    for array in test_array:
        mask = array[:,:,3]
        imgA = array[:,:,0:3]

        transform = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                            std=[0.229, 0.224, 0.225])])
        imgA = transform(imgA)
        imgA = imgA.to(device)      
        imgA = imgA.unsqueeze(0)
        
        if model.name == "U2NET":
            prediction, _,_,_,_,_,_ = model(imgA)
        else:
            prediction = model(imgA)

        pred_np = prediction.cpu().detach().numpy().copy()  # pred_np.shape = (1, 10, 256, 256)
        pred_np = (np.argmax(pred_np, axis=1) * 10).astype(np.uint8)
        hist = CON_matrix(mask, pred_np[0])
        sum_miou = MIOU(hist) + sum_miou
        sum_mean_acc = Mean_Pixel_Accuracy(hist) + sum_mean_acc
        sum_acc = Pixel_Accuracy(hist) + sum_acc
    
    mean_miou = sum_miou / test_array.shape[0]
    mean_acc = sum_mean_acc / test_array.shape[0]
    acc = sum_acc / test_array.shape[0]
    
    return mean_miou, mean_acc, acc

# Evaluate the model

In [52]:
# Load the test dataset
# test_array = np.load("../data/test3_arr.npy")   # path to original ground truth
test_array = np.load("../data/good_test_arr.npy") # path were bad ground truth is removed

# Load the model to be evaluated
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load("../trained_models/resnest101/ade1-4_d05_b6_e100/ResNeSt101_50_loss_trian_0.0183_val_0.01608.pt", map_location=device)
# Evaluate
miou, mean_acc, acc = evaluate(model, device, test_array)
print(f"Mean IOU: {miou}")
print(f"Mean Pixel Accuracy: {mean_acc}")
print(f"Pixel Accuracy: {acc}")

  acc = np.diag(hist) /  hist.sum(1)
  acc = np.diag(hist) /  hist.sum(1)
  acc = np.diag(hist) /  hist.sum(1)
  acc = np.diag(hist) /  hist.sum(1)
  acc = np.diag(hist) /  hist.sum(1)
  acc = np.diag(hist) /  hist.sum(1)
  acc = np.diag(hist) /  hist.sum(1)
  acc = np.diag(hist) /  hist.sum(1)
  acc = np.diag(hist) /  hist.sum(1)


Mean IOU: 0.4855293261939275
Mean Pixel Accuracy: 0.6210484732327283
Pixel Accuracy: 0.8663412729899088


## Best so far
**Model**: resnest50/ad1e-4_d08_b3_e250/ResNeSt50_130_loss_trian_0.02333_val_0.01631.pt

Goodtest:
```
Mean IOU: 0.5111849011853413
Mean Pixel Accuracy: 0.6279514596907317
Pixel Accuracy: 0.8778921763102213
```