In [1]:
# ============================================================
# CELL 1: Load Everything for Evaluation
# ============================================================
import numpy as np
import pandas as pd
import faiss
import time

# Paths
EMBED_DIR = '../data/embeddings'
DATA_DIR = '../data/processed'

# Load embeddings
ft_embeddings = np.load(f'{EMBED_DIR}/finetuned_embeddings.npy').astype(np.float32)
ft_ids = np.load(f'{EMBED_DIR}/finetuned_ids.npy')
bl_embeddings = np.load(f'{EMBED_DIR}/baseline_embeddings.npy').astype(np.float32)
bl_ids = np.load(f'{EMBED_DIR}/baseline_ids.npy')

# Normalize
faiss.normalize_L2(ft_embeddings)
faiss.normalize_L2(bl_embeddings)

# Load metadata
df = pd.read_csv(f'{DATA_DIR}/filtered_styles.csv')
test_df = pd.read_csv(f'{DATA_DIR}/test.csv')

# Build FAISS indexes for both
ft_index = faiss.IndexFlatL2(ft_embeddings.shape[1])
ft_index.add(ft_embeddings)

bl_index = faiss.IndexFlatL2(bl_embeddings.shape[1])
bl_index.add(bl_embeddings)

# Create lookup: image_id → article type
id_to_type = dict(zip(df['id'], df['articleType']))

print(f"Fine-tuned embeddings: {ft_embeddings.shape}")
print(f"Baseline embeddings:   {bl_embeddings.shape}")
print(f"Test set queries:      {len(test_df)}")
print(f"FAISS indexes built:   ✅")


Fine-tuned embeddings: (43916, 2048)
Baseline embeddings:   (43916, 2048)
Test set queries:      6588
FAISS indexes built:   ✅


In [2]:
# ============================================================
# CELL 2: Define Evaluation Metrics
# ============================================================

def evaluate_search(index, embeddings, ids, query_ids, id_to_type, k=10):
    """
    Run evaluation on a set of queries.
    
    Parameters:
        index:      FAISS index to search
        embeddings: all embeddings (to find query vector by position)
        ids:        all image IDs (to convert position → image ID)
        query_ids:  list of image IDs to use as queries
        id_to_type: dict mapping image_id → articleType
        k:          how many results to retrieve
    
    Returns:
        dictionary with all metrics
    """
    
    # Storage for all metrics
    precisions = {1: [], 5: [], 10: []}   # P@1, P@5, P@10
    recalls = {1: [], 5: [], 10: []}       # R@1, R@5, R@10
    reciprocal_ranks = []                   # for MRR
    average_precisions = []                 # for mAP
    
    # We need to know: for each query, how many TOTAL correct items exist?
    # = how many images share the same articleType in the ENTIRE index
    
    for query_id in query_ids:
        
        # --- STEP 1: Find query's position and type ---
        query_positions = np.where(ids == query_id)[0]
        if len(query_positions) == 0:
            continue    # skip if query not found in embeddings
        
        query_pos = query_positions[0]
        query_type = id_to_type.get(query_id)
        if query_type is None:
            continue    # skip if no metadata
        
        # --- STEP 2: Search FAISS ---
        query_vector = embeddings[query_pos].reshape(1, -1)
        distances, indices = index.search(query_vector, k + 1)  # k+1 because first result is the query itself
        
        # Remove the query itself from results (first result, distance=0)
        result_positions = indices[0][1:]   # skip first result
        result_ids = [ids[pos] for pos in result_positions]
        
        # --- STEP 3: Check each result — correct or wrong? ---
        correct = []    # list of True/False for each result
        for rid in result_ids:
            result_type = id_to_type.get(rid)
            correct.append(result_type == query_type)
        
        # correct = [True, False, True, True, False, True, False, ...]
        #            rank1  rank2  rank3  rank4  rank5  rank6  rank7
        
        # --- STEP 4: Precision@K ---
        for k_val in [1, 5, 10]:
            if k_val <= len(correct):
                num_correct = sum(correct[:k_val])     # count True in top k_val
                precisions[k_val].append(num_correct / k_val)
        
        # --- STEP 5: Recall@K ---
        # Total number of same-type images in dataset (minus the query itself)
        total_relevant = sum(1 for iid in ids if id_to_type.get(iid) == query_type) - 1
        
        for k_val in [1, 5, 10]:
            if k_val <= len(correct):
                num_correct = sum(correct[:k_val])
                recalls[k_val].append(num_correct / total_relevant)
        
        # --- STEP 6: Reciprocal Rank (for MRR) ---
        rr = 0.0
        for rank, is_correct in enumerate(correct):
            if is_correct:
                rr = 1.0 / (rank + 1)    # rank is 0-indexed, so +1
                break
        reciprocal_ranks.append(rr)
        
        # --- STEP 7: Average Precision (for mAP) ---
        ap = 0.0
        num_correct_so_far = 0
        for rank, is_correct in enumerate(correct):
            if is_correct:
                num_correct_so_far += 1
                precision_at_rank = num_correct_so_far / (rank + 1)
                ap += precision_at_rank
        
        # Divide by total correct found (not total relevant)
        num_correct_total = sum(correct)
        if num_correct_total > 0:
            ap = ap / num_correct_total
        average_precisions.append(ap)
    
    # --- STEP 8: Average all metrics ---
    results = {
        'P@1':  np.mean(precisions[1]),
        'P@5':  np.mean(precisions[5]),
        'P@10': np.mean(precisions[10]),
        'R@1':  np.mean(recalls[1]),
        'R@5':  np.mean(recalls[5]),
        'R@10': np.mean(recalls[10]),
        'MRR':  np.mean(reciprocal_ranks),
        'mAP@10': np.mean(average_precisions),
    }
    
    return results

print("Evaluation function defined ✅")


Evaluation function defined ✅


In [3]:
# ============================================================
# CELL 3: Run Evaluation — Baseline vs Fine-tuned
# ============================================================
import time

# Get test set image IDs
test_ids = test_df['id'].values
print(f"Evaluating on {len(test_ids)} test queries...\n")

# --- Evaluate Fine-tuned Model ---
print("Evaluating fine-tuned model...")
start = time.time()
ft_results = evaluate_search(ft_index, ft_embeddings, ft_ids, test_ids, id_to_type, k=10)
ft_time = time.time() - start
print(f"  Done in {ft_time:.1f}s\n")

# --- Evaluate Baseline Model ---
print("Evaluating baseline model...")
start = time.time()
bl_results = evaluate_search(bl_index, bl_embeddings, bl_ids, test_ids, id_to_type, k=10)
bl_time = time.time() - start
print(f"  Done in {bl_time:.1f}s\n")

# --- Display Results ---
print("=" * 65)
print("EVALUATION RESULTS: Baseline vs Fine-tuned")
print(f"Test set: {len(test_ids)} queries")
print("=" * 65)
print(f"{'Metric':<12} {'Baseline':>12} {'Fine-tuned':>12} {'Change':>12}")
print("-" * 65)

for metric in ['P@1', 'P@5', 'P@10', 'R@1', 'R@5', 'R@10', 'MRR', 'mAP@10']:
    bl_val = bl_results[metric]
    ft_val = ft_results[metric]
    diff = ft_val - bl_val
    arrow = "▲" if diff > 0.001 else ("▼" if diff < -0.001 else "─")
    print(f"{metric:<12} {bl_val:>11.2%} {ft_val:>11.2%} {arrow:>4} {diff:>+7.2%}")

print("=" * 65)


Evaluating on 6588 test queries...

Evaluating fine-tuned model...
  Done in 186.6s

Evaluating baseline model...
  Done in 197.8s

EVALUATION RESULTS: Baseline vs Fine-tuned
Test set: 6588 queries
Metric           Baseline   Fine-tuned       Change
-----------------------------------------------------------------
P@1               83.26%      86.23%    ▲  +2.98%
P@5               78.04%      82.61%    ▲  +4.57%
P@10              75.46%      81.12%    ▲  +5.66%
R@1                0.15%       0.16%    ─  +0.00%
R@5                0.64%       0.72%    ─  +0.08%
R@10               1.18%       1.39%    ▲  +0.21%
MRR               88.28%      90.22%    ▲  +1.94%
mAP@10            83.60%      86.79%    ▲  +3.19%


In [4]:
# ============================================================
# CELL 4: Per-Category Evaluation
# ============================================================

# Get top 20 categories by count in test set
top_categories = test_df['articleType'].value_counts().head(20).index.tolist()

print("=" * 80)
print("PER-CATEGORY RESULTS (Top 20 categories, Fine-tuned model)")
print("=" * 80)
print(f"{'Category':<22} {'Count':>6} {'P@5':>8} {'P@10':>8} {'MRR':>8} {'mAP@10':>8}")
print("-" * 80)

category_results = {}

for category in top_categories:
    # Get test images for this category
    cat_test_ids = test_df[test_df['articleType'] == category]['id'].values
    
    # Evaluate
    cat_results = evaluate_search(
        ft_index, ft_embeddings, ft_ids, cat_test_ids, id_to_type, k=10
    )
    category_results[category] = cat_results
    
    print(f"{category:<22} {len(cat_test_ids):>6} {cat_results['P@5']:>7.1%} "
          f"{cat_results['P@10']:>7.1%} {cat_results['MRR']:>7.1%} {cat_results['mAP@10']:>7.1%}")

# Find best and worst
print("-" * 80)
best_cat = max(category_results.items(), key=lambda x: x[1]['mAP@10'])
worst_cat = min(category_results.items(), key=lambda x: x[1]['mAP@10'])

print(f"\nBest category:  {best_cat[0]} (mAP@10 = {best_cat[1]['mAP@10']:.1%})")
print(f"Worst category: {worst_cat[0]} (mAP@10 = {worst_cat[1]['mAP@10']:.1%})")


PER-CATEGORY RESULTS (Top 20 categories, Fine-tuned model)
Category                Count      P@5     P@10      MRR   mAP@10
--------------------------------------------------------------------------------
Tshirts                  1060   90.3%   89.4%   95.1%   92.8%
Shirts                    483   94.1%   93.6%   97.2%   95.6%
Casual Shoes              427   73.5%   72.4%   86.5%   80.0%
Watches                   382   99.2%   98.8%  100.0%   99.6%
Sports Shoes              306   75.7%   74.1%   86.7%   80.8%
Kurtas                    277   78.4%   77.3%   87.1%   82.7%
Tops                      265   54.3%   51.5%   78.9%   67.4%
Handbags                  264   93.6%   91.8%   97.6%   95.2%
Heels                     199   75.7%   73.7%   89.5%   82.2%
Sunglasses                161  100.0%  100.0%  100.0%  100.0%
Wallets                   141   90.9%   89.9%   95.5%   92.8%
Flip Flops                137   84.2%   82.6%   91.3%   87.2%
Sandals                   135   73.6%   72.2%   86

In [5]:
# ============================================================
# CELL 5: Latency Benchmark
# ============================================================
import time
import numpy as np

print("=" * 55)
print("LATENCY BENCHMARK")
print("=" * 55)

# Single query latency (average over 100 runs)
query_vector = ft_embeddings[0].reshape(1, -1)

times = []
for _ in range(100):
    start = time.time()
    distances, indices = ft_index.search(query_vector, k=10)
    elapsed = time.time() - start
    times.append(elapsed * 1000)  # convert to ms

print(f"\nSingle query (Flat Index, k=10):")
print(f"  Average:  {np.mean(times):.2f} ms")
print(f"  Median:   {np.median(times):.2f} ms")
print(f"  P95:      {np.percentile(times, 95):.2f} ms")
print(f"  P99:      {np.percentile(times, 99):.2f} ms")

# Batch query latency
for batch_size in [1, 10, 50, 100]:
    query_vectors = ft_embeddings[:batch_size]
    
    times = []
    for _ in range(20):
        start = time.time()
        ft_index.search(query_vectors, k=10)
        elapsed = time.time() - start
        times.append(elapsed * 1000)
    
    avg_time = np.mean(times)
    per_query = avg_time / batch_size
    qps = batch_size / (avg_time / 1000)
    print(f"\nBatch size {batch_size:>3}: {avg_time:>8.2f} ms total | "
          f"{per_query:.2f} ms/query | {qps:.0f} queries/sec")

print(f"\n{'=' * 55}")
print("PRODUCTION READINESS:")
print(f"  Target: < 100ms per query for real-time search")
print(f"  Actual: {np.mean(times)/100:.2f} ms per query")
print(f"  Status: {'✅ READY' if np.mean(times)/100 < 100 else '❌ TOO SLOW'}")
print(f"{'=' * 55}")


LATENCY BENCHMARK

Single query (Flat Index, k=10):
  Average:  7.89 ms
  Median:   6.81 ms
  P95:      7.52 ms
  P99:      9.16 ms

Batch size   1:     6.88 ms total | 6.88 ms/query | 145 queries/sec

Batch size  10:    33.45 ms total | 3.34 ms/query | 299 queries/sec

Batch size  50:    37.48 ms total | 0.75 ms/query | 1334 queries/sec

Batch size 100:    46.16 ms total | 0.46 ms/query | 2166 queries/sec

PRODUCTION READINESS:
  Target: < 100ms per query for real-time search
  Actual: 0.46 ms per query
  Status: ✅ READY
