In [8]:
# Cell 1: Imports and Configuration
import os
import json
import numpy as np
import pandas as pd
import faiss
import matplotlib.pyplot as plt
import seaborn as sns
from sentence_transformers import SentenceTransformer

# Set plot style for better visuals
sns.set_theme(style="whitegrid")

# --- Configuration ---
FAISS_INDEX_PATH = "data/index/text_index.faiss"
METADATA_PATH = "data/index/text_index_metadata.csv"
MODEL_NAME = "all-MiniLM-L6-v2"
TESTS_JSON_PATH = "tests/baseline_robustness_tests.json"
TOP_K_DEFAULT = 10

# Helper to check paths
def check_paths():
    paths = [FAISS_INDEX_PATH, METADATA_PATH, TESTS_JSON_PATH]
    for p in paths:
        if not os.path.exists(p):
            print(f"[WARNING] File not found: {p}")
        else:
            print(f"[OK] Found: {p}")

check_paths()



In [6]:
# Cell 2: Load Resources
def load_resources():
    """Load FAISS index, metadata DataFrame, and MiniLM model."""
    if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(METADATA_PATH):
        raise FileNotFoundError("Index or Metadata file missing.")

    print("[INFO] Loading FAISS index...")
    index = faiss.read_index(FAISS_INDEX_PATH)

    print("[INFO] Loading metadata...")
    meta_df = pd.read_csv(METADATA_PATH)
    
    print(f"[INFO] Loading Model ({MODEL_NAME})...")
    model = SentenceTransformer(MODEL_NAME)

    return index, meta_df, model

# Load resources globally for the notebook
index, meta_df, model = load_resources()
print("Resources loaded successfully.")

FileNotFoundError: Index or Metadata file missing.

In [None]:
# Cell 3: Search and Encoding Functions
def encode_and_normalize(model, query: str):
    """Encode query and L2-normalize."""
    q_emb = model.encode([query], show_progress_bar=False)
    norms = np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-12
    return (q_emb / norms).astype("float32")

def search(index, meta_df, model, query: str, top_k=10):
    """Query FAISS index and return formatted results."""
    q_emb = encode_and_normalize(model, query)
    scores, indices = index.search(q_emb, top_k)
    
    results = []
    for rank, (idx, score) in enumerate(zip(indices[0], scores[0]), start=1):
        if idx < len(meta_df):
            row = meta_df.iloc[idx]
            results.append({
                "rank": rank,
                "uid": row["uid"],
                "caption": row["caption"],
                "score": float(score)
            })
    return results

def get_cosine_sim(model, q1, q2):
    """Compute cosine similarity between two strings."""
    emb1 = encode_and_normalize(model, q1)
    emb2 = encode_and_normalize(model, q2)
    return float(np.dot(emb1[0], emb2[0]))

In [None]:
# Cell 4: Manual Sanity Check
queries = ["white sofa with wooden legs", "airplane"]

for q in queries:
    print(f"\n[Query]: {q}")
    res = search(index, meta_df, model, q, top_k=3)
    for r in res:
        print(f"  {r['rank']}. [{r['uid']}] {r['caption']} (Score: {r['score']:.4f})")

In [None]:
# Cell 5: Robustness Testing Functions
def run_robustness_test(tests_path, index, meta_df, model):
    with open(tests_path, "r") as f:
        tests = json.load(f)

    results = []
    sim_drops = []
    
    print(f"Running {len(tests)} test scenarios...")
    
    for test in tests:
        orig_q = test["orig"]
        
        # 1. Establish Ground Truth from Original Query (Top-1)
        orig_res = search(index, meta_df, model, orig_q, top_k=1)
        if not orig_res: continue
        target_uid = orig_res[0]["uid"]
        
        # 2. Evaluate Variants
        for v in test.get("variants", []):
            pert_q = v["query"]
            v_type = v.get("type", "unknown")
            
            # Search
            res = search(index, meta_df, model, pert_q, top_k=10)
            
            # Check if target UID is in results
            rank = next((r["rank"] for r in res if r["uid"] == target_uid), None)
            hit_at_10 = 1 if rank else 0
            hit_at_1 = 1 if rank == 1 else 0
            
            # Cosine Drop
            sim = get_cosine_sim(model, orig_q, pert_q)
            drop = max(0, 1.0 - sim)
            
            results.append({
                "test_name": test["name"],
                "type": v_type,
                "orig_query": orig_q,
                "pert_query": pert_q,
                "rank": rank if rank else 11, # 11 indicates miss for top-10
                "hit_at_10": hit_at_10,
                "hit_at_1": hit_at_1,
                "sim_drop": drop
            })
            
    return pd.DataFrame(results)

# Run the test
df_results = run_robustness_test(TESTS_JSON_PATH, index, meta_df, model)
df_results.head()

In [None]:
# Cell 6: Visualization of Robustness
if not df_results.empty:
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))

    # Plot 1: Cosine Similarity Drop by Perturbation Type
    sns.boxplot(data=df_results, x="type", y="sim_drop", ax=axes[0], palette="Blues")
    axes[0].set_title("Cosine Similarity Drop by Perturbation Type")
    axes[0].set_ylabel("Similarity Drop (Lower is Better)")
    axes[0].set_xlabel("Perturbation Type")

    # Plot 2: Retrieval Success Rate (R@10)
    # Calculate success rate per type
    success_rates = df_results.groupby("type")["hit_at_10"].mean().reset_index()
    sns.barplot(data=success_rates, x="type", y="hit_at_10", ax=axes[1], palette="Greens")
    axes[1].set_title("Retrieval Success Rate (R@10) by Perturbation Type")
    axes[1].set_ylabel("Success Rate (Higher is Better)")
    axes[1].set_ylim(0, 1.0)
    
    plt.tight_layout()
    plt.show()

    # --- Print Summary Metrics ---
    print("\n=== Robustness Summary ===")
    print(f"Mean Similarity Drop: {df_results['sim_drop'].mean():.4f}")
    print(f"Overall R@10 Consistency: {df_results['hit_at_10'].mean():.2%}")
    
    # Calculate Robustness Ratio (RR)
    # RR = Perturbed Accuracy / Original Accuracy (Assumed 100% since we derived GT from it)
    print(f"Robustness Ratio (Approx): {df_results['hit_at_10'].mean():.3f}")
else:
    print("No results to plot.")

In [None]:
# Cell 7: Save Baseline Results
df_results.to_csv("data/baseline_robustness_results.csv", index=False)
print("Baseline results saved for future comparison.")