In [None]:
from functools import partial
from pathlib import Path

import cv2
import torch
import matplotlib.pyplot as plt
import numpy as np
from pydantic import BaseModel, Field
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.ops import box_convert
from tqdm import tqdm

from ssd import SSD
from ssd.data import LetterboxTransform, SSDDataset
from ssd.structs import FrameLabels
from ssd.utils import TrainUtils
from ssd.utils.metrics_calculator import MetricsCalculator

In [None]:
class EvaluateConfig(BaseModel):
    images_dir: Path
    labels_dir: Path
    min_confidence_threshold: float = Field(default=0.1)
    num_top_k: int = Field(default=100)
    nms_iou_threshold: float = Field(default=0.2)

### Define constants

In [None]:
DEVICE = torch.device("cuda:0")
CONFIG = EvaluateConfig(
    images_dir=Path("/mnt/data/datasets/object_detection/coco/images/val2017"),
    labels_dir=Path("/mnt/data/datasets/object_detection/coco/labels/val2017"),
    min_confidence_threshold=0.1,
    num_top_k=200,
    nms_iou_threshold=0.2,
)

MODEL_FILE = Path(
    "/mnt/data/code/ssd/runs/ea4e4832-f3ec-4aef-8a65-d4badf2bb9c8/best.pt"
)
IMAGE_WIDTH = 300
IMAGE_HEIGHT = 300
DTYPE = torch.float32

### Evaluate the model

In [None]:
# Load the model in
model = SSD.load(MODEL_FILE, DEVICE)

In [None]:
transform = LetterboxTransform(IMAGE_WIDTH, IMAGE_HEIGHT, DTYPE)
collate_func = partial(TrainUtils.batch_collate_func, device=DEVICE)

dataset = SSDDataset(
    CONFIG.images_dir,
    CONFIG.labels_dir,
    model.num_classes,
    transform,
    None,
    DEVICE,
    DTYPE,
)
data_loader = DataLoader(dataset, 8, shuffle=True, collate_fn=collate_func)

In [None]:
model.eval()

images: Tensor
frame_labels: list[FrameLabels]
metrics_calculator = MetricsCalculator(model.num_classes)
image_detections: list[dict[str, Tensor]] = []
image_labels: list[dict[str, Tensor]] = []
for images, frame_labels in tqdm(data_loader):
    with torch.no_grad():
        head_outputs, anchors = model.forward(images)
        frame_detections = model._post_process_detections(
            head_outputs,
            anchors,
            CONFIG.min_confidence_threshold,
            CONFIG.num_top_k,
            CONFIG.nms_iou_threshold,
        )

        metrics_calculator.update(frame_detections, frame_labels)

### Analyse per-class metrics

In [None]:
CLASS_ID = 0

In [None]:
# Plot the precision
precisions = metrics_calculator.precisions()[:, :, CLASS_ID]
precisions = precisions.cpu().numpy()

plt.figure()
for iou_idx in range(precisions.shape[1]):
    plt.plot(
        metrics_calculator._confidence_thresholds,
        precisions[:, iou_idx],
        label=f"IoU thresh = {metrics_calculator._iou_thresholds[iou_idx]:.2f}",
    )
plt.xlim((0, 1))
plt.ylim((0, 1))
plt.title(f"Precision with confidence\nclass_id={CLASS_ID}")
plt.xlabel("Confidence threshold")
plt.ylabel("Precision")
plt.legend()
plt.grid()

In [None]:
# Plot the recall
recalls = metrics_calculator.recalls()[:, :, CLASS_ID]
recalls = recalls.cpu().numpy()

plt.figure()
for iou_idx in range(precisions.shape[1]):
    plt.plot(
        metrics_calculator._confidence_thresholds,
        recalls[:, iou_idx],
        label=f"IoU thresh = {metrics_calculator._iou_thresholds[iou_idx]:.2f}",
    )
plt.xlim((0, 1))
plt.ylim((0, 1))
plt.title(f"Recall with confidence\nclass_id={CLASS_ID}")
plt.xlabel("Confidence threshold")
plt.ylabel("Recall")
plt.legend()
plt.grid()

In [None]:
# Create the PR-curve
plt.figure()
for iou_idx in range(precisions.shape[1]):
    plt.plot(
        recalls[:, iou_idx],
        precisions[:, iou_idx],
        label=f"IoU thresh = {metrics_calculator._iou_thresholds[iou_idx]:.2f}",
    )
plt.xlim((0, 1))
plt.ylim((0, 1))
plt.title(f"Precision-recall curve\nclass_id={CLASS_ID}")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.legend()
plt.grid()

### Analyse overall accuracy

In [None]:
mAPs = metrics_calculator.mAPs().cpu().numpy()

plt.figure()
class_ids = [cid for cid in range(model.num_classes)]
plt.bar(class_ids, mAPs)
plt.xlabel("Class ID")
plt.ylabel("mAP@(50-95)")
plt.grid()
plt.xlim((0, model.num_classes))
plt.ylim((0, 1))

In [None]:
mAPs.mean().item()

### Visualise detections

In [None]:
class_colours = [
    (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255))
    for _ in range(80)
]

In [None]:
images: Tensor
objects: list[FrameLabels]
images, objects = next(iter(data_loader))
with torch.no_grad():
    head_outputs, anchors = model.forward(images)
    frame_detections = model._post_process_detections(
        head_outputs,
        anchors,
        CONFIG.min_confidence_threshold,
        CONFIG.num_top_k,
        CONFIG.nms_iou_threshold,
    )

In [None]:
# Display labels
num_rows = 2
num_cols = images.shape[0] // num_rows
fig, axes = plt.subplots(num_rows, num_cols)
fig.set_figwidth(16)
fig.set_figheight(8)
for idx in range(images.shape[0]):
    row = idx // num_cols
    col = idx % num_cols

    # Extract the image
    image = images[idx, ...].permute((1, 2, 0)).clone()
    image *= 255
    image = image.to(torch.uint8).cpu().numpy().copy()

    # Draw the detections on the image
    objs = objects[idx]
    boxes = box_convert(objs.boxes, "cxcywh", "xyxy")
    boxes[:, ::2] *= IMAGE_WIDTH
    boxes[:, 1::2] *= IMAGE_HEIGHT
    boxes = boxes.cpu().to(torch.int).numpy()
    class_ids = objs.class_ids.cpu().to(torch.int).numpy()

    for box_idx in range(boxes.shape[0]):
        box = boxes[box_idx, :]
        class_id = class_ids[box_idx]
        image = cv2.rectangle(
            image, tuple(box[:2]), tuple(box[2:]), class_colours[class_id], 2
        )
        cv2.putText(
            image, f"c={class_id}", tuple(box[:2]), 0, 0.6, class_colours[class_id], 2
        )

    axes[row, col].imshow(image)

In [None]:
# Display detections
num_rows = 2
num_cols = images.shape[0] // num_rows
fig, axes = plt.subplots(num_rows, num_cols)
fig.set_figwidth(16)
fig.set_figheight(8)
for idx in range(images.shape[0]):
    row = idx // num_cols
    col = idx % num_cols

    # Extract the image
    image = images[idx, ...].permute((1, 2, 0)).clone()
    image *= 255
    image = image.to(torch.uint8).cpu().numpy().copy()

    # Draw the detections on the image
    detections = frame_detections[idx]
    boxes = box_convert(detections.boxes, "cxcywh", "xyxy")
    boxes[:, ::2] *= IMAGE_WIDTH
    boxes[:, 1::2] *= IMAGE_HEIGHT
    boxes = boxes.cpu().to(torch.int).numpy()
    class_ids = detections.class_ids.cpu().to(torch.int).numpy()

    for box_idx in range(boxes.shape[0]):
        box = boxes[box_idx, :]
        class_id = class_ids[box_idx]
        image = cv2.rectangle(
            image, tuple(box[:2]), tuple(box[2:]), class_colours[class_id], 2
        )
        cv2.putText(
            image, f"c={class_id}", tuple(box[:2]), 0, 0.6, class_colours[class_id], 2
        )

    axes[row, col].imshow(image)

In [None]:
images[0]