# Perform experiments
This notebook is used to performed different experiments such as finding the optimal hyperparameters for Grounding DINO.

### 0. Import libraries

In [None]:
%load_ext autoreload
%autoreload 2

import re
import numpy as np
import polars as pl
from tqdm import tqdm
from itertools import product

from ground_objects import *
from compute_metrics import *
from annotate_paintings_utils import *

In [None]:
ANNOTATIONS_PATH = "../../data/annotations/"
GROUNDING_RESULTS_PATH = "../../experiments/grounding/"

INTERMEDIATE_DATA_PATH = "../../data/intermediate/filtered_paintings/"

### 1. Find the optimal hyperparameters for Grounding DINO

In [None]:
mini_val_set_ids = [
    91,
    256,
    517,
    518,
    549,
    723,
    1006,
    1056,
    1067,
    1667,
    1722,
    1753,
    1799,
    1823,
    1966,
    2024,
    2156,
    2225,
    2441,
    2484,
    2628,
    2737,
    3344,
    3687,
    4416,
    5551,
    6660,
    8316,
    8344,
    10748,
    11422,
    11819,
]
mini_test_set_ids = [1466, 2461, 2549, 2613, 2657, 2702, 8197, 9053, 9821, 11593]

In [None]:
annotations = pl.read_json(ANNOTATIONS_PATH + "manual_annotations.json").with_columns(
    pl.col("object_name").str.replace_all(",", "", literal=True).alias("object_name")
)
mini_val_set = (
    annotations.filter(pl.col("painting_id").is_in(mini_val_set_ids))
    .group_by("painting_id")
    .agg(pl.col("object_name"))
    .to_numpy()
)
mini_test_set = (
    annotations.filter(pl.col("painting_id").is_in(mini_test_set_ids))
    .group_by("painting_id")
    .agg(pl.col("object_name"))
    .to_numpy()
)

In [None]:
# get device type
device = get_device()

# load groudning model
grounding_processor, grounding_model = get_grounding_model(device)

# if an image is not included, it doesn't have annotations
ground_truth_bboxes, labels_to_ids = get_bbox_annotations()

In [None]:
def ground_all_objects(
    dataset,
    ground_truth_bboxes,
    labels_to_ids,
    grounding_processor,
    grounding_model,
    device,
    object_threshold=0.3,
    text_threshold=0.3,
):
    mape = 0
    paintings_no = 0
    all_predicted_bboxes = []
    all_ground_truth_bboxes = []

    for painting_id, objects in dataset:
        _, image = load_image(painting_id)

        labels_scores_boxes, _ = detect_objects(
            image,
            objects,
            grounding_processor,
            grounding_model,
            device,
            object_threshold=object_threshold,
            text_threshold=text_threshold,
        )

        for pred in labels_scores_boxes:
            if pred[0] == "":
                return None, None

        get_bounding_boxes(
            labels_scores_boxes,
            labels_to_ids,
            ground_truth_bboxes,
            painting_id,
            all_predicted_bboxes,
            all_ground_truth_bboxes,
            device,
        )

        gt_bboxes_no = len(
            [gt_bboxes for gt_bboxes in ground_truth_bboxes if gt_bboxes["image_id"] == painting_id]
        )
        pred_bboxes_no = len(labels_scores_boxes)

        if gt_bboxes_no != 0:
            mape += abs(gt_bboxes_no - pred_bboxes_no) / gt_bboxes_no
            paintings_no += 1

    map_50, _ = compute_mean_average_precision(
        all_predicted_bboxes, all_ground_truth_bboxes, device
    )

    mape /= paintings_no

    return map_50, mape

In [None]:
object_thresholds = np.arange(0.3, 0.351, 0.01)
text_thresholds = np.arange(0.3, 0.351, 0.01)
grounding_results = {
    "object_threshold": [],
    "text_threshold": [],
    "map_50_val": [],
    "mape_val": [],
    "map_50_test": [],
    "mape_test": [],
}

for object_threshold, text_threshold in tqdm(list(product(object_thresholds, text_thresholds))):
    map_50_val, mape_val = ground_all_objects(
        mini_val_set,
        ground_truth_bboxes,
        labels_to_ids,
        grounding_processor,
        grounding_model,
        device,
        object_threshold,
        text_threshold,
    )

    if map_50_val is None:
        continue

    map_50_test, mape_test = ground_all_objects(
        mini_test_set,
        ground_truth_bboxes,
        labels_to_ids,
        grounding_processor,
        grounding_model,
        device,
        object_threshold,
        text_threshold,
    )

    if map_50_test is None:
        continue

    grounding_results["object_threshold"].append(object_threshold)
    grounding_results["text_threshold"].append(text_threshold)
    grounding_results["map_50_val"].append(map_50_val)
    grounding_results["mape_val"].append(mape_val)
    grounding_results["map_50_test"].append(map_50_test)
    grounding_results["mape_test"].append(mape_test)

    grounding_results_df = pl.from_dict(grounding_results)
    # grounding_results_df.write_csv(f"{GROUNDING_RESULTS_PATH}grounding_dino_hyperparameters.csv")

grounding_results_df

In [None]:
grounding_results_df = pl.read_csv(f"{GROUNDING_RESULTS_PATH}grounding_dino_hyperparameters.csv")
grounding_results_df = (
    grounding_results_df.with_columns((1 - pl.col("mape_val")).alias("1-mape_val"))
    .with_columns(((pl.col("map_50_val") + pl.col("1-mape_val")) / 2).alias("avg metric val"))
    .with_columns((1 - pl.col("mape_test")).alias("1-mape_test"))
    .with_columns(((pl.col("map_50_test") + pl.col("1-mape_test")) / 2).alias("avg metric test"))
    .drop("mape_val", "mape_test")
)


grounding_results_df.sort("avg metric val", descending=True)

In [None]:
map_50_test, diff_obj_no = ground_all_objects(
    mini_val_set,
    ground_truth_bboxes,
    labels_to_ids,
    grounding_processor,
    grounding_model,
    device,
    0.34,
    0.32,
)

### 2. Analyze what to remove from descriptions to clean them

In [None]:
annotated_painting_ids = pl.read_json(ANNOTATIONS_PATH + "manual_annotations.json")[
    "painting_id"
].to_list()
descriptions = (
    pl.read_json(f"{INTERMEDIATE_DATA_PATH}filtered_paintings_enhanced_data.json")
    .filter(pl.col("id").is_in(annotated_painting_ids))["source", "description"]
    .to_numpy()
)

In [None]:
def print_in_chunks(text, chunk_size=24):
    words = text.split()
    for i in range(0, len(words), chunk_size):
        chunk = words[i:i + chunk_size]
        print(' '.join(chunk))

In [None]:
def clean_description(text):
    # remove [url href=...]...[/url], keep inner text
    text = re.sub(r'\[url href=.*?\](.*?)\[/url\]', r'\1', text)
    
    # remove [i], [/i], [b], [/b], [u], [/u]
    text = re.sub(r'\[/?[ibu]\]', '', text)
    
    # remove raw URLs
    text = re.sub(r'http[s]?://\S+|www\.\S+', '', text)

    # remove remaining url tags
    text = re.sub(r'\[/url\]', '', text)
    text = re.sub(r'\[url=?', '', text)

    # collapse multiple spaces and strip whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    
    return text

In [None]:
for source, description in descriptions:
    original_description = description
    cleaned_description = clean_description(description)

    print(f"---{source}--- {original_description == cleaned_description}")

    if original_description != cleaned_description:
        print_in_chunks(original_description)
        print("+++")
        print_in_chunks(cleaned_description)
        print("\n")