In [1]:
import torch
from PIL import Image
from safetensors.torch import load_model

from yolo.model import Yolo
from yolo.data import CollateWithAnchors, CocoDataset, get_val_transforms
from yolo.anchors import DecodeDetections
from yolo.utils import DetectionMetrics

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

%load_ext autoreload
%autoreload 2

In [3]:
IMAGE_SIZE = 608
ANCHORS = [
    (10,13),
    (16,30),
    (33,23),
    (30,61),
    (62,45),
    (59,119),
    (116,90),
    (156,198),
    (373,326)
]
SCALES = [8, 16, 32]
num_anchors_per_scale = 3

val_dataset = CocoDataset(
    dataset_root = "/media/bryan/ssd01/fiftyone/coco-2017",
    split = "validation",
    transform = get_val_transforms(resize_size=IMAGE_SIZE)
)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    pin_memory=True,
    drop_last=False,
    num_workers=0,
    collate_fn=CollateWithAnchors(ANCHORS, SCALES, IMAGE_SIZE, IMAGE_SIZE, num_classes = val_dataset.num_classes)
)

model_checkpoint = "/media/bryan/ssd01/expr/yolo_from_scratch/debug02/checkpoints/checkpoint_9/model.safetensors"
model = Yolo(
    val_dataset.num_classes,
    num_anchors_per_scale,
)
weight_init = model.body.conv1.weight.detach().clone()
load_model(model, model_checkpoint, device="cpu")
weight_ckpt = model.body.conv1.weight.detach().clone()
assert not torch.allclose(weight_init, weight_ckpt)
model.eval()
model.to("cuda");


loading annotations into memory...
Done (t=0.17s)
creating index...
index created!


In [7]:
detection_decoder = DecodeDetections(
    ANCHORS,
    SCALES,
    IMAGE_SIZE,
    IMAGE_SIZE,
    class_names = val_dataset.class_names,
    num_anchors_per_scale=num_anchors_per_scale,
    box_min_area = 50
)
metrics = DetectionMetrics(val_dataset.class_names) #  backend="faster_coco_eval")

In [8]:
with torch.inference_mode():
    for i, batch in enumerate(val_dataloader):
        if i > 25:
            break
        outputs = model(batch["image"].to("cuda"))
        preds = detection_decoder(outputs, objectness_threshold=0.5, iou_threshold=0.5)
        metrics.update(preds, batch)

In [9]:
metric_outputs = metrics.compute()

In [10]:
metric_outputs

{'AP': 0.059606656432151794,
 'AP50': 0.10258711129426956,
 'AP75': 0.05941104516386986,
 'AP-large': 0.0769912376999855,
 'AP-medium': 0.044655200093984604,
 'AP-small': 0.011817298829555511,
 'AP-per-class/person': 0.10900713503360748,
 'AP-per-class/bicycle': 0.0,
 'AP-per-class/car': 0.1068156361579895,
 'AP-per-class/motorcycle': 0.04961424693465233,
 'AP-per-class/airplane': 0.05445544421672821,
 'AP-per-class/bus': 0.09529703110456467,
 'AP-per-class/train': 0.0,
 'AP-per-class/truck': 0.05445544421672821,
 'AP-per-class/boat': 0.0,
 'AP-per-class/traffic light': 0.031683169305324554,
 'AP-per-class/fire hydrant': 0.2052145153284073,
 'AP-per-class/stop sign': 0.2059405893087387,
 'AP-per-class/parking meter': 0.0,
 'AP-per-class/bench': 0.0,
 'AP-per-class/bird': 0.0,
 'AP-per-class/cat': 0.18415841460227966,
 'AP-per-class/dog': 0.10099010169506073,
 'AP-per-class/horse': 0.08193068951368332,
 'AP-per-class/sheep': 0.17298443615436554,
 'AP-per-class/cow': 0.10808581113815308,