In [None]:
import requests
import torch

from PIL import Image
from transformers import AutoImageProcessor, AutoModelForObjectDetection

# Name of repo on the hub or path to a local folder
model_name = "ckpts/detr-simple-2-ft-9-epochs-05-01_12-28"
# model_name = "ckpts/detr-simple-ft-368-epochs"

image_processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForObjectDetection.from_pretrained(model_name)


In [None]:
print(f'Labels in model: {model.config.id2label}')
id2label = {
    0: "biker",
    1: "car",
    2: "pedestrian",
    3: "trafficlight",
    4: "trafficlight-Green",
    5: "trafficlight-Greenleft",
    6: "trafficlight-Red",
    7: "trafficlight-Redleft",
    8: "trafficlight-Yellow",
    9: "trafficlight-Yellowleft",
    10: "truck",
    11: "Arret"
}

In [None]:

# Load image for inference
fname = "../data/joint_55_56/frame704.jpg"
image = Image.open(fname)

# Prepare image for the model
inputs = image_processor(images=image, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)


width, height = image.size
target_sizes = torch.tensor([height, width]).unsqueeze(0)  # add batch dim
results = image_processor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[0]

for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
    box = [round(i, 2) for i in box.tolist()]
    print(
        f"Detected {model.config.id2label[label.item()]} with confidence "
        f"{round(score.item(), 3)} at location {box}"
    )


In [None]:
import matplotlib.pyplot as plt

# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

def plot_results(pil_img, scores, labels, boxes):
    plt.figure(figsize=(16,10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    for score, label, (xmin, ymin, xmax, ymax),c  in zip(scores.tolist(), labels.tolist(), boxes.tolist(), colors):
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   fill=False, color=c, linewidth=3))
        text = f'{model.config.id2label[label]}: {score:0.2f}'
        ax.text(xmin, ymin, text, fontsize=15,
                bbox=dict(facecolor='yellow', alpha=0.5))
    plt.axis('off')
    plt.show()
    


In [None]:
plot_results(image, results["scores"], results["labels"], results["boxes"])