# Prepare data for fine-tuning
Prepare data for fine-tuning the Open Grounding DINO model.

### 0. Import libraries and load data

In [None]:
import os
import re
import sys
import copy
import json
import shutil
import polars as pl
from tqdm import tqdm
from transformers import AutoTokenizer
from sklearn.model_selection import train_test_split

RAW_DATA_PATH = "../../data/raw/"
PROCESSED_DATA_PATH = "../../data/processed/"

TEXT_INCLUDED = "labels"
STORAGE_PATH = "../../data/fine-tuning/"

# TEXT_INCLUDED = "descriptions"
# STORAGE_PATH = "../../data/baseline/"

sys.path.append("../annotate_dataset/")

from annotate_paintings_utils import *

In [None]:
with open(f"{PROCESSED_DATA_PATH}paintings_with_filtered_objects.json") as f:
    all_annotations = json.load(f)

### 1. Create the stratified train/validation/test sets

In [None]:
paintings_data = pl.from_dicts(
    [
        {key: value for key, value in annotation.items() if key != "objects"}
        for annotation in all_annotations
    ]
)
paintings_data = paintings_data.with_columns((pl.col("year") // 100).alias("century"))
stratify_cols = ["century", "source", "coarse_type", "first_style"]

paintings_data = paintings_data.with_columns(
    pl.col(col).cast(pl.Utf8).fill_null("") for col in stratify_cols
).with_columns(
    pl.struct(stratify_cols)
    .map_elements(lambda x: "_".join(str(v) for v in x.values()), return_dtype=pl.String)
    .alias("stratify_key")
)

rare_stratify_keys = (
    paintings_data["stratify_key"]
    .value_counts()
    .filter(pl.col("count") < 10)["stratify_key"]
    .to_list()
)
paintings_data = paintings_data.with_columns(
    pl.when(pl.col("stratify_key").is_in(rare_stratify_keys))
    .then(pl.lit("other"))
    .otherwise(pl.col("stratify_key"))
    .alias("updated_statify_key")
)

painting_ids = paintings_data["id"].to_numpy()
stratify_key = paintings_data["updated_statify_key"].to_numpy()

train_ids, temp_ids, _, temp_stratify_key = train_test_split(
    painting_ids, stratify_key, test_size=0.20, random_state=42, stratify=stratify_key
)

val_ids, test_ids, _, _ = train_test_split(
    temp_ids, temp_stratify_key, test_size=0.50, random_state=42, stratify=temp_stratify_key
)
print(f"Train: {len(train_ids)} Validation: {len(val_ids)} Test: {len(test_ids)}")

### 1. Create data for fine-tuning in jsonl format

In [None]:
try:
    os.mkdir(STORAGE_PATH)
except:
    pass

try:
    shutil.rmtree(STORAGE_PATH + "train/")
except FileNotFoundError:
    pass
os.mkdir(STORAGE_PATH + "train/")

try:
    shutil.rmtree(STORAGE_PATH + "val/")
except FileNotFoundError:
    pass
os.mkdir(STORAGE_PATH + "val/")

try:
    shutil.rmtree(STORAGE_PATH + "test/")
except FileNotFoundError:
    pass
os.mkdir(STORAGE_PATH + "test/")

In [None]:
def create_dataset(all_annotations, set_name, tokenizer):
    annotations = []
    too_long_desc_counter = 0
    paintings_with_long_descriptions = []

    for painting_annotations in tqdm(all_annotations):
        painting_id = painting_annotations["id"]
        extracted_obj_desc = [[obj.replace(".", ""), re.sub(r' +', ' ', obj_data["description"][:-1].replace(".", " | ").strip())] for obj, obj_data in painting_annotations["objects"].items()]

        if TEXT_INCLUDED == "labels":
            caption = " . ".join([obj_info[0] for obj_info in extracted_obj_desc]) + " ."
        elif TEXT_INCLUDED == "descriptions":
            caption = " . ".join([obj_info[1] for obj_info in extracted_obj_desc]) + " ."

        if len(tokenizer.tokenize(caption)) > 254:
            too_long_desc_counter += 1
            paintings_with_long_descriptions.append(painting_id)
            continue

        image = load_image(painting_id)[1]

        current_annotation = {
            "filename": f"{painting_id}.png",
            "height": image.size[1],
            "width": image.size[0],
        }
        regions = []

        for obj, desc in extracted_obj_desc:
            for bbox in painting_annotations["objects"][obj]["bounding_boxes"]:
                if TEXT_INCLUDED == "labels":
                    regions.append({"bbox": bbox[1], "phrase": obj})
                elif TEXT_INCLUDED == "descriptions":
                    regions.append({"bbox": bbox[1], "phrase": desc})

        current_annotation["grounding"] = {
            "caption": caption,
            "regions": regions,
        }

        annotations.append(current_annotation)
        source_path = f"{RAW_DATA_PATH}filtered_paintings/{painting_id}.png"
        destination_path = f"{STORAGE_PATH}{set_name}/{painting_id}.png"
        shutil.copy2(source_path, destination_path)

    with open(f"{STORAGE_PATH}{set_name}/{set_name}_annotations.jsonl", "w", encoding="utf-8") as f:
        for item in annotations:
            json_line = json.dumps(item, ensure_ascii=False)
            f.write(json_line + "\n")

    print(f"{too_long_desc_counter} paintings from the {set_name} set have to long cumulated object descriptions.")
    print(paintings_with_long_descriptions)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

train_annotations = [annotation for annotation in all_annotations if annotation["id"] in train_ids]
create_dataset(train_annotations, "train", tokenizer)
val_annotations = [annotation for annotation in all_annotations if annotation["id"] in val_ids]
create_dataset(val_annotations, "val", tokenizer)
test_annotations = [annotation for annotation in all_annotations if annotation["id"] in test_ids]
create_dataset(test_annotations, "test", tokenizer)

### 2. Analyze object labels per each set

In [None]:
chosen_set = test_annotations

object_labels = []
object_descriptions = []

for annotation in chosen_set:
    object_labels.append(list(annotation["objects"].keys()))
    object_descriptions.append([obj_data["description"] for obj_data in annotation["objects"].values()])
    
paintings_objects = copy.deepcopy(chosen_set)

for index, painting_objects in enumerate(paintings_objects):
    paintings_objects[index]["objects"] = object_labels[index]
    paintings_objects[index]["object_description"] = object_descriptions[index]

paintings_objects = pl.from_dicts(paintings_objects, infer_schema_length=1000).explode("objects", "object_description")
print(f"Number of objects: {paintings_objects.shape[0]}")

unique_labels = paintings_objects.group_by("objects").len().sort("len", descending=True)
print(f"The number of unique labels: {unique_labels.shape[0]}")

with pl.Config(tbl_rows=100):
    display(unique_labels[:100])