# Evaluate Models

## Imports

In [None]:
import os.path

import matplotlib.pyplot as plt

import numpy as np

import torch
from torch.utils.data import DataLoader

from torchvision import transforms

import yaml

from model import RandomlyConnectedModel

from evaluation.hamlyn import evaluate_ssim
from evaluation.scared import evaluate_keyframes
from evaluation.utils import prepare_state_dict
from evaluation import sparsification as s
from evaluation import transforms as t

from loaders.hamlyn import HamlynDataset
from loaders.scared import SCAREDKeyframesLoader

## Configuration

In [None]:
model_name = 'test.pt'
hamlyn_path = '../da-vinci'
scared_path = '../scared'
batch_size = 8

In [None]:
device = torch.device('cuda') \
    if torch.cuda.is_available() \
    else torch.device('cpu')

with open('config.yml') as f:
    config = yaml.load(f, Loader=yaml.Loader)

model = RandomlyConnectedModel(**config).to(device)

model_stem = os.path.splitext(model_name)[0]
model_save_to = os.path.join('results', model_stem)

hamlyn_transform = transforms.Compose([
    t.ResizeImage((256, 512)),
    t.ToTensor()
])

scared_transform = transforms.Compose([
    t.ResizeImage((1024, 1280)),
    t.ToTensor()
])

## Hamlyn SSIM/Sparsification

### Setup

In [None]:
hamlyn_model_path = os.path.join('models', 'hamlyn', model_name)
hamlyn_save_to = os.path(model_save_to, 'hamlyn')

state_dict = torch.load(hamlyn_model_path).to(device)
state_dict = prepare_state_dict(state_dict)

model.load_state_dict(state_dict)

hamlyn_dataset = HamlynDataset(hamlyn_path, 'test', hamlyn_transform)
hamlyn_dataloader = DataLoader(hamlyn_dataset, batch_size, shuffle=True)

### Run Hamlyn Evaluation

In [None]:
ssims, spars = evaluate_ssim(model, hamlyn_dataloader, hamlyn_save_to, device)

#### SSIM Metric

In [None]:
mean_ssim = sum(ssims) / len(ssims)
print(f'Mean SSIM on Hamlyn test set: {mean_ssim}')

#### Sparsification Plot and Metrics

In [None]:
pred_curves, oracle_curves, random_curves = zip(*spars)

pred_curve = np.array(pred_curves).mean(axis=0)
oracle_curve = np.array(oracle_curves).mean(axis=0)
random_curve = np.array(random_curves).mean(axis=0)

ause = s.ause(oracle_curve, pred_curve)
aurg = s.aurg(pred_curve, random_curve)

figure, (curve_axis, error_axis) = plt.subplots(2, 0)

## SCARED MAE

### Setup

In [None]:
scared_model_path = os.path.join('models', 'scared', model_name)
scared_save_to = os.path(model_save_to, 'scared')

state_dict = torch.load(scared_model_path).to(device)
state_dict = prepare_state_dict(state_dict)

model.load_state_dict(state_dict)

scared_dataset_8 = SCAREDKeyframesLoader(hamlyn_path, 'test', 8, hamlyn_transform)
scared_dataloader_8 = DataLoader(scared_dataset_8, batch_size, shuffle=True)

scared_dataset_9 = SCAREDKeyframesLoader(hamlyn_path, 'test', 9, hamlyn_transform)
scared_dataloader_9 = DataLoader(scared_dataset_9, batch_size, shuffle=True)

### Run SCARED Evaluation

In [None]:
maes = evaluate_keyframes(model, scared_dataloader_8, scared_save_to, device)
mean_mae = sum(maes) / len(maes)
print(f'Mean Absolute Depth on SCARED Dataset 8: {mean_mae} mm')

maes = evaluate_keyframes(model, scared_dataloader_9, scared_save_to, device)
mean_mae = sum(maes) / len(maes)
print(f'Mean Absolute Depth on SCARED Dataset 9: {mean_mae} mm')