In [None]:
from pathlib import Path

import cv2
import torch
import matplotlib.pyplot as plt
import numpy as np
from torchvision.transforms.functional import pil_to_tensor
from PIL import Image

from ssd import SSD
from ssd.data import LetterboxTransform

In [None]:
IMAGE_WIDTH = 300
IMAGE_HEIGHT = 300

In [None]:
device = torch.device("cpu")
model = SSD.load(
    Path("/mnt/data/code/ssd/models/23ac32e7-5881-4f99-87a5-3b556464f721/best.pt"),
    device,
)

images_dir = Path("/mnt/data/datasets/object_detection/coco/images/val2017")
image_files = list(images_dir.glob("*.jpg"))

transform = LetterboxTransform()

In [None]:
file_idx = 4
file = image_files[file_idx]

# Letterbox image
image = pil_to_tensor(Image.open(file))
letterbox_image = transform.transform_image(image, device)
image = letterbox_image.permute((1, 2, 0)).numpy().astype(np.uint8).copy()

# Run inference on the image
detections = model.infer(file=file, confidence_threshold=0.5, num_top_k=400, nms_iou_threshold=0.3)
for detection, score, label in zip(detections.boxes, detections.scores, detections.class_ids):
    cx, cy, w, h = detection
    left = int((cx - w/2) * IMAGE_WIDTH)
    right = int((cx + w/2) * IMAGE_WIDTH)
    top = int((cy - h/2) * IMAGE_HEIGHT)
    bottom = int((cy + h/2) * IMAGE_HEIGHT)

    image = cv2.rectangle(image, (left, top), (right, bottom), (255, 0, 0), 2)
    image = cv2.putText(image, f"C={int(label.item())} S={score.item():.3f}", (left, top - 4), 0, 0.3, (255, 0, 0), 1)
    print(f"Score: {score.item()} Label: {label.item()}")

plt.imshow(image)