In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys

LIBRARY_ROOT_PATH = os.path.abspath(os.path.join(".."))
sys.path.append(LIBRARY_ROOT_PATH)

In [None]:
ASSETS_PATH = os.path.join(LIBRARY_ROOT_PATH, "assets")
IMAGES_DIR_PATH = os.path.join(ASSETS_PATH, "images")

In [None]:
from glob import glob

In [None]:
IMAGES_PATHS = glob(os.path.join(IMAGES_DIR_PATH, "*.jpg"))

## How torchvision object detection models work?

In [None]:
import torch

In [None]:
from torchvision.io.image import read_image
from torchvision.models.detection import (
    retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights
)
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image

In [None]:
device = torch.device("cuda:0")

# Step 1: Initialize model with the best available weights
weights = RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1
model = retinanet_resnet50_fpn_v2(weights=weights).to(device)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

In [None]:
# Step 3: Apply inference preprocessing transforms
img = read_image(IMAGES_PATHS[0])
batch = [preprocess(img).to(device)]

# Step 4: Use the model and visualize the prediction
prediction = model(batch)[0]
labels = [weights.meta["categories"][i] for i in prediction["labels"]]
box = draw_bounding_boxes(img, boxes=prediction["boxes"],
                          labels=labels,
                          colors="red",
                          width=4, font_size=30)
im = to_pil_image(box.detach())
im.show()