# Annotate paintings
This notebook leverages an LLM to extract the described objects that appear in a painting and then ground them with Grounding DINO.

### 0. Import libraries and set the configuration

In [None]:
%load_ext autoreload
%autoreload 2

from tqdm import tqdm

from call_llm import *
from ground_objects import *
from compute_metrics import *
from annotate_paintings_utils import *

In [None]:
GEMINI_MODEL = "gemini-2.0-flash"
SENTENCE_SIMILARITY_MODEL_NAME = "all-mpnet-base-v2"
GROUNDING_MODEL_ID = "IDEA-Research/grounding-dino-base"

### 1. Import models and data

In [None]:
# get device type
device = get_device()

# load models
llm_client = get_llm_client()
grounding_processor, grounding_model = get_grounding_model(GROUNDING_MODEL_ID, device)
sentence_similarity_model = get_sentence_similarity_model(SENTENCE_SIMILARITY_MODEL_NAME)

# load data
paintings_data, annotations, few_shot_examples, test_paintings = load_data()

# if an image is not included, it doesn't have annotations
ground_truth_bboxes, labels_to_ids = get_bbox_annotations()

### 2. Experiment with annotation prompts

In [None]:
verbose = False
prompt_type = "basic_with_spans"
observations = "first trial with few-shot learning"

In [None]:
# TODO: handle empty detections for a painting
# TODO: store everything that's needed inside the results file (e.g. spans)
# TODO: handle if the LLM output is None

In [None]:
tp_fp_fn_objects = [0, 0, 0]
tp_fp_fn_spans = [0, 0, 0]
total_token_count = 0

painting_ids = []

all_predicted_objects = []
all_ground_truth_objects = []

all_predicted_bboxes = []
all_ground_truth_bboxes = []


span_similarity_metrics = {
    "cosine similarity": [],
    "Levenshtein distance": [],
    "delete percentage": [],
    "false positive percentage": [],
    "coverage percentage": [],
}

for painting in tqdm(test_paintings[:1]):
    painting_id = painting["painting_id"]
    painting_ids.append(painting_id)

    description = painting["description"]
    image = load_image(painting_id)

    # extract described objects
    llm_output, token_count = generate(
        llm_client,
        few_shot_examples,
        image,
        description,
        prompt_type,
        GEMINI_MODEL,
        verbose,
    )
    total_token_count += token_count
    spans_are_extracted = "description_spans" in llm_output[0].__dict__

    # handle objects
    predicted_objects, ground_truth_objects = process_objects(
        llm_output, painting, all_predicted_objects, all_ground_truth_objects, verbose
    )
    compute_f1(predicted_objects, ground_truth_objects, tp_fp_fn_objects)

    # handle spans
    if spans_are_extracted:
        predicted_spans_per_object, ground_truth_spans_per_object, predicted_spans, ground_truth_spans = process_spans(llm_output, painting)
        compute_spans_quality(
            ground_truth_spans_per_object,
            predicted_spans_per_object,
            span_similarity_metrics,
            sentence_similarity_model,
            verbose
        )
        compute_f1(predicted_spans, ground_truth_spans, tp_fp_fn_spans)


    # ground objects
    labels_scores_boxes, results = detect_objects(
        image,
        predicted_objects,
        grounding_processor,
        grounding_model,
        device,
        verbose,
        object_threshold=0.3,
        text_threshold=0.3,
    )

    get_bounding_boxes(
        labels_scores_boxes,
        labels_to_ids,
        ground_truth_bboxes,
        painting_id,
        all_predicted_bboxes,
        all_ground_truth_bboxes,
        device,
    )

micro_f1_objects = compute_micro_f1(tp_fp_fn_objects, "objects", verbose)

if spans_are_extracted:
    micro_f1_spans = compute_micro_f1(tp_fp_fn_spans, "spans", verbose)
    for metric in span_similarity_metrics:
        span_similarity_metrics[metric] = np.array(span_similarity_metrics[metric]).mean()

map_50, map_50_95 = compute_mean_average_precision(
    all_predicted_bboxes, all_ground_truth_bboxes, device, verbose
)
print(f"Total token count: {total_token_count}")

In [None]:
# store results for the tested prompt
# results_values = list(zip(painting_ids, all_predicted_objects, all_ground_truth_objects))
# store_results(micro_f1, results_values, prompt_type, observations)

TODO: compute the F1 score for spans