In [None]:
from functools import partial
from pathlib import Path

import torch
import matplotlib.pyplot as plt
from pydantic import BaseModel, Field
from torch import Tensor
from torch.utils.data import DataLoader
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from tqdm import tqdm

from ssd import SSD
from ssd.data import LetterboxTransform, SSDDataset
from ssd.structs import FrameLabels
from ssd.utils import TrainUtils, Metrics

In [None]:
class EvaluateConfig(BaseModel):
    images_dir: Path
    labels_dir: Path
    min_confidence_threshold: float = Field(default=0.1)
    num_top_k: int = Field(default=100)
    nms_iou_threshold: float = Field(default=0.2)

### Define constants

In [None]:
DEVICE = torch.device("cuda:0")
CONFIG = EvaluateConfig(
    images_dir=Path("/mnt/data/datasets/object_detection/coco/images/val2017"),
    labels_dir=Path("/mnt/data/datasets/object_detection/coco/labels/val2017"),
    min_confidence_threshold=0.1,
    num_top_k=100,
    nms_iou_threshold=0.2
)

MODEL_FILE = Path("/mnt/data/code/ssd/models/91f18512-9b06-4c9a-9d2c-8330ed7458c3/best.pt")
IMAGE_WIDTH = 300
IMAGE_HEIGHT = 300
DTYPE = torch.float32

### Evaluate the model

In [None]:
# Load the model in
model = SSD.load(MODEL_FILE, DEVICE)

In [None]:
transform = LetterboxTransform(IMAGE_WIDTH, IMAGE_HEIGHT, DTYPE)
collate_func = partial(TrainUtils.batch_collate_func, device=DEVICE)

dataset = SSDDataset(CONFIG.images_dir, CONFIG.labels_dir, model.num_classes, transform, DEVICE, DTYPE)
data_loader = DataLoader(dataset, 8, shuffle=False, collate_fn=collate_func)

In [None]:
model.eval()

images: Tensor
frame_labels: list[FrameLabels]
metrics = Metrics()
image_detections: list[dict[str, Tensor]] = []
image_labels: list[dict[str, Tensor]] = []
for images, frame_labels in tqdm(data_loader):
    with torch.no_grad():
        head_outputs, anchors = model.forward(images)
        frame_detections = model._post_process_detections(head_outputs, anchors, CONFIG.min_confidence_threshold, CONFIG.num_top_k, CONFIG.nms_iou_threshold)

        metrics.update(frame_detections, frame_labels)

        image_detections += [d.model_dump() for d in frame_detections]
        image_labels += [{"boxes": l.boxes, "labels": l.class_ids} for l in frame_labels]

In [None]:
pr_curve = metrics.generate_precision_recall_curve(model.num_classes, 0.4)

In [None]:
confidences = pr_curve.confidences
precisions = torch.stack(pr_curve.precisions, dim=0)
recalls = torch.stack(pr_curve.recalls, dim=0)

In [None]:
plt.figure(1)
# for class_id in range(precisions.shape[1]):
for class_id in range(10):
    plt.plot(confidences, precisions[:, class_id].to("cpu").numpy(), label=f"Class ID: {class_id}")
plt.title("Precision-Confidence Curve")
plt.xlabel("Confidence")
plt.ylabel("Precision")
plt.xlim((0, 1))
plt.ylim((0, 1))
plt.legend()
plt.grid()

In [None]:
plt.figure(1)
# for class_id in range(recalls.shape[1]):
for class_id in range(10):
    plt.plot(confidences, recalls[:, class_id].to("cpu").numpy(), label=f"Class ID: {class_id}")
plt.title("Recall-Confidence Curve")
plt.xlabel("Confidence")
plt.ylabel("Recall")
plt.xlim((0, 1))
plt.ylim((0, 1))
plt.legend()
plt.grid()

In [None]:
plt.figure(1)
# for class_id in range(recalls.shape[1]):
for class_id in range(10):
    plt.plot(recalls[:, class_id].to("cpu").numpy(), precisions[:, class_id].to("cpu").numpy(), label=f"Class ID: {class_id}")
plt.title("Precision-Recall Curve")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.xlim((0, 1))
plt.ylim((0, 1))
plt.legend()
plt.grid()

In [None]:
if len(pr_curve.precisions) == 0 or pr_curve.precisions[0].shape == (0,):
    raise ValueError("No precisions found.")

if len(pr_curve.precisions) != len(pr_curve.recalls):
    msg = (
        f"Precision-recall shape mismatch: {len(pr_curve.precisions)} != "
        f"{len(pr_curve.recalls)}."
    )
    raise ValueError(msg)

dtype = pr_curve.precisions[0].dtype
device = pr_curve.precisions[0].device
num_classes = pr_curve.precisions[0].shape[0]

APs = torch.zeros((num_classes,), dtype=dtype, device=device)

# Swap the recalls and precisions so they occur from left to right on the plot
recalls_ltr = pr_curve.recalls[::-1]
precisions_ltr = pr_curve.precisions[::-1]

for idx in range(len(recalls_ltr)):
    if idx == 0:
        r0 = torch.zeros((num_classes,), dtype=dtype, device=device)
        p0 = torch.ones((num_classes,), dtype=dtype, device=device)
    else:
        r0 = recalls_ltr[idx - 1]
        p0 = precisions_ltr[idx - 1]
    r1 = recalls_ltr[idx]
    p1 = precisions_ltr[idx]


    # Calculate the mean precision betweeen r0 and r1
    delta_r = r1 - r0
    p_mean = p0 + (p1 - p0) / 2

    # Update the area under the curve
    APs += p_mean * delta_r

APs

In [None]:
mAP = MeanAveragePrecision(box_format="cxcywh")
mAP.update(image_detections, image_labels)
results = mAP.compute()

results

In [None]:
type(results)

In [None]:
image_detections[0]["boxes"].shape

In [None]:
image_labels[0]["boxes"].shape

In [None]:
image_labels[0]["labels"]

In [None]:
image_detections[0]["labels"]