In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import torch
import json
import requests
import os
import torch
from torchvision import transforms
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection, AutoModelForCausalLM

In [None]:
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

In [None]:
def plot_results(image, result, category, caption, prompt_image):
    boxes = result["boxes"]
    scores = result["scores"]
    labels = result["labels"]
    fig, axes = plt.subplots(1, 2, figsize = (8, 4))
    plt.axis("off")
    # Plot validation image
    fig.suptitle(f"Category: {category}\nCaption: {caption}")
    axes[0].imshow(image)
    axes[0].set_title("Validation Image")
    axes[0].axis("off")
    colors = plt.cm.hsv(torch.linspace(0, 1, len(boxes))).tolist()
    
    # Add box
    for box, score, label, color in zip(boxes, scores, labels, colors):
        xmin, ymin, xmax, ymax = box.tolist()
        axes[0].add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill = False, color = color, linewidth = 2))
        text = f"{label}: {score:.2f}"
        axes[0].text(xmin, ymin, text, bbox = dict(facecolor = "yellow", alpha = 0.5))
    
    # Plot prompt image
    axes[1].imshow(prompt_image)
    axes[1].set_title("Prompt Image")
    axes[1].axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
import json

# DINO
dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny") # .to(device)

with open("OFA_cropped_output.txt", "r") as f:
    lines = f.readlines()
full_results = []
for idx, line in enumerate(lines):
    try:
        # Split line: 
        parts = line.strip().split(" || ")
        image_path, caption = parts[0], parts[1]
        category = int(parts[2])
        cropped_image_path = parts[3]

        prompt_image = Image.open(cropped_image_path)

        image = Image.open(requests.get(image_path, stream=True).raw)
        image = image.convert('RGB')

        text_labels = [[caption, "null"]]

        # Process with DINO
        dino_inputs = dino_processor(images=image, text=text_labels, return_tensors="pt") # .to(device)
        with torch.no_grad():
            dino_outputs = dino_model(**dino_inputs)

        # Post-process
        results = dino_processor.post_process_grounded_object_detection(
            dino_outputs,
            dino_inputs.input_ids,
            box_threshold=0.4,
            text_threshold=0.3,
            target_sizes=[image.size[::-1]]
        )

        # Plot output and prompt image
        plot_results(image, results[0], category, caption, prompt_image)
    except Exception as e:
        print(e)