# Annotate paintings
This notebook leverages an LLM to extract the described objects that appear in a painting and then ground them with Grounding DINO.

### 0. Import libraries and set the configuration

In [None]:
import io
import os
import json
import random

import torch
import polars as pl
from PIL import Image
from google import genai
from google.genai import types
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection

In [None]:
GEMINI_MODEL = "gemini-2.0-flash"
GROUNDING_MODEL_ID = "IDEA-Research/grounding-dino-base"


RAW_DATA_PATH = "../../data/raw/"
ANNOTATIONS_PATH = "../../data/annotations/"
INTERMEDIATE_DATA_PATH = "../../data/intermediate/filtered_paintings/"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
with open("../../config/keys.json", "r") as file:
    os.environ["GEMINI_API_KEY"] = json.load(file)["gemini_api_key"]

### 1. Extract with an LLM the described objects

In [None]:
def show_available_models(client):
    existing_models = [
        model.name.split("/")[1]
        for model in client.models.list()
        if "generateContent" in model.supported_actions
    ]

    print(existing_models)

In [None]:
def image_to_bytes(image):
    # define an in-memory byte stream
    img_byte_array = io.BytesIO()

    # convert the image to a byte representation and store it in the in-memory byte stream
    image.save(img_byte_array, format=image.format or "PNG")

    # get the byte representation of the image
    img_bytes = img_byte_array.getvalue()

    return img_bytes

In [None]:
def generate(client, examples, image, description):
    # instructions = """List separated by comma of **ONLY** the objects (lowercased) described in the given painting that also appear in the textual description. 
    # If an object is only present in the description, but there aren't details about it, do not include it in the list. Explain how your reasoning step by step. The description is:\n
    # """
    # input_prompt = instructions + f'"{description}"'

    prompt_parts = []

    prompt_parts.append(types.Content(role="user", parts=[types.Part.from_text(text="\nHere are some examples:")]))

    for example in examples:
        example_painting_id = example["painting_id"]
        example_description = example["description"]
        example_detected_objects = ", ".join(example["object_name"])

        example_image = Image.open(f"{RAW_DATA_PATH}filtered_paintings/{example_painting_id}.png")

        prompt_parts.append(types.Content(role="user", parts=[
            types.Part.from_bytes(mime_type="image/png", data=image_to_bytes(example_image)),
            types.Part.from_text(text=f"Description: \"{example_description}\"")
        ]))
        prompt_parts.append(types.Content(role="model", parts=[
            types.Part.from_text(text=f"Detected objects: {example_detected_objects}")
        ]))
        prompt_parts.append(types.Part.from_text(text="---"))


    prompt_parts.append(types.Content(role="user", parts=[
        types.Part.from_bytes(mime_type="image/png", data=image_to_bytes(image)),
        types.Part.from_text(text=f"Description: \"{description}\"\n\nList separated by comma of **ONLY** the objects (lowercased) described in the given painting that also appear in the textual description.")
    ]))

    generate_content_config = types.GenerateContentConfig(
        response_mime_type="text/plain",
        system_instruction=[
            types.Part.from_text(text="""You are an expert in art who can identify objects present in both a painting and its textual description."""),
        ],
    )
    
    response = client.models.generate_content(
        model=GEMINI_MODEL,
        contents=prompt_parts,
        config=generate_content_config,
    )

    output = response.text
    prompt_tokens_count = response.usage_metadata.prompt_token_count
    output_tokens_count = response.usage_metadata.candidates_token_count

    print(f"Prompt tokens count: {prompt_tokens_count}\nOutput tokens count: {output_tokens_count}")
    print(f"Response:\n{output}\n")

    return prompt_parts, generate_content_config, output

### 2. Ground Objects with Grounding DINO

In [None]:
processor = AutoProcessor.from_pretrained(GROUNDING_MODEL_ID)
model = AutoModelForZeroShotObjectDetection.from_pretrained(GROUNDING_MODEL_ID).to(DEVICE)

In [None]:
def display_annotated_image(image, labels_scores_boxes):
    font = ImageFont.truetype("../../config/alata-regular.ttf", 18)
    draw = ImageDraw.Draw(image, "RGBA")

    for label, score, coords in labels_scores_boxes:
        random_color = "#{:06x}".format(random.randint(0, 0xFFFFFF)) + "80"
        text_position = (coords[0] + 10, coords[1] + 5)

        draw.rectangle(coords, outline=random_color, width=5)
        draw.text(text_position, label + " " + str(round(score, 2)), fill=random_color, font=font)

    display(image)

In [None]:
def detect_objects(image, text, processor, model, object_threshold=0.3, text_threshold=0.3):
    inputs = processor(images=image, text=text, return_tensors="pt").to(DEVICE)

    with torch.no_grad():
        outputs = model(**inputs)

    results = processor.post_process_grounded_object_detection(
        outputs,
        inputs.input_ids,
        # threshold for filtering object detection predictions (lower -> more bounding boxes)
        threshold=object_threshold,
        # threshold for filtering text detection predictions (lower -> the input text is taken exactly)
        text_threshold=text_threshold,
        target_sizes=[image.size[::-1]],
    )

    assert len(results) == 1

    labels = results[0]["text_labels"]
    scores = results[0]["scores"].cpu().numpy()
    box_coordinates = [list(coords) for coords in results[0]["boxes"].cpu().numpy()]
    labels_scores_boxes = sorted(list(zip(labels, scores, box_coordinates)), key=lambda x: x[1])

    for label, score, coords in labels_scores_boxes:
        print(label, float(score), [float(coord) for coord in coords])

    display_annotated_image(image, labels_scores_boxes)

    return labels_scores_boxes

### 3. Process painting

In [None]:
client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
paintings_data = pl.read_json(f"{INTERMEDIATE_DATA_PATH}filtered_paintings_enhanced_data.json")

In [None]:
few_shot_example_ids = [2156, 2484, 11819, 256, 10748, 3344, 10676]
annotations = pl.read_json(ANNOTATIONS_PATH + "manual_annotations.json")
few_shots_descriptions = paintings_data.filter(pl.col("id").is_in(few_shot_example_ids)).select("id", "description").rename({"id": "painting_id"})

few_shot_examples = (
    annotations.filter(pl.col("painting_id").is_in(few_shot_example_ids))
    .group_by("painting_id")
    .agg(pl.col("*"))
    .select("painting_id", "object_name", "description_spans")
)

few_shot_examples = few_shot_examples.join(few_shots_descriptions, on="painting_id").to_dicts()

In [None]:
painting_id = 2461
description = paintings_data[painting_id, "description"]
image = Image.open(f"../../data/raw/filtered_paintings/{painting_id}.png")

# extarct described objects
prompt_parts, generate_content_config, output = generate(client, few_shot_examples, image, description)

# ground objects
text = output.replace(", ", ". ") + "."
labels_scores_boxes = detect_objects(
    image, text, processor, model, object_threshold=0.3, text_threshold=0.3
)