# Evaluation Script (Notebook Version)

This notebook is the notebook version of `scripts/eval.py` - evaluate the trained model.


In [None]:
# Install dependencies
%pip install torch torchvision numpy pillow pyyaml tqdm scikit-learn transformers

# Mount Google Drive if needed
# from google.colab import drive
# drive.mount('/content/drive')


In [None]:
import sys
from pathlib import Path
import json
import torch
import yaml
from tqdm import tqdm

# Add project to path
BASE_DIR = Path('/content/CLIP_model') if Path('/content/CLIP_model').exists() else Path.cwd().parent
sys.path.insert(0, str(BASE_DIR))

from src.data.coco_dataset import build_coco_dataloader
from src.eval.eval_retrieval import evaluate_retrieval
from src.models.clip_model import CLIPModel
from src.utils.tokenization import SimpleTokenizer


In [None]:
# Configuration - adjust these paths
CONFIG_PATH = BASE_DIR / "configs/clip_coco_small.yaml"
CHECKPOINT_PATH = BASE_DIR / "checkpoints/best_model.pt"  # Update with your checkpoint

# Load config
with open(CONFIG_PATH, "r") as f:
    config = yaml.safe_load(f)

print(f"Config: {CONFIG_PATH}")
print(f"Checkpoint: {CHECKPOINT_PATH}")
print(f"Checkpoint exists: {CHECKPOINT_PATH.exists()}")


In [None]:
# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load model
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
model = CLIPModel(
    vision_config=config["model"]["vision"],
    text_config=config["model"]["text"],
).to(device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
print(f"Model loaded from epoch {checkpoint['epoch']}")


In [None]:
# Build tokenizer
tokenizer = SimpleTokenizer(
    vocab_size=config["model"]["text"]["vocab_size"], min_freq=2
)

# Build vocab from validation set
temp_loader = build_coco_dataloader(
    annotation_file=str(BASE_DIR / config["data"]["val"]["annotation_file"]),
    image_dir=str(BASE_DIR / config["data"]["val"]["image_dir"]),
    batch_size=32,
    shuffle=False,
    num_workers=2,
    max_samples=5000,
)

all_captions = []
for batch in tqdm(temp_loader, desc="Building vocab"):
    all_captions.extend(batch["caption"])

tokenizer.build_vocab(all_captions)
print(f"Tokenizer vocabulary size: {len(tokenizer)}")


In [None]:
# Create data loader
def collate_fn(batch, tokenizer, max_seq_length):
    """Custom collate function to tokenize captions."""
    images = torch.stack([item["image"] for item in batch])
    captions = [item["caption"] for item in batch]

    token_ids = [
        tokenizer.encode(cap, max_length=max_seq_length) for cap in captions
    ]
    token_tensor = torch.tensor(token_ids)
    mask = token_tensor == tokenizer.get_pad_token_id()

    return {
        "image": images,
        "text_tokens": token_tensor,
        "text_mask": mask,
        "caption": captions,
    }

from torch.utils.data import DataLoader

val_dataset = build_coco_dataloader(
    annotation_file=str(BASE_DIR / config["data"]["val"]["annotation_file"]),
    image_dir=str(BASE_DIR / config["data"]["val"]["image_dir"]),
    batch_size=config["eval"]["batch_size"],
    shuffle=False,
    num_workers=2,
    subset_percentage=config["data"]["val"].get("subset_percentage"),
).dataset

val_loader = DataLoader(
    val_dataset,
    batch_size=config["eval"]["batch_size"],
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    collate_fn=lambda b: collate_fn(
        b, tokenizer, config["model"]["text"]["max_seq_length"]
    ),
)

print(f"Validation batches: {len(val_loader)}")


In [None]:
# Evaluate
results = evaluate_retrieval(
    model, val_loader, device, k_values=config["eval"]["top_k"]
)

# Print results
print("\n=== Retrieval Results ===")
print("\nImage-to-Text Retrieval:")
for k in config["eval"]["top_k"]:
    print(f"  Recall@{k}: {results['image_to_text'][k]:.4f}")

print("\nText-to-Image Retrieval:")
for k in config["eval"]["top_k"]:
    print(f"  Recall@{k}: {results['text_to_image'][k]:.4f}")


In [None]:
# Save results
results["checkpoint"] = str(CHECKPOINT_PATH)
results["epoch"] = checkpoint["epoch"]

results_dir = BASE_DIR / "results"
results_dir.mkdir(exist_ok=True)

results_path = results_dir / "eval_results.json"
with open(results_path, "w") as f:
    json.dump(results, f, indent=2)

print(f"\nResults saved to: {results_path}")
