In [None]:
import os
import json
import numpy as np
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer
from scipy import stats
import matplotlib.pyplot as plt
import seaborn as sns

# Comparison logic requires statistical tests
from scipy.stats import ttest_rel, wilcoxon

In [9]:
# Cell 1: Imports and Configuration

# Set plot style for better visuals
sns.set_theme(style="whitegrid")

# --- Configuration ---

FAISS_INDEX_PATH = "../data/index/text_index.faiss"
METADATA_PATH = "../data/index/text_index_metadata.csv"
MODEL_NAME = "all-MiniLM-L6-v2"
TESTS_JSON_PATH = "../tests/baseline_robustness_tests.json"
TOP_K_DEFAULT = 10

# Helper to check paths
def check_paths():
    paths = [FAISS_INDEX_PATH, METADATA_PATH, TESTS_JSON_PATH]
    for p in paths:
        if not os.path.exists(p):
            print(f"[WARNING] File not found: {p}")
        else:
            print(f"[OK] Found: {p}")

check_paths()

[OK] Found: ../data/index/text_index.faiss
[OK] Found: ../data/index/text_index_metadata.csv
[OK] Found: ../tests/baseline_robustness_tests.json


In [10]:
def load_resources():
    """
    load FAISS index, metadata DataFrame, and MiniLM model.

    returns:
        index: FAISS index
        meta_df: DataFrame with columns ['uid', 'caption', 'image_path']
        model: SentenceTransformer model for query encoding
    """
    if not os.path.exists(FAISS_INDEX_PATH):
        raise FileNotFoundError(f"FAISS index not found at {FAISS_INDEX_PATH}")
    if not os.path.exists(METADATA_PATH):
        raise FileNotFoundError(f"Metadata CSV not found at {METADATA_PATH}")

    print("[INFO] Loading FAISS index from disk...")
    index = faiss.read_index(FAISS_INDEX_PATH)

    print("[INFO] Loading metadata from disk...")
    meta_df = pd.read_csv(METADATA_PATH)
    print(f"[INFO] Loaded metadata with {len(meta_df)} rows.")

    print(f"[INFO] Loading MiniLM model ({MODEL_NAME}) for query encoding...")
    model = SentenceTransformer(MODEL_NAME)

    return index, meta_df, model

In [11]:
def encode_and_normalize_query(model, query: str) -> np.ndarray:
    """
    encode a single query string with MiniLM and L2-normalize the embedding

    input:
        model: SentenceTransformer
        query: text string

    output:
        q_emb_norm: NumPy array of shape (1, D), dtype float32
    """
    q_emb = model.encode([query], show_progress_bar=False)
    norms = np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-12
    q_emb_norm = (q_emb / norms).astype("float32")
    return q_emb_norm


def cosine_similarity(model, q1: str, q2: str) -> float:
    """
    compute cosine similarity between two query strings
    using the same MiniLM encoder and L2-normalization.
    """
    emb1 = encode_and_normalize_query(model, q1)  # shape (1, D)
    emb2 = encode_and_normalize_query(model, q2)  # shape (1, D)

    # since they are normalized, cosine = dot product
    sim = float(np.dot(emb1[0], emb2[0]))
    return sim


def compute_cosine_similarity_drop(
    model,
    original_queries: list,
    perturbed_queries: list
) -> dict:
    """
    Compute cosine similarity drop between original and perturbed queries.
    
    Args:
        model: SentenceTransformer model
        original_queries: List of original query strings
        perturbed_queries: List of perturbed query strings (same length)
    
    Returns:
        {
            'similarities': list of floats,
            'mean_similarity': float,
            'mean_drop': float (1 - mean_similarity),
            'std_similarity': float,
            'min_similarity': float,
            'max_similarity': float
        }
    """
    if len(original_queries) != len(perturbed_queries):
        raise ValueError("Query lists must have same length")
    
    similarities = []
    for orig, pert in zip(original_queries, perturbed_queries):
        sim = cosine_similarity(model, orig, pert)
        similarities.append(sim)
    
    similarities = np.array(similarities)
    
    return {
        'similarities': similarities.tolist(),
        'mean_similarity': float(np.mean(similarities)),
        'mean_drop': float(1.0 - np.mean(similarities)),
        'std_similarity': float(np.std(similarities)),
        'min_similarity': float(np.min(similarities)),
        'max_similarity': float(np.max(similarities))
    }

In [12]:
def search(index, meta_df, model, query: str, top_k: int = TOP_K_DEFAULT):
    """
    run a text query against the FAISS index and return top k results

    inputs:
        index: FAISS index (IndexFlatIP with normalized embeddings)
        meta_df: DataFrame with metadata (uid, caption, image_path)
        model: MiniLM SentenceTransformer
        query: text query string
        top_k: number of results to return

    output:
        results: list of dicts like:
            {
                "rank": int,
                "idx": int,
                "uid": str,
                "caption": str,
                "score": float,
            }
    """
    q_emb_norm = encode_and_normalize_query(model, query)
    scores, indices = index.search(q_emb_norm, top_k)
    scores = scores[0]
    indices = indices[0]

    results = []
    for rank, (idx, score) in enumerate(zip(indices, scores), start=1):
        row = meta_df.iloc[idx]
        results.append(
            {
                "rank": rank,
                "idx": int(idx),
                "uid": row["uid"],
                "caption": row["caption"],
                "score": float(score),
            }
        )
    return results

In [13]:
def evaluate_random_subset(
    index,
    meta_df,
    model,
    num_samples: int = 100,
    top_k: int = 10,
    seed: int = 42,
):
    """
    evaluate retrieval performance on a random subset of captions.

    for each sampled row:
        - use its caption as the query
        - treat its UID as the "correct" asset
        - run search(top_k)
        - compute R@1, R@5, R@10 and Reciprocal Rank

    returns:
        dict with averaged metrics: R@1, R@5, R@10, MRR, num_samples
    """
    rng = np.random.default_rng(seed)
    n = len(meta_df)
    if num_samples > n:
        num_samples = n

    sampled_indices = rng.choice(n, size=num_samples, replace=False)

    r_at_1 = 0
    r_at_5 = 0
    r_at_10 = 0
    mrr_sum = 0.0
    individual_scores = []

    for count, idx in enumerate(sampled_indices, start=1):
        row = meta_df.iloc[idx]
        true_uid = row["uid"]
        query_caption = row["caption"]

        results = search(index, meta_df, model, query_caption, top_k=top_k)

        # find rank of the correct UID in the results
        rank_of_true = None
        for r in results:
            if r["uid"] == true_uid:
                rank_of_true = r["rank"]
                break

        individual_scores.append({
            'query': query_caption,
            'true_uid': true_uid,
            'rank': rank_of_true,
            'found_in_top1': rank_of_true == 1 if rank_of_true else False,
            'found_in_top5': rank_of_true <= 5 if rank_of_true else False,
            'found_in_top10': rank_of_true <= 10 if rank_of_true else False,
            'reciprocal_rank': 1.0 / rank_of_true if rank_of_true else 0.0
        })

        if rank_of_true is not None:
            if rank_of_true <= 1:
                r_at_1 += 1
            if rank_of_true <= 5:
                r_at_5 += 1
            if rank_of_true <= 10:
                r_at_10 += 1

            # Reciprocal Rank
            mrr_sum += 1.0 / rank_of_true

        if count % 10 == 0:
            print(f"[INFO] Processed {count}/{num_samples} samples...")

    num = float(num_samples)
    metrics = {
        "R@1": r_at_1 / num,
        "R@5": r_at_5 / num,
        "R@10": r_at_10 / num,
        "MRR": mrr_sum / num,
        "num_samples": num_samples,
        "individual_scores": individual_scores,
    }
    return metrics

In [14]:
def compute_robustness_ratio(
    index,
    meta_df,
    model,
    original_queries: list,
    perturbed_queries: list,
    ground_truth_uids: list,
    k_values: list = [1, 5, 10]
) -> dict:
    """
    Compute Robustness Ratio: RR = R@k(perturbed) / R@k(original).
    
    Args:
        index: FAISS index
        meta_df: Metadata DataFrame
        model: SentenceTransformer model
        original_queries: List of original query strings
        perturbed_queries: List of perturbed versions (same length)
        ground_truth_uids: List of correct UIDs for each query
        k_values: List of k values to test (default: [1, 5, 10])
    
    Returns:
        {
            'original': {'R@1': float, 'R@5': float, 'R@10': float},
            'perturbed': {'R@1': float, 'R@5': float, 'R@10': float},
            'robustness_ratios': {'RR@1': float, 'RR@5': float, 'RR@10': float},
            'num_queries': int
        }
    """
    if len(original_queries) != len(perturbed_queries) != len(ground_truth_uids):
        raise ValueError("All input lists must have same length")
    
    def compute_recall_at_k(queries, uids, k):
        """Helper: compute R@k for a set of queries."""
        hits = 0
        for query, true_uid in zip(queries, uids):
            results = search(index, meta_df, model, query, top_k=k)
            if any(r['uid'] == true_uid for r in results):
                hits += 1
        return hits / len(queries) if queries else 0.0
    
    # Compute R@k for original and perturbed queries
    results = {
        'original': {},
        'perturbed': {},
        'robustness_ratios': {},
        'num_queries': len(original_queries)
    }
    
    for k in k_values:
        r_orig = compute_recall_at_k(original_queries, ground_truth_uids, k)
        r_pert = compute_recall_at_k(perturbed_queries, ground_truth_uids, k)
        
        results['original'][f'R@{k}'] = r_orig
        results['perturbed'][f'R@{k}'] = r_pert
        results['robustness_ratios'][f'RR@{k}'] = r_pert / r_orig if r_orig > 0 else 0.0
    
    return results


def run_robustness_tests(index, meta_df, model, tests_path, top_ks=(1, 5, 10)):
    """
    Run robustness tests over a set of query families.

    tests: list of dicts, each like:
        {
            "name": str,
            "orig": str,
            "orig_type": str (optional, default "canonical"),
            "variants": [
                {"query": str, "type": str},
                ...
            ]
        }
    """
    
    with open(tests_path, "r", encoding="utf-8") as f:
        tests = json.load(f)


    max_k = max(top_ks)
    top_ks_sorted = sorted(top_ks)

    def classify_rank(rank):
        if rank is None:
            return "miss"
        for k in top_ks_sorted:
            if rank <= k:
                return f"R@{k}"
        return "miss"

    results = []

    def evaluate_query(query, query_type, variant_label, variant_type, target_uid, test_name):
        res = search(index, meta_df, model, query, top_k=max_k)

        rank = None
        for r in res:
            if r["uid"] == target_uid:
                rank = r["rank"]
                break

        success = classify_rank(rank)

        results.append(
            {
                "test_name": test_name,
                "query_type": query_type,        # "orig" or "variant"
                "variant_label": variant_label,  # "orig" or the variant query
                "variant_type": variant_type,    # e.g. "typo", "hypernym", etc.
                "query": query,
                "rank": rank,
                "success_level": success,
            }
        )

        return rank, success

    print("[INFO] Running robustness tests (R@1 / R@5 / R@10)...")

    # Collect data for aggregate metrics
    all_original_queries = []
    all_variant_queries = []
    all_target_uids = []

    for test in tests:
        name = test.get("name", "UNKNOWN_TEST")
        orig_query = test["orig"]
        orig_type = test.get("orig_type", "canonical")
        variants = test.get("variants", [])

        print("\n" + "-" * 80+"\n")
        print(f"[TEST] {name}")
        print(f"  Original query: {orig_query}")

        # determine target UID from original query (top-1 result)
        orig_results = search(index, meta_df, model, orig_query, top_k=max_k)
        orig_top = orig_results[0]
        target_uid = orig_top["uid"]
        print(
            "  [ORIG TOP-1] uid={}  score={:.4f}".format(
                target_uid, orig_top["score"]
            )
        )
        print("              caption: {}".format(orig_top["caption"]))

        # evaluate original query
        orig_rank, orig_level = evaluate_query(
            query=orig_query,
            query_type="orig",
            variant_label="orig",
            variant_type=orig_type,
            target_uid=target_uid,
            test_name=name,
        )
        print(f"\n  => Original: success={orig_level}  rank={orig_rank}  type={orig_type}")

        # pretty print variants in a table-like format
        if variants:
            print("\n  Variants:")
            # header
            print("    {succ:<7}  {rank:<4}  {vtype:<20}  {query}".format(
                succ="success",
                rank="rank",
                vtype="type",
                query="query",
            ))
            print("    {:-<7}  {:-<4}  {:-<20}  {:-<40}".format("", "", "", ""))

            for v in variants:
                q = v["query"]
                v_type = v.get("type", "unknown")

                rank_v, level_v = evaluate_query(
                    query=q,
                    query_type="variant",
                    variant_label=q,
                    variant_type=v_type,
                    target_uid=target_uid,
                    test_name=name,
                )

                rank_str = "-" if rank_v is None else str(rank_v)
                print(
                    "    {succ:<7}  {rank:<4}  {vtype:<20}  {query}".format(
                        succ=level_v,
                        rank=rank_str,
                        vtype=v_type[:20],
                        query=q,
                    )
                )
                
                # Collect for aggregate metrics
                all_original_queries.append(orig_query)
                all_variant_queries.append(q)
                all_target_uids.append(target_uid)

    # Compute aggregate metrics after all tests
    if all_variant_queries:
        print("\n" + "=" * 80)
        print("AGGREGATE METRICS (across all tests):")
        print("=" * 80)
        
        # Compute δSim
        delta_sim = compute_cosine_similarity_drop(
            model,
            all_original_queries,
            all_variant_queries
        )
        
        # Compute Robustness Ratio
        rr_metrics = compute_robustness_ratio(
            index, meta_df, model,
            all_original_queries,
            all_variant_queries,
            all_target_uids,
            k_values=[1, 5, 10]
        )
        
        print(f"\nCosine Similarity Drop (δSim):")
        print(f"  Mean similarity: {delta_sim['mean_similarity']:.4f}")
        print(f"  Mean drop (δSim): {delta_sim['mean_drop']:.4f}")
        print(f"  Std deviation:   {delta_sim['std_similarity']:.4f}")
        print(f"  Range:           [{delta_sim['min_similarity']:.4f}, {delta_sim['max_similarity']:.4f}]")
        
        print(f"\nRobustness Ratio (RR):")
        print(f"  Original R@1:  {rr_metrics['original']['R@1']:.3f}")
        print(f"  Perturbed R@1: {rr_metrics['perturbed']['R@1']:.3f}")
        print(f"  RR@1:          {rr_metrics['robustness_ratios']['RR@1']:.3f}")
        print(f"  RR@5:          {rr_metrics['robustness_ratios']['RR@5']:.3f}")
        print(f"  RR@10:         {rr_metrics['robustness_ratios']['RR@10']:.3f}")
        print(f"  Total queries: {rr_metrics['num_queries']}")

    return results

In [15]:
# load index, metadata, and text encoder
index, meta_df, model = load_resources()

'(ProtocolError('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')), '(Request ID: 678f2a20-854f-44e6-9a77-ff1184b777a8)')' thrown while requesting HEAD https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/./modules.json
Retrying in 1s [Retry 1/5].


[INFO] Loading FAISS index from disk...
[INFO] Loading metadata from disk...
[INFO] Loaded metadata with 10000 rows.
[INFO] Loading MiniLM model (all-MiniLM-L6-v2) for query encoding...


In [16]:
# quick sanity-check queries
test_queries = [
    "white sofa with wooden legs",
    "airplane",
]

for q in test_queries:
    print("\n" + "=" * 80)
    print(f"[QUERY] {q}")
    results = search(index, meta_df, model, q, top_k=5)
    for r in results:
        print(f"  {r['rank']}. uid={r['uid']}  score={r['score']:.4f}")
        print(f"     caption: {r['caption']}")


[QUERY] white sofa with wooden legs
  1. uid=53d0b31aa7f84bc4b1733224963d0114  score=0.9310
     caption: A white sofa with wooden legs and a wooden frame.
  2. uid=f761658fafcc42a78fd42912ef9f57e9  score=0.8172
     caption: A modern wooden sofa with white cushions and pillows, featuring a slatted backrest and solid legs.
  3. uid=69ed1dad032444cdb5bbb6fb883982e3  score=0.7854
     caption: A white sofa with curved features and an integrated wall light and shelf.
  4. uid=90be7242f24749c3a8e0b0a69c616fc1  score=0.7761
     caption: Brown leather Chesterfield sofa with metal legs.
  5. uid=d0b4812ad71b4d7f9eda54f48b65362a  score=0.7688
     caption: A white couch, bench, or sofa on a gray background.

[QUERY] airplane
  1. uid=ed60862e215a498a9a214ca51af1bf32  score=0.7869
     caption: A white airplane.
  2. uid=042f8ae68d594a0497d10b5f73e154d9  score=0.7335
     caption: A small blue airplane.
  3. uid=10891bcc6a924d33b0c38dda490f55ac  score=0.7335
     caption: A small blue airplan

In [17]:
# evaluation on a random subset
print("\n" + "=" * 80)
print("[INFO] Starting evaluation on random subset of captions...")

metrics = evaluate_random_subset(
    index,
    meta_df,
    model,
    num_samples=100,
    top_k=10,
)

print("\nEvaluation Summary ({} samples):".format(metrics["num_samples"]))
print("  R@1  = {:.3f}".format(metrics["R@1"]))
print("  R@5  = {:.3f}".format(metrics["R@5"]))
print("  R@10 = {:.3f}".format(metrics["R@10"]))
print("  MRR  = {:.3f}".format(metrics["MRR"]))


[INFO] Starting evaluation on random subset of captions...
[INFO] Processed 10/100 samples...
[INFO] Processed 20/100 samples...
[INFO] Processed 30/100 samples...
[INFO] Processed 40/100 samples...
[INFO] Processed 50/100 samples...
[INFO] Processed 60/100 samples...
[INFO] Processed 70/100 samples...
[INFO] Processed 80/100 samples...
[INFO] Processed 90/100 samples...
[INFO] Processed 100/100 samples...

Evaluation Summary (100 samples):
  R@1  = 0.990
  R@5  = 1.000
  R@10 = 1.000
  MRR  = 0.995


In [19]:
# small robustness check
robustness_results = run_robustness_tests(index, meta_df, model, TESTS_JSON_PATH)

[INFO] Running robustness tests (R@1 / R@5 / R@10)...

--------------------------------------------------------------------------------

[TEST] penguin
  Original query: A stylized purple penguin with a yellow beak, black wings and feet with yellow tips, and a gradient orange-yellow belly.
  [ORIG TOP-1] uid=1e00453a031940b48a4ea1ea16bb1a88  score=1.0000
              caption: A stylized purple penguin with a yellow beak, black wings and feet with yellow tips, and a gradient orange-yellow belly.

  => Original: success=R@1  rank=1  type=canonical

  Variants:
    success  rank  type                  query
    -------  ----  --------------------  ----------------------------------------
    miss     -     typo                  purple pgnguin
    R@1      1     paraphrase            stylized penguin with yellow beak
    R@5      3     filler                a cute small cartoon penguin with with wings and black eyes
    miss     -     synonym               A purple bird with wings
    mis

NOW SECOND PART

In [24]:
import os
import torch
import pandas as pd
import numpy as np
import faiss
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel, CLIPVisionModel, CLIPImageProcessor
from PIL import Image
from tqdm.notebook import tqdm

# --- Configuration ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TEST_DATA_PATH = "../data/processed/test_split.csv" 
MODEL_BASE_PATH = "../models/contrastive_base/model_state_dict.pth"
MODEL_AUG_PATH = "../models/contrastive_finetuned/model_state_dict.pth" # Augmented model
IMAGES_ROOT = "../" # Adjust based on where your notebook sits relative to data

# Constants matching training
MAX_TEXT_LEN = 128
PROJECTION_DIM = 384
VISION_MODEL = "openai/clip-vit-base-patch16"
TEXT_MODEL = "sentence-transformers/all-MiniLM-L6-v2"

# --- Model Class (Must match training exactly) ---
class MultiModalContrastiveModel(nn.Module):
    def __init__(self, text_model_name, vision_model_name, projection_dim=384):
        super().__init__()
        self.text_encoder = AutoModel.from_pretrained(text_model_name)
        self.vision_encoder = CLIPVisionModel.from_pretrained(vision_model_name)
        vision_hidden_size = self.vision_encoder.config.hidden_size
        self.vision_projection = nn.Linear(vision_hidden_size, projection_dim)
        self.logit_scale = nn.Parameter(torch.ones([]) * 2.6592)

    def forward(self, input_ids, attention_mask, pixel_values):
        # Text
        text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        token_embeddings = text_outputs.last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        text_embeds = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        
        # Image
        vision_outputs = self.vision_encoder(pixel_values=pixel_values)
        image_embeds_raw = vision_outputs.pooler_output 
        image_embeds = self.vision_projection(image_embeds_raw)
        
        # Norm
        return F.normalize(text_embeds, p=2, dim=1), F.normalize(image_embeds, p=2, dim=1), self.logit_scale.exp()

    def encode_text(self, input_ids, attention_mask):
        text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        token_embeddings = text_outputs.last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        text_embeds = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return F.normalize(text_embeds, p=2, dim=1)

    def encode_image(self, pixel_values):
        vision_outputs = self.vision_encoder(pixel_values=pixel_values)
        image_embeds = self.vision_projection(vision_outputs.pooler_output)
        return F.normalize(image_embeds, p=2, dim=1)

print("[INFO] Model class defined.")

[INFO] Model class defined.


In [25]:
class EvalDataset(Dataset):
    def __init__(self, df, processor):
        self.df = df
        self.processor = processor
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        # Handle path fixing
        img_path = row['image_path']
        if not img_path.startswith('../'):
            img_path = os.path.join(IMAGES_ROOT, img_path)
            
        try:
            image = Image.open(img_path).convert("RGB")
        except:
            # Create black image if file missing to prevent crash during eval
            image = Image.new('RGB', (224, 224), color='black')
            
        return self.processor(images=image, return_tensors="pt").pixel_values.squeeze(0)

def build_image_index(model, df, batch_size=64):
    """
    Passes all test images through the model's Vision Encoder
    to create a FAISS index.
    """
    model.eval()
    processor = CLIPImageProcessor.from_pretrained(VISION_MODEL)
    dataset = EvalDataset(df, processor)
    loader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
    
    all_embeddings = []
    
    print(f"[INFO] Generating embeddings for {len(df)} test images...")
    with torch.no_grad():
        for batch_pixels in tqdm(loader):
            batch_pixels = batch_pixels.to(DEVICE)
            img_embs = model.encode_image(batch_pixels)
            all_embeddings.append(img_embs.cpu().numpy())
            
    embeddings = np.concatenate(all_embeddings, axis=0).astype('float32')
    
    # Build Index
    index = faiss.IndexFlatIP(PROJECTION_DIM)
    index.add(embeddings)
    print(f"[INFO] Index built with {index.ntotal} vectors.")
    return index

def load_trained_model(path):
    print(f"[INFO] Loading weights from {path}...")
    model = MultiModalContrastiveModel(TEXT_MODEL, VISION_MODEL, PROJECTION_DIM)
    model.load_state_dict(torch.load(path, map_location=DEVICE))
    model.to(DEVICE)
    model.eval()
    return model

In [26]:
def evaluate_supervised_model(model, index, test_df, robustness_tests_path):
    """
    Evaluates a specific model state.
    1. Overall Metrics on Test Set (using captions as queries)
    2. Robustness Metrics (using perturbed queries)
    """
    tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL)
    
    # --- 1. General Test Set Evaluation ---
    # We use the ground truth captions in test_df as queries
    captions = test_df['caption'].tolist()
    
    # Tokenize all captions
    print("[INFO] Encoding test set captions...")
    all_text_embs = []
    batch_size = 64
    
    for i in range(0, len(captions), batch_size):
        batch_text = captions[i : i+batch_size]
        inputs = tokenizer(batch_text, padding=True, truncation=True, max_length=MAX_TEXT_LEN, return_tensors="pt").to(DEVICE)
        with torch.no_grad():
            emb = model.encode_text(inputs['input_ids'], inputs['attention_mask'])
            all_text_embs.append(emb.cpu().numpy())
            
    query_embs = np.concatenate(all_text_embs, axis=0)
    
    # Search
    k = 10
    D, I = index.search(query_embs, k)
    
    # Calculate R@k
    r1, r5, r10, mrr = 0, 0, 0, 0
    n = len(test_df)
    
    for i in range(n):
        # The correct image index is simply 'i' because we indexed the dataframe in order
        correct_idx = i 
        retrieved_indices = I[i]
        
        if correct_idx in retrieved_indices:
            rank = np.where(retrieved_indices == correct_idx)[0][0] + 1
            mrr += 1.0 / rank
            if rank == 1: r1 += 1
            if rank <= 5: r5 += 1
            if rank <= 10: r10 += 1
            
    general_metrics = {
        "R@1": r1/n, "R@5": r5/n, "R@10": r10/n, "MRR": mrr/n
    }
    
    # --- 2. Robustness Tests (Specific JSON) ---
    print("[INFO] Running Robustness Tests...")
    with open(robustness_tests_path, 'r') as f:
        tests = json.load(f)
        
    rob_results = []
    
    for test in tests:
        # Get target UID from the 'orig' query using the DF lookup
        # Note: In Supervised, we find the row with the matching caption or UID.
        # Ideally, robustness tests should map to UIDs present in the Test Set.
        # If the test query isn't in the test set, we can't evaluate R@k for it easily without the image.
        # assumption: The robustness JSON queries correspond to assets IN the test set.
        
        # Helper: Encode single query
        def run_query(q_str):
            inp = tokenizer([q_str], padding=True, truncation=True, max_length=MAX_TEXT_LEN, return_tensors="pt").to(DEVICE)
            with torch.no_grad():
                q_emb = model.encode_text(inp['input_ids'], inp['attention_mask']).cpu().numpy()
            _, I_q = index.search(q_emb, 10)
            return I_q[0]

        # Find target index by matching 'orig' query to test_df caption (Approximation)
        # In production, tests.json should have 'uid'.
        matches = test_df[test_df['caption'] == test['orig']]
        if matches.empty:
            continue # Skip tests that aren't in this specific test split
            
        target_idx = matches.index[0] # Index relative to the DF (and thus FAISS)
        # We need the integer location in the current test_df, not the pandas Index if it's not reset
        target_iloc = test_df.index.get_loc(target_idx)
        
        # Evaluate Original
        indices = run_query(test['orig'])
        rank = np.where(indices == target_iloc)[0][0] + 1 if target_iloc in indices else None
        
        rob_results.append({
            "test_name": test['name'], "query_type": "orig", "variant_type": "canonical",
            "query": test['orig'], "rank": rank
        })
        
        # Evaluate Variants
        for v in test['variants']:
            indices = run_query(v['query'])
            rank = np.where(indices == target_iloc)[0][0] + 1 if target_iloc in indices else None
            
            rob_results.append({
                "test_name": test['name'], "query_type": "variant", "variant_type": v['type'],
                "query": v['query'], "rank": rank
            })

    return general_metrics, rob_results

In [27]:
# Load Test Data
test_df = pd.read_csv(TEST_DATA_PATH)
# Ensure clean index for mapping
test_df = test_df.reset_index(drop=True)

results_store = {}

# --- 1. Evaluate BASE Model ---
print("\n" + "="*40)
print("EVALUATING: Contrastive BASE (10k)")
print("="*40)
model_base = load_trained_model(MODEL_BASE_PATH)
index_base = build_image_index(model_base, test_df)

metrics_base, rob_base = evaluate_supervised_model(
    model_base, index_base, test_df, TESTS_JSON_PATH
)
results_store['Base'] = {'metrics': metrics_base, 'robustness': rob_base}

# Cleanup to save VRAM
del model_base
del index_base
torch.cuda.empty_cache()

# --- 2. Evaluate AUGMENTED Model ---
print("\n" + "="*40)
print("EVALUATING: Contrastive AUGMENTED (Paraphrased)")
print("="*40)
model_aug = load_trained_model(MODEL_AUG_PATH)
index_aug = build_image_index(model_aug, test_df)

metrics_aug, rob_aug = evaluate_supervised_model(
    model_aug, index_aug, test_df, TESTS_JSON_PATH
)
results_store['Augmented'] = {'metrics': metrics_aug, 'robustness': rob_aug}

# Print General Comparison
print("\n--- Final Metric Comparison ---")
comparison_df = pd.DataFrame({
    'Base': results_store['Base']['metrics'],
    'Augmented': results_store['Augmented']['metrics']
})
print(comparison_df)


EVALUATING: Contrastive BASE (10k)
[INFO] Loading weights from ../models/contrastive_base/model_state_dict.pth...


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

sns.set_theme(style="whitegrid", context="paper", font_scale=1.4)
palette = sns.color_palette("deep")

def plot_model_comparison(results_store):
    
    # --- FIGURE 4: Side-by-Side Recall Comparison ---
    metrics_data = []
    for model_name, data in results_store.items():
        m = data['metrics']
        metrics_data.append({'Model': model_name, 'Metric': 'R@1', 'Score': m['R@1']})
        metrics_data.append({'Model': model_name, 'Metric': 'R@5', 'Score': m['R@5']})
        metrics_data.append({'Model': model_name, 'Metric': 'R@10', 'Score': m['R@10']})
    
    df_metrics = pd.DataFrame(metrics_data)
    
    plt.figure(figsize=(8, 6))
    ax = sns.barplot(data=df_metrics, x='Metric', y='Score', hue='Model', palette="viridis")
    
    plt.title("Supervised Fine-Tuning: Base vs. Augmented Data", fontweight='bold', pad=15)
    plt.ylim(0, 1.05)
    plt.ylabel("Recall Score")
    plt.legend(title="Training Strategy")
    
    # Label bars
    for container in ax.containers:
        ax.bar_label(container, fmt='%.2f', padding=3, fontsize=10)
        
    plt.tight_layout()
    plt.show()

    # --- FIGURE 5: Robustness Drop-off Comparison ---
    # We want to see how much performance drops on variants for EACH model
    
    rob_data = []
    for model_name, data in results_store.items():
        df_r = pd.DataFrame(data['robustness'])
        
        # Calculate Hit@10 for variants vs originals
        # Filter for variants only
        df_var = df_r[df_r['query_type'] == 'variant'].copy()
        
        # Calculate success (Rank <= 10)
        df_var['success'] = df_var['rank'].apply(lambda x: 1 if x is not None and x <= 10 else 0)
        
        # Aggregate by variant type
        agg = df_var.groupby('variant_type')['success'].mean().reset_index()
        agg['Model'] = model_name
        rob_data.append(agg)
        
    df_rob_comp = pd.concat(rob_data)
    
    plt.figure(figsize=(10, 6))
    
    # Grouped Bar Plot for Robustness
    ax2 = sns.barplot(
        data=df_rob_comp, 
        x='variant_type', 
        y='success', 
        hue='Model', 
        palette="magma"
    )
    
    plt.title("Robustness Analysis: Resilience to Query Perturbation", fontweight='bold', pad=15)
    plt.ylabel("Recall@10 on Perturbed Queries")
    plt.xlabel("Perturbation Type")
    plt.ylim(0, 1.1)
    
    plt.tight_layout()
    plt.show()

plot_model_comparison(results_store)