In [None]:
from functools import partial
from pathlib import Path

import torch
from pydantic import BaseModel, Field
from torch import Tensor
from torch.utils.data import DataLoader
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from tqdm import tqdm

from ssd import SSD
from ssd.data import LetterboxTransform, SSDDataset
from ssd.utils import TrainUtils

In [None]:
class EvaluateConfig(BaseModel):
    images_dir: Path
    labels_dir: Path
    min_confidence_threshold: float = Field(default=0.1)
    num_top_k: int = Field(default=100)
    nms_iou_threshold: float = Field(default=0.2)

### Define constants

In [None]:
DEVICE = torch.device("cuda:0")
CONFIG = EvaluateConfig(
    images_dir=Path("/mnt/data/datasets/object_detection/coco/images/val2017"),
    labels_dir=Path("/mnt/data/datasets/object_detection/coco/labels/val2017"),
    min_confidence_threshold=0.1,
    num_top_k=100,
    nms_iou_threshold=0.2
)

MODEL_FILE = Path("/mnt/data/code/ssd/models/91f18512-9b06-4c9a-9d2c-8330ed7458c3/best.pt")
IMAGE_WIDTH = 300
IMAGE_HEIGHT = 300
DTYPE = torch.float32

### Evaluate the model

In [None]:
# Load the model in
model = SSD.load(MODEL_FILE, DEVICE)

In [None]:
transform = LetterboxTransform(IMAGE_WIDTH, IMAGE_HEIGHT, DTYPE)
collate_func = partial(TrainUtils.batch_collate_func, device=DEVICE)

dataset = SSDDataset(CONFIG.images_dir, CONFIG.labels_dir, model.num_classes, transform, DEVICE, DTYPE)
data_loader = DataLoader(dataset, 8, shuffle=False, collate_fn=collate_func)

In [None]:
model.eval()

images: Tensor
labels: list[Tensor]
image_detections: list[dict[str, Tensor]] = []
image_labels: list[dict[str, Tensor]] = []
for images, labels in tqdm(data_loader):
    with torch.no_grad():
        head_outputs, anchors = model.forward(images)
        frame_detections = model._post_process_detections(head_outputs, anchors, CONFIG.min_confidence_threshold, CONFIG.num_top_k, CONFIG.nms_iou_threshold)

        image_detections += [d.model_dump() for d in frame_detections]
        image_labels += [{"boxes": l[:, 1:], "labels": l[:, 0].to(torch.int) - 1} for l in labels]

In [None]:
mAP = MeanAveragePrecision(box_format="cxcywh", iou_thresholds=[0.5])
mAP.update(image_detections, image_labels)
mAP.compute()

In [None]:
image_detections[0]["boxes"].shape

In [None]:
image_labels[0]["boxes"].shape

In [None]:
image_labels[0]["labels"]

In [None]:
image_detections[0]["labels"]