# Library


## Basic library


In [None]:
import json
import os
import random
import re
import warnings
from datetime import datetime
from tabnanny import verbose
from turtle import update
from typing import List

import hydra
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import torch
import transformers
from hydra import compose, initialize, initialize_config_dir, initialize_config_module
from IPython.display import clear_output, display
from omegaconf import DictConfig, OmegaConf
from scipy.spatial.distance import cdist
from sklearn.cluster import KMeans
from sympy import count_ops, use
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoImageProcessor, AutoProcessor, AutoTokenizer

# Import local packages
from src.data.cdc_datamodule import CDC_test
from src.models.cdc import CDC
from src.models.components.clustering import Clustering, UMAP_vis
from src.utils import (
    EmbeddingManager,
    print_model_info,
)
from src.utils.evaltools import eval_rank_oracle_check_per_label
from src.utils.inference import encode_data, inference_test

# Setup
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize Hydra
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path="configs", version_base=None)
cfg = compose(config_name="redcaps")
print(*cfg, sep="\n")

## Visualization library


In [None]:
import os
import sys

import pandas as pd
import torch
from PIL import Image

from eval import main

pd.set_option("display.max_colwidth", None)  # Ensures full text is shown
pd.set_option("display.max_rows", 200)  # Increase max rows if needed
pd.set_option("display.max_columns", 50)  # Increase max columns if needed

In [None]:
def plot_umap(umap_features_np, umap_labels, cluster_centers, representatives):
    # Plot UMAP before clustering update
    fig = plt.figure(figsize=(10, 10))
    tmp_labels = umap_labels >= 0

    if umap_features_np is not None:
        plt.scatter(
            umap_features_np[~tmp_labels, 0],
            umap_features_np[~tmp_labels, 1],
            c=[0.5, 0.5, 0.5],
            s=0.2,
            alpha=0.5,
        )

        plt.scatter(
            umap_features_np[tmp_labels, 0],
            umap_features_np[tmp_labels, 1],
            c=umap_labels[tmp_labels],
            s=0.2,
            alpha=0.5,
        )

    if cluster_centers is not None:
        plt.scatter(
            cluster_centers[:, 0],
            cluster_centers[:, 1],
            c="black",
            s=100,
            marker="x",
            label="Cluster Centers",
        )

    if representatives is not None:
        plt.scatter(
            representatives[:, 0],
            representatives[:, 1],
            c="red",
            s=100,
            marker="o",
            label="Representatives",
        )
        labels = np.arange(len(representatives))
        for i, (xi, yi) in enumerate(zip(representatives[:, 0], representatives[:, 1])):
            plt.text(xi + 0.4, yi + 0.4, str(labels[i]), fontsize=12, color="purple")

    # Add the number of umap_labels to the plot as title
    plt.title("UMAP with cluster_centers")
    plt.colorbar()
    return fig

# Experiments

Select the best label for each image and text.


This corresponds to the main() function in the original code: eval.py


In [None]:
# Set seed
seed = cfg.seed
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Define the parent folder
parent_folder = "res"

res_path = None

if res_path is None:
    print("No path provided. Searching for the latest experiment...")
    # Get a list of all subdirectories inside the parent folder
    subfolders = [
        os.path.join(parent_folder, d)
        for d in os.listdir(parent_folder)
        if os.path.isdir(os.path.join(parent_folder, d))
    ]

    # Sort subfolders by modification time (newest first)
    res_path = max(subfolders, key=os.path.getmtime) if subfolders else None

print(f"Using results from: {res_path}")

In [None]:
use_best_label = True

# Initialize Model
model = CDC(
    clip_trainable=False,
    d_model=cfg.model.d_model,
    nhead=cfg.model.num_heads,
    num_layers=cfg.model.num_layers,
    label_dim=cfg.model.label_dim,
)
model = nn.DataParallel(model)
# load model
model.load_state_dict(torch.load(f"{res_path}/final_model.pth"))
model.to(device)

clustering = Clustering()
umap_vis = UMAP_vis()

processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [None]:
ann_path = cfg.dataset.test_path
ann = json.load(open(ann_path, "r"))

if len(ann) > 5000:
    ratio = 5000 / len(ann)
else:
    ratio = 1
print(ratio)

test_dataset = CDC_test(
    annotation_path=cfg.dataset.test_path,
    image_path=cfg.dataset.img_path_test,
    processor=processor,
    ratio=ratio,
    crop_num=5,
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=cfg.eval.batch_size,
    shuffle=False,
    num_workers=cfg.train.num_workers,
)

unique_embeddings = torch.load(f"{res_path}/unique_embeddings.pt")

store_path_0 = "/project/Deep-Clustering/ckpt2/tmp0"  # Store np features of all
if not os.path.exists(store_path_0):
    os.makedirs(store_path_0)


store_path = "/project/Deep-Clustering/ckpt2/tmp"
# Check if the directory exists, if not, create it
if not os.path.exists(store_path):
    os.makedirs(store_path)
store_path_2 = "/project/Deep-Clustering/ckpt2/tmp2"
# Check if the directory exists, if not, create it
if not os.path.exists(store_path_2):
    os.makedirs(store_path_2)

In [None]:
# Check item in test_dataset
item = test_dataset[2]
print(item)

## if extracted already, just load the data


In [None]:
embedding_manager = EmbeddingManager(
    ann,
    embedding_dim=cfg.model.label_dim,
    chunk_size=cfg.train.batch_size,
    embeddings_dir=f"{res_path}/init/",
    load_existing=True,
    sample_ids_list=None,
)
all_embeddings = embedding_manager.get_all_embeddings()
sample_ids, label_embedding = embedding_manager.get_all_embeddings()

In [None]:
umap_features = umap_vis.learn_umap(label_embedding, n_components=2)

print("##########Performing Clustering##########")
umap_labels, _ = clustering.get_hdbscan(umap_features, n_clusters=0, method="leaf")
umap_features_cluster = umap_vis.predict_umap(unique_embeddings.cpu().numpy())

If want to re-compute the center in high dimensional space, load that block. That happen if the clustering center is changed

In [None]:
_, cluster_centers, cluster_counts = clustering.hdbscan_update(
    umap_labels=umap_labels,
    original_embeddings=label_embedding,
    update_type="hard",
    alpha=0.4,
    update_noise="assign",
    center_only=True,
)
center_sorted_indices = torch.argsort(cluster_counts, descending=True)
unique_embeddings = cluster_centers[center_sorted_indices]
umap_features_cluster = umap_vis.predict_umap(unique_embeddings.cpu().numpy())

In [None]:
umap_features_np = umap_features.cpu().numpy()
umap_labels_np = umap_labels.cpu().numpy()
umap_features_cluster_np = umap_features_cluster.cpu().numpy()

# Save the UMAP features and labels
np.save(os.path.join(store_path_0, "umap_features.npy"), umap_features_np)
np.save(os.path.join(store_path_0, "umap_labels.npy"), umap_labels_np)
np.save(os.path.join(store_path_0, "umap_features_cluster.npy"), umap_features_cluster_np)

Loadm


In [None]:
umap_features_np = np.load(os.path.join(store_path_0, "umap_features.npy"), allow_pickle=True)
umap_labels_np = np.load(os.path.join(store_path_0, "umap_labels.npy"), allow_pickle=True)
umap_features_cluster_np = np.load(
    os.path.join(store_path_0, "umap_features_cluster.npy"), allow_pickle=True
)

Percentile not working, as its super dense in the middle.


In [None]:
# def cluster_by_direction_with_representatives(
#     umap_features,
#     num_directions=4,
#     percentiles=[0, 100],
# ):
#     """
#     Cluster UMAP points by angular direction and return farthest & middle points per direction.
#     Returns:
#         direction_labels: (N,) array of direction index
#         representatives: List of tuples (farthest_idx, middle_idx) per direction
#     """
#     # Center around origin
#     centered = umap_features - np.mean(umap_features, axis=0)

#     # Compute angles and radii
#     angles = np.arctan2(centered[:, 1], centered[:, 0])  # [-π, π]
#     angles = (angles + 2 * np.pi) % (2 * np.pi)  # [0, 2π)
#     radii = np.linalg.norm(centered, axis=1)

#     # Assign to angular bins
#     angle_bins = np.linspace(0, 2 * np.pi, num_directions + 1)
#     direction_labels = np.digitize(angles, angle_bins) - 1  # 0-based

#     percentile_indices = []

#     for i in range(num_directions):
#         idx_in_bin = np.where(direction_labels == i)[0]
#         if len(idx_in_bin) == 0:
#             percentile_indices.append([None] * len(percentiles))
#             continue

#         r_bin = radii[idx_in_bin]
#         sorted_idx = idx_in_bin[np.argsort(r_bin)]  # sorted indices by distance

#         reps = []
#         for p in percentiles:
#             # Compute percentile index (e.g., 20% → 0.2 * len)
#             rank = int(
#                 np.clip(np.ceil(p / 100 * len(sorted_idx)) - 1, 0, len(sorted_idx) - 1)
#             )
#             reps.append(sorted_idx[rank])

#         percentile_indices.append(reps)

#     return direction_labels, percentile_indices

In [None]:
def cluster_by_direction_with_radial_steps(umap_features, num_directions=4, n_steps=10):
    """
    For each direction bin, find n points along the vector from center to the farthest point,
    equally spaced in Euclidean distance.

    Returns:
        direction_labels: (N,) array of direction index
        stepwise_indices: List of length `num_directions`, each item is a list of n point indices
    """
    center = np.mean(umap_features, axis=0)
    centered = umap_features - center

    # Compute angle and radius
    angles = np.arctan2(centered[:, 1], centered[:, 0])
    angles = (angles + 2 * np.pi) % (2 * np.pi)
    radii = np.linalg.norm(centered, axis=1)

    # Bin into angular sectors
    angle_bins = np.linspace(0, 2 * np.pi, num_directions + 1)
    direction_labels = np.digitize(angles, angle_bins) - 1

    stepwise_indices = []

    for i in range(num_directions):
        idx_in_bin = np.where(direction_labels == i)[0]
        if len(idx_in_bin) == 0:
            stepwise_indices.append([None] * n_steps)
            continue

        # Points in bin
        points = centered[idx_in_bin]

        # Find farthest point
        far_idx_local = np.argmax(np.linalg.norm(points, axis=1))
        far_point = points[far_idx_local]
        far_idx = idx_in_bin[far_idx_local]

        # Vector from center to far point (in centered coordinates)
        vec = far_point
        vec_norm = np.linalg.norm(vec)

        if vec_norm == 0:
            stepwise_indices.append([far_idx] * n_steps)
            continue

        # Directional unit vector
        dir_vec = vec / vec_norm

        # Step distances
        step_dists = np.linspace(0, vec_norm, n_steps + 1)[1:]  # exclude 0

        # For each step, find closest point to the target step location
        reps = []
        for d in step_dists:
            target = d * dir_vec
            dists_to_target = np.linalg.norm(points - target, axis=1)
            closest_idx_local = np.argmin(dists_to_target)
            reps.append(idx_in_bin[closest_idx_local])

        stepwise_indices.append(reps)

    return direction_labels, stepwise_indices

In [None]:
direction_labels, direction_representatives_idx = cluster_by_direction_with_radial_steps(
    umap_features_np, num_directions=4, n_steps=10
)
direction_representatives_idx = np.array(direction_representatives_idx).flatten()
direction_representatives = umap_features_np[direction_representatives_idx]
print(direction_representatives.shape)

In [None]:
kmeans_number = 20
kmeans = KMeans(n_clusters=min(kmeans_number, umap_features_cluster_np.shape[0])).fit(
    umap_features_cluster_np
)
centroids = kmeans.cluster_centers_
# Find closest real embedding to each centroid

indices = np.argmin(cdist(centroids, umap_features_cluster_np), axis=1)
representatives = umap_features_cluster_np[indices]

In [None]:
fig = plot_umap(umap_features_np, direction_labels, umap_features_cluster_np, representatives)
plt.show()

In [None]:
fig = plot_umap(None, direction_labels, umap_features_cluster_np, representatives)
plt.show()

In [None]:
fig = plot_umap(
    None,
    direction_labels,
    umap_features_cluster_np,
    direction_representatives,
)
plt.show()

In [None]:
unique_embeddings = label_embedding[direction_representatives_idx]
# unique_embeddings = label_embedding[indices]

## EXP 2: Test all labels and focus visualization of a single image / text


In [None]:
print("##########Testing test dataset##########")
(
    img_emb,
    txt_emb,
    txt_full,
    text_to_image_map,
    image_to_text_map,
    inds_raw_tti,
    inds_raw_itt,
) = encode_data(
    model,
    processor,
    test_dataloader,
    device,
)

In [None]:
# Go with itt experiments through all labels for a single image
inds_tti_all = []
mask_tti_all = []
inds_itt_all = []
mask_itt_all = []

In [None]:
save_id = 0
for idx, selected_label in enumerate(tqdm(unique_embeddings)):
    (
        _,
        ints_tti,
        mask_tti,
        inds_itt,
        mask_itt,
    ) = eval_rank_oracle_check_per_label(
        model,
        selected_label,
        img_emb,
        txt_emb,
        txt_full,
        text_to_image_map,
        image_to_text_map,
        inds_raw_tti=inds_raw_tti,
        inds_raw_itt=inds_raw_itt,
    )
    inds_tti_all.append(ints_tti.detach().cpu().numpy().astype(np.uint16))
    mask_tti_all.append(mask_tti.detach().cpu().numpy().astype(np.uint16))
    inds_itt_all.append(inds_itt.detach().cpu().numpy().astype(np.uint16))
    mask_itt_all.append(mask_itt.detach().cpu().numpy().astype(np.uint16))

    # comb_emb = comb_emb.detach().cpu().numpy()
    # torch.save(
    #     comb_emb,
    #     f"{store_path_2}/comb_emb_{idx}.pt",
    # )
    # Save the buffer every 10 iterations
    if (idx + 1) % 10 == 0 or idx == len(unique_embeddings) - 1:
        np.save(
            f"{store_path}/inds_tti_all_{save_id}.npy",
            np.array(inds_tti_all).astype(np.uint16),
        )
        np.save(
            f"{store_path}/mask_tti_all_{save_id}.npy",
            np.array(mask_tti_all).astype(np.uint16),
        )
        np.save(
            f"{store_path}/inds_itt_all_{save_id}.npy",
            np.array(inds_itt_all).astype(np.uint16),
        )
        np.save(
            f"{store_path}/mask_itt_all_{save_id}.npy",
            np.array(mask_itt_all).astype(np.uint16),
        )
        save_id += 1
        # clear the buffer
        inds_tti_all.clear()
        mask_tti_all.clear()
        inds_itt_all.clear()
        mask_itt_all.clear()

"""
1. inds_itt is the indices of the itt inds using combined embeddings
2. mask_itt is the mask of the image-shape that indicate which image improved by using the selected label.
"""

In [None]:
def load_sorted_npy_files(folder_path: str, prefix: str) -> List[np.ndarray]:
    """
    Load .npy files matching the pattern {prefix}_{idx}.npy in numeric order.

    Args:
        folder_path (str): Path to the directory containing .npy files.
        prefix (str): Prefix of the filename before the index, e.g., 'xxx' for 'xxx_1.npy'.

    Returns:
        List[np.ndarray]: List of loaded numpy arrays sorted by their index.
    """
    pattern = re.compile(rf"{re.escape(prefix)}_(\d+)\.npy$")

    file_list = sorted(
        [f for f in os.listdir(folder_path) if pattern.match(f)],
        key=lambda x: int(pattern.match(x).group(1)),  # type: ignore
    )

    output = [np.load(os.path.join(folder_path, fname)) for fname in file_list]
    output = np.concatenate(output, axis=0)

    return output  # type: ignore

In [None]:
# Load saved files
inds_tti_all = load_sorted_npy_files(store_path, "inds_tti_all")
mask_tti_all = load_sorted_npy_files(store_path, "mask_tti_all")
inds_itt_all = load_sorted_npy_files(store_path, "inds_itt_all")
mask_itt_all = load_sorted_npy_files(store_path, "mask_itt_all")

In [None]:
print(f"inds_tti_all shape: {inds_tti_all.shape}, mask_tti_all shape: {mask_tti_all.shape}")
print(f"inds_itt_all shape: {inds_itt_all.shape}, mask_itt_all shape: {mask_itt_all.shape}")

### Visualization itt


In [None]:
ann_path = cfg.dataset.test_path
img_path = cfg.dataset.img_path_test

ann = json.load(open(ann_path, "r"))
txt_collection = [item["caption"] for item in ann]

if type(txt_collection[0]) is not str:
    txt_collection = [item for sublist in txt_collection for item in sublist]

## itt


In [None]:
image_index = 0  # Which image to choose to visualize
mask_itt_indices = 0  # Which label to choose to visualize
only_see_improved = False  # Only see mask_itt_indices if it improved the image
check_top_k = 50

In [None]:
import numpy as np
from scipy.special import softmax
from scipy.stats import entropy  # Computes KL divergence
from sklearn.metrics.pairwise import cosine_similarity


def compute_similarity_distribution(text_feats, image_feat, temperature=0.07):
    sims = cosine_similarity(text_feats, image_feat.reshape(1, -1)).squeeze()
    return softmax(sims / temperature)


def compute_kl_divergence(text_feats_1, text_feats_2, image_feat, temperature=0.07):
    # Convert both to probability distributions over the k retrieved texts
    P = compute_similarity_distribution(text_feats_1, image_feat, temperature)
    Q = compute_similarity_distribution(text_feats_2, image_feat, temperature)

    # Add small epsilon to avoid log(0) issues
    epsilon = 1e-8
    P = np.clip(P, epsilon, 1.0)
    Q = np.clip(Q, epsilon, 1.0)

    kl_PQ = entropy(P, Q)  # KL(P || Q)
    kl_QP = entropy(Q, P)  # KL(Q || P)
    jsd = 0.5 * (kl_PQ + kl_QP)

    return {"KL(P||Q)": kl_PQ.item(), "KL(Q||P)": kl_QP.item(), "JSD": jsd.item()}

In [None]:
image_index_slider = widgets.IntSlider(
    value=image_index, min=0, max=len(ann) - 1, step=1, description="Image Index"
)
mask_itt_slider = widgets.IntSlider(
    value=mask_itt_indices,
    min=0,
    max=len(unique_embeddings),
    step=1,
    description="Label Index",
)
only_improved_checkbox = widgets.Checkbox(
    value=only_see_improved, description="Only Show Improved"
)
top_k_slider = widgets.IntSlider(value=check_top_k, min=1, max=50, step=1, description="Top K")


def update_visualization(image_index, mask_itt_indices, only_see_improved, check_top_k):
    clear_output(wait=True)

    inds_itt_per_image = [inds_itt[image_index] for inds_itt in inds_itt_all]
    mask_itt_per_image = [mask_itt[image_index] for mask_itt in mask_itt_all]

    # Make sure that image_index and mask_itt_indices are within the range
    image_index = min(image_index, len(img_emb) - 1)
    mask_itt_indices = min(mask_itt_indices, len(mask_itt_per_image) - 1)

    # Find the True values in mask_itt and its corresponding indices
    mask_itt_indices_true = np.where(mask_itt_per_image)[0]
    # print(
    #     f"Image is improved with at least one label: {len(mask_itt_indices_true) > 0}"
    # )
    # print(
    #     f"Number of labels that improved the image: {len(mask_itt_indices_true)} out of {len(mask_itt_per_image)}"
    # )
    improved = mask_itt_per_image[mask_itt_indices]

    if only_see_improved and improved:
        mask_itt_indices = mask_itt_indices_true[
            min(mask_itt_indices, len(mask_itt_indices_true) - 1)
        ]

    # Apply filtering
    if only_see_improved and not improved:
        print("No improvement for this image.")
        return

    # Get item
    item = ann[image_index]

    img = os.path.join(img_path, item["image"])
    img = Image.open(img).convert("RGB")
    image_feature = img_emb[image_index]

    # turn off axis
    plt.axis("off")
    plt.title("Improved" if improved else "Not Improved")
    plt.imshow(img)
    plt.show()

    # First get original caption
    original_caption = item["caption"]

    # Then get the caption retrieved by raw
    retrived_caption_index_raw = inds_raw_itt[image_index][:check_top_k]
    retrived_caption_raw = [txt_collection[i] for i in retrived_caption_index_raw]
    retrived_caption_raw_features = txt_emb[retrived_caption_index_raw]

    # Finally get the caption retrieved by our method
    retrived_caption_index_cdc = inds_itt_per_image[mask_itt_indices][:check_top_k].tolist()
    retrived_caption_cdc = [txt_collection[i] for i in retrived_caption_index_cdc]
    # retrived_caption_cdc_features = torch.load(
    #     f"{store_path}/comb_emb_{mask_itt_indices}.pt", weights_only=False
    # )[retrived_caption_index_cdc]
    retrived_caption_cdc_features = txt_emb[retrived_caption_index_cdc]

    # Compute KL divergence
    result = compute_kl_divergence(
        retrived_caption_raw_features, retrived_caption_cdc_features, image_feature
    )
    print(result)

    # Turn into a panda dataframe
    df = pd.DataFrame(
        {
            "Raw_retrieve": retrived_caption_raw,
            "CDC_retrieve": retrived_caption_cdc,
        }
    )

    # Function to highlight duplicates
    def highlight_duplicates(val, col1, col2):
        # If the value appears in both columns, color it blue

        if val in original_caption:
            return "background-color: lightgreen"

        if val in df[col1].values and val in df[col2].values:
            return "background-color: lightblue"

        return ""

    # Display the dataframe
    with pd.option_context("display.max_colwidth", None):
        styled_df = df.style.map(highlight_duplicates, col1="Raw_retrieve", col2="CDC_retrieve")
        display(styled_df)

In [None]:
# Interactive UI
ui = widgets.VBox([image_index_slider, mask_itt_slider, only_improved_checkbox, top_k_slider])

out = widgets.Output()


def on_change(change):
    with out:
        update_visualization(
            image_index_slider.value,
            mask_itt_slider.value,
            only_improved_checkbox.value,
            top_k_slider.value,
        )


image_index_slider.observe(on_change, names="value")
mask_itt_slider.observe(on_change, names="value")
only_improved_checkbox.observe(on_change, names="value")
top_k_slider.observe(on_change, names="value")

update_visualization(image_index, mask_itt_indices, only_see_improved, check_top_k)

In [None]:
display(ui, out)  # 2026

## Divergence Access


## tti


In [None]:
text_index = 10  # Which image to choose to visualize
mask_tti_indices = 0  # Which label to choose to visualize
only_see_improved = False  # Only see mask_tti_indices if it improved the image
check_top_k = 10

In [None]:
text_index_slider = widgets.IntSlider(
    value=text_index, min=0, max=(len(ann) - 1) * 5, step=1, description="Text Index"
)
mask_tti_slider = widgets.IntSlider(
    value=mask_tti_indices,
    min=0,
    max=len(unique_embeddings),
    step=1,
    description="Label Index",
)
only_improved_checkbox = widgets.Checkbox(
    value=only_see_improved, description="Only Show Improved"
)
top_k_slider = widgets.IntSlider(value=check_top_k, min=1, max=50, step=1, description="Top K")


def update_visualization2(text_index, mask_tti_indices, only_see_improved, check_top_k):
    clear_output(wait=True)

    inds_tti_per_text = [inds_tti[text_index] for inds_tti in inds_tti_all]
    mask_tti_per_text = [mask_tti[text_index] for mask_tti in mask_tti_all]

    # Make sure that image_index and mask_itt_indices are within the range
    text_index = min(text_index, len(txt_emb) - 1)
    mask_tti_indices = min(mask_tti_indices, len(mask_tti_per_text) - 1)

    # Find the True values in mask_itt and its corresponding indices
    mask_tti_indices_true = np.where(mask_tti_per_text)[0]
    improved = mask_tti_per_text[mask_tti_indices]

    if only_see_improved and improved:
        mask_tti_indices = mask_tti_indices_true[
            min(mask_tti_indices, len(mask_tti_indices_true) - 1)
        ]

    # Apply filtering
    if only_see_improved and not improved:
        print("No improvement for this text.")
        return

    query_caption = txt_collection[text_index]
    print(f"Query Caption: {query_caption}")
    print(f"Improved: {improved}")
    print("=" * 60)

    # Get item
    raw_indices = inds_raw_tti[text_index][:check_top_k]
    cdc_indices = inds_tti_per_text[mask_tti_indices][:check_top_k].tolist()

    fig, axes = plt.subplots(nrows=check_top_k, ncols=2, figsize=(15, check_top_k * 2.5))
    if check_top_k == 1:
        axes = np.expand_dims(axes, axis=0)

    for i in range(check_top_k):
        for j, (indices, label) in enumerate(zip([raw_indices, cdc_indices], ["Raw", "CDC"])):
            ax = axes[i, j]
            img_id = indices[i]
            item = ann[img_id]
            img_path_full = os.path.join(img_path, item["image"])
            img = Image.open(img_path_full).convert("RGB")
            ax.imshow(img)
            ax.axis("off")
            ax.set_title(f"{label} Top-{i+1}: {item['caption']}", fontsize=8)

    plt.tight_layout()
    plt.show()

In [None]:
# Interactive UI
ui = widgets.VBox([text_index_slider, mask_tti_slider, only_improved_checkbox, top_k_slider])

out = widgets.Output()


def on_change(change):
    with out:
        update_visualization2(
            text_index_slider.value,
            mask_tti_slider.value,
            only_improved_checkbox.value,
            top_k_slider.value,
        )


text_index_slider.observe(on_change, names="value")
mask_tti_slider.observe(on_change, names="value")
only_improved_checkbox.observe(on_change, names="value")
top_k_slider.observe(on_change, names="value")

update_visualization2(text_index, mask_tti_indices, only_see_improved, check_top_k)

In [None]:
display(ui, out)