In [None]:
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image
import requests

processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")

In [None]:
url = "http://images.cocodataset.org/val2017/000000439715.jpg"
image = Image.open(requests.get(url, stream=True).raw)


In [None]:
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)

In [None]:
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(
 outputs, target_sizes=target_sizes, threshold=0.9)[0]

In [None]:
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
    box = [round(i, 2) for i in box.tolist()]
    print(
        f"Detected {model.config.id2label[label.item()]} \
 with confidence "
        f"{round(score.item(), 3)} at location {box}"
    )


In [None]:
COLORS = [
    [0.0, 0.5, 0.8],
    [0.9, 0.3, 0.1],
    [0.9, 0.6, 0.1],
    [0.4, 0.1, 0.5],
    [0.4, 0.6, 0.1],
    [0.3, 0.7, 0.9],
]


def visualize_prediction(pil_img, output_dict, threshold):
    keep = output_dict["scores"] > threshold
    boxes = output_dict["boxes"][keep].tolist()
    scores = output_dict["scores"][keep].tolist()
    labels = output_dict["labels"][keep].tolist()
    labels = [model.config.id2label[x] for x in labels]
    plt.figure(figsize=(8, 5))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    for score, (xmin, ymin, xmax, ymax), label, color in zip(
        scores, boxes, labels, colors
    ):
        ax.add_patch(
            plt.Rectangle(
                (xmin, ymin),
                xmax - xmin,
                ymax - ymin,
                fill=False,
                color=color,
                linewidth=3,
            )
        )
        ax.text(xmin, ymin, label, fontsize=8, bbox=dict(facecolor="yellow", alpha=0.5))
        plt.axis("off")
    plt.show()


In [None]:
visualize_prediction(image, results, 0.9)
