# Cross modal retrieval
This notebook is used to evaluate the quality of the embedding space through cross-modal retrieval.

### 0. Import libraries and load data

In [None]:
import textwrap
import numpy as np
import polars as pl
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy.stats import wilcoxon
from sklearn.metrics.pairwise import cosine_similarity

DETAILED_EMBEDDING_SPACE_EVALUATION = True

if DETAILED_EMBEDDING_SPACE_EVALUATION:
    INPUT_PATH_EMBEDDINGS = "../../data/embeddings/"
    EMBEDDINGS_FILE_NAME = (
        "baseline_embeddings_test_projections_text_embedding_enhanced_features.json"
    )
    IMAGES_PATH = "../../data/fine-tuning_clip/"

else:
    INPUT_PATH_EMBEDDINGS = "../../../Open-Grounding-DINO/embeddings_data/"
    EMBEDDINGS_FILE_NAME = "clip_embeddings_test_clip_full_2e_6_diff_lr.json"
    IMAGES_PATH = (
        "../../../Enhancing-Visual-Grounding-in-Paintings-with-Descriptions/data/fine-tuning_clip/"
    )

In [None]:
try:
    embeddings_data = pl.read_json(
        f"{INPUT_PATH_EMBEDDINGS}{EMBEDDINGS_FILE_NAME}", infer_schema_length=1000
    ).explode(pl.all())
except:
    embeddings_data = pl.read_json(
        f"{INPUT_PATH_EMBEDDINGS}{EMBEDDINGS_FILE_NAME}", infer_schema_length=1000
    )

try:
    embeddings_data = embeddings_data.with_row_index()
except:
    pass

indices = embeddings_data["index"].to_list()
embeddings_data = embeddings_data.with_columns(
    (f"test/" + pl.col("index").cast(pl.String) + ".png").alias("image_name")
).rename({"object_description": "text"})
embeddings_data

### 1. Define functions to perform retrieval and measure metrics

In [None]:
def retrieve_documents(
    embeddings_data, description_embeddings, image_object_embeddings, query_index, query_type
):

    if query_type == "description":
        # get object image embeddings based on description embedding
        description_query = embeddings_data["text"][query_index]
        relevant_document_ids = embeddings_data.filter(pl.col("text") == description_query)[
            "index"
        ].to_numpy()

        similarities = cosine_similarity(
            np.array(description_embeddings[query_index]).reshape(1, -1),
            np.array(image_object_embeddings),
        )[0]

        ranked_document_indices = np.argsort(similarities)[::-1]
        sorted_similarities = np.sort(similarities)[::-1]

    elif query_type == "image":
        # get description embedding based on image object embedding
        corresponding_descriptions = embeddings_data["text"][query_index]
        relevant_document_ids = embeddings_data.filter(
            pl.col("text") == corresponding_descriptions
        )["index"].to_numpy()

        similarities = cosine_similarity(
            np.array(image_object_embeddings[query_index]).reshape(1, -1),
            np.array(description_embeddings),
        )[0]

        ranked_document_indices = np.argsort(similarities)[::-1]
        sorted_similarities = np.sort(similarities)[::-1]

    else:
        raise NameError("This query type does not exist.")
    
    return relevant_document_ids, ranked_document_indices, sorted_similarities

In [None]:
def evaluate_retrieval(
    embeddings_data, description_embeddings, image_object_embeddings, indices, query_type
):
    hit_at_1 = []
    hit_at_5 = []
    hit_at_10 = []
    first_rank = []
    reciprocal_first_rank = []

    for query_index in tqdm(indices):

        relevant_document_ids, ranked_document_indices, sorted_similarities = retrieve_documents(
            embeddings_data,
            description_embeddings,
            image_object_embeddings,
            query_index,
            query_type,
        )

        hit_at_1.append(int(np.isin(relevant_document_ids, ranked_document_indices[:1]).any()))
        hit_at_5.append(int(np.isin(relevant_document_ids, ranked_document_indices[:5]).any()))
        hit_at_10.append(int(np.isin(relevant_document_ids, ranked_document_indices[:10]).any()))
        first_rank.append(
            np.where(np.isin(ranked_document_indices, relevant_document_ids) == True)[0][0] + 1
        )
        reciprocal_first_rank.append(1 / first_rank[-1])

    hit_rate_at_1 = np.array(hit_at_1).mean()
    hit_rate_at_5 = np.array(hit_at_5).mean()
    hit_rate_at_10 = np.array(hit_at_10).mean()
    median_rank = round(np.median(np.array(first_rank)), 2)
    mean_reciprocal_rank = round(np.array(reciprocal_first_rank).mean(), 4)

    print(
        f"Hit@1: {hit_rate_at_1}\nHit@5: {hit_rate_at_5}\nHit@10: {hit_rate_at_10}\nMedian Rank: {median_rank}\nMRR: {mean_reciprocal_rank}"
    )

### 2. Perform retrieval

In [None]:
image_object_embeddings = embeddings_data["embedding_object_image"].to_list()

try:
    description_embeddings_enhanced = embeddings_data["text_embedding_enhanced"].to_list()
except:
    pass

try:
    description_embeddings_backbone = embeddings_data["text_embedding_backbone"].to_list()
except:
    pass

#### 2.1. Use image objects as queries 

In [None]:
try:
    evaluate_retrieval(
        embeddings_data, description_embeddings_backbone, image_object_embeddings, indices, "image"
    )
except:
    pass

In [None]:
try:
    evaluate_retrieval(
        embeddings_data, description_embeddings_enhanced, image_object_embeddings, indices, "image"
    )
except:
    pass

#### 2.2. Use descriptions as queries

In [None]:
try:
    evaluate_retrieval(
        embeddings_data,
        description_embeddings_backbone,
        image_object_embeddings,
        indices,
        "description",
    )
except:
    pass

In [None]:
try:
    evaluate_retrieval(
        embeddings_data,
        description_embeddings_enhanced,
        image_object_embeddings,
        indices,
        "description",
    )
except:
    pass

### 3. Perform retrieval based on type

In [None]:
embeddings_data["coarse_type"].value_counts().sort("count")

In [None]:
coarse_types = set(
    embeddings_data.filter(pl.col("coarse_type").is_not_null())["coarse_type"].to_list()
)

coarse_types

#### 3.1. Use image objects as queries 

In [None]:
for coarse_type in coarse_types:
    print(f"{"-" * len(coarse_type)}\n{coarse_type}")
    indices_current_type = embeddings_data.filter(pl.col("coarse_type") == coarse_type)[
        "index"
    ].to_list()
    evaluate_retrieval(
        embeddings_data,
        description_embeddings_enhanced,
        image_object_embeddings,
        indices_current_type,
        "image",
    )

#### 3.2. Use descriptions as queries 

In [None]:
for coarse_type in coarse_types:
    print(f"{"-" * len(coarse_type)}\n{coarse_type}")
    indices_current_type = embeddings_data.filter(pl.col("coarse_type") == coarse_type)[
        "index"
    ].to_list()
    evaluate_retrieval(
        embeddings_data,
        description_embeddings_enhanced,
        image_object_embeddings,
        indices_current_type,
        "description",
    )

### 4. Perform visual analysis

In [None]:
def visualize_retrieval(images, descriptions, labels, output_filename="", font_size=8): # font_size=12
    fig, axes = plt.subplots(1, len(images), figsize=(12, 5)) # (18, 8)

    plt.subplots_adjust(top=0.9, bottom=0.1, hspace=0.6, wspace=0.1)

    for i in range(len(images)):
        ax = axes[i]

        ax.imshow(images[i])
        if i == 0:
            ax.set_title(f"Query [{labels[i]}]", fontsize=font_size + 4)
        else:
            ax.set_title(f"Rank {i} [{labels[i]}]", fontsize=font_size + 4)
        ax.axis("off")

        wrapped_text = textwrap.fill(descriptions[i], width=30)
        ax.text(
            0.5,
            -0.05,
            wrapped_text,
            transform=ax.transAxes,
            ha="center",
            va="top",
            fontsize=font_size,
            wrap=True,
        )

    plt.tight_layout()
    if output_filename != "":
        plt.savefig(output_filename, bbox_inches="tight")
    plt.show()

In [None]:
projected_embeddings = (
    pl.read_json(
        f"{INPUT_PATH_EMBEDDINGS}baseline_embeddings_test_projections_text_embedding_enhanced_features.json",
        infer_schema_length=1000,
    ).sort("painting_id")
    .with_row_index()
    .with_columns((f"test/" + pl.col("index").cast(pl.String) + ".png").alias("image_name"))
    .with_columns((pl.col("year") // 100 + 1).alias("century"))
    .rename({"object_description": "text"})
)
probabilities = projected_embeddings["probability"]
projected_embeddings = projected_embeddings.sort("probability", descending=True).unique(subset=["text"], keep="first").sort("index").drop("index").with_row_index()

description_projected_embeddings = projected_embeddings["text_embedding_enhanced"].to_list()
image_object_projected_embeddings = projected_embeddings["embedding_object_image"].to_list()
projected_embeddings.head()

In [None]:
clip_embeddings = (
    pl.read_json(
        f"{INPUT_PATH_EMBEDDINGS}clip_embeddings_test_clip_full_1e_6_diff_lr_not_frozen_features.json",
        infer_schema_length=1000,
    )
    .with_columns((f"test/" + pl.col("index").cast(pl.String) + ".png").alias("image_name"))
    .with_columns((pl.col("year") // 100 + 1).alias("century"))
    .rename({"object_description": "text"})
).with_columns(pl.Series(probabilities).alias("probability"))
clip_embeddings = clip_embeddings.sort("probability", descending=True).unique(subset=["text"], keep="first").sort("index").drop("index").with_row_index()

description_clip_embeddings = clip_embeddings["text_embedding_enhanced"].to_list()
image_object_clip_embeddings = clip_embeddings["embedding_object_image"].to_list()
clip_embeddings.head()

In [None]:
def display_retrieval_results(embeddings_data, description_embeddings, image_object_embeddings, query_index, query_modality, top_k=5):
    _, ranked_document_indices, _ = retrieve_documents(
        embeddings_data,
        description_embeddings,
        image_object_embeddings,
        query_index=query_index,
        query_type=query_modality,
    )

    query_image = [Image.open(f"{IMAGES_PATH}{embeddings_data['image_name'][int(query_index)]}")]
    query_description = [embeddings_data["text"][int(query_index)]]
    query_label = [embeddings_data["label"][int(query_index)]]

    retrieved_images = [
        Image.open(f"{IMAGES_PATH}{embeddings_data['image_name'][int(index)]}")
        for index in ranked_document_indices[:top_k]
    ]

    retrieved_descriptions = [
        embeddings_data["text"][int(index)] for index in ranked_document_indices[:top_k]
    ]

    retrieved_labels = [
        embeddings_data["label"][int(index)] for index in ranked_document_indices[:top_k]
    ]

    visualize_retrieval(
        query_image + retrieved_images, query_description + retrieved_descriptions, query_label + retrieved_labels
    )

In [None]:
query_index = projected_embeddings.filter(pl.col("label") == "lamb")["index"][0]

display_retrieval_results(projected_embeddings, description_projected_embeddings, image_object_projected_embeddings, query_index, "image")
display_retrieval_results(projected_embeddings, description_projected_embeddings, image_object_projected_embeddings, query_index, "description")

display_retrieval_results(clip_embeddings, description_clip_embeddings, image_object_clip_embeddings, query_index, "image")
display_retrieval_results(clip_embeddings, description_clip_embeddings, image_object_clip_embeddings, query_index, "description")

In [None]:
with pl.Config(tbl_rows=20):
    display(projected_embeddings.sort("probability", descending=True).unique(subset=["description"], keep="first")["label"].value_counts().filter(pl.col("count") == 1))

### 5. Analyze the attributes of the retrieved items

In [None]:
def quantify_identical_feature_value(
    embeddings_data, description_embeddings, image_object_embeddings, feature_name, query_index, query_type, top_k=5
):

    if query_type == "description":
        # get object image embeddings based on description embedding
        description_query = embeddings_data["text"][query_index]
        relevant_document_ids = embeddings_data.filter(pl.col("text") == description_query)[
            "index"
        ].to_numpy()

        similarities = cosine_similarity(
            np.array(description_embeddings[query_index]).reshape(1, -1),
            np.array(image_object_embeddings),
        )[0]

        ranked_document_indices = np.argsort(similarities)[::-1][:top_k]

    elif query_type == "image":
        # get description embedding based on image object embedding
        corresponding_descriptions = embeddings_data["text"][query_index]
        relevant_document_ids = embeddings_data.filter(
            pl.col("text") == corresponding_descriptions
        )["index"].to_numpy()

        similarities = cosine_similarity(
            np.array(image_object_embeddings[query_index]).reshape(1, -1),
            np.array(description_embeddings),
        )[0]

        ranked_document_indices = np.argsort(similarities)[::-1][:top_k]

    else:
        raise NameError("This query type does not exist.")
    
    feature_query_item = embeddings_data.filter(pl.col("index").is_in(relevant_document_ids))[feature_name][0]
    feature_retrieved_items = embeddings_data.filter(pl.col("index").is_in(ranked_document_indices))[feature_name].to_numpy()

    percentage_same_feature = (feature_query_item == feature_retrieved_items).sum() / top_k
    
    return percentage_same_feature

In [None]:
def quantify_identical_feature_value_wrapper(embeddings, feature, modality):
    feature_values = (
            embeddings.filter(pl.col(feature).is_not_null())[feature]
            .value_counts()
            .sort("count", descending=True)[:10][feature]
            .to_list()
        )
    
    selected_embeddings = embeddings.filter(pl.col(feature).is_in(feature_values)).drop("index").with_row_index()

    description_embeddings = selected_embeddings["text_embedding_enhanced"].to_list()
    image_object_embeddings = selected_embeddings["embedding_object_image"].to_list()

    items_same_feature_value = np.array([])

    for i in range(selected_embeddings.shape[0]):
        items_same_feature_value = np.append(items_same_feature_value, quantify_identical_feature_value(selected_embeddings, description_embeddings, image_object_embeddings, feature, i, modality))

    avg_percentage_same_feature_value = items_same_feature_value.mean()
    print(f"{modality} - {feature}: {avg_percentage_same_feature_value}")

In [None]:
print("Projected Grounding DINO embeddings")
quantify_identical_feature_value_wrapper(projected_embeddings, "coarse_type", "image")
quantify_identical_feature_value_wrapper(projected_embeddings, "first_style", "image")
quantify_identical_feature_value_wrapper(projected_embeddings, "century", "image")
quantify_identical_feature_value_wrapper(projected_embeddings, "label", "image")
print("---")

quantify_identical_feature_value_wrapper(projected_embeddings, "coarse_type", "description")
quantify_identical_feature_value_wrapper(projected_embeddings, "first_style", "description")
quantify_identical_feature_value_wrapper(projected_embeddings, "century", "description")
quantify_identical_feature_value_wrapper(projected_embeddings, "label", "description")

print("\nCLIP embeddings")
quantify_identical_feature_value_wrapper(clip_embeddings, "coarse_type", "image")
quantify_identical_feature_value_wrapper(clip_embeddings, "first_style", "image")
quantify_identical_feature_value_wrapper(clip_embeddings, "century", "image")
quantify_identical_feature_value_wrapper(clip_embeddings, "label", "image")
print("---")

quantify_identical_feature_value_wrapper(clip_embeddings, "coarse_type", "description")
quantify_identical_feature_value_wrapper(clip_embeddings, "first_style", "description")
quantify_identical_feature_value_wrapper(clip_embeddings, "century", "description")
quantify_identical_feature_value_wrapper(clip_embeddings, "label", "description")

### 6. Compare embedding methods statistically

In [None]:
projected_embeddings = (
    pl.read_json(
        f"{INPUT_PATH_EMBEDDINGS}baseline_embeddings_test_projections_text_embedding_enhanced_features.json",
        infer_schema_length=1000,
    ).rename({"object_description": "text"}).with_row_index()
)
description_projected_embeddings = projected_embeddings["text_embedding_enhanced"].to_list()
image_object_projected_embeddings = projected_embeddings["embedding_object_image"].to_list()
projected_embeddings.head()

In [None]:
clip_embeddings = (
    pl.read_json(
        f"{INPUT_PATH_EMBEDDINGS}clip_embeddings_test_clip_full_1e_6_diff_lr_not_frozen_features.json",
        infer_schema_length=1000,
    ).rename({"object_description": "text"}).drop("index").with_row_index()
)
description_clip_embeddings = clip_embeddings["text_embedding_enhanced"].to_list()
image_object_clip_embeddings = clip_embeddings["embedding_object_image"].to_list()
clip_embeddings.head()

In [None]:
def compute_rank_biserial_correlation(clip_ranks, grounding_dino_ranks):
    clip_ranks = np.array(clip_ranks, dtype=float)
    grounding_dino_ranks = np.array(grounding_dino_ranks, dtype=float)

    rank_diffs = clip_ranks - grounding_dino_ranks

    n_pos = np.sum(rank_diffs > 0)
    n_neg = np.sum(rank_diffs < 0)
    n_total = len(rank_diffs)

    rank_biserial_correlation_value = (n_pos - n_neg) / n_total
    print(f"Rank-biserial correlation: {rank_biserial_correlation_value}")

In [None]:
modality = "image"
clip_ranks = []
grounding_dino_ranks = []

for query_index in tqdm(range(clip_embeddings.shape[0])):
    relevant_document_ids, ranked_document_indices, _ = retrieve_documents(
        clip_embeddings, description_clip_embeddings, image_object_clip_embeddings, query_index, modality
    )
    clip_ranks.append(int(np.where(np.isin(ranked_document_indices, relevant_document_ids) == True)[0][0] + 1))

    relevant_document_ids, ranked_document_indices, _ = retrieve_documents(
        projected_embeddings, description_projected_embeddings, image_object_projected_embeddings, query_index, modality
    )
    grounding_dino_ranks.append(int(np.where(np.isin(ranked_document_indices, relevant_document_ids) == True)[0][0] + 1))

statistic, p_value = wilcoxon(clip_ranks, grounding_dino_ranks, zero_method="wilcox", alternative="two-sided")
print(f"Statistic: {statistic} - p-value: {p_value}")
compute_rank_biserial_correlation(clip_ranks, grounding_dino_ranks)