# Goal
Once we have a trained a model, we want to see how well it performs in the test set.
We would also like to compare performance between the models, so this is what we are going to do in this notebook.

# 1. Imports

In [None]:
# Data tools
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np
import json
import sys
sys.path.append(".")
import gc

# Data visualization
import cv2
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw

# Data loading and manipulation
from torch.utils.data import DataLoader
from packages.dataset import Dataset
import albumentations as albu
from packages.helpers import *

# Machine Learning model and training
import torch
import segmentation_models_pytorch as smp

# Clean cache
gc.collect()
torch.cuda.empty_cache()

# 2. Model and data
We are going to fetch a previsouly trained model in the "models" folder and use to infer predictions on the test dataset.

### 2.1. Global variables

In [None]:
DATA_DIR = 'data'
CLASSES = ['solar panel']
ENCODER = 'se_resnext50_32x4d'
ENCODER_WEIGHTS = 'imagenet'
DEVICE = 'cuda'

PANELS, BACKGROUND = 0, 1
MASK_VALUE = PANELS     # Choose on which class the model was trained on
EPOCHS = 200                # Indicate for how many epochs the model was run

MODEL = 'unet'

### 2.1. Load model

In [None]:
best_model = torch.load(f'./models/best_model_{MODEL}_{MASK_VALUE}_{EPOCHS}.pth')

### 2.2. Setup test dataset

In [None]:
# Define path for testing data
x_test_dir = os.path.join(DATA_DIR, 'test')
y_test_dir = os.path.join(DATA_DIR, 'testannot')

# Parameters
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

# Build dataset
test_dataset = Dataset(
    x_test_dir, 
    y_test_dir, 
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
    mask_value=MASK_VALUE
)

# Load dataset
test_dataloader = DataLoader(test_dataset)

# 3. Evaluate model on test set

### 3.1. Loss and optimization functions

In [None]:
loss = smp.utils.losses.DiceLoss()
loss = smp.losses.SoftBCEWithLogitsLoss()
loss.__name__ = "SoftBCEWithLogitsLoss"

# loss_weights = [1, 100] # 1 for background and 10 for solar panel
# loss = smp.losses.CategoricalCELoss(class_weights=loss_weights)
# loss.__name__ = "CategoricalCELoss"

metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
    smp.utils.metrics.Fscore(),
    smp.utils.metrics.Accuracy(),
    smp.utils.metrics.Recall(),
    smp.utils.metrics.Precision(),
]

### 3.2. Run testing

In [None]:
test_epoch = smp.utils.train.ValidEpoch(
    model=best_model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
)

logs = test_epoch.run(test_dataloader)

# 4. Extract metrics

In [None]:
logs

In [None]:
def get_masks(test_dataset, threshold):

    image, gt_mask = test_dataset
    
    gt_mask = gt_mask.squeeze()
    
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    pr_mask = best_model.predict(x_tensor)
    pr_mask = (pr_mask.squeeze().cpu().numpy())
    pr_mask[pr_mask <= threshold] = 0
    pr_mask[pr_mask > threshold] = 1

    return gt_mask, pr_mask
    
    

# 5. Visualize predictions

In [None]:
# Get test dataset without transformations for image visualization
test_dataset_vis = Dataset(
    x_test_dir, y_test_dir, 
    classes=CLASSES,
    mask_value=MASK_VALUE
)

In [None]:
def custom_metrics(truth, pre):

    gt = torch.from_numpy(truth)
    pred = torch.from_numpy(pre)


    tp = torch.sum((pred == (not MASK_VALUE)) * (gt == (not MASK_VALUE)))          # true positives: all pixels where both prediction and ground truth is one (solar panel)
    fp = torch.sum((pred == (not MASK_VALUE)) * (gt == (MASK_VALUE)))          # false positives: prediction = one, ground truth = 0 (background)
    fn = torch.sum((pred == (MASK_VALUE)) * (gt == (not MASK_VALUE)))          # false negatives: inverse



    precision = tp/(tp + fp)
    recall = tp/(tp + fn) 
    
    return precision, recall


In [None]:
# 0 is a solar panel prediction and 1 is background prediction

if MASK_VALUE == BACKGROUND:
    threshold = 0.8
else:
    threshold = 0.1

for i in range(5):
    n = np.random.choice(len(test_dataset))
    
    image_vis = test_dataset_vis[n][0].astype('uint8')

    
    gt_mask, pr_mask = get_masks(test_dataset[n],threshold=threshold)
    precision, recall = custom_metrics(gt_mask, pr_mask)
        
    visualize(
        image=image_vis, 
        ground_truth_mask=gt_mask, 
        predicted_mask=pr_mask
    )
    print(f'precision: {precision}, recall: {recall}')



## 6. Metrics


In [None]:
precision_array = []
recall_array = []

for i in range(len(test_dataset)):

    gt_mask, pr_mask = get_masks(test_dataset[i],threshold=threshold)

    precision, recall = custom_metrics(gt_mask, pr_mask)

    if not precision.isnan():
        precision_array.append(float(precision))

    if not recall.isnan():
        recall_array.append(float(recall))


In [None]:
# write to json
with open(f'./metrics/test_metrics_{MODEL}_{MASK_VALUE}_{EPOCHS}.json', 'w') as f:
    json.dump({'precision': precision_array, 'recall': recall_array}, f)

In [None]:
print(f'precision: {np.array(precision_array).mean()}')
print(f'recall: {np.array(recall_array).mean()}')
