# Complete Training Pipeline for ArtEmis Image Captioning

This notebook trains and evaluates both CNN+LSTM and Vision-Language Transformer models.


## 1. Setup and Configuration


In [None]:
import sys
from pathlib import Path

# Add src to path
project_root = Path().resolve().parent
sys.path.insert(0, str(project_root))

import json
import logging
import random
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from tqdm import tqdm

from src.datasets.artemis_dataset import ArtemisDataset, collate_captions
from src.models.cnn_lstm.cnn_lstm_model import ImageCaptioningCNNLSTM
from src.models.vit.caption_transformer import CaptionTransformer
from src.models.vit.config_transformer import TransformerHyperParams
from src.utils.tokenization import Tokenizer
from src.evaluation.evaluate_models import evaluate_cnn_lstm, evaluate_transformer

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('logs/training.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Set random seeds
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)


In [None]:
# Configuration
CONFIG = {
    # Data paths
    "csv_path": project_root / "data" / "artemis_sample_8k.csv",
    "images_root": project_root / "data" / "wikiart_sample_128",
    "vocab_path": project_root / "src" / "utils" / "vocab.json",
    "split_path": project_root / "data" / "splits.json",
    
    # Training hyperparameters
    "batch_size": 32,
    "num_workers": 4,
    "max_len": 32,
    "epochs_cnn_lstm": 15,
    "epochs_transformer": 15,
    "lr_cnn_lstm": 1e-3,
    "lr_transformer": 1e-4,
    "weight_decay": 1e-4,
    "max_grad_norm": 5.0,
    "early_stop_patience": 5,
    
    # Model configs
    "cnn_lstm": {
        "embedding_dim": 256,
        "hidden_dim": 256,
        "num_layers": 1,
        "dropout": 0.1,
        "image_feat_dim": 256,
    },
    "transformer": {
        "d_model": 256,
        "num_heads": 8,
        "num_layers": 4,
        "patch_size": 16,
        "dropout": 0.1,
        "mlp_ratio": 4.0,
    },
    
    # Output paths
    "checkpoint_dir_cnn_lstm": project_root / "checkpoints" / "cnn_lstm",
    "checkpoint_dir_transformer": project_root / "checkpoints" / "transformer",
    "results_dir": project_root / "results",
    
    # Device
    "device": "cuda" if torch.cuda.is_available() else "cpu",
}

# Create directories
for path in [CONFIG["checkpoint_dir_cnn_lstm"], CONFIG["checkpoint_dir_transformer"], CONFIG["results_dir"], project_root / "logs"]:
    path.mkdir(parents=True, exist_ok=True)

print(f"Using device: {CONFIG['device']}")
print(f"CUDA available: {torch.cuda.is_available()}")


## 2. Load Tokenizer and Build Datasets


In [None]:
# Load tokenizer
tokenizer = Tokenizer.load(CONFIG["vocab_path"])
print(f"Vocabulary size: {len(tokenizer.word2idx)}")
print(f"Special tokens: PAD={tokenizer.pad_idx}, BOS={tokenizer.bos_idx}, EOS={tokenizer.eos_idx}")


In [None]:
# Build train/val/test splits
def build_splits(csv_path, split_path, val_ratio=0.1, test_ratio=0.1, seed=42):
    if split_path.exists():
        with open(split_path, 'r') as f:
            return json.load(f)
    
    df = pd.read_csv(csv_path)
    unique_images = df["painting"].dropna().astype(str).unique().tolist()
    
    rng = random.Random(seed)
    rng.shuffle(unique_images)
    
    total = len(unique_images)
    test_count = max(1, int(total * test_ratio))
    val_count = max(1, int(total * val_ratio))
    train_count = total - val_count - test_count
    
    splits = {
        "train": unique_images[:train_count],
        "val": unique_images[train_count:train_count + val_count],
        "test": unique_images[train_count + val_count:],
    }
    
    split_path.parent.mkdir(parents=True, exist_ok=True)
    with open(split_path, 'w') as f:
        json.dump(splits, f, indent=2)
    
    return splits

splits = build_splits(CONFIG["csv_path"], CONFIG["split_path"], val_ratio=0.1, test_ratio=0.1)
print(f"Train: {len(splits['train'])}, Val: {len(splits['val'])}, Test: {len(splits['test'])}")


In [None]:
# Create datasets
train_dataset = ArtemisDataset(
    csv_path=CONFIG["csv_path"],
    img_root=CONFIG["images_root"],
    tokenizer=tokenizer,
    max_len=CONFIG["max_len"],
    image_filter=set(splits["train"]),
)

val_dataset = ArtemisDataset(
    csv_path=CONFIG["csv_path"],
    img_root=CONFIG["images_root"],
    tokenizer=tokenizer,
    max_len=CONFIG["max_len"],
    image_filter=set(splits["val"]),
)

test_dataset = ArtemisDataset(
    csv_path=CONFIG["csv_path"],
    img_root=CONFIG["images_root"],
    tokenizer=tokenizer,
    max_len=CONFIG["max_len"],
    image_filter=set(splits["test"]),
)

print(f"Train dataset: {len(train_dataset)} samples")
print(f"Val dataset: {len(val_dataset)} samples")
print(f"Test dataset: {len(test_dataset)} samples")


In [None]:
# Create data loaders with multiprocessing
device = torch.device(CONFIG["device"])
pin_memory = device.type == "cuda"

train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=True,
    num_workers=CONFIG["num_workers"],
    collate_fn=collate_captions,
    pin_memory=pin_memory,
    persistent_workers=CONFIG["num_workers"] > 0,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=False,
    num_workers=CONFIG["num_workers"],
    collate_fn=collate_captions,
    pin_memory=pin_memory,
    persistent_workers=CONFIG["num_workers"] > 0,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=False,
    num_workers=CONFIG["num_workers"],
    collate_fn=collate_captions,
    pin_memory=pin_memory,
    persistent_workers=CONFIG["num_workers"] > 0,
)

print(f"Data loaders created with num_workers={CONFIG['num_workers']}, pin_memory={pin_memory}")


## 3. Train CNN+LSTM Model


In [None]:
# Initialize CNN+LSTM model
cnn_lstm_model = ImageCaptioningCNNLSTM(
    vocab_size=len(tokenizer.word2idx),
    **CONFIG["cnn_lstm"],
).to(device)

print(f"CNN+LSTM model parameters: {sum(p.numel() for p in cnn_lstm_model.parameters()):,}")


In [None]:
# Training functions for CNN+LSTM
def train_epoch_cnn_lstm(model, loader, optimizer, criterion, device, scaler, max_grad_norm):
    model.train()
    total_loss = 0.0
    for batch in tqdm(loader, desc="Training", leave=False):
        images = batch["images"].to(device, non_blocking=True)
        captions_in = batch["captions_in"].to(device, non_blocking=True)
        captions_out = batch["captions_out"].to(device, non_blocking=True)
        
        optimizer.zero_grad(set_to_none=True)
        with autocast(enabled=scaler.is_enabled()):
            logits = model(images, captions_in)
            loss = criterion(logits.view(-1, logits.size(-1)), captions_out.view(-1))
        
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        clip_grad_norm_(model.parameters(), max_grad_norm)
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
    return total_loss / len(loader)

@torch.no_grad()
def eval_epoch_cnn_lstm(model, loader, criterion, device, use_amp):
    model.eval()
    total_loss = 0.0
    for batch in tqdm(loader, desc="Evaluating", leave=False):
        images = batch["images"].to(device, non_blocking=True)
        captions_in = batch["captions_in"].to(device, non_blocking=True)
        captions_out = batch["captions_out"].to(device, non_blocking=True)
        
        with autocast(enabled=use_amp):
            logits = model(images, captions_in)
            loss = criterion(logits.view(-1, logits.size(-1)), captions_out.view(-1))
        total_loss += loss.item()
    return total_loss / len(loader)


In [None]:
# Train CNN+LSTM
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_idx)
optimizer = Adam(cnn_lstm_model.parameters(), lr=CONFIG["lr_cnn_lstm"])
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)
use_amp = device.type == "cuda"
scaler = GradScaler(enabled=use_amp)

cnn_lstm_history = {"train_loss": [], "val_loss": []}
best_val_loss = float('inf')
early_stop_counter = 0

for epoch in range(1, CONFIG["epochs_cnn_lstm"] + 1):
    train_loss = train_epoch_cnn_lstm(cnn_lstm_model, train_loader, optimizer, criterion, device, scaler, CONFIG["max_grad_norm"])
    val_loss = eval_epoch_cnn_lstm(cnn_lstm_model, val_loader, criterion, device, use_amp)
    
    scheduler.step(val_loss)
    cnn_lstm_history["train_loss"].append(train_loss)
    cnn_lstm_history["val_loss"].append(val_loss)
    
    logger.info(f"CNN+LSTM Epoch {epoch}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
    print(f"Epoch {epoch}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stop_counter = 0
        # Save checkpoint
        checkpoint = {
            "model_state": cnn_lstm_model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "epoch": epoch,
            "val_loss": val_loss,
            "config": CONFIG["cnn_lstm"],
        }
        torch.save(checkpoint, CONFIG["checkpoint_dir_cnn_lstm"] / f"best_epoch{epoch:03d}_val{val_loss:.3f}.pt")
        print(f"Saved checkpoint: best_epoch{epoch:03d}_val{val_loss:.3f}.pt")
    else:
        early_stop_counter += 1
        if early_stop_counter >= CONFIG["early_stop_patience"]:
            logger.info("Early stopping triggered")
            print("Early stopping triggered")
            break


## 4. Train Transformer Model


In [None]:
# Initialize Transformer model
hyper = TransformerHyperParams(
    vocab_size=len(tokenizer.word2idx),
    pad_idx=tokenizer.pad_idx,
    max_seq_len=CONFIG["max_len"],
    **CONFIG["transformer"],
)

transformer_model = CaptionTransformer(
    vision_config=hyper.vision_config(),
    decoder_config=hyper.decoder_config(),
).to(device)

print(f"Transformer model parameters: {sum(p.numel() for p in transformer_model.parameters()):,}")


In [None]:
# Training functions for Transformer
def train_epoch_transformer(model, loader, optimizer, criterion, device, scaler, max_grad_norm):
    model.train()
    total_loss = 0.0
    for batch in tqdm(loader, desc="Training", leave=False):
        images = batch["images"].to(device, non_blocking=True)
        captions_in = batch["captions_in"].to(device, non_blocking=True)
        captions_out = batch["captions_out"].to(device, non_blocking=True)
        
        optimizer.zero_grad(set_to_none=True)
        with autocast(enabled=scaler.is_enabled()):
            logits = model(images, captions_in)
            loss = criterion(logits.view(-1, logits.size(-1)), captions_out.view(-1))
        
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        clip_grad_norm_(model.parameters(), max_grad_norm)
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
    return total_loss / len(loader)

@torch.no_grad()
def eval_epoch_transformer(model, loader, criterion, device, use_amp):
    model.eval()
    total_loss = 0.0
    for batch in tqdm(loader, desc="Evaluating", leave=False):
        images = batch["images"].to(device, non_blocking=True)
        captions_in = batch["captions_in"].to(device, non_blocking=True)
        captions_out = batch["captions_out"].to(device, non_blocking=True)
        
        with autocast(enabled=use_amp):
            logits = model(images, captions_in)
            loss = criterion(logits.view(-1, logits.size(-1)), captions_out.view(-1))
        total_loss += loss.item()
    return total_loss / len(loader)


In [None]:
# Train Transformer
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_idx)
optimizer = AdamW(transformer_model.parameters(), lr=CONFIG["lr_transformer"], weight_decay=CONFIG["weight_decay"])
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)
use_amp = device.type == "cuda"
scaler = GradScaler(enabled=use_amp)

transformer_history = {"train_loss": [], "val_loss": []}
best_val_loss = float('inf')
early_stop_counter = 0

for epoch in range(1, CONFIG["epochs_transformer"] + 1):
    train_loss = train_epoch_transformer(transformer_model, train_loader, optimizer, criterion, device, scaler, CONFIG["max_grad_norm"])
    val_loss = eval_epoch_transformer(transformer_model, val_loader, criterion, device, use_amp)
    
    scheduler.step(val_loss)
    transformer_history["train_loss"].append(train_loss)
    transformer_history["val_loss"].append(val_loss)
    
    logger.info(f"Transformer Epoch {epoch}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
    print(f"Epoch {epoch}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stop_counter = 0
        # Save checkpoint
        checkpoint = {
            "model_state": transformer_model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "epoch": epoch,
            "val_loss": val_loss,
            "hyperparams": hyper.__dict__,
        }
        torch.save(checkpoint, CONFIG["checkpoint_dir_transformer"] / f"best_epoch{epoch:03d}_val{val_loss:.3f}.pt")
        print(f"Saved checkpoint: best_epoch{epoch:03d}_val{val_loss:.3f}.pt")
    else:
        early_stop_counter += 1
        if early_stop_counter >= CONFIG["early_stop_patience"]:
            logger.info("Early stopping triggered")
            print("Early stopping triggered")
            break


## 5. Plot Training Curves


In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# CNN+LSTM
axes[0].plot(cnn_lstm_history["train_loss"], label="Train Loss", marker='o')
axes[0].plot(cnn_lstm_history["val_loss"], label="Val Loss", marker='s')
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("CNN+LSTM Training Curves")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Transformer
axes[1].plot(transformer_history["train_loss"], label="Train Loss", marker='o')
axes[1].plot(transformer_history["val_loss"], label="Val Loss", marker='s')
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Loss")
axes[1].set_title("Transformer Training Curves")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(CONFIG["results_dir"] / "training_curves.png", dpi=150, bbox_inches='tight')
plt.show()


## 6. Evaluate Models on Test Set


In [None]:
# Evaluate CNN+LSTM
print("Evaluating CNN+LSTM model...")
cnn_lstm_scores = evaluate_cnn_lstm(
    cnn_lstm_model,
    tokenizer,
    test_loader,
    device,
    CONFIG["max_len"],
)

print("\nCNN+LSTM Results:")
for metric, score in cnn_lstm_scores.items():
    print(f"  {metric}: {score:.4f}")


In [None]:
# Evaluate Transformer
print("Evaluating Transformer model...")
transformer_scores = evaluate_transformer(
    transformer_model,
    tokenizer,
    test_loader,
    device,
    CONFIG["max_len"],
)

print("\nTransformer Results:")
for metric, score in transformer_scores.items():
    print(f"  {metric}: {score:.4f}")


In [None]:
# Save evaluation results
results = {
    "cnn_lstm": cnn_lstm_scores,
    "transformer": transformer_scores,
    "config": CONFIG,
    "training_history": {
        "cnn_lstm": cnn_lstm_history,
        "transformer": transformer_history,
    },
}

with open(CONFIG["results_dir"] / "evaluation_results.json", 'w') as f:
    json.dump(results, f, indent=2)

print(f"\nResults saved to {CONFIG['results_dir'] / 'evaluation_results.json'}")
