# Cross modal retrieval
This notebook is used to evaluate the quality of the embedding space through cross-modal retrieval.

### 0. Import libraries and load data

In [None]:
import numpy as np
import polars as pl
from PIL import Image 
from tqdm import tqdm
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.metrics.pairwise import cosine_similarity

DETAILED_EMBEDDING_SPACE_EVALUATION = True

if DETAILED_EMBEDDING_SPACE_EVALUATION:
    INPUT_PATH_EMBEDDINGS = "../../../Open-Grounding-DINO/embeddings_data/"
    EMBEDDINGS_FILE_NAME = "clip_embeddings_test_clip_full_2e_6_diff_lr.json"
    IMAGES_PATH = "../../../Enhancing-Visual-Grounding-in-Paintings-with-Descriptions/data/fine-tuning_clip/"

else:
    INPUT_PATH_EMBEDDINGS = "../../data/embeddings/"
    EMBEDDINGS_FILE_NAME = "baseline_embeddings_test_projections_text_embedding_enhanced_features.json"
    IMAGES_PATH = "../../data/fine-tuning_clip"

In [None]:
try:
    embeddings_data = pl.read_json(f"{INPUT_PATH_EMBEDDINGS}{EMBEDDINGS_FILE_NAME}", infer_schema_length=1000).explode(pl.all())
except:
    embeddings_data = pl.read_json(f"{INPUT_PATH_EMBEDDINGS}{EMBEDDINGS_FILE_NAME}", infer_schema_length=1000)

try:
    embeddings_data = embeddings_data.with_row_index()
except:
    pass

indices = embeddings_data["index"].to_list()
embeddings_data = embeddings_data.with_columns((f"test/" + pl.col("index").cast(pl.String) + ".png").alias("image_name")).rename({"object_description": "text"})
embeddings_data

### 1. Define functions to perform retrieval and measure metrics

In [None]:
def retrieve_documents(embeddings_data, description_embeddings, image_object_embeddings, query_index, query_type):
    if query_type == "description":
        # get object image embeddings based on description embedding
        description_query = embeddings_data["text"][query_index]
        relevant_document_ids = embeddings_data.filter(pl.col("text") == description_query)["index"].to_numpy()

        similarities = cosine_similarity(np.array(description_embeddings[query_index]).reshape(1, -1), np.array(image_object_embeddings))[0]
        ranked_document_indices = np.argsort(similarities)[::-1]
        sorted_similarities = np.sort(similarities)[::-1]

    elif query_type == "image":
        # get description embedding based on image object embedding
        corresponding_descriptions = embeddings_data["text"][query_index]
        relevant_document_ids = embeddings_data.filter(pl.col("text") == corresponding_descriptions)["index"].to_numpy()

        similarities = cosine_similarity(np.array(image_object_embeddings[query_index]).reshape(1, -1), np.array(description_embeddings))[0]
        ranked_document_indices = np.argsort(similarities)[::-1]
        sorted_similarities = np.sort(similarities)[::-1]

    else:
        raise NameError("This query type does not exist.")
        
    return relevant_document_ids, ranked_document_indices, sorted_similarities

In [None]:
def evaluate_retrieval(embeddings_data, description_embeddings, image_object_embeddings, indices, query_type):
    hit_at_1 = []
    hit_at_5 = []
    hit_at_10 = []
    first_rank = []
    reciprocal_first_rank = []

    for query_index in tqdm(indices):
        relevant_document_ids, ranked_document_indices, sorted_similarities = retrieve_documents(embeddings_data, description_embeddings, image_object_embeddings, query_index, query_type)

        hit_at_1.append(int(np.isin(relevant_document_ids, ranked_document_indices[:1]).any()))
        hit_at_5.append(int(np.isin(relevant_document_ids, ranked_document_indices[:5]).any()))
        hit_at_10.append(int(np.isin(relevant_document_ids, ranked_document_indices[:10]).any()))

        first_rank.append(np.where(np.isin(ranked_document_indices, relevant_document_ids) == True)[0][0] + 1)
        reciprocal_first_rank.append(1 / first_rank[-1])

    hit_rate_at_1 = np.array(hit_at_1).mean()
    hit_rate_at_5 = np.array(hit_at_5).mean()
    hit_rate_at_10 = np.array(hit_at_10).mean()
    median_rank = round(np.median(np.array(first_rank)), 2)
    mean_reciprocal_rank = round(np.array(reciprocal_first_rank).mean(), 4)

    print(f"Hit@1: {hit_rate_at_1}\nHit@5: {hit_rate_at_5}\nHit@10: {hit_rate_at_10}\nMedian Rank: {median_rank}\nMRR: {mean_reciprocal_rank}")

### 2. Perform retrieval

In [None]:
image_object_embeddings = embeddings_data["embedding_object_image"].to_list()

try:
    description_embeddings_enhanced = embeddings_data["text_embedding_enhanced"].to_list()
except:
    pass

try:
    description_embeddings_backbone = embeddings_data["text_embedding_backbone"].to_list()
except:
    pass

#### 2.1. Use image objects as queries 

In [None]:
try:
    evaluate_retrieval(embeddings_data, description_embeddings_backbone, image_object_embeddings, indices, "image")
except:
    pass

In [None]:
try:
    evaluate_retrieval(embeddings_data, description_embeddings_enhanced, image_object_embeddings, indices, "image")
except:
    pass

#### 2.2. Use descriptions as queries

In [None]:
try:
    evaluate_retrieval(embeddings_data, description_embeddings_backbone, image_object_embeddings, indices, "description")
except:
    pass

In [None]:
try:
    evaluate_retrieval(embeddings_data, description_embeddings_enhanced, image_object_embeddings, indices, "description")
except:
    pass

### 3. Perform retrieval based on type

In [None]:
embeddings_data["coarse_type"].value_counts().sort("count")

In [None]:
coarse_types = set(embeddings_data.filter(pl.col("coarse_type").is_not_null())["coarse_type"].to_list())
coarse_types

#### 3.1. Use image objects as queries 

In [None]:
for coarse_type in coarse_types:
    print(f"{"-" * len(coarse_type)}\n{coarse_type}")
    indices_current_type = embeddings_data.filter(pl.col("coarse_type") == coarse_type)["index"].to_list()
    evaluate_retrieval(embeddings_data, description_embeddings_enhanced, image_object_embeddings, indices_current_type, "image")

#### 3.2. Use descriptions as queries 

In [None]:
for coarse_type in coarse_types:
    print(f"{"-" * len(coarse_type)}\n{coarse_type}")
    indices_current_type = embeddings_data.filter(pl.col("coarse_type") == coarse_type)["index"].to_list()
    evaluate_retrieval(embeddings_data, description_embeddings_enhanced, image_object_embeddings, indices_current_type, "description")

### 4. Perform visual analysis

In [None]:
def display_images_in_row_plotly(image_list, titles=None):
    """
    Displays a list of PIL images in a row using Plotly subplots.

    Args:
        image_list (list): A list of PIL Image objects.
        titles (list, optional): A list of titles for each image.
                                 Must be the same length as image_list.
    """
    if not image_list:
        print("The image list is empty.")
        return

    num_images = len(image_list)

    if titles and len(titles) != num_images:
        print("Warning: Number of titles does not match number of images. Titles will be ignored.")
        titles = None

    # Create subplots: 1 row, num_images columns
    fig = make_subplots(rows=1, cols=num_images, subplot_titles=titles)

    for i, pil_img in enumerate(image_list):
        # Convert PIL image to NumPy array
        np_img = np.array(pil_img)

        # Add image trace to the subplot
        fig.add_trace(
            go.Image(z=np_img),
            row=1, col=i + 1  # i+1 because columns are 1-indexed
        )

        # Remove axis labels and ticks for cleaner image display
        fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=i + 1)
        fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=i + 1)

    # Update layout for better appearance
    fig.update_layout(
        title_text="PIL Images in a Row (Plotly)",
        # Adjust height as needed. Width will scale with number of images.
        height=300, 
        # Optionally set a fixed width if you want, but autosize usually works well.
        # width=num_images * 200, 
        margin=dict(l=20, r=20, t=50, b=20) # Adjust margins
    )

    fig.show()

In [None]:
# # for more diverse results, keep only the most probable bounding box per description
# embeddings_data = (
#     embeddings_data.sort("probability", descending=True)
#     .group_by("text", maintain_order=True)
#     .agg(pl.all().first())
#     .sort("painting_id")
# )
# image_object_embeddings = embeddings_data["embedding_object_image"].to_list()

# try:
#     description_embeddings_enhanced = embeddings_data["text_embedding_enhanced"].to_list()
# except:
#     pass

# try:
#     description_embeddings_backbone = embeddings_data["text_embedding_backbone"].to_list()
# except:
#     pass

In [None]:
query_index = 188
relevant_document_ids, ranked_document_indices, sorted_similarities = retrieve_documents(embeddings_data, description_embeddings_enhanced, image_object_embeddings, query_index=query_index, query_type="image")

relevant_images = [Image.open(f"{IMAGES_PATH}{embeddings_data['image_name'][int(index)]}") for index in relevant_document_ids[:5]] 
relevant_descriptions = set([embeddings_data['text'][int(index)] for index in relevant_document_ids[:5]])

retrieved_images = [Image.open(f"{IMAGES_PATH}{embeddings_data['image_name'][int(index)]}") for index in ranked_document_indices[:5]]
retrieved_descriptions = [embeddings_data['text'][int(index)] for index in ranked_document_indices[:5]]

print("==QUERY==")
for description in relevant_descriptions:
    print(description)
display_images_in_row_plotly(relevant_images)

print("==RETRIEVED==")
for description in retrieved_descriptions:
    print(description)
display_images_in_row_plotly(retrieved_images)

In [None]:
query_index = 188
relevant_document_ids, ranked_document_indices, sorted_similarities = retrieve_documents(embeddings_data, description_embeddings_enhanced, image_object_embeddings, query_index=query_index, query_type="description")

relevant_images = [Image.open(f"{IMAGES_PATH}{embeddings_data['image_name'][int(index)]}") for index in relevant_document_ids[:5]] 
relevant_descriptions = set([embeddings_data['text'][int(index)] for index in relevant_document_ids[:5]])

retrieved_images = [Image.open(f"{IMAGES_PATH}{embeddings_data['image_name'][int(index)]}") for index in ranked_document_indices[:5]]
retrieved_descriptions = [embeddings_data['text'][int(index)] for index in ranked_document_indices[:5]]

print("==QUERY==")
for description in relevant_descriptions:
    print(description)
display_images_in_row_plotly(relevant_images)

print("==RETRIEVED==")
for description in retrieved_descriptions:
    print(description)
display_images_in_row_plotly(retrieved_images)