In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

LIBRARY_ROOT_PATH = os.path.abspath(os.path.join(".."))
sys.path.append(LIBRARY_ROOT_PATH)

In [3]:
ASSETS_PATH = os.path.join(LIBRARY_ROOT_PATH, "assets")
IMAGES_DIR_PATH = os.path.join(ASSETS_PATH, "images")

In [4]:
from glob import glob

In [5]:
IMAGES_PATHS = glob(os.path.join(IMAGES_DIR_PATH, "*.jpg"))

## How torchvision object detection models work?

In [6]:
import torch

In [7]:
from torchvision.io.image import read_image
from torchvision.models.detection import (
    retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights
)
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image

In [8]:
device = torch.device("cuda:0")

# Step 1: Initialize model with the best available weights
weights = RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1
model = retinanet_resnet50_fpn_v2(weights=weights).to(device)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

In [23]:
# Step 3: Apply inference preprocessing transforms
img = read_image(IMAGES_PATHS[0])
batch = [preprocess(img).to(device)]

# Step 4: Use the model and visualize the prediction
with torch.no_grad():
    prediction = model(batch)[0]
    

labels = [weights.meta["categories"][i] for i in prediction["labels"]]
box = draw_bounding_boxes(img, boxes=prediction["boxes"],
                          labels=labels,
                          colors="red",
                          width=4, font_size=30)
im = to_pil_image(box.detach())
im.show()



In [11]:
prediction.keys()

dict_keys(['boxes', 'scores', 'labels'])

In [22]:
prediction["boxes"][0]

tensor([ 598.6989,   40.1023, 1030.0309,  843.2996], device='cuda:0')

In [18]:
img.shape

torch.Size([3, 853, 1280])

In [24]:
type(preprocess)

torchvision.transforms._presets.ObjectDetection