In [1]:
#pip install timm

In [None]:
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from transformers import DetrImageProcessor, DetrForObjectDetection

In [None]:
# Load the pre-trained DETR model and processor
model_name = "facebook/detr-resnet-50"
processor = DetrImageProcessor.from_pretrained(model_name)
model = DetrForObjectDetection.from_pretrained(model_name)

In [None]:
# Load and preprocess the image
image_path = "Restaurant.jpg"
image = Image.open(image_path)
inputs = processor(images=image, return_tensors="pt", size=(800, 800), size_format="longest_edge")

In [None]:
# Make predictions
outputs = model(**inputs)
detections = processor.post_process_object_detection(outputs)

In [None]:
# Filter detections based on the threshold
threshold = 0.5
for detection in detections:
    selected_indices = detection["scores"] > threshold

    # Display the filtered predictions
    labels = detection["labels"][selected_indices].cpu().numpy()
    scores = detection["scores"][selected_indices].detach().cpu().numpy()
    boxes = detection["boxes"][selected_indices].detach().cpu().numpy()

    # Visualize the bounding boxes on the image
    img_with_boxes = image.copy()
    draw = ImageDraw.Draw(img_with_boxes)
    for box, label in zip(boxes, labels):
        box = [round(coord, 2) for coord in box]
        draw.rectangle(box, outline="red", width=3)
        draw.text((box[0], box[1]), f"{model.config.id2label[label.item()]}: {scores[label]:.2f}", fill="red")

    plt.imshow(img_with_boxes)
    plt.axis("off")
    plt.show()

    # Get the count of each type of object
    label_counts = {}
    for label in labels:
        if label not in label_counts:
            label_counts[label] = 1
        else:
            label_counts[label] += 1

    # Display the count of each type of object
    summary = ""
    for label, count in label_counts.items():
        summary += f"{count} {model.config.id2label[label]}{'s' if count > 1 else ''}, "
    summary = summary.rstrip(", ")
    print("Summary:", summary)