
# Fine-tune CLIP (ViT-B/32) on Flickr30k (CPU)

This notebook guides you through running a CPU-only linear-probe style fine-tuning
of OpenAI's CLIP model on the Flickr30k dataset. You can either rely on the
Hugging Face dataset loader (when accessible) or provide a local TSV file with
image/caption metadata. The notebook evaluates Recall@1/5/10 and median rank for
both image→text and text→image retrieval, comparing the raw (zero-shot) CLIP
model against the fine-tuned checkpoint.



## What You'll Get

- CPU-friendly fine-tuning that updates only CLIP projection heads plus the
  temperature parameter.
- Retrieval metrics (Recall@1/5/10 and Median Rank) for both image→text and
  text→image directions.
- Built-in comparison tables and JSON export showing zero-shot vs fine-tuned
  performance.
- Optional TSV-based data ingestion if you prefer a manifest-driven workflow.



## Installation (CPU)

Install dependencies in a CPU environment (recommended to use a fresh virtualenv
or conda env):

```bash
python -m pip install --upgrade pip
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip install open_clip_torch datasets pillow tqdm numpy faiss-cpu
```


In [None]:

# Uncomment and run the following lines if you still need to install dependencies.
#!python -m pip install --upgrade pip
#!pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
#!pip install open_clip_torch datasets pillow tqdm numpy faiss-cpu


## Data Options

1. **Torchvision dataset (recommended):** Point to your Flickr30k image directory, the
   official `captions.txt`, and optionally a directory/file containing split lists
   (`train.txt`, `val.txt`, `test.txt`).
2. **Local TSV fallback:** Provide a tab-separated file with columns
   `image_path`, `caption`, `image_id`, `split`. Each image should appear five times
   (one per caption). This is handy for custom subsets or pre-filtered data.

Pick exactly one path (torchvision *or* TSV) per run.


---

## Imports and Dataset Utilities

In [None]:
import json
import math
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List, Optional, Sequence, Tuple

import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from tqdm import tqdm

import open_clip

try:
    from torchvision.datasets import Flickr30k as TVFlickr30k
    from torchvision.datasets.folder import default_loader as tv_default_loader
except ImportError:
    TVFlickr30k = None
    tv_default_loader = None


In [None]:
@dataclass
class Sample:
    image: torch.Tensor
    tokens: torch.Tensor
    image_id: str
    caption: str


In [None]:
class TorchvisionFlickr(Dataset):
    """Dataset backed by torchvision Flickr30k annotations (images + captions)."""

    def __init__(
        self,
        root: str | os.PathLike[str],
        ann_file: str | os.PathLike[str],
        split: str,
        preprocess,
        tokenizer,
        subset: Optional[int] = None,
        split_file: str | os.PathLike[str] | None = None,
    ) -> None:
        if tv_default_loader is None:
            raise ImportError(
                "torchvision is required for TorchvisionFlickr. Install torchvision first."
            )

        self.root = Path(root)
        self.ann_file = Path(ann_file)
        self.preprocess = preprocess
        self.tokenizer = tokenizer
        self.loader = tv_default_loader

        if not self.root.exists():
            raise FileNotFoundError(f"Image root not found: {root}")
        if not self.ann_file.is_file():
            raise FileNotFoundError(f"Annotation file not found: {ann_file}")

        allowed_images: Optional[set[str]] = None
        split_lower = split.lower()
        candidate_paths: List[Path] = []
        if split_file:
            candidate_paths.append(Path(split_file))
        candidate_paths.extend(
            [
                self.ann_file.with_name(f"{split_lower}.txt"),
                self.ann_file.with_name(f"{split_lower}.lst"),
                self.ann_file.with_name(f"{split_lower}.list"),
                self.root / f"{split_lower}.txt",
                self.root / "splits" / f"{split_lower}.txt",
                self.root / "splits" / f"{split_lower}.list",
            ]
        )

        for candidate in candidate_paths:
            if candidate.is_file():
                allowed_images = self._read_split_file(candidate)
                break

        captions_map: Dict[str, List[str]] = {}
        with self.ann_file.open("r", encoding="utf-8") as handle:
            for line in handle:
                line = line.strip()
                if not line or "	" not in line:
                    continue
                key, caption = line.split("	", 1)
                image_name = key.split("#", 1)[0]
                captions_map.setdefault(image_name, []).append(caption)

        rows: List[Tuple[Path, str, str]] = []
        for image_name, captions in captions_map.items():
            normalized = image_name if image_name.lower().endswith(".jpg") else f"{image_name}.jpg"
            if allowed_images and normalized not in allowed_images:
                continue
            image_path = self._resolve_image_path(normalized)
            if image_path is None:
                continue
            for caption in captions:
                if not caption:
                    continue
                rows.append((image_path, caption, normalized))

        if subset and subset > 0:
            rows = rows[:subset]

        if not rows:
            raise ValueError(
                f"No rows found for split='{split}' using torchvision Flickr30k annotations."
            )

        self.rows = rows

    def _read_split_file(self, path: Path) -> set[str]:
        allowed: set[str] = set()
        with path.open("r", encoding="utf-8") as handle:
            for line in handle:
                entry = line.strip().split()[0]
                if not entry:
                    continue
                if not entry.lower().endswith(".jpg"):
                    entry = f"{entry}.jpg"
                allowed.add(entry)
        return allowed

    def _resolve_image_path(self, image_name: str) -> Path | None:
        candidates = [
            self.root / image_name,
            self.root / "flickr30k_images" / image_name,
            self.root / "flickr30k-images" / image_name,
            self.root / "Images" / image_name,
        ]
        for candidate in candidates:
            if candidate.is_file():
                return candidate
        return None

    def __len__(self) -> int:
        return len(self.rows)

    def __getitem__(self, index: int) -> Sample:
        image_path, caption, image_id = self.rows[index]
        image = self.loader(str(image_path))
        if not isinstance(image, Image.Image):
            raise TypeError("torchvision default_loader did not return a PIL.Image")
        image = self.preprocess(image.convert("RGB"))
        tokens = self.tokenizer([caption])[0]
        return Sample(image=image, tokens=tokens, image_id=image_id, caption=caption)


In [None]:

def collate(batch: Sequence[Sample]) -> Tuple[torch.Tensor, torch.Tensor, List[str], List[str]]:
    images = torch.stack([item.image for item in batch], dim=0)
    tokens = torch.stack([item.tokens for item in batch], dim=0)
    image_ids = [item.image_id for item in batch]
    captions = [item.caption for item in batch]
    return images, tokens, image_ids, captions


def recall_at_k(similarity: np.ndarray, gt_sets: Sequence[Sequence[int]], ks: Iterable[int] = (1, 5, 10)) -> dict:
    if similarity.ndim != 2:
        raise ValueError('similarity matrix must be 2D')

    num_images = similarity.shape[0]
    order = np.argsort(-similarity, axis=1)
    ranks: List[int] = []
    for idx in range(num_images):
        gt = gt_sets[idx]
        rank_candidates = [int(np.where(order[idx] == gt_idx)[0][0]) for gt_idx in gt]
        ranks.append(min(rank_candidates))

    ranks_array = np.array(ranks, dtype=np.int64)
    metrics = {f'R@{k}': float((ranks_array < k).mean() * 100.0) for k in ks}
    metrics['MedR'] = float(np.median(ranks_array) + 1.0)
    return metrics


def evaluate(model: torch.nn.Module, dataloader: DataLoader, device: str = 'cpu') -> dict:
    model.eval()
    img_features: List[np.ndarray] = []
    txt_features: List[np.ndarray] = []
    caption_image_ids: List[str] = []
    image_id_to_index: dict[str, int] = {}
    unique_image_ids: List[str] = []

    with torch.no_grad():
        for images, tokens, image_ids, _ in tqdm(dataloader, desc='Encode eval', leave=False):
            images = images.to(device)
            tokens = tokens.to(device)

            image_feats = model.encode_image(images)
            text_feats = model.encode_text(tokens)
            image_feats = image_feats / image_feats.norm(dim=-1, keepdim=True)
            text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)

            for j in range(text_feats.size(0)):
                txt_features.append(text_feats[j].cpu().numpy())
                caption_image_ids.append(image_ids[j])

            for j in range(image_feats.size(0)):
                iid = image_ids[j]
                if iid not in image_id_to_index:
                    image_id_to_index[iid] = len(unique_image_ids)
                    unique_image_ids.append(iid)
                    img_features.append(image_feats[j].cpu().numpy())

    if not img_features or not txt_features:
        raise ValueError('Evaluation dataset produced no features')

    img_matrix = np.stack(img_features, axis=0)
    txt_matrix = np.stack(txt_features, axis=0)

    gt_sets: List[List[int]] = [list() for _ in unique_image_ids]
    for caption_idx, iid in enumerate(caption_image_ids):
        gt_sets[image_id_to_index[iid]].append(caption_idx)

    sim_i2t = img_matrix @ txt_matrix.T
    i2t_metrics = recall_at_k(sim_i2t, gt_sets)

    sim_t2i = txt_matrix @ img_matrix.T
    order_t2i = np.argsort(-sim_t2i, axis=1)
    ranks = []
    for caption_idx, iid in enumerate(caption_image_ids):
        target = image_id_to_index[iid]
        rank = int(np.where(order_t2i[caption_idx] == target)[0][0])
        ranks.append(rank)
    ranks_array = np.array(ranks, dtype=np.int64)
    t2i_metrics = {f'R@{k}': float((ranks_array < k).mean() * 100.0) for k in (1, 5, 10)}
    t2i_metrics['MedR'] = float(np.median(ranks_array) + 1.0)

    return {'I->T': i2t_metrics, 'T->I': t2i_metrics}


In [None]:

def train_one_epoch(
    model: torch.nn.Module,
    dataloader: DataLoader,
    optimizer: optim.Optimizer,
    device: str = 'cpu',
) -> float:
    model.train()
    total_loss = 0.0
    batches = 0

    for images, tokens, _, _ in tqdm(dataloader, desc='Train', leave=False):
        images = images.to(device)
        tokens = tokens.to(device)

        image_feats = model.encode_image(images)
        text_feats = model.encode_text(tokens)
        image_feats = image_feats / image_feats.norm(dim=-1, keepdim=True)
        text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)

        logit_scale = model.logit_scale.exp()
        logits = logit_scale * image_feats @ text_feats.t()
        labels = torch.arange(logits.size(0), device=device)

        loss_i2t = nn.CrossEntropyLoss()(logits, labels)
        loss_t2i = nn.CrossEntropyLoss()(logits.t(), labels)
        loss = 0.5 * (loss_i2t + loss_t2i)

        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_([p for p in model.parameters() if p.requires_grad], 1.0)
        optimizer.step()

        with torch.no_grad():
            model.logit_scale.data.clamp_(math.log(1 / 0.07), math.log(100.0))

        total_loss += loss.item()
        batches += 1

    return total_loss / max(batches, 1)


In [None]:

def _flatten_metrics(metrics: dict) -> dict:
    flat: dict[str, float] = {}
    for direction in ('I->T', 'T->I'):
        for metric, value in metrics[direction].items():
            flat[f'{direction}_{metric}'] = float(value)
    return flat


def _compare_and_print(split_name: str, baseline: dict, tuned: dict) -> None:
    def fmt(value: float) -> str:
        return f'{value:6.2f}'

    print('
' + '=' * 64)
    print(f'{split_name} — Raw CLIP vs Fine-tuned CLIP')
    print('=' * 64)

    for direction in ('I->T', 'T->I'):
        print(f'
{direction}')
        header = '{:<8} {:>12} {:>12} {:>12}'.format('Metric', 'Raw', 'Fine-tuned', 'Δ (pp)')
        print(header)
        print('-' * len(header))
        for metric in ('R@1', 'R@5', 'R@10', 'MedR'):
            base_val = float(baseline[direction][metric])
            tuned_val = float(tuned[direction][metric])
            delta = tuned_val - base_val
            arrow = '→'
            if metric == 'MedR':
                if delta < 0:
                    arrow = '↓'
                elif delta > 0:
                    arrow = '↑'
            else:
                if delta > 0:
                    arrow = '↑'
                elif delta < 0:
                    arrow = '↓'
            print(
                '{:<8} {:>12} {:>12} {:>5s}{:>6.2f}'.format(
                    metric, fmt(base_val), fmt(tuned_val), arrow, abs(delta)
                )
            )


def _save_results_json(
    path: str | os.PathLike[str],
    val_raw: dict,
    val_tuned: dict,
    test_raw: dict,
    test_tuned: dict,
) -> None:
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    payload = {
        'val': {'raw': val_raw, 'tuned': val_tuned},
        'test': {'raw': test_raw, 'tuned': test_tuned},
    }
    with path.open('w', encoding='utf-8') as handle:
        json.dump(payload, handle, indent=2)
    print(f'
Saved results to: {path}')


In [None]:
def build_dataloader(
    split: str,
    preprocess,
    tokenizer,
    batch_size: int,
    shuffle: bool,
    subset: int,
    data_csv: Optional[str],
    torchvision_root: Optional[str],
    torchvision_ann: Optional[str],
    torchvision_split: Optional[str],
) -> DataLoader:
    if data_csv:
        dataset = TSVFlickr(
            data_csv,
            split,
            preprocess,
            tokenizer,
            subset=subset if subset > 0 else None,
        )
    elif torchvision_root and torchvision_ann:
        split_file: Optional[Path] = None
        if torchvision_split:
            split_path = Path(torchvision_split)
            if split_path.is_dir():
                for suffix in ('.txt', '.lst', '.list'):
                    candidate = split_path / f"{split.lower()}{suffix}"
                    if candidate.is_file():
                        split_file = candidate
                        break
            elif split_path.is_file():
                split_file = split_path
        dataset = TorchvisionFlickr(
            torchvision_root,
            torchvision_ann,
            split,
            preprocess,
            tokenizer,
            subset=subset if subset > 0 else None,
            split_file=split_file,
        )
    else:
        raise ValueError('Provide either data_csv or torchvision root/annotation paths.')

    workers = 2 if torch.get_num_threads() > 1 else 0

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=workers,
        pin_memory=False,
        collate_fn=collate,
    )


## Configure Dataset Source

- **Torchvision**: set `torchvision_root` to the image directory, `torchvision_ann`
  to `captions.txt`, and (optionally) `torchvision_split` to a folder/file with
  `train/val/test` lists. This is the default/official setup.
- **TSV**: if you maintain a custom tab-separated manifest, provide its path via
  `data_csv`. Leave the torchvision fields blank.

Do not mix sources in the same run.


## Download Flickr30k Dataset

The Flickr30k dataset consists of images and captions. You need to download both components:

### Method 1: Manual Download (Recommended)

1. **Images**: 
   - Visit: https://shannon.cs.illinois.edu/DenotationGraph/
   - Request access and download `flickr30k_images.tar.gz`
   
2. **Captions**: 
   - Download from: http://shannon.cs.illinois.edu/DenotationGraph/data/flickr30k.tar.gz

### Method 2: Using Commands

Run the download and setup commands below:

In [26]:
# Check Flickr30k dataset setup status
import os
from pathlib import Path

def check_flickr30k_status():
    """Check the current status of Flickr30k dataset setup"""
    
    # Use absolute path to project root
    project_root = Path(__file__).parent.parent if '__file__' in globals() else Path.cwd().parent
    data_dir = project_root / 'data'
    images_dir = data_dir / 'flickr30k_images'
    caption_file = data_dir / 'flickr30k' / 'results_20130124.token'
    
    print("📊 Flickr30k Dataset Status:")
    print("=" * 40)
    print(f"🔍 Looking in: {data_dir}")
    
    # Check directories
    print(f"📁 Data directory: {'✅ Exists' if data_dir.exists() else '❌ Missing'}")
    print(f"📁 Images directory: {'✅ Exists' if images_dir.exists() else '❌ Missing'}")
    print(f"📄 Captions file: {'✅ Exists' if caption_file.exists() else '❌ Missing'}")
    
    # Count images if directory exists
    if images_dir.exists():
        img_count = len(list(images_dir.glob('*.jpg')))
        print(f"🖼️  Image count: {img_count}")
        
        if img_count > 0:
            print("\n🎉 Dataset ready for training!")
            return True
        else:
            print(f"\n⚠️  No images found in {images_dir}")
            print("   Please download flickr30k_images.tar.gz from:")
            print("   https://shannon.cs.illinois.edu/DenotationGraph/")
            print(f"   And extract all .jpg files to: {images_dir}")
            return False
    else:
        print(f"\n❌ Images directory {images_dir} not found")
        return False

# Run the status check
dataset_ready = check_flickr30k_status()

if not dataset_ready:
    print("\n" + "="*50)
    print("📥 TO COMPLETE SETUP:")
    print("1. Visit: https://shannon.cs.illinois.edu/DenotationGraph/")
    print("2. Request access and download 'flickr30k_images.tar.gz'")
    print("3. Extract to: ../data/flickr30k_images/ (relative to notebook)")
    print("4. Run this cell again to verify setup")
    print("="*50)

📊 Flickr30k Dataset Status:
🔍 Looking in: /Users/yzqi/Projects/clip-finetuning/data
📁 Data directory: ✅ Exists
📁 Images directory: ✅ Exists
📄 Captions file: ✅ Exists
🖼️  Image count: 0

⚠️  No images found in /Users/yzqi/Projects/clip-finetuning/data/flickr30k_images
   Please download flickr30k_images.tar.gz from:
   https://shannon.cs.illinois.edu/DenotationGraph/
   And extract all .jpg files to: /Users/yzqi/Projects/clip-finetuning/data/flickr30k_images

📥 TO COMPLETE SETUP:
1. Visit: https://shannon.cs.illinois.edu/DenotationGraph/
2. Request access and download 'flickr30k_images.tar.gz'
3. Extract to: ../data/flickr30k_images/ (relative to notebook)
4. Run this cell again to verify setup


### Alternative: Terminal Commands

If you prefer using the terminal, run these commands:

```bash
# Create directory structure
mkdir -p ./data/flickr30k_images
mkdir -p ./data/flickr30k

# Download captions (annotations)
cd ./data
wget http://shannon.cs.illinois.edu/DenotationGraph/data/flickr30k.tar.gz
tar -xzf flickr30k.tar.gz
rm flickr30k.tar.gz

# For images: manually download from https://shannon.cs.illinois.edu/DenotationGraph/
# Extract flickr30k_images.tar.gz to ./data/flickr30k_images/
```

### Expected Final Structure:
```
./data/
├── flickr30k_images/
│   ├── 1000092795.jpg
│   ├── 10002456.jpg
│   └── ... (31,783 total images)
└── flickr30k/
    └── results_20130124.token
```

In [None]:
# Optional: If you have direct access to flickr30k_images.tar.gz, 
# uncomment and modify the path below to extract it automatically

def extract_flickr30k_images(tar_path):
    """Extract Flickr30k images from tar file"""
    import tarfile
    
    tar_path = Path(tar_path)
    if not tar_path.exists():
        print(f"❌ File not found: {tar_path}")
        return
    
    extract_dir = Path('./data/flickr30k_images')
    extract_dir.mkdir(exist_ok=True)
    
    print(f"📦 Extracting {tar_path.name}...")
    with tarfile.open(tar_path, 'r:gz') as tar:
        # Extract only .jpg files
        members = [m for m in tar.getmembers() if m.name.endswith('.jpg')]
        for member in members:
            # Extract with just the filename, not the full path
            member.name = Path(member.name).name
            tar.extract(member, extract_dir)
            
    print(f"✅ Extracted {len(members)} images to {extract_dir}")

# Uncomment and set the correct path if you have the tar file:
# extract_flickr30k_images("path/to/flickr30k_images.tar.gz")

In [None]:
@dataclass
class Config:
    data_csv: str = ''  # TSV path if using local data
    torchvision_root: str = ''  # Directory with Flickr30k images
    torchvision_ann: str = ''  # Path to captions.txt
    torchvision_split: str = ''  # Optional split directory/file
    epochs: int = 1
    batch_size: int = 32
    subset: int = 5000
    zero_shot: bool = False
    val_split: str = 'val'
    test_split: str = 'test'
    results_json: str = 'artifacts/clip_flickr30k_results.json'
    save_checkpoint: bool = True
    checkpoint_path: str = 'checkpoints/linearprobe_best.pt'


In [None]:
def run_experiment(cfg: Config) -> dict:
    device = 'cpu'
    model_name = 'ViT-B-32'
    pretrained = 'openai'

    model, preprocess_train, preprocess_eval = open_clip.create_model_and_transforms(
        model_name, pretrained=pretrained
    )
    tokenizer = open_clip.get_tokenizer(model_name)

    model.to(device)

    if not cfg.data_csv and not (cfg.torchvision_root and cfg.torchvision_ann):
        raise ValueError('Provide either data_csv or torchvision root/ann paths.')

    for param in model.parameters():
        param.requires_grad = False

    trainable_params: List[torch.nn.Parameter] = []
    if getattr(model, 'text_projection', None) is not None:
        model.text_projection.requires_grad = True
        trainable_params.append(model.text_projection)
    if getattr(model, 'visual', None) is not None and getattr(model.visual, 'proj', None) is not None:
        model.visual.proj.requires_grad = True
        trainable_params.append(model.visual.proj)

    model.logit_scale.requires_grad = True
    trainable_params.append(model.logit_scale)

    val_loader = build_dataloader(
        cfg.val_split,
        preprocess_eval,
        tokenizer,
        batch_size=cfg.batch_size,
        shuffle=False,
        subset=0,
        data_csv=cfg.data_csv,
        torchvision_root=cfg.torchvision_root,
        torchvision_ann=cfg.torchvision_ann,
        torchvision_split=cfg.torchvision_split,
    )
    test_loader = build_dataloader(
        cfg.test_split,
        preprocess_eval,
        tokenizer,
        batch_size=cfg.batch_size,
        shuffle=False,
        subset=0,
        data_csv=cfg.data_csv,
        torchvision_root=cfg.torchvision_root,
        torchvision_ann=cfg.torchvision_ann,
        torchvision_split=cfg.torchvision_split,
    )

    print('Evaluating RAW CLIP (zero-shot) before training...')
    baseline_val = evaluate(model, val_loader, device)
    baseline_test = evaluate(model, test_loader, device)
    _compare_and_print('VAL', baseline_val, baseline_val)
    _compare_and_print('TEST', baseline_test, baseline_test)

    if cfg.zero_shot:
        _save_results_json(
            cfg.results_json,
            _flatten_metrics(baseline_val),
            _flatten_metrics(baseline_val),
            _flatten_metrics(baseline_test),
            _flatten_metrics(baseline_test),
        )
        return {
            'baseline': {'val': baseline_val, 'test': baseline_test},
            'fine_tuned': {'val': baseline_val, 'test': baseline_test},
        }

    train_loader = build_dataloader(
        'train',
        preprocess_train,
        tokenizer,
        batch_size=cfg.batch_size,
        shuffle=True,
        subset=cfg.subset,
        data_csv=cfg.data_csv,
        torchvision_root=cfg.torchvision_root,
        torchvision_ann=cfg.torchvision_ann,
        torchvision_split=cfg.torchvision_split,
    )

    optimizer = optim.AdamW(
        [param for param in trainable_params if param is not None],
        lr=1e-3,
        weight_decay=0.2,
    )

    best_val_r1 = -float('inf')
    for epoch in range(1, cfg.epochs + 1):
        print(f'Epoch {epoch}/{cfg.epochs}')
        train_loss = train_one_epoch(model, train_loader, optimizer, device)
        val_metrics = evaluate(model, val_loader, device)
        val_score = 0.5 * (
            val_metrics['I->T']['R@1'] + val_metrics['T->I']['R@1']
        )
        print(
            f'train_loss={train_loss:.4f}  val_R@1(avg)={val_score:.2f}  details={val_metrics}'
        )

        if cfg.save_checkpoint and val_score > best_val_r1:
            best_val_r1 = val_score
            Path(cfg.checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
            torch.save({'model': model.state_dict()}, cfg.checkpoint_path)
            print(f'Saved checkpoint to {cfg.checkpoint_path}')

    print('
Running final evaluations with fine-tuned model...')
    val_metrics = evaluate(model, val_loader, device)
    test_metrics = evaluate(model, test_loader, device)

    _compare_and_print('VAL', baseline_val, val_metrics)
    _compare_and_print('TEST', baseline_test, test_metrics)

    _save_results_json(
        cfg.results_json,
        _flatten_metrics(baseline_val),
        _flatten_metrics(val_metrics),
        _flatten_metrics(baseline_test),
        _flatten_metrics(test_metrics),
    )

    return {
        'baseline': {'val': baseline_val, 'test': baseline_test},
        'fine_tuned': {'val': val_metrics, 'test': test_metrics},
    }



## Recommended CPU-Friendly Settings

Start with a quick sanity run on a subset before committing to the full train
split:

- `epochs = 1`
- `batch_size = 32`
- `subset = 5000`

If the run is still slow, drop `subset` to ~3000. For improved accuracy once
things look stable, try `epochs = 3` (still CPU-only, just longer).


In [None]:
# Example configuration (edit and uncomment to run)
# cfg = Config(
#     data_csv='',                 # set if using TSV
#     torchvision_root='/data/flickr30k/images',
#     torchvision_ann='/data/flickr30k/annotations/captions.txt',
#     torchvision_split='/data/flickr30k/splits',
#     epochs=1,
#     batch_size=32,
#     subset=5000,
#     zero_shot=False,
# )
# results = run_experiment(cfg)
# results



## Notes & Next Steps

- The JSON metrics file (default `artifacts/clip_flickr30k_results.json`) stores
  both baseline and fine-tuned metrics for later analysis or visualization.
- Use `zero_shot=True` to capture just the raw CLIP performance before any
  training. This is useful for regression checks.
- The notebook's helper functions mirror the standalone scripts
  (`finetune_clip_flickr30k_cpu.py` and `scripts/compare_models.py`) so you can
  mix and match CLI or notebook workflows depending on preference.
- Consider logging additional diagnostics (loss curves, wall-clock timings) or
  exporting metrics to experiment trackers as you iterate.
