In [1]:
# imports

from pathlib import Path
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
from transformers import CLIPModel, CLIPProcessor
from rich.console import Console
import matplotlib.pyplot as plt
import seaborn as sns

console = Console()

device = "cuda" if torch.cuda.is_available() else "cpu"
console.print(f"[bold green]Using device:[/bold green] {device}")

PROJECT_ROOT = Path("../").resolve()

VAL_CSV = PROJECT_ROOT / "data" / "processed" / "coco_val_20k.csv"
CHECKPOINT_PATH = PROJECT_ROOT / "checkpoints" / "best_model.pt"

assert VAL_CSV.exists(), "Validation CSV not found."

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df = pd.read_csv(VAL_CSV)

# For faster experimentation, evaluate on 2000 samples first
df = df.sample(2000, random_state=42).reset_index(drop=True)

len(df)

2000

In [3]:
from PIL import Image

def compute_embeddings(model, processor, dataframe):
    model.eval()
    image_embeddings = []
    text_embeddings = []

    with torch.no_grad():
        for i in tqdm(range(len(dataframe))):
            row = dataframe.iloc[i]
            image = Image.open(row["image_path"]).convert("RGB")
            caption = row["caption"]

            inputs = processor(
                text=[caption],
                images=[image],
                return_tensors="pt",
                padding=True
            ).to(device)

            outputs = model(**inputs)

            image_embeds = outputs.image_embeds
            text_embeds = outputs.text_embeds

            image_embeddings.append(image_embeds.cpu())
            text_embeddings.append(text_embeds.cpu())

    image_embeddings = torch.cat(image_embeddings)
    text_embeddings = torch.cat(text_embeddings)

    # Normalize
    image_embeddings /= image_embeddings.norm(dim=1, keepdim=True)
    text_embeddings /= text_embeddings.norm(dim=1, keepdim=True)

    return image_embeddings, text_embeddings

In [4]:
# Retrieval metrics

def recall_at_k(similarity, k):
    correct = 0
    for i in range(len(similarity)):
        if i in similarity[i].topk(k).indices:
            correct += 1
    return correct / len(similarity)


def mean_reciprocal_rank(similarity):
    reciprocal_ranks = []
    for i in range(len(similarity)):
        sorted_indices = similarity[i].argsort(descending=True)
        rank = (sorted_indices == i).nonzero(as_tuple=True)[0].item() + 1
        reciprocal_ranks.append(1 / rank)
    return np.mean(reciprocal_ranks)

In [5]:
# Evaluate pretrained CLIP

console.print("[bold cyan]Evaluating PRETRAINED CLIP[/bold cyan]")

pretrained_model = CLIPModel.from_pretrained(
    "openai/clip-vit-base-patch32",
    use_safetensors=True
).to(device)

processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

img_emb, txt_emb = compute_embeddings(pretrained_model, processor, df)

similarity = img_emb @ txt_emb.T

for k in [1,5,10]:
    console.print(f"Recall@{k}: {recall_at_k(similarity, k):.4f}")

console.print(f"MRR: {mean_reciprocal_rank(similarity):.4f}")

Loading weights: 100%|██████████| 398/398 [00:00<00:00, 1694.54it/s, Materializing param=visual_projection.weight]                                
[1mCLIPModel LOAD REPORT[0m from: openai/clip-vit-base-patch32
Key                                  | Status     |  | 
-------------------------------------+------------+--+-
text_model.embeddings.position_ids   | UNEXPECTED |  | 
vision_model.embeddings.position_ids | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m
The image processor of type `CLIPImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. 
100%|██████████| 2000/2000 [00:20<00:00, 95.36it/s] 


In [7]:
from peft import LoraConfig, get_peft_model

console.print("\n[bold yellow]Evaluating FINE-TUNED CLIP (LoRA)[/bold yellow]")

# Load base model
finetuned_model = CLIPModel.from_pretrained(
    "openai/clip-vit-base-patch32",
    use_safetensors=True
)

# Apply SAME LoRA config as training
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1,
    bias="none"
)

finetuned_model.text_model = get_peft_model(
    finetuned_model.text_model,
    peft_config
)

# Load weights
state_dict = torch.load(CHECKPOINT_PATH, map_location=device)
finetuned_model.load_state_dict(state_dict)

finetuned_model.to(device)
finetuned_model.eval()

# Compute embeddings
img_emb_ft, txt_emb_ft = compute_embeddings(
    finetuned_model,
    processor,
    df
)

similarity_ft = img_emb_ft @ txt_emb_ft.T

for k in [1,5,10]:
    console.print(f"Recall@{k}: {recall_at_k(similarity_ft, k):.4f}")

console.print(f"MRR: {mean_reciprocal_rank(similarity_ft):.4f}")

Loading weights: 100%|██████████| 398/398 [00:00<00:00, 1869.60it/s, Materializing param=visual_projection.weight]                                
[1mCLIPModel LOAD REPORT[0m from: openai/clip-vit-base-patch32
Key                                  | Status     |  | 
-------------------------------------+------------+--+-
text_model.embeddings.position_ids   | UNEXPECTED |  | 
vision_model.embeddings.position_ids | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m
100%|██████████| 2000/2000 [00:21<00:00, 93.67it/s]


In [8]:
# Compare results

metrics = {
    "Model": ["Pretrained", "Fine-tuned"],
    "Recall@1": [
        recall_at_k(similarity, 1),
        recall_at_k(similarity_ft, 1)
    ],
    "Recall@5": [
        recall_at_k(similarity, 5),
        recall_at_k(similarity_ft, 5)
    ],
    "MRR": [
        mean_reciprocal_rank(similarity),
        mean_reciprocal_rank(similarity_ft)
    ]
}

results_df = pd.DataFrame(metrics)
results_df

Unnamed: 0,Model,Recall@1,Recall@5,MRR
0,Pretrained,0.403,0.7345,0.552257
1,Fine-tuned,0.4705,0.816,0.622517
