In [None]:
from transformers import YolosFeatureExtractor, YolosForObjectDetection
from PIL import Image
import requests
import numpy as np
import cv2
from matplotlib import pyplot as plt
import matplotlib.patches as patches
import json
import torch
import os
import torch.nn.functional as F

In [None]:
with open(os.path.join("..", "data", "coco_classes.txt"), "r") as f:
    coco_classes = [c.rstrip("\n") for c in f.readlines()]
coco_classes.insert(0, "unknown")

In [None]:
print(len(coco_classes))
print(coco_classes)

In [None]:
feature_extractor = YolosFeatureExtractor.from_pretrained("hustvl/yolos-small")
model = YolosForObjectDetection.from_pretrained("hustvl/yolos-small")

In [None]:
all_images = []

thumbnail_dir = os.path.join("..", "data", "thumbnails")
img_names = os.listdir(thumbnail_dir)
for img_name in img_names:
    img_path = os.path.join(thumbnail_dir, img_name)
    all_images.append(Image.open(img_path))

In [None]:
images = all_images[-5:]

In [None]:
plt.imshow(images[-3])
print(np.array(images[-3]).shape)

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = feature_extractor(images=images, return_tensors="pt")
outputs = model(**inputs)

# model predicts bounding boxes and corresponding COCO classes
logits = outputs.logits
bboxes = outputs.pred_boxes

In [None]:
print(inputs.keys())
print(inputs["pixel_values"].shape)
print(logits.shape)
print(bboxes.shape)

In [None]:
probs = F.softmax(logits.detach().clone(), dim=-1) # [B, 100, 92]
preds      = probs.argmax(-1) # [B, 100]
confidence = probs.max(-1)[0] # [B, 100]
known_indices     = [(preds_img != 91).nonzero()[:,0] for preds_img in preds] # [B, known indices]
confident_indices = [(conf_img > 0.75).nonzero()[:,0] for conf_img in confidence] # [B, confident indices]
indices = [list(set(known_idx_img.tolist()).intersection(set(conf_idx_img.tolist()))) for known_idx_img,conf_idx_img in zip(known_indices, confident_indices)] # [B, intersection of indices]
pred_classes = [[coco_classes[v] for v in preds_img[idx_img]] for preds_img, idx_img in zip(preds, indices)] # [B, predicted classes]

confidence = (confidence*100).tolist()

pred_classes

In [None]:
amount = len(images)
fig, ax = plt.subplots(amount,1, figsize=(15,amount*13))

pred_boxes = bboxes.detach().clone()
cmap = plt.cm.get_cmap("hsv", len(coco_classes))

for img_idx, image in enumerate(images):
    a = ax[img_idx]
    a.imshow(image)
    a.axis("Off")

    box_counter = 0
    for j, patch_idx in enumerate(indices[img_idx]):
        c = pred_classes[img_idx][j]
        
        conf = confidence[img_idx][patch_idx]
        bbox = pred_boxes[img_idx, patch_idx].detach()
        x, y, W, H = bbox.split(1)
        im_w = image.width
        im_h = image.height
        W *= im_w
        H *= im_h
        x = x*im_w - W*.5
        y = y*im_h - H*.5

        a.text(x, y, f"{c} ({conf:.2f}%)", fontsize=20, c="white")

        # Create a Rectangle patch
        color = cmap(preds[img_idx][patch_idx].item())
        # color = "b"
        rect = patches.Rectangle((x, y), W, H, linewidth=2,
            edgecolor=color,
            facecolor='none',
        )

        # Add the patch to the Axes
        a.add_patch(rect)
        box_counter += 1
    print("box count: ", box_counter)