# VQA End-to-End Pipeline

This notebook covers the full pipeline for the Visual Question Answering (VQA) project:

| Step | Description |
|------|-------------|
| 0 | Environment & GPU check |
| 1 | Build vocabulary from training data |
| 2 | Explore the dataset |
| 3 | Train a model (A / B / C / D) |
| 4 | Plot training curves |
| 5 | Evaluate a single model (all metrics) |
| 6 | Compare all 4 models side-by-side |
| 7 | Single-sample inference |
| 8 | Attention visualization (Model C / D) |

**Run from repository root:** `vqa_new/`

---
## Step 0 — Environment & GPU Check

In [None]:
import os, sys, json, warnings
import torch
warnings.filterwarnings('ignore')

# ── Ensure the notebook runs from the project root ──────────────────
ROOT = os.path.abspath('')          # notebook cwd
if ROOT.endswith('src'):            # handle case where notebook is ran from src/
    ROOT = os.path.dirname(ROOT)
os.chdir(ROOT)
sys.path.insert(0, os.path.join(ROOT, 'src'))

print(f"Working directory : {os.getcwd()}")
print(f"Python            : {sys.version.split()[0]}")
print(f"PyTorch           : {torch.__version__}")
print(f"CUDA available    : {torch.cuda.is_available()}")

if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        props = torch.cuda.get_device_properties(i)
        mem_gb = props.total_memory / 1e9
        print(f"GPU {i}             : {props.name}  ({mem_gb:.1f} GB)")

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nUsing device      : {DEVICE}")

### Global Path & Hyperparameter Configuration
Edit this cell to change model type, epochs, batch size, etc.

In [None]:
# ── Paths ───────────────────────────────────────────────────────────
TRAIN_IMAGE_DIR       = 'data/raw/images/train2014'
VAL_IMAGE_DIR         = 'data/raw/images/val2014'
TRAIN_QUESTION_JSON   = 'data/raw/vqa_json/v2_OpenEnded_mscoco_train2014_questions.json'
VAL_QUESTION_JSON     = 'data/raw/vqa_json/v2_OpenEnded_mscoco_val2014_questions.json'
TRAIN_ANNOTATION_JSON = 'data/raw/vqa_json/v2_mscoco_train2014_annotations.json'
VAL_ANNOTATION_JSON   = 'data/raw/vqa_json/v2_mscoco_val2014_annotations.json'
VOCAB_Q_PATH          = 'data/processed/vocab_questions.json'
VOCAB_A_PATH          = 'data/processed/vocab_answers.json'
CHECKPOINT_DIR        = 'checkpoints'

# ── Training hyperparameters ─────────────────────────────────────────
MODEL_TYPE    = 'A'     # 'A' | 'B' | 'C' | 'D'
EPOCHS        = 10
LR            = 1e-3
BATCH_SIZE    = 128
RESUME        = None    # set e.g. 'checkpoints/model_a_resume.pth' to continue training

# ── Scheduled Sampling (reduces exposure bias) ───────────────────────
# epsilon = SS_K / (SS_K + exp(epoch / SS_K))  — starts ~1, decays toward 0
SCHEDULED_SAMPLING = False
SS_K               = 5.0

# ── CNN Fine-tuning (models B and D only) ────────────────────────────
# Unfreezes ResNet layer3 + layer4 with a smaller LR to fine-tune high-level
# features for VQA without causing catastrophic forgetting of ImageNet knowledge.
# Recommended: train ~5 epochs frozen first, then resume with FINETUNE_CNN=True.
FINETUNE_CNN    = False   # Set True to unfreeze layer3+layer4 of ResNet backbone
CNN_LR_FACTOR   = 0.1     # backbone LR = LR × CNN_LR_FACTOR  (e.g. 1e-4 with default 1e-3)

# ── Dataset caps (None = use full dataset) ────────────────────────────
MAX_TRAIN_SAMPLES = None
MAX_VAL_SAMPLES   = None

# ── Evaluation / comparison settings ─────────────────────────────────
EVAL_EPOCH        = 10
COMPARE_EPOCH     = 10
COMPARE_MODELS    = ['A', 'B', 'C', 'D']
NUM_EVAL_SAMPLES  = None

# ── Decoding strategy ─────────────────────────────────────────────────
BEAM_WIDTH = 1    # 1 = greedy (fast); 3–5 = beam search (better quality, slower)

print("Configuration loaded.")
print(f"  Model             : {MODEL_TYPE}")
print(f"  Epochs            : {EPOCHS}  |  LR: {LR}")
print(f"  Batch size        : {BATCH_SIZE}")
print(f"  Scheduled Sampling: {SCHEDULED_SAMPLING}  (k={SS_K})")
print(f"  CNN Fine-tuning   : {FINETUNE_CNN}  (backbone LR factor={CNN_LR_FACTOR})")
print(f"  Beam Width        : {BEAM_WIDTH}")

---
## Step 1 — Build Vocabulary

Builds:
- `data/processed/vocab_questions.json` — question vocabulary (words appearing ≥ 3 times)
- `data/processed/vocab_answers.json` — answer vocabulary (most common answers, threshold ≥ 5)

> **Skip this step if the vocab files already exist.** The cell checks for existing files automatically.

In [None]:
from collections import Counter
from vocab import Vocabulary

os.makedirs('data/processed', exist_ok=True)

def build_vocab_if_missing():
    q_exists = os.path.exists(VOCAB_Q_PATH)
    a_exists = os.path.exists(VOCAB_A_PATH)

    if q_exists and a_exists:
        print("Vocab files already exist — skipping build.")
        print(f"  {VOCAB_Q_PATH}")
        print(f"  {VOCAB_A_PATH}")
        return

    # ── Question vocabulary ──────────────────────────────────────────
    if not q_exists:
        print(f"Building question vocab from: {TRAIN_QUESTION_JSON}")
        with open(TRAIN_QUESTION_JSON, 'r') as f:
            questions = json.load(f)['questions']
        q_texts = [q['question'] for q in questions]
        q_vocab = Vocabulary()
        q_vocab.build(q_texts, threshold=3)
        q_vocab.save(VOCAB_Q_PATH)
        print(f"  Saved  : {VOCAB_Q_PATH}  ({len(q_vocab)} tokens)")

    # ── Answer vocabulary ────────────────────────────────────────────
    if not a_exists:
        print(f"Building answer vocab from: {TRAIN_ANNOTATION_JSON}")
        with open(TRAIN_ANNOTATION_JSON, 'r') as f:
            annotations = json.load(f)['annotations']
        a_texts = [ann['multiple_choice_answer'] for ann in annotations]
        a_vocab = Vocabulary()
        a_vocab.build(a_texts, threshold=5)
        a_vocab.save(VOCAB_A_PATH)
        print(f"  Saved  : {VOCAB_A_PATH}  ({len(a_vocab)} tokens)")

build_vocab_if_missing()

# Load for use in later cells
vocab_q = Vocabulary(); vocab_q.load(VOCAB_Q_PATH)
vocab_a = Vocabulary(); vocab_a.load(VOCAB_A_PATH)
print(f"\nQuestion vocab size : {len(vocab_q)}")
print(f"Answer vocab size   : {len(vocab_a)}")

---
## Step 2 — Explore the Dataset

Load a few samples from the training set and visualize them.

In [None]:
from dataset import VQADataset, vqa_collate_fn

# Quick preview dataset (first 200 samples)
preview_dataset = VQADataset(
    image_dir=TRAIN_IMAGE_DIR,
    question_json_path=TRAIN_QUESTION_JSON,
    annotations_json_path=TRAIN_ANNOTATION_JSON,
    vocab_q=vocab_q,
    vocab_a=vocab_a,
    split='train2014',
    max_samples=200
)

print(f"Dataset size (preview): {len(preview_dataset)} samples")

img_tensor, q_tensor, a_tensor = preview_dataset[0]
print(f"Image tensor shape    : {img_tensor.shape}")
print(f"Question tensor shape : {q_tensor.shape}")
print(f"Answer tensor shape   : {a_tensor.shape}")

In [None]:
import matplotlib
matplotlib.use('Agg')     # headless-safe backend
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image as PILImage

def denormalize(t):
    """Convert ImageNet-normalized tensor -> (H, W, 3) numpy array in [0, 1]."""
    mean = np.array([0.485, 0.456, 0.406])
    std  = np.array([0.229, 0.224, 0.225])
    img  = t.permute(1, 2, 0).numpy()
    return np.clip(img * std + mean, 0, 1)

def decode_tensor_str(tensor, vocab):
    special = {vocab.word2idx.get(t, -1) for t in ('<pad>', '<start>', '<end>')}
    return ' '.join(vocab.idx2word[int(i)] for i in tensor if int(i) not in special)

# Show 6 random samples
n_show  = 6
indices = np.random.choice(len(preview_dataset), n_show, replace=False)

fig, axes = plt.subplots(2, 3, figsize=(14, 8))
for ax, idx in zip(axes.flat, indices):
    img, q, a = preview_dataset[idx]
    q_str = decode_tensor_str(q, vocab_q)
    a_str = decode_tensor_str(a, vocab_a)
    ax.imshow(denormalize(img))
    ax.set_title(f'Q: {q_str}\nA: {a_str}', fontsize=8, wrap=True)
    ax.axis('off')

plt.suptitle('Training Samples', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.savefig('checkpoints/dataset_preview.png', dpi=120, bbox_inches='tight')
plt.show()
print("Saved: checkpoints/dataset_preview.png")

---
## Step 3 — Train

Trains the model specified by `MODEL_TYPE` in the configuration cell.

**Features enabled:**
- Mixed precision (AMP) on CUDA
- `ReduceLROnPlateau` learning rate scheduler
- Best checkpoint saved to `checkpoints/model_X_best.pth`
- Full resume checkpoint saved to `checkpoints/model_X_resume.pth`
- Loss history saved to `checkpoints/history_model_X.json`

**To resume training:** set `RESUME = 'checkpoints/model_a_resume.pth'` in the config cell.

In [None]:
import math
import random
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast
import tqdm.notebook as tqdm_nb

from models.vqa_models import VQAModelA, VQAModelB, VQAModelC, VQAModelD, hadamard_fusion

def get_model(model_type, vocab_q_size, vocab_a_size):
    models_map = {
        'A': VQAModelA, 'B': VQAModelB, 'C': VQAModelC, 'D': VQAModelD
    }
    if model_type not in models_map:
        raise ValueError(f"Unknown model type: {model_type}. Choose A/B/C/D.")
    return models_map[model_type](vocab_size=vocab_q_size, answer_vocab_size=vocab_a_size)


def ss_forward(model, model_type, imgs, questions, decoder_input, epsilon):
    """
    Scheduled Sampling forward pass (Bengio et al. 2015).
    epsilon=1.0 -> pure teacher forcing; epsilon=0.0 -> fully autoregressive.
    """
    max_len = decoder_input.size(1)

    if model_type in ('C', 'D'):
        img_features = F.normalize(model.i_encoder(imgs), p=2, dim=-1)
        q_feat   = model.q_encoder(questions)
        fusion   = hadamard_fusion(img_features.mean(dim=1), q_feat)
    else:
        img_feat = F.normalize(model.i_encoder(imgs), p=2, dim=1)
        q_feat   = model.q_encoder(questions)
        fusion   = hadamard_fusion(img_feat, q_feat)

    h = fusion.unsqueeze(0).repeat(model.num_layers, 1, 1)
    c = torch.zeros_like(h)
    hidden = (h, c)

    current_token = decoder_input[:, 0]
    logits_list   = []

    for t in range(max_len):
        tok = current_token.unsqueeze(1)
        if model_type in ('C', 'D'):
            logit, hidden, _ = model.decoder.decode_step(tok, hidden, img_features)
        else:
            emb         = model.decoder.dropout(model.decoder.embedding(tok))
            out, hidden = model.decoder.lstm(emb, hidden)
            logit       = model.decoder.fc(out.squeeze(1))
        logits_list.append(logit)
        if t < max_len - 1:
            current_token = (decoder_input[:, t + 1] if random.random() < epsilon
                             else logit.detach().argmax(dim=-1))

    return torch.stack(logits_list, dim=1)


def train_model(model_type=MODEL_TYPE, epochs=EPOCHS, lr=LR,
                batch_size=BATCH_SIZE, resume=RESUME,
                scheduled_sampling=SCHEDULED_SAMPLING, ss_k=SS_K,
                finetune_cnn=FINETUNE_CNN, cnn_lr_factor=CNN_LR_FACTOR):

    os.makedirs(CHECKPOINT_DIR, exist_ok=True)

    train_dataset = VQADataset(
        image_dir=TRAIN_IMAGE_DIR, question_json_path=TRAIN_QUESTION_JSON,
        annotations_json_path=TRAIN_ANNOTATION_JSON,
        vocab_q=vocab_q, vocab_a=vocab_a,
        split='train2014', max_samples=MAX_TRAIN_SAMPLES
    )
    val_dataset = VQADataset(
        image_dir=VAL_IMAGE_DIR, question_json_path=VAL_QUESTION_JSON,
        annotations_json_path=VAL_ANNOTATION_JSON,
        vocab_q=vocab_q, vocab_a=vocab_a,
        split='val2014', max_samples=MAX_VAL_SAMPLES
    )
    print(f"Train: {len(train_dataset):,} | Val: {len(val_dataset):,}")

    pin = torch.cuda.is_available()
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                              collate_fn=vqa_collate_fn, num_workers=4, pin_memory=pin)
    val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False,
                              collate_fn=vqa_collate_fn, num_workers=4, pin_memory=pin)

    model     = get_model(model_type, len(vocab_q), len(vocab_a)).to(DEVICE)
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    # ── Differential LR when fine-tuning pretrained CNN backbone ─────────────
    # Models B/D: unfreeze ResNet layer3+layer4 at low LR to avoid forgetting.
    # Models A/C: scratch CNN — finetune_cnn has no effect.
    if finetune_cnn and model_type in ('B', 'D'):
        model.i_encoder.unfreeze_top_layers()
        backbone_ids    = {id(p) for p in model.i_encoder.backbone_params()}
        backbone_params = [p for p in model.parameters()
                           if id(p) in backbone_ids and p.requires_grad]
        other_params    = [p for p in model.parameters()
                           if id(p) not in backbone_ids and p.requires_grad]
        optimizer = optim.Adam([
            {'params': other_params,    'lr': lr},
            {'params': backbone_params, 'lr': lr * cnn_lr_factor},
        ])
        print(f"CNN fine-tuning  : ON | backbone LR={lr*cnn_lr_factor:.2e}  other LR={lr:.2e}")
        print(f"  backbone trainable params: {sum(p.numel() for p in backbone_params):,}")
    else:
        optimizer = optim.Adam(model.parameters(), lr=lr)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2
    )
    use_amp = torch.cuda.is_available()
    scaler  = GradScaler(enabled=use_amp)

    history       = {'train_loss': [], 'val_loss': []}
    history_path  = f"{CHECKPOINT_DIR}/history_model_{model_type.lower()}.json"
    best_val_loss = float('inf')
    start_epoch   = 0

    if resume and os.path.exists(resume):
        print(f"Resuming from: {resume}")
        ckpt = torch.load(resume, map_location=lambda s, l: s)
        model.load_state_dict(ckpt['model_state_dict'])
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        scheduler.load_state_dict(ckpt['scheduler_state_dict'])
        scaler.load_state_dict(ckpt['scaler_state_dict'])
        start_epoch   = ckpt['epoch']
        best_val_loss = ckpt['best_val_loss']
        history       = ckpt.get('history', history)
        print(f"  Resumed at epoch {start_epoch} | best_val_loss: {best_val_loss:.4f}")

    print(f"\nModel: {model_type} | Device: {DEVICE}")
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable params: {total_params:,}")
    if scheduled_sampling:
        eps0 = ss_k / (ss_k + math.exp(start_epoch / ss_k))
        epsN = ss_k / (ss_k + math.exp((start_epoch + epochs - 1) / ss_k))
        print(f"Scheduled Sampling: ON | k={ss_k} | epsilon {eps0:.2f} -> {epsN:.2f}")

    for epoch in tqdm_nb.trange(start_epoch, start_epoch + epochs, desc=f'Model {model_type}'):
        model.train()
        total_loss = 0
        for imgs, questions, answer in train_loader:
            imgs, questions, answer = imgs.to(DEVICE), questions.to(DEVICE), answer.to(DEVICE)
            dec_in  = answer[:, :-1]
            dec_tgt = answer[:, 1:]
            optimizer.zero_grad()
            with autocast(enabled=use_amp):
                if scheduled_sampling:
                    epsilon = ss_k / (ss_k + math.exp(epoch / ss_k))
                    logits  = ss_forward(model, model_type, imgs, questions, dec_in, epsilon)
                else:
                    logits  = model(imgs, questions, dec_in)
                loss = criterion(logits.view(-1, logits.size(-1)), dec_tgt.contiguous().view(-1))
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            scaler.step(optimizer)
            scaler.update()
            total_loss += loss.item()
        avg_train = total_loss / len(train_loader)

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for imgs, questions, answer in val_loader:
                imgs, questions, answer = imgs.to(DEVICE), questions.to(DEVICE), answer.to(DEVICE)
                with autocast(enabled=use_amp):
                    logits = model(imgs, questions, answer[:, :-1])
                    loss   = criterion(logits.view(-1, logits.size(-1)),
                                       answer[:, 1:].contiguous().view(-1))
                val_loss += loss.item()
        avg_val    = val_loss / len(val_loader)
        current_lr = optimizer.param_groups[0]['lr']

        log = (f"Epoch {epoch+1}/{start_epoch+epochs} | Train: {avg_train:.4f} "
               f"| Val: {avg_val:.4f} | LR: {current_lr:.2e}")
        if scheduled_sampling:
            log += f" | SS ε: {ss_k / (ss_k + math.exp(epoch / ss_k)):.3f}"
        print(log)

        scheduler.step(avg_val)
        history['train_loss'].append(avg_train)
        history['val_loss'].append(avg_val)
        with open(history_path, 'w') as f:
            json.dump(history, f, indent=2)

        torch.save(model.state_dict(),
                   f"{CHECKPOINT_DIR}/model_{model_type.lower()}_epoch{epoch+1}.pth")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict':     model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'scaler_state_dict':    scaler.state_dict(),
            'best_val_loss':        best_val_loss,
            'history':              history,
        }, f"{CHECKPOINT_DIR}/model_{model_type.lower()}_resume.pth")

        if avg_val < best_val_loss:
            best_val_loss = avg_val
            torch.save(model.state_dict(),
                       f"{CHECKPOINT_DIR}/model_{model_type.lower()}_best.pth")
            print(f"  -> New best val loss: {best_val_loss:.4f}. Saved best checkpoint.")

    print(f"\nTraining complete. Best val loss: {best_val_loss:.4f}")
    return history


history = train_model()

---
## Step 4 — Plot Training Curves

In [None]:
MODEL_COLORS = {'A': '#1f77b4', 'B': '#ff7f0e', 'C': '#2ca02c', 'D': '#d62728'}
MODEL_LABELS = {
    'A': 'A — Scratch CNN, No Attn',
    'B': 'B — ResNet101, No Attn',
    'C': 'C — Scratch CNN, Attention',
    'D': 'D — ResNet101, Attention',
}

def plot_all_curves(model_types=None, save_path='checkpoints/training_curves.png'):
    if model_types is None:
        model_types = ['A', 'B', 'C', 'D']

    fig, (ax_train, ax_val) = plt.subplots(1, 2, figsize=(14, 5))
    plotted = False

    for mt in model_types:
        path = f"{CHECKPOINT_DIR}/history_model_{mt.lower()}.json"
        if not os.path.exists(path):
            print(f"  [SKIP] {path} not found")
            continue
        with open(path) as f:
            h = json.load(f)
        eps   = range(1, len(h['train_loss']) + 1)
        color = MODEL_COLORS[mt]
        label = MODEL_LABELS[mt]
        ax_train.plot(eps, h['train_loss'], 'o-', ms=3, color=color, label=label)
        ax_val.plot(eps,   h['val_loss'],   's--', ms=3, color=color, label=label)
        plotted = True

    if not plotted:
        print("No history files found. Train at least one model first.")
        return

    for ax, title in [(ax_train, 'Training Loss'), (ax_val, 'Validation Loss')]:
        ax.set_title(title, fontsize=12)
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        ax.legend(fontsize=9)
        ax.grid(True, alpha=0.3)

    fig.suptitle('VQA Training Curves', fontsize=14, fontweight='bold')
    plt.tight_layout()
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f"Saved: {save_path}")

plot_all_curves()

---
## Step 5 — Evaluate a Single Model

Computes all metrics on the full validation split:
- **VQA Accuracy** — `min(matching_annotations / 3, 1.0)` per sample
- **Exact Match**
- **BLEU-1 / 2 / 3 / 4**
- **METEOR**

In [None]:
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.translate.meteor_score import meteor_score
import torch.nn.functional as F

nltk.download('wordnet', quiet=True)
nltk.download('omw-1.4', quiet=True)

from inference import (
    get_model as inf_get_model,
    batch_greedy_decode,
    batch_greedy_decode_with_attention,
    batch_beam_search_decode,
    batch_beam_search_decode_with_attention,
)

def decode_tensor_str(a_tensor, vocab):
    special = {vocab.word2idx.get(t, -1) for t in ('<pad>', '<start>', '<end>')}
    return ' '.join(vocab.idx2word[int(i)] for i in a_tensor if int(i) not in special)


def run_evaluate(model_type=MODEL_TYPE, epoch=EVAL_EPOCH,
                 num_samples=NUM_EVAL_SAMPLES, beam_width=BEAM_WIDTH):
    checkpoint = f"{CHECKPOINT_DIR}/model_{model_type.lower()}_epoch{epoch}.pth"
    if not os.path.exists(checkpoint):
        print(f"Checkpoint not found: {checkpoint}")
        return None

    val_dataset = VQADataset(
        image_dir=VAL_IMAGE_DIR,
        question_json_path=VAL_QUESTION_JSON,
        annotations_json_path=VAL_ANNOTATION_JSON,
        vocab_q=vocab_q, vocab_a=vocab_a,
        split='val2014', max_samples=num_samples
    )

    with open(VAL_ANNOTATION_JSON) as f:
        raw_anns = json.load(f)['annotations']
    qid_to_all = {
        ann['question_id']: [a['answer'].lower().strip() for a in ann['answers']]
        for ann in raw_anns
    }
    question_ids = [q['question_id'] for q in val_dataset.questions]

    model = inf_get_model(model_type, len(vocab_q), len(vocab_a))
    model.load_state_dict(torch.load(checkpoint, map_location=lambda s, l: s))
    model.to(DEVICE).eval()

    use_attention = model_type in ('C', 'D')
    if beam_width > 1:
        decode_fn     = batch_beam_search_decode_with_attention if use_attention \
                        else batch_beam_search_decode
        decode_kwargs = dict(beam_width=beam_width)
        decode_label  = f'beam (width={beam_width})'
    else:
        decode_fn     = batch_greedy_decode_with_attention if use_attention \
                        else batch_greedy_decode
        decode_kwargs = {}
        decode_label  = 'greedy'

    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False,
                            collate_fn=vqa_collate_fn, num_workers=2)

    smoothie  = SmoothingFunction().method1
    all_preds = []
    all_gts   = []

    print(f"Evaluating Model {model_type} | {len(val_dataset):,} samples | decode: {decode_label}")
    with torch.no_grad():
        for imgs, qs, ans in tqdm_nb.tqdm(val_loader, desc='Eval'):
            preds = decode_fn(model, imgs, qs, vocab_a, device=DEVICE, **decode_kwargs)
            all_preds.extend(preds)
            for a_t in ans:
                all_gts.append(decode_tensor_str(a_t, vocab_a))

    n = len(all_preds)
    em = vqa_acc = b1 = b2 = b3 = b4 = met = 0.0

    for idx, (pred, gt) in enumerate(zip(all_preds, all_gts)):
        p = pred.strip().lower()
        g = gt.strip().lower()
        if p == g:
            em += 1
        qid     = question_ids[idx]
        answers = qid_to_all.get(qid, [g])
        vqa_acc += min(sum(1 for a in answers if a == p) / 3.0, 1.0)
        gw = gt.split() or ['<unk>']
        pw = pred.split() or ['<unk>']
        b1  += sentence_bleu([gw], pw, weights=(1,0,0,0),             smoothing_function=smoothie)
        b2  += sentence_bleu([gw], pw, weights=(0.5,0.5,0,0),         smoothing_function=smoothie)
        b3  += sentence_bleu([gw], pw, weights=(1/3,1/3,1/3,0),       smoothing_function=smoothie)
        b4  += sentence_bleu([gw], pw, weights=(0.25,0.25,0.25,0.25), smoothing_function=smoothie)
        met += meteor_score([gw], pw)

    results = {
        'model_type':   model_type,
        'checkpoint':   checkpoint,
        'n':            n,
        'decode':       decode_label,
        'vqa_accuracy': vqa_acc / n * 100,
        'exact_match':  em / n * 100,
        'bleu1':        b1 / n,
        'bleu2':        b2 / n,
        'bleu3':        b3 / n,
        'bleu4':        b4 / n,
        'meteor':       met / n,
    }

    print(f"\n{'='*52}")
    print(f"  Model        : {results['model_type']}")
    print(f"  Checkpoint   : {results['checkpoint']}")
    print(f"  Decode       : {results['decode']}")
    print(f"  Samples      : {results['n']:,}")
    print(f"  {'-'*48}")
    print(f"  VQA Accuracy : {results['vqa_accuracy']:>7.2f}%")
    print(f"  Exact Match  : {results['exact_match']:>7.2f}%")
    print(f"  BLEU-1       : {results['bleu1']:>8.4f}")
    print(f"  BLEU-2       : {results['bleu2']:>8.4f}")
    print(f"  BLEU-3       : {results['bleu3']:>8.4f}")
    print(f"  BLEU-4       : {results['bleu4']:>8.4f}")
    print(f"  METEOR       : {results['meteor']:>8.4f}")
    print(f"{'='*52}")
    return results


eval_results = run_evaluate()

---
## Step 6 — Compare All 4 Models

Evaluates each model (A / B / C / D) and prints a side-by-side comparison table.
Models without a checkpoint at the specified epoch are skipped automatically.

In [None]:
def compare_models(model_types=COMPARE_MODELS, epoch=COMPARE_EPOCH,
                   num_samples=NUM_EVAL_SAMPLES, beam_width=BEAM_WIDTH):
    val_dataset = VQADataset(
        image_dir=VAL_IMAGE_DIR,
        question_json_path=VAL_QUESTION_JSON,
        annotations_json_path=VAL_ANNOTATION_JSON,
        vocab_q=vocab_q, vocab_a=vocab_a,
        split='val2014', max_samples=num_samples
    )

    with open(VAL_ANNOTATION_JSON) as f:
        raw_anns = json.load(f)['annotations']
    qid_to_all = {
        ann['question_id']: [a['answer'].lower().strip() for a in ann['answers']]
        for ann in raw_anns
    }
    question_ids = [q['question_id'] for q in val_dataset.questions]

    decode_label = f'beam (width={beam_width})' if beam_width > 1 else 'greedy'
    print(f"Comparing: {model_types} | epoch={epoch} | samples={len(val_dataset):,} | decode={decode_label}")

    table_rows = {}
    smoothie   = SmoothingFunction().method1

    for mt in model_types:
        ckpt = f"{CHECKPOINT_DIR}/model_{mt.lower()}_epoch{epoch}.pth"
        if not os.path.exists(ckpt):
            print(f"  [SKIP] {ckpt} not found.")
            table_rows[mt] = None
            continue

        model = inf_get_model(mt, len(vocab_q), len(vocab_a))
        model.load_state_dict(torch.load(ckpt, map_location=lambda s, l: s))
        model.to(DEVICE).eval()

        use_attention = mt in ('C', 'D')
        if beam_width > 1:
            decode_fn     = batch_beam_search_decode_with_attention if use_attention \
                            else batch_beam_search_decode
            decode_kwargs = dict(beam_width=beam_width)
        else:
            decode_fn     = batch_greedy_decode_with_attention if use_attention \
                            else batch_greedy_decode
            decode_kwargs = {}

        val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False,
                                collate_fn=vqa_collate_fn, num_workers=2)

        all_preds, all_gts = [], []
        with torch.no_grad():
            for imgs, qs, ans in tqdm_nb.tqdm(val_loader, desc=f'Model {mt}', leave=False):
                preds = decode_fn(model, imgs, qs, vocab_a, device=DEVICE, **decode_kwargs)
                all_preds.extend(preds)
                for a_t in ans:
                    all_gts.append(decode_tensor_str(a_t, vocab_a))

        n = len(all_preds)
        em = vqa_acc = b1 = b2 = b3 = b4 = met = 0.0

        for idx, (pred, gt) in enumerate(zip(all_preds, all_gts)):
            p = pred.strip().lower()
            g = gt.strip().lower()
            if p == g:
                em += 1
            qid     = question_ids[idx]
            answers = qid_to_all.get(qid, [g])
            vqa_acc += min(sum(1 for a in answers if a == p) / 3.0, 1.0)
            gw = gt.split() or ['<unk>']
            pw = pred.split() or ['<unk>']
            b1  += sentence_bleu([gw], pw, weights=(1,0,0,0),             smoothing_function=smoothie)
            b2  += sentence_bleu([gw], pw, weights=(0.5,0.5,0,0),         smoothing_function=smoothie)
            b3  += sentence_bleu([gw], pw, weights=(1/3,1/3,1/3,0),       smoothing_function=smoothie)
            b4  += sentence_bleu([gw], pw, weights=(0.25,0.25,0.25,0.25), smoothing_function=smoothie)
            met += meteor_score([gw], pw)

        table_rows[mt] = {
            'vqa_accuracy': vqa_acc / n * 100,
            'exact_match':  em / n * 100,
            'bleu1': b1/n, 'bleu2': b2/n, 'bleu3': b3/n, 'bleu4': b4/n,
            'meteor': met/n, 'n': n
        }

    hdr = f"{'Model':<7} {'VQA Acc':>9} {'Exact':>8} {'BLEU-1':>8} {'BLEU-2':>8} {'BLEU-3':>8} {'BLEU-4':>8} {'METEOR':>8}"
    print()
    print(hdr)
    print('-' * len(hdr))
    for mt in sorted(table_rows):
        r = table_rows[mt]
        if r is None:
            print(f"{mt:<7} {'N/A':>9} {'N/A':>8} {'N/A':>8} {'N/A':>8} {'N/A':>8} {'N/A':>8} {'N/A':>8}")
        else:
            print(f"{mt:<7}"
                  f" {r['vqa_accuracy']:>8.2f}%"
                  f" {r['exact_match']:>7.2f}%"
                  f" {r['bleu1']:>8.4f}"
                  f" {r['bleu2']:>8.4f}"
                  f" {r['bleu3']:>8.4f}"
                  f" {r['bleu4']:>8.4f}"
                  f" {r['meteor']:>8.4f}")
    print()
    return table_rows


compare_results = compare_models()

In [None]:
# Bar chart: VQA Accuracy & BLEU-1 across all models
valid = {mt: r for mt, r in compare_results.items() if r is not None}

if valid:
    labels   = sorted(valid.keys())
    vqa_vals = [valid[m]['vqa_accuracy'] for m in labels]
    b1_vals  = [valid[m]['bleu1'] * 100 for m in labels]  # scale to %
    met_vals = [valid[m]['meteor'] * 100 for m in labels]

    x  = np.arange(len(labels))
    w  = 0.26
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.bar(x - w, vqa_vals, w, label='VQA Acc (%)',   color='#1f77b4')
    ax.bar(x,     b1_vals,  w, label='BLEU-1 ×100',   color='#ff7f0e')
    ax.bar(x + w, met_vals, w, label='METEOR ×100',   color='#2ca02c')

    ax.set_xticks(x)
    ax.set_xticklabels([f'Model {m}' for m in labels], fontsize=11)
    ax.set_ylabel('Score')
    ax.set_title('Model Comparison — VQA Acc / BLEU-1 / METEOR', fontsize=13)
    ax.legend()
    ax.grid(axis='y', alpha=0.3)

    plt.tight_layout()
    plt.savefig('checkpoints/comparison_bar.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("Saved: checkpoints/comparison_bar.png")
else:
    print("No valid results to plot.")

---
## Step 7 — Single-Sample Inference

Run all 4 models on the **same image + question** and compare their predicted answers.

In [None]:
from torchvision import transforms as T
from inference import greedy_decode, greedy_decode_with_attention

# Choose a sample from the val set
SAMPLE_IDX = 0   # change to any index in the val question JSON

transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

with open(VAL_QUESTION_JSON) as f:
    val_questions = json.load(f)['questions']
with open(VAL_ANNOTATION_JSON) as f:
    val_annotations = json.load(f)['annotations']

qid2gt = {
    ann['question_id']: ann['multiple_choice_answer']
    for ann in val_annotations
}

sample    = val_questions[SAMPLE_IDX]
q_text    = sample['question']
img_id    = sample['image_id']
qid       = sample['question_id']
gt_answer = qid2gt.get(qid, '?')
img_path  = os.path.join(VAL_IMAGE_DIR, f'COCO_val2014_{img_id:012d}.jpg')

original_img = PILImage.open(img_path).convert('RGB')
img_tensor   = transform(original_img)
q_tensor     = torch.tensor(vocab_q.numericalize(q_text), dtype=torch.long)

print(f"Question  : {q_text}")
print(f"GT Answer : {gt_answer}")
print(f"Image     : {img_path}")

In [None]:
predictions = {}

for mt in ['A', 'B', 'C', 'D']:
    ckpt = f"{CHECKPOINT_DIR}/model_{mt.lower()}_epoch{EVAL_EPOCH}.pth"
    if not os.path.exists(ckpt):
        predictions[mt] = '(checkpoint missing)'
        continue
    model = inf_get_model(mt, len(vocab_q), len(vocab_a))
    model.load_state_dict(torch.load(ckpt, map_location=lambda s, l: s))
    model.to(DEVICE).eval()

    if mt in ('C', 'D'):
        ans = greedy_decode_with_attention(model, img_tensor, q_tensor, vocab_a, device=DEVICE)
    else:
        ans = greedy_decode(model, img_tensor, q_tensor, vocab_a, device=DEVICE)

    predictions[mt] = ans

# Show image + results
fig, axes = plt.subplots(1, 2, figsize=(12, 5), gridspec_kw={'width_ratios': [1.2, 1]})

axes[0].imshow(original_img)
axes[0].set_title(f'Q: {q_text}', fontsize=10, wrap=True)
axes[0].axis('off')

table_text  = [[mt, pred] for mt, pred in sorted(predictions.items())]
table_text += [['GT', gt_answer]]
col_labels  = ['Model', 'Predicted Answer']
tbl = axes[1].table(cellText=table_text, colLabels=col_labels,
                    loc='center', cellLoc='left')
tbl.scale(1, 2)
tbl.auto_set_font_size(False)
tbl.set_fontsize(11)
axes[1].axis('off')
axes[1].set_title('Predictions vs Ground Truth', fontsize=11)

plt.tight_layout()
plt.savefig('checkpoints/inference_sample.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nResults:")
for mt, pred in sorted(predictions.items()):
    match = '✓' if pred.strip().lower() == gt_answer.strip().lower() else ' '
    print(f"  [{match}] Model {mt}: {pred}")
print(f"       GT:       {gt_answer}")

---
## Step 8 — Attention Visualization (Model C / D)

For each generated answer token, show which image regions the decoder attended to.
The heatmap uses the `alpha` weights from Bahdanau attention (`alpha` shape: `(49,)` → reshaped to `7×7`).

In [None]:
from models.vqa_models import hadamard_fusion

ATTN_MODEL_TYPE = 'C'   # change to 'D' for ResNet + attention
ATTN_SAMPLE_IDX = SAMPLE_IDX


def visualize_attention_notebook(model_type, sample_idx, epoch=EVAL_EPOCH,
                                 save_path=None, max_tokens=12):
    ckpt = f"{CHECKPOINT_DIR}/model_{model_type.lower()}_epoch{epoch}.pth"
    if not os.path.exists(ckpt):
        print(f"Checkpoint not found: {ckpt}")
        return

    model = inf_get_model(model_type, len(vocab_q), len(vocab_a))
    model.load_state_dict(torch.load(ckpt, map_location=lambda s, l: s))
    model.to(DEVICE).eval()

    # Load sample
    sample   = val_questions[sample_idx]
    q_text   = sample['question']
    img_id   = sample['image_id']
    img_path = os.path.join(VAL_IMAGE_DIR, f'COCO_val2014_{img_id:012d}.jpg')

    raw_img  = PILImage.open(img_path).convert('RGB')
    img_t    = transform(raw_img)
    q_t      = torch.tensor(vocab_q.numericalize(q_text), dtype=torch.long)

    # ── Decode step-by-step, collect alphas ──────────────────────────
    tokens, alphas = [], []
    with torch.no_grad():
        img  = img_t.unsqueeze(0).to(DEVICE)
        q    = q_t.unsqueeze(0).to(DEVICE)

        img_features = model.i_encoder(img)
        img_features = F.normalize(img_features, p=2, dim=-1)  # (1, 49, hidden)
        q_feat       = model.q_encoder(q)                      # (1, hidden)
        img_mean     = img_features.mean(dim=1)                # (1, hidden)
        fusion       = hadamard_fusion(img_mean, q_feat)

        h = fusion.unsqueeze(0).repeat(model.num_layers, 1, 1)
        c = torch.zeros_like(h)
        hidden = (h, c)

        start_idx = vocab_a.word2idx['<start>']
        end_idx   = vocab_a.word2idx['<end>']
        tok = torch.tensor([[start_idx]], dtype=torch.long, device=DEVICE)

        for _ in range(max_tokens):
            logit, hidden, alpha = model.decoder.decode_step(tok, hidden, img_features)
            pred = logit.argmax(dim=-1).item()
            if pred == end_idx:
                break
            word = vocab_a.idx2word.get(pred, '<unk>')
            tokens.append(word)
            alphas.append(alpha.squeeze(0).cpu().numpy())  # (49,)
            tok = torch.tensor([[pred]], dtype=torch.long, device=DEVICE)

    if not tokens:
        print("No tokens decoded.")
        return

    # ── Plot ─────────────────────────────────────────────────────────
    n_cols = len(tokens) + 1
    fig, axes = plt.subplots(1, n_cols, figsize=(3 * n_cols, 4))

    img_np = np.array(raw_img.resize((224, 224))) / 255.0

    axes[0].imshow(img_np)
    axes[0].set_title('Original', fontsize=9)
    axes[0].axis('off')

    for i, (word, alpha) in enumerate(zip(tokens, alphas)):
        attn_map = alpha.reshape(7, 7)
        attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
        attn_up  = np.array(
            PILImage.fromarray((attn_map * 255).astype(np.uint8)).resize((224, 224), PILImage.BILINEAR)
        ) / 255.0
        axes[i + 1].imshow(img_np)
        axes[i + 1].imshow(attn_up, alpha=0.55, cmap='jet')
        axes[i + 1].set_title(f'"{word}"', fontsize=9)
        axes[i + 1].axis('off')

    answer = ' '.join(tokens)
    gt_ans = qid2gt.get(sample['question_id'], '?')
    fig.suptitle(f'Model {model_type} | Q: {q_text}\nPred: {answer}  |  GT: {gt_ans}',
                 fontsize=10, fontweight='bold')
    plt.tight_layout()

    if save_path is None:
        save_path = f"{CHECKPOINT_DIR}/attn_model_{model_type.lower()}.png"
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f"Question : {q_text}")
    print(f"Predicted: {answer}")
    print(f"GT answer: {gt_ans}")
    print(f"Saved    : {save_path}")


visualize_attention_notebook(ATTN_MODEL_TYPE, ATTN_SAMPLE_IDX)

In [None]:
# Visualize attention on multiple samples side-by-side
N_SAMPLES_ATTN = 4   # how many samples to show

ckpt = f"{CHECKPOINT_DIR}/model_{ATTN_MODEL_TYPE.lower()}_epoch{EVAL_EPOCH}.pth"
if os.path.exists(ckpt):
    model_attn = inf_get_model(ATTN_MODEL_TYPE, len(vocab_q), len(vocab_a))
    model_attn.load_state_dict(torch.load(ckpt, map_location=lambda s, l: s))
    model_attn.to(DEVICE).eval()

    sample_indices = np.random.choice(min(500, len(val_questions)), N_SAMPLES_ATTN, replace=False)

    fig, axes = plt.subplots(N_SAMPLES_ATTN, 1, figsize=(20, 5 * N_SAMPLES_ATTN))
    if N_SAMPLES_ATTN == 1:
        axes = [axes]

    for row_ax, s_idx in zip(axes, sample_indices):
        s        = val_questions[s_idx]
        q_txt    = s['question']
        i_id     = s['image_id']
        i_path   = os.path.join(VAL_IMAGE_DIR, f'COCO_val2014_{i_id:012d}.jpg')
        raw      = PILImage.open(i_path).convert('RGB')
        img_t    = transform(raw)
        q_t      = torch.tensor(vocab_q.numericalize(q_txt), dtype=torch.long)

        toks, alps = [], []
        with torch.no_grad():
            im  = img_t.unsqueeze(0).to(DEVICE)
            qq  = q_t.unsqueeze(0).to(DEVICE)
            imf = F.normalize(model_attn.i_encoder(im), p=2, dim=-1)
            qf  = model_attn.q_encoder(qq)
            fus = hadamard_fusion(imf.mean(dim=1), qf)
            h   = fus.unsqueeze(0).repeat(model_attn.num_layers, 1, 1)
            hid = (h, torch.zeros_like(h))
            tok = torch.tensor([[vocab_a.word2idx['<start>']]], dtype=torch.long, device=DEVICE)
            for _ in range(10):
                logit, hid, alpha = model_attn.decoder.decode_step(tok, hid, imf)
                pred = logit.argmax(dim=-1).item()
                if pred == vocab_a.word2idx['<end>']:
                    break
                toks.append(vocab_a.idx2word.get(pred, '<unk>'))
                alps.append(alpha.squeeze(0).cpu().numpy())
                tok = torch.tensor([[pred]], dtype=torch.long, device=DEVICE)

        if not toks:
            continue

        n_cols = min(len(toks), 8) + 1
        inner_fig, inner_axes = plt.subplots(1, n_cols, figsize=(3 * n_cols, 3.5))
        img_np = np.array(raw.resize((224, 224))) / 255.0
        inner_axes[0].imshow(img_np)
        inner_axes[0].set_title('Orig', fontsize=8); inner_axes[0].axis('off')
        for j, (w, a) in enumerate(zip(toks[:n_cols-1], alps[:n_cols-1])):
            am = (a.reshape(7,7)); am = (am - am.min()) / (am.max() - am.min() + 1e-8)
            au = np.array(PILImage.fromarray((am*255).astype(np.uint8)).resize((224,224), PILImage.BILINEAR)) / 255.
            inner_axes[j+1].imshow(img_np)
            inner_axes[j+1].imshow(au, alpha=0.5, cmap='jet')
            inner_axes[j+1].set_title(f'"{w}"', fontsize=8)
            inner_axes[j+1].axis('off')
        ans = ' '.join(toks)
        gt  = qid2gt.get(s['question_id'], '?')
        inner_fig.suptitle(f'Q: {q_txt}  |  Pred: {ans}  |  GT: {gt}', fontsize=9, fontweight='bold')
        inner_fig.tight_layout()
        inner_fig.savefig(f"{CHECKPOINT_DIR}/attn_multi_{s_idx}.png", dpi=120, bbox_inches='tight')
        plt.show()
        plt.close(inner_fig)
else:
    print(f"Model {ATTN_MODEL_TYPE} checkpoint not found at: {ckpt}")

---
## Summary

| Step | Output |
|------|--------|
| Vocab | `data/processed/vocab_questions.json`, `vocab_answers.json` |
| Train | `checkpoints/model_X_epoch*.pth`, `model_X_best.pth`, `model_X_resume.pth`, `history_model_X.json` |
| Curves | `checkpoints/training_curves.png` |
| Evaluation | Printed table: VQA Acc / Exact Match / BLEU-1~4 / METEOR |
| Comparison | `checkpoints/comparison_bar.png` |
| Inference | `checkpoints/inference_sample.png` |
| Attention | `checkpoints/attn_model_C.png`, `attn_multi_*.png` |

**To re-run a specific step**, simply re-execute that cell — all variables from earlier cells are available throughout the notebook session.