# Evaluation


## Imports


In [1]:
# Import external libraries
import torch
import numpy as np
import pandas as pd

from PIL import Image
from pathlib import Path
from torchinfo import summary

In [2]:
# Import internal libraries
from melanoma_classification.model import get_dermmel_classifier_v1
from utils.dermmel import DermMel
from melanoma_classification.utils import (
    production_transform, 
    get_device,
    visualize_single_attention,
    visualize_multihead_as_single_attention,
    visualize_multihead_attention,
)
from evaluation.evaluator import (
    visualize_loss,
    visualize_f1_precision_recall,
    visualize_accuracy,
    create_evaluation_report,
    visualize_confusion_matrix,
    visualize_model_confidence,
)

## Preparations


In [None]:
# Init device
device = get_device()
print(f"Using device: {device}")

In [4]:
# Set paths
figure_path = Path("evaluation") / "images"
checkpoint_base_path = Path("checkpoints") / "dermmel_orig_image_test"
training_metrics_filename = "metrics.csv"
final_model_path = Path("..") / "src" / "melanoma_classification" / "weights"
final_model_path.mkdir(exist_ok=True)


### Create & read in model from checkpoints


In [None]:
checkpoint_path = checkpoint_base_path / "checkpoint_epoch_20.pth"
checkpoint = torch.load(checkpoint_path, map_location=device)

In [None]:
vit = get_dermmel_classifier_v1()
vit.load_state_dict(checkpoint["model_state_dict"])

# Load the model
summary(vit, input_size=(1, 3, 224, 224), device=device)

### Create test dataset & dataloader


In [None]:
test_dataset = DermMel(
    "../data", split="test", transform=production_transform()
)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=1, shuffle=False, num_workers=0
)

### Read-in training metrics


In [None]:
training_metrics = pd.read_csv(checkpoint_base_path / training_metrics_filename)
training_metrics.head()

## Training analysis


In [None]:
visualize_loss(training_metrics, figure_path / "loss.png")

In [None]:
visualize_f1_precision_recall(
    training_metrics, 
    figure_path / "f1_precision_recall.png"
)

In [None]:
visualize_accuracy(training_metrics, figure_path / "accuracy.png")

## Test model


In [12]:
evaluation_report = create_evaluation_report(
    vit, test_dataloader, test_dataset.classes, device
)

In [None]:
visualize_confusion_matrix(
    evaluation_report,
    test_dataset.classes,
    figure_path / "confusion_matrix.png"
)

In [None]:
visualize_model_confidence(evaluation_report, figure_path / "model_confidences.png")

## Save only the model


In [15]:
# Save the model to production folder
torch.save(vit.state_dict(), final_model_path / "vit.pth")

## Infer unseen image & visualize attention maps


Image source: [Wikipedia](https://en.wikipedia.org/wiki/Melanoma) 

In [None]:
# Read and process image
img_path = figure_path / "Melanoma.jpg"
raw_image = Image.open(img_path).convert("RGB")

image = (
    production_transform()(image=np.array(raw_image))["image"]
    .to(device)
    .unsqueeze(0)
)

In [None]:
# Classify the image
vit.eval()
with torch.no_grad():
    model_outputs = vit(image)
    logits = model_outputs["outputs"]
    attention = model_outputs["attentions"]
    logits = torch.nn.functional.softmax(logits, 1)
    confidence, prediction = torch.max(logits, dim=1)
    confidence, prediction = confidence.item(), prediction.item()

detected = vit.class_map[prediction]
print(f"Found a {detected} sample with confidence {confidence*100:.2f}%.")

In [None]:
# Visualize the attention maps

# Choose layer to visualize
layer = -1

# Visualize single attention map averaged over all heads and layers
visualize_single_attention(
    raw_image,
    attention,
    img_path=img_path / "single_attention.png"
)

# Visualize multihead attention maps as a single attention map for a specific
# layer
visualize_multihead_as_single_attention(
    raw_image,
    attention,
    layer=layer,
    img_path=img_path / f"multihead_as_single_attention{layer}.png"
)

# Visualize multihead attention maps for a specific layer separately for each
# head
visualize_multihead_attention(
    raw_image, 
    attention, 
    img_path=img_path / f"multihead_attention{layer}.png"
)