In [None]:
import jsonlines
import json
import requests
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
from io import BytesIO

def draw_boxes(image, predictions):
    draw = ImageDraw.Draw(image)
    
    # Draw predicted bounding boxes
    for pred in predictions:
        bbox = pred['bounding_box']
        bbox = [int(coord) for coord in bbox]
        label = pred['object_name']
        is_hallucination = pred['is_hallucination']
        is_misclassification = pred['is_misclassification']
        color = "red" if is_hallucination else "blue"
        color = "yellow" if is_misclassification else color
        draw.rectangle(bbox, outline=color, width=2)
        draw.text((bbox[0], bbox[1]), f"{label} (P)", fill=color)
    
    # Draw ground truth bounding boxes
    for pred in predictions:
        bbox = pred['bbox_match_box']
        bbox = [int(coord) for coord in bbox]
        label = pred['bbox_match_object']
        draw.rectangle(bbox, outline="green", width=2)
        draw.text((bbox[0], bbox[1]), f"{label} (GT)", fill="green")
    
    return image

data = []
with jsonlines.open('pipeline_outputs/bbox_hallucinations_hth_0.5_mth_0.5.jsonl') as reader:
    for obj in reader:
        data.append(obj)

with open('data/bbox_pope_images/labels.json') as f:
    labels = json.load(f)

image_urls = {img['file_name']: img['coco_url'] for img in labels['images']}

images_data = {}
for obj in data:
    image_id = obj['question_id']
    if image_id not in images_data:
        images_data[image_id] = []
    # Only include predictions where both is_hallucination and is_misclassification are False
    if not obj['is_hallucination'] and not obj['is_misclassification']:
        images_data[image_id].append(obj)

for image_id, predictions in images_data.items():
    if not predictions:  # Skip images with no valid predictions
        continue

    image_url = image_urls.get(image_id)
    if not image_url:
        print(f"Image URL for {image_id} not found.")
        continue
    
    response = requests.get(image_url)
    if response.status_code != 200:
        print(f"Failed to download image {image_id}")
        continue
    
    image = Image.open(BytesIO(response.content))
    
    image_with_boxes = draw_boxes(image, predictions)
    
    plt.figure(figsize=(12, 8))
    plt.imshow(image_with_boxes)
    plt.axis('off')
    plt.title(image_id)
    plt.show()
