# Library

## Basic library

In [None]:
import json
import os
import random
import warnings
from datetime import datetime
from tabnanny import verbose
from turtle import update

import hydra
import matplotlib.pyplot as plt
import numpy as np
import torch
import transformers
from hydra import compose, initialize, initialize_config_dir, initialize_config_module
from omegaconf import DictConfig, OmegaConf
from sympy import count_ops, use
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoImageProcessor, AutoProcessor, AutoTokenizer

# Import local packages
from src.data.cdc_datamodule import CDC_test
from src.models.cdc import CDC
from src.models.components.clustering import Clustering, UMAP_vis
from src.utils import (
    print_model_info,
)
from src.utils.evaltools import eval_rank_oracle_check_per_label
from src.utils.inference import encode_data, inference_test

# Setup
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize Hydra
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path="configs", version_base=None)
cfg = compose(config_name="flickr30k")
print(*cfg, sep="\n")

## Visualization library

In [None]:
import os
import sys

import pandas as pd
import torch
from IPython.display import display
from PIL import Image

from eval import main

pd.set_option("display.max_colwidth", None)  # Ensures full text is shown
pd.set_option("display.max_rows", 200)  # Increase max rows if needed
pd.set_option("display.max_columns", 50)  # Increase max columns if needed

# Experiments
Select the best label for each image and text.

This corresponds to the main() function in the original code: eval.py

In [None]:
# Set seed
seed = cfg.seed
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Define the parent folder
parent_folder = "res"

res_path = "/project/Deep-Clustering/res/20250219_043513_mscoco-preextracted"

if res_path is None:
    print("No path provided. Searching for the latest experiment...")
    # Get a list of all subdirectories inside the parent folder
    subfolders = [
        os.path.join(parent_folder, d)
        for d in os.listdir(parent_folder)
        if os.path.isdir(os.path.join(parent_folder, d))
    ]

    # Sort subfolders by modification time (newest first)
    res_path = max(subfolders, key=os.path.getmtime) if subfolders else None

print(f"Using results from: {res_path}")

In [None]:
use_best_label = True

# Initialize Model
model = CDC(
    clip_trainable=False,
    d_model=cfg.model.d_model,
    nhead=cfg.model.num_heads,
    num_layers=cfg.model.num_layers,
    label_dim=cfg.model.label_dim,
)
model = nn.DataParallel(model)
# load model
model.load_state_dict(torch.load(f"{res_path}/final_model.pth"))
model.to(device)

clustering = Clustering()

processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

ann_path = cfg.dataset.test_path
ann = json.load(open(ann_path, "r"))

if len(ann) > 5000:
    ratio = 5000 / len(ann)
else:
    ratio = 1

del ann_path, ann

test_dataset = CDC_test(
    annotation_path=cfg.dataset.test_path,
    image_path=cfg.dataset.img_path_test,
    processor=processor,
    ratio=ratio,
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=cfg.eval.batch_size,
    shuffle=False,
    num_workers=cfg.train.num_workers,
)

## EXP 1: Find labels among all possible choices and select the best one

In [None]:
# print("##########Testing test dataset##########")
# unique_embeddings = torch.load(f"{res_path}/unique_embeddings.pt")
# (
#     img_emb,
#     txt_emb,
#     txt_full,
#     text_to_image_map,
#     image_to_text_map,
#     best_label_tti,
#     best_label_itt,
#     inds_raw_tti,
#     inds_raw_itt,
# ) = inference_test(
#     model,
#     processor,
#     test_dataloader,
#     unique_embeddings,
#     -1,
#     device,
#     inspect_labels=True,
#     use_best_label=use_best_label,
# )

Visualization of original image and text embeddings

In [None]:
# # First use img_emb + txt_emb to create a cat_emb and compute umap
# cat_emb = torch.cat((img_emb, txt_emb), dim=0)

# umap_vis = UMAP_vis()
# umap_features = umap_vis.learn_umap(cat_emb, n_components=2)
# umap_features_raw_image = umap_features[: img_emb.shape[0], :]
# umap_features_raw_text = umap_features[img_emb.shape[0] :, :]

# # Compute original umap
# fig = plt.figure(figsize=(10, 10))

# # Plot image embeddings
# plt.scatter(
#     umap_features[: img_emb.shape[0], 0],
#     umap_features[: img_emb.shape[0], 1],
#     s=5,
#     alpha=1,
#     label="Image Embeddings",
# )

# # Plot text embeddings
# plt.scatter(
#     umap_features[img_emb.shape[0] : img_emb.shape[0] + txt_emb.shape[0], 0],
#     umap_features[img_emb.shape[0] : img_emb.shape[0] + txt_emb.shape[0], 1],
#     s=5,
#     alpha=1,
#     label="Text Embeddings",
# )

# plt.legend()
# plt.grid()
# plt.xlabel("UMAP 1")
# plt.ylabel("UMAP 2")
# str_tag = "raw"
# plt.show()
# plt.close()

### itt part

In [None]:
# # itt label selections
# # Find unique values in best_label
# unique_values_label_itt, counts_label_itt = torch.unique(best_label_itt, return_counts=True)
# unique_values_label_itt = unique_values_label_itt[unique_values_label_itt != -1]

# N = len(unique_values_label_itt)

# print(f"##########Evaluating top {N} labels##########")
# selected_label_indices_itt = int(unique_values_label_itt[0])
# selected_label_itt = unique_embeddings[selected_label_indices_itt]

In [None]:
# # Go with itt experiments
# comb_emb_itt, inds_itt, mask_itt = eval_rank_oracle_check_per_label(
#     model,
#     selected_label_itt,
#     True,  # True means go with itt experiments
#     img_emb,
#     txt_emb,
#     txt_full,
#     text_to_image_map,
#     image_to_text_map,
#     inds_raw_tti=inds_raw_tti,
#     inds_raw_itt=inds_raw_itt,
# )

# """
# 1. Comb_emb_itt is the combined embeddings of all texts and the selected label
# 2. inds_itt is the indices of the itt inds using combined embeddings
# 3. mask_itt is the mask of the image-shape that indicate which image improved by using the selected label
# """

In [None]:
# # Plot UMAP for all points
# fig = plt.figure(figsize=(10, 10))

# comb_emb_itt_np = comb_emb_itt.detach().cpu().numpy()
# umap_features_new_itt = umap_vis.predict_umap(comb_emb_itt_np)

# # Compute original umap
# fig = plt.figure(figsize=(10, 10))
# # Plot image embeddings
# plt.scatter(
#     umap_features_raw_image[:, 0],
#     umap_features_raw_image[:, 1],
#     s=5,
#     alpha=1,
#     label="Image Embeddings",
# )

# # Plot text embeddings
# plt.scatter(
#     umap_features_raw_text[:, 0],
#     umap_features_raw_text[:, 1],
#     s=5,
#     alpha=1,
#     label="Text Embeddings",
# )

# # Plot combined embeddings
# plt.scatter(
#     umap_features_new_itt[:, 0],
#     umap_features_new_itt[:, 1],
#     s=5,
#     alpha=1,
#     label="Combined Embeddings",
# )
# plt.legend()
# plt.grid()

# plt.xlabel("UMAP 1")
# plt.ylabel("UMAP 2")
# str_tag = "best" if use_best_label else "first"
# plt.show()
# plt.close()

### tti part

In [None]:
# # tti label selections
# # Find unique values in best_label
# unique_values_label_tti, counts_label_tti = torch.unique(best_label_tti, return_counts=True)
# unique_values_label_tti = unique_values_label_tti[unique_values_label_tti != -1]

# N = len(unique_values_label_tti)

# print(f"##########Evaluating top {N} labels##########")
# selected_label_indices_tti = int(unique_values_label_tti[0])
# selected_label_tti = unique_embeddings[selected_label_indices_tti]

In [None]:
# # Go with tti experiments
# comb_emb_tti, inds_tti, mask_tti, inds_tti, mask_tti = eval_rank_oracle_check_per_label(
#     model,
#     selected_label_tti,
#     img_emb,
#     txt_emb,
#     txt_full,
#     text_to_image_map,
#     image_to_text_map,
#     inds_raw_tti=inds_raw_tti,
#     inds_raw_itt=inds_raw_itt,
# )

# """
# 1. Comb_emb_tti is the combined embeddings of all images and the selected label
# 2. inds_tti is the indices of the tti inds using combined embeddings
# 3. mask_tti is the mask of the image-shape that indicate which image improved by using the selected label
# """

In [None]:
# # Plot itt UMAP for improved points only
# fig = plt.figure(figsize=(10, 10))

# # Plot image embeddings
# plt.scatter(
#     umap_features_raw_image[mask_itt, 0],
#     umap_features_raw_image[mask_itt, 1],
#     s=5,
#     alpha=1,
#     label="Image Embeddings (Improved)",
# )

# # Plot text embeddings
# plt.scatter(
#     umap_features_raw_text[
#         mask_tti, 0
#     ],  # TODO this is not actually mask_tti but rather expanded mask_tti
#     umap_features_raw_text[mask_tti, 1],
#     s=5,
#     alpha=1,
#     label="Text Embeddings (Improved)",
# )

# # Plot combined embeddings
# plt.scatter(
#     umap_features_new_itt[mask_tti, 0],
#     umap_features_new_itt[mask_tti, 1],
#     s=5,
#     alpha=1,
#     label="Combined Embeddings (Improved)",
# )

# # Plot image embeddings
# plt.scatter(
#     umap_features_raw_image[~mask_itt, 0],
#     umap_features_raw_image[~mask_itt, 1],
#     s=5,
#     alpha=0.05,
#     color="blue",
# )

# # Plot text embeddings
# plt.scatter(
#     umap_features_raw_text[~mask_itt_expand, 0],
#     umap_features_raw_text[~mask_itt_expand, 1],
#     s=5,
#     alpha=0.05,
#     color="orange",
# )

# # Plot combined embeddings
# plt.scatter(
#     umap_features_new_itt[~mask_itt_expand, 0],
#     umap_features_new_itt[~mask_itt_expand, 1],
#     s=5,
#     alpha=0.05,
#     color="green",
# )

# plt.legend()
# plt.grid()

# plt.xlabel("UMAP 1")
# plt.ylabel("UMAP 2")
# str_tag = "best" if use_best_label else "first"
# plt.show()
# plt.close()

# umap_vis.close_cluster()

# # Clean cuda cache
# torch.cuda.empty_cache()

### Visualization itt

In [None]:
# ann_path = cfg.dataset.test_path
# img_path = cfg.dataset.img_path_test

# ann = json.load(open(ann_path, "r"))
# txt_collection = [item["caption"] for item in ann]
# txt_collection = [item for sublist in txt_collection for item in sublist]

In [None]:
# # Find the True values in mask_itt and its corresponding indices
# mask_itt_indices = np.where(mask_itt)[0]
# mask_itt_indices[:10]

In [None]:
# i = 2
# i = min(i, len(mask_itt_indices) - 1)

# idx = mask_itt_indices[i]  # idx is the true index of the image in ann that improved
# check_top_k = 50
# item = ann[idx]

# img = os.path.join(img_path, item["image"])
# img = Image.open(img).convert("RGB")

# # turn off axis
# plt.axis("off")
# plt.imshow(img)
# plt.show()

# # First get original caption
# original_caption = item["caption"]

# # Then get the caption retrieved by raw
# retrived_caption_index_raw = inds_raw_itt[idx][:check_top_k]
# retrived_caption_raw = [txt_collection[i] for i in retrived_caption_index_raw]

# # Finally get the caption retrieved by our method
# retrived_caption_index_cdc = inds_itt[idx][:check_top_k].tolist()
# retrived_caption_cdc = [txt_collection[i] for i in retrived_caption_index_cdc]

# # Turn into a panda dataframe
# df = pd.DataFrame(
#     {
#         "Raw_retrieve": retrived_caption_raw,
#         "CDC_retrieve": retrived_caption_cdc,
#     }
# )


# # Function to highlight duplicates
# def highlight_duplicates(val, col1, col2):
#     # If the value appears in both columns, color it blue

#     if val in original_caption:
#         return "background-color: lightgreen"

#     if val in df[col1].values and val in df[col2].values:
#         return "background-color: lightblue"

#     return ""


# # Display the dataframe
# with pd.option_context("display.max_colwidth", None):
#     styled_df = df.style.map(highlight_duplicates, col1="Raw_retrieve", col2="CDC_retrieve")
#     display(styled_df)

## EXP 2: Test all labels and focus visualization of a single image / text

In [None]:
print("##########Testing test dataset##########")
unique_embeddings = torch.load(f"{res_path}/unique_embeddings.pt")
(
    img_emb,
    txt_emb,
    txt_full,
    text_to_image_map,
    image_to_text_map,
    inds_raw_tti,
    inds_raw_itt,
) = encode_data(
    model,
    processor,
    test_dataloader,
    device,
)

In [None]:
# Go with itt experiments through all labels for a single image
inds_itt_all = []
mask_itt_all = []
for selected_label in tqdm(unique_embeddings):
    (
        _,
        _,
        _,
        inds_itt,
        mask_itt,
    ) = eval_rank_oracle_check_per_label(
        model,
        selected_label,
        img_emb,
        txt_emb,
        txt_full,
        text_to_image_map,
        image_to_text_map,
        inds_raw_tti=inds_raw_tti,
        inds_raw_itt=inds_raw_itt,
    )

    inds_itt_all.append(inds_itt)
    mask_itt_all.append(mask_itt)

"""
1. inds_itt is the indices of the itt inds using combined embeddings
2. mask_itt is the mask of the image-shape that indicate which image improved by using the selected label.
"""

### Visualization itt

In [None]:
ann_path = cfg.dataset.test_path
img_path = cfg.dataset.img_path_test

ann = json.load(open(ann_path, "r"))
txt_collection = [item["caption"] for item in ann]
txt_collection = [item for sublist in txt_collection for item in sublist]

In [None]:
image_index = 5  # Which image to choose to visualize
mask_itt_indices = 5  # Which label to choose to visualize
only_see_improved = True  # Only see mask_itt_indices if it improved the image

inds_itt_per_image = [inds_itt[image_index] for inds_itt in inds_itt_all]
mask_itt_per_image = [mask_itt[image_index] for mask_itt in mask_itt_all]

# Make sure that image_index and mask_itt_indices are within the range
image_index = min(image_index, len(img_emb) - 1)
mask_itt_indices = min(mask_itt_indices, len(mask_itt_per_image) - 1)

# Find the True values in mask_itt and its corresponding indices
mask_itt_indices_true = np.where(mask_itt_per_image)[0]
print(f"Image is improved with at least one label: {len(mask_itt_indices_true) > 0}")
print(
    f"Number of labels that improved the image: {len(mask_itt_indices_true)} out of {len(mask_itt_per_image)}"
)
improved = mask_itt_per_image[mask_itt_indices]
print(f"Image is improved: {improved}")

if only_see_improved and improved:
    mask_itt_indices = mask_itt_indices_true[0]

In [None]:
check_top_k = 50
item = ann[image_index]

img = os.path.join(img_path, item["image"])
img = Image.open(img).convert("RGB")

# turn off axis
plt.axis("off")
plt.title("Improved" if improved else "Not Improved")
plt.imshow(img)
plt.show()

# First get original caption
original_caption = item["caption"]

# Then get the caption retrieved by raw
retrived_caption_index_raw = inds_raw_itt[image_index][:check_top_k]
retrived_caption_raw = [txt_collection[i] for i in retrived_caption_index_raw]

# Finally get the caption retrieved by our method
retrived_caption_index_cdc = inds_itt_per_image[mask_itt_indices][:check_top_k].tolist()
retrived_caption_cdc = [txt_collection[i] for i in retrived_caption_index_cdc]

# Turn into a panda dataframe
df = pd.DataFrame(
    {
        "Raw_retrieve": retrived_caption_raw,
        "CDC_retrieve": retrived_caption_cdc,
    }
)


# Function to highlight duplicates
def highlight_duplicates(val, col1, col2):
    # If the value appears in both columns, color it blue

    if val in original_caption:
        return "background-color: lightgreen"

    if val in df[col1].values and val in df[col2].values:
        return "background-color: lightblue"

    return ""


# Display the dataframe
with pd.option_context("display.max_colwidth", None):
    styled_df = df.style.map(highlight_duplicates, col1="Raw_retrieve", col2="CDC_retrieve")
    display(styled_df)