# Training Script (Notebook Version)

This notebook is the notebook version of `scripts/train.py` - train the CLIP model from scratch.


In [None]:
# Install dependencies
%pip install torch torchvision numpy pillow pyyaml tqdm scikit-learn transformers

# Mount Google Drive if needed
# from google.colab import drive
# drive.mount('/content/drive')


In [None]:
import sys
from pathlib import Path
import torch
import yaml
from tqdm import tqdm

# Add project to path
BASE_DIR = Path('/content/CLIP_model') if Path('/content/CLIP_model').exists() else Path.cwd().parent
sys.path.insert(0, str(BASE_DIR))

from src.data.coco_dataset import build_coco_dataloader
from src.models.clip_model import CLIPModel
from src.training.train_clip import get_lr_scheduler, save_checkpoint, train_epoch
from src.utils.tokenization import SimpleTokenizer


In [None]:
# Configuration - adjust these paths as needed
CONFIG_PATH = BASE_DIR / "configs/clip_coco_small.yaml"
RESUME_FROM = None  # Set to checkpoint path if resuming, e.g., "checkpoints/checkpoint_epoch_5.pt"

# Load config
with open(CONFIG_PATH, "r") as f:
    config = yaml.safe_load(f)

print("Configuration:")
print(f"  Config: {CONFIG_PATH}")
print(f"  Resume from: {RESUME_FROM}")
print(f"  Batch size: {config['training']['batch_size']}")
print(f"  Learning rate: {config['training']['learning_rate']}")
print(f"  Epochs: {config['training']['num_epochs']}")


In [None]:
# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create directories
checkpoint_dir = BASE_DIR / config["training"]["save_dir"]
checkpoint_dir.mkdir(exist_ok=True, parents=True)
print(f"Checkpoint directory: {checkpoint_dir}")


In [None]:
# Build tokenizer
print("Building tokenizer vocabulary...")
tokenizer = SimpleTokenizer(
    vocab_size=config["model"]["text"]["vocab_size"], min_freq=2
)

# Load subset to build vocab
temp_loader = build_coco_dataloader(
    annotation_file=str(BASE_DIR / config["data"]["train"]["annotation_file"]),
    image_dir=str(BASE_DIR / config["data"]["train"]["image_dir"]),
    batch_size=32,
    shuffle=False,
    num_workers=2,
    max_samples=5000,
)

all_captions = []
for batch in tqdm(temp_loader, desc="Collecting captions"):
    all_captions.extend(batch["caption"])

tokenizer.build_vocab(all_captions)
print(f"Vocabulary size: {len(tokenizer)}")


In [None]:
# Create data loaders
def collate_fn(batch, tokenizer, max_seq_length):
    """Custom collate function to tokenize captions."""
    images = torch.stack([item["image"] for item in batch])
    captions = [item["caption"] for item in batch]

    # Tokenize captions
    token_ids = [
        tokenizer.encode(cap, max_length=max_seq_length) for cap in captions
    ]
    token_tensor = torch.tensor(token_ids)

    # Create mask (True for padding)
    mask = token_tensor == tokenizer.get_pad_token_id()

    return {
        "image": images,
        "text_tokens": token_tensor,
        "text_mask": mask,
        "caption": captions,
    }

from torch.utils.data import DataLoader

train_dataset = build_coco_dataloader(
    annotation_file=str(BASE_DIR / config["data"]["train"]["annotation_file"]),
    image_dir=str(BASE_DIR / config["data"]["train"]["image_dir"]),
    batch_size=config["training"]["batch_size"],
    shuffle=True,
    num_workers=4,
    subset_percentage=config["data"]["train"].get("subset_percentage"),
).dataset

train_loader = DataLoader(
    train_dataset,
    batch_size=config["training"]["batch_size"],
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    collate_fn=lambda b: collate_fn(
        b, tokenizer, config["model"]["text"]["max_seq_length"]
    ),
)

print(f"Train batches: {len(train_loader)}")


In [None]:
# Create model
model = CLIPModel(
    vision_config=config["model"]["vision"],
    text_config=config["model"]["text"],
).to(device)

print(f"Model created on {device}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")


In [None]:
# Setup optimizer and scheduler
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config["training"]["learning_rate"],
    weight_decay=config["training"]["weight_decay"],
)

num_training_steps = len(train_loader) * config["training"]["num_epochs"]
scheduler = get_lr_scheduler(
    optimizer,
    num_warmup_steps=config["training"]["warmup_steps"],
    num_training_steps=num_training_steps,
)

scaler = (
    torch.cuda.amp.GradScaler() if config["training"]["use_amp"] else None
)

print(f"Optimizer: AdamW")
print(f"Total training steps: {num_training_steps}")
print(f"Warmup steps: {config['training']['warmup_steps']}")
print(f"AMP enabled: {config['training']['use_amp']}")


In [None]:
# Resume from checkpoint if provided
start_epoch = 0
best_loss = float("inf")
if RESUME_FROM:
    checkpoint_path = BASE_DIR / RESUME_FROM
    if checkpoint_path.exists():
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        start_epoch = checkpoint["epoch"]
        best_loss = checkpoint["loss"]
        print(f"Resumed from epoch {start_epoch}, loss: {best_loss:.4f}")
    else:
        print(f"Checkpoint not found: {checkpoint_path}")
else:
    print("Starting training from scratch")


In [None]:
# Training loop
for epoch in range(start_epoch, config["training"]["num_epochs"]):
    print(f"\nEpoch {epoch + 1}/{config['training']['num_epochs']}")

    train_loss = train_epoch(
        model, train_loader, optimizer, scheduler, scaler, device, config
    )
    print(f"Train loss: {train_loss:.4f}")

    # Save checkpoint
    if (epoch + 1) % config["training"]["save_every"] == 0:
        checkpoint_path = save_checkpoint(
            model,
            optimizer,
            scheduler,
            epoch + 1,
            train_loss,
            checkpoint_dir,
            len(tokenizer),
        )
        print(f"Checkpoint saved: {checkpoint_path}")

    # Save best model
    if train_loss < best_loss:
        best_loss = train_loss
        best_path = save_checkpoint(
            model,
            optimizer,
            scheduler,
            epoch + 1,
            train_loss,
            checkpoint_dir,
            len(tokenizer),
            is_best=True,
        )
        print(f"Best model saved: {best_path}")

print("\nTraining completed!")
