# Deep Machine Learning Project (SSY340)

Project Group 92

## Setup:

### Imports and CUDA setup:

In [None]:
import os, math, random, torch, gc, ast, re
from collections import Counter, defaultdict
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
from transformers import AutoImageProcessor

print("\nPyTorch:", torch.__version__)
print("CUDA version:", torch.version.cuda)
print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
DEVICE = torch.device("cuda")
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
PIN_MEMORY = True

SEED = 0
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Method for clearing cache and GPU memory
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

### Constants:

In [None]:
# Paths
IMAGES_DIR = "flickr30k-images"
CSV_PATH = "flickr_annotations_30k.csv"
CAPTIONS_FILE = "Flickr30k.token.txt"

NUM_WORKERS = 0
MIN_FREQ = 5 # Minimum frequency for vocab, lower value means slower training but bigger vocabulary
MAX_LEN = 50
EVAL_MAX_LEN = 30
PRINT_EVERY = 10

TRAIN_RATIO = 0.8
VAL_RATIO = 0.1
# TEST_RATIO = 0.1

### Setup Flickr30k.token.txt (captions):

In [None]:
def _get_image_name(row, df):
    for col in ('file_name','filename','image','img','image_filename','img_name','image_name','img_id','image_id','path'):
        if col in df.columns:
            val = row.get(col)
            if pd.isna(val):
                continue
            return os.path.basename(str(val))
    return f"{row.name}.jpg"

def _get_captions(row, df):
    for col in ('raw','captions','sentences','sentence','caption','raw_captions','sentids'):
        if col in df.columns:
            val = row.get(col)
            if pd.isna(val):
                continue
            if isinstance(val, (list, tuple)):
                return [str(x).strip() for x in val if str(x).strip()]
            if isinstance(val, str):
                try:
                    parsed = ast.literal_eval(val)
                    if isinstance(parsed, (list, tuple)):
                        return [str(x).strip() for x in parsed if str(x).strip()]
                    if isinstance(parsed, dict) and 'raw' in parsed:
                        r = parsed['raw']
                        if isinstance(r, (list, tuple)):
                            return [str(x).strip() for x in r if str(x).strip()]
                except Exception:
                    pass
                for sep in ('|||', '||', '\n'):
                    if sep in val:
                        return [s.strip() for s in val.split(sep) if s.strip()]
                return [val.strip()]
    return []

def generate_token_file_from_csv(csv_path, captions_file):
    df = pd.read_csv(csv_path, low_memory=False)
    print("CSV columns:", list(df.columns))
    with open(captions_file, 'w', encoding='utf-8') as fout:
        for _, row in df.iterrows():
            img_name = _get_image_name(row, df)
            caps = _get_captions(row, df)
            if not caps:
                continue
            for i, c in enumerate(caps):
                fout.write(f"{img_name}#{i}\t{c}\n")
    print("Wrote token file:", captions_file)

### Setup tokenizer/vocab:

In [None]:
class Vocab:
    def __init__(self, min_freq=MIN_FREQ, reserved=None):
        if reserved is None:
            reserved = ['<pad>', '<start>', '<end>', '<unk>']
        self.min_freq = min_freq
        self.reserved = reserved
        self.freq = Counter()
        self.itos = []
        self.stoi = {}

    def build(self, token_lists):
        for t in token_lists:
            self.freq.update(t)
        self.itos = list(self.reserved)
        for tok, cnt in self.freq.most_common():
            if cnt < self.min_freq:
                continue
            if tok in self.reserved:
                continue
            self.itos.append(tok)
        self.stoi = {tok:i for i,tok in enumerate(self.itos)}

    def __len__(self):
        return len(self.itos)

    def numericalize(self, tokens):
        return [self.stoi.get(t, self.stoi['<unk>']) for t in tokens]


def tokenize_caption(text):
    text = text.lower()
    text = re.sub(r"[^a-z0-9' ]+", " ", text)
    tokens = text.split()
    return tokens

### Datasets:

In [None]:
class Flickr30kDataset(Dataset):
    def __init__(self, images_dir, captions_file, vocab=None, transform=None, split='train', seed=SEED, return_raw_caption=False):
        super().__init__()
        self.images_dir = str(images_dir)
        self.transform = transform
        self.return_raw_caption = return_raw_caption

        image_to_captions = defaultdict(list)
        with open(captions_file, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                parts = line.split('\t')
                if len(parts) != 2:
                    continue
                img_token, cap = parts
                img_name = img_token.split('#')[0]
                image_to_captions[img_name].append(cap)

        available = set(os.listdir(self.images_dir))
        self.entries = []
        for img, caps in image_to_captions.items():
            if img not in available:
                continue
            for c in caps:
                self.entries.append((img, c))

        images = sorted(list({e[0] for e in self.entries}))
        random.Random(seed).shuffle(images)
        n_train = int(len(images) * TRAIN_RATIO)
        n_val = int(len(images) * VAL_RATIO)
        train_images = set(images[:n_train])
        val_images = set(images[n_train:n_train+n_val])
        test_images = set(images[n_train+n_val:])

        if split == 'train':
            self.entries = [e for e in self.entries if e[0] in train_images]
        elif split == 'val':
            self.entries = [e for e in self.entries if e[0] in val_images]
        elif split == 'test':
            self.entries = [e for e in self.entries if e[0] in test_images]

        if vocab is None and split == 'train':
            token_lists = [tokenize_caption(c) for _, c in self.entries]
            self.vocab = Vocab(min_freq=MIN_FREQ)
            self.vocab.build(token_lists)
        elif vocab is not None:
            self.vocab = vocab

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx):
        img_name, cap = self.entries[idx]
        img_path = os.path.join(self.images_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image_transformed = self.transform(image)
        else:
            image_transformed = image

        if self.return_raw_caption:
            return image_transformed, cap, len(cap)

        tokens = tokenize_caption(cap)
        numeric = [self.vocab.stoi['<start>']] + self.vocab.numericalize(tokens) + [self.vocab.stoi['<end>']]
        num_caption = torch.tensor(numeric, dtype=torch.long)
        return image_transformed, num_caption, num_caption.size(0)


def make_collate_fn(pad_idx):
    def collate_fn(batch):
        images, caps, lengths = zip(*batch)
        images = torch.stack(images, dim=0)
        caps_padded = nn.utils.rnn.pad_sequence(
            caps, batch_first=True, padding_value=pad_idx
        )
        return images, caps_padded, lengths
    return collate_fn

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_ds = Flickr30kDataset(IMAGES_DIR, CAPTIONS_FILE, vocab=None, transform=train_transform, split='train')
vocab = train_ds.vocab
print("Vocab size:", len(vocab))
PAD_IDX = vocab.stoi['<pad>']
collate_base = make_collate_fn(PAD_IDX)
val_ds = Flickr30kDataset(IMAGES_DIR, CAPTIONS_FILE, vocab=vocab, transform=val_transform, split='val')
test_ds = Flickr30kDataset(IMAGES_DIR, CAPTIONS_FILE, vocab=vocab, transform=val_transform, split='test')

### Fit and plot model functions:

In [None]:
def plot_training_history(history):
    epochs = range(1, len(history['train_loss']) + 1)

    # Plot loss:
    plt.figure(figsize=(8,5))
    plt.plot(epochs, history['train_loss'], label='Train Loss', marker='o')
    plt.plot(epochs, history['val_loss'], label='Val Loss', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.show()

    # Plot accuracy:
    plt.figure(figsize=(8,5))
    plt.plot(epochs, history['train_acc'], label='Train Accuracy', marker='o')
    plt.plot(epochs, history['val_acc'], label='Val Accuracy', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Training and Validation Accuracy')
    plt.legend()
    plt.grid(True)
    plt.show()

CEL_base = nn.CrossEntropyLoss(ignore_index=vocab.stoi['<pad>'])

def fit_model(
    enc, dec,
    train_loader, val_loader,
    enc_opt, dec_opt,
    device,
    vocab,
    output_dir,
    num_epochs,
    criterion=CEL_base,
    print_every=PRINT_EVERY
):
    os.makedirs(output_dir, exist_ok=True)
    enc.to(device)
    dec.to(device)

    best_val = float('inf')
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    pad_idx = vocab.stoi['<pad>'] if vocab else 0

    def compute_accuracy(logits, targets):
        preds = logits.argmax(dim=-1)
        mask = targets != pad_idx
        correct = (preds == targets) & mask
        total = mask.sum().item()
        return correct.sum().item() / total if total > 0 else 0.0

    amp_enabled = (device.type == "cuda")
    scaler = torch.amp.GradScaler("cuda", enabled=amp_enabled)


    for epoch in range(1, num_epochs+1):
        enc.train(); dec.train()
        train_loss_accum, train_acc_accum, steps = 0.0, 0.0, 0

        for batch_idx, (images, caps, _) in enumerate(train_loader, 1):
            if device.type == "cuda":
                images = images.to(device, non_blocking=True).to(memory_format=torch.channels_last)
            else:
                images = images.to(device)
            caps = caps.to(device, non_blocking=(device.type == "cuda"))

            if enc_opt: enc_opt.zero_grad(set_to_none=True)
            if dec_opt: dec_opt.zero_grad(set_to_none=True)

            # ---- Trim to actual max caption length (no extra pad tokens) ----
            inp = caps[:, :-1]
            tgt = caps[:,  1:]

            max_len = (tgt != pad_idx).sum(dim=1).max().item()
            if max_len == 0:
                continue  # nothing to learn this batch
            inp = inp[:, :max_len]
            tgt = tgt[:, :max_len]

            # ---- Forward (mixed precision) ----
            with torch.amp.autocast("cuda", enabled=amp_enabled):
                features = enc(images)
                logits = dec(features, inp)
                loss = criterion(logits.reshape(-1, logits.size(-1)), tgt.reshape(-1))

            # ---- backward (scaled) ----
            scaler.scale(loss).backward()

            # (optional) grad clipping
            if enc_opt:
                scaler.unscale_(enc_opt)
                torch.nn.utils.clip_grad_norm_(enc.parameters(), 5.0)
            if dec_opt:
                scaler.unscale_(dec_opt)
                torch.nn.utils.clip_grad_norm_(dec.parameters(), 5.0)

            # ---- step ----
            if enc_opt: scaler.step(enc_opt)
            if dec_opt: scaler.step(dec_opt)
            scaler.update()

            train_loss_accum += loss.item()
            train_acc_accum += compute_accuracy(logits, tgt)
            steps += 1
            if batch_idx % print_every == 0 or batch_idx == len(train_loader):
                print(f"Epoch {epoch} Train batch {batch_idx}/{len(train_loader)} "
                      f"Loss={train_loss_accum/steps:.4f} Acc={100*train_acc_accum/steps:.2f}%")

        train_loss = train_loss_accum / steps
        train_acc = 100 * train_acc_accum / steps

        # --------------- Validate: ---------------

        enc.eval(); dec.eval()
        val_loss_accum, val_acc_accum, steps = 0.0, 0.0, 0
        with torch.no_grad():
            for batch_idx, (images, caps, _) in enumerate(val_loader, 1):
                images, caps = images.to(device), caps.to(device)

                # Trim to batch’s true max length
                inp = caps[:, :-1]
                tgt = caps[:,  1:]
                max_len = (tgt != pad_idx).sum(dim=1).max().item()
                if max_len == 0:
                    continue
                inp = inp[:, :max_len]
                tgt = tgt[:, :max_len]

                # Faster eval with AMP
                with torch.amp.autocast("cuda", enabled=amp_enabled):
                    features = enc(images)
                    logits = dec(features, inp)
                    loss = criterion(logits.reshape(-1, logits.size(-1)),
                                     tgt.reshape(-1))

                val_loss_accum += loss.item()

                # Accuracy on trimmed targets
                pred_tokens = logits.argmax(dim=2)
                mask = tgt != pad_idx
                correct = (pred_tokens == tgt) & mask
                nonpad = mask.sum().item()
                if nonpad > 0:
                    val_acc_accum += correct.sum().item() / nonpad
                    steps += 1

        val_loss = val_loss_accum / steps
        val_acc = 100 * val_acc_accum / steps

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)

        print(f"Epoch {epoch}/{num_epochs} train_loss={train_loss:.4f} train_acc={train_acc:.2f}% "
              f"val_loss={val_loss:.4f} val_acc={val_acc:.2f}%")

        # --------------- Save model: ---------------

        from pathlib import Path

        ckpt = {
            "epoch": epoch,
            "encoder_state_dict": enc.state_dict(),
            "decoder_state_dict": dec.state_dict(),
            "vocab": getattr(vocab, "itos", vocab),
            "history": history,
            "enc_optimizer_state_dict": enc_opt.state_dict() if enc_opt else None,
            "dec_optimizer_state_dict": dec_opt.state_dict() if dec_opt else None,
        }

        outdir = Path(output_dir).resolve()
        outdir.mkdir(parents=True, exist_ok=True)

        ckpt_path = outdir / f"ckpt_epoch_{epoch:02d}.pth"
        try:
            torch.save(ckpt, ckpt_path)
        except:
            pass

    plot_training_history(history)
    return history

### BLEU evaluation:

In [None]:
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction

def _build_refs_for_split(ds):
    """
    Collect all reference captions (tokenized) per image *within the split*.
    Returns: dict img_name -> list of reference token lists.
    """
    img2refs = {}
    for img_name, cap in ds.entries:
        img2refs.setdefault(img_name, []).append(tokenize_caption(cap))
    return img2refs

@torch.no_grad()
def generate_caption(enc, dec, img_path, transform, vocab, device,
                     max_len=MAX_LEN, use_amp=True):
    enc.to(device).eval()
    dec.to(device).eval()

    with Image.open(img_path) as im:
        img = im.convert('RGB')
        x = transform(img).unsqueeze(0).to(device)

    dtype = "cuda" if device.type == "cuda" else None
    with torch.amp.autocast(dtype, enabled=(dtype is not None and use_amp)):
        feats = enc(x)

    # --- ensure encoder output dtype matches decoder params (fix Half vs Float) ---
    dec_dtype = next(dec.parameters()).dtype
    feats = feats.to(dec_dtype)

    start_id = vocab.stoi['<start>']
    end_id   = vocab.stoi['<end>']
    pad_id   = vocab.stoi['<pad>']

    if hasattr(dec, "sample") and callable(getattr(dec, "sample")):
        try:
            # Newer signature (with EOS, blocking, etc.)
            ids = dec.sample(
                feats,
                start_id=start_id,
                end_id=end_id,
                pad_id=pad_id,
                max_len=max_len,
                no_repeat_ngram_size=3,
                temperature=1.0,
                top_k=0
            )[0].tolist()
        except TypeError:
            # Older signature: (features, start_id=None, max_len=None)
            ids = dec.sample(feats, start_id=start_id, max_len=max_len)[0].tolist()
    else:
        # Greedy fallback via forward()
        generated = torch.tensor([[start_id]], device=device, dtype=torch.long)
        ids = []
        for _ in range(max_len):
            with torch.amp.autocast(dtype, enabled=(dtype is not None and use_amp)):
                logits = dec(feats, generated)        # [1,T,V]
                next_id = logits[:, -1, :].argmax(-1) # [1]
            nid = int(next_id.item())
            ids.append(nid)
            if nid == end_id:
                break
            generated = torch.cat([generated, next_id.unsqueeze(1)], dim=1)

    # ids -> words
    words = []
    for idx in ids:
        if idx == end_id: break
        if idx in (pad_id, start_id): continue
        tok = vocab.itos[idx] if idx < len(vocab.itos) else '<unk>'
        words.append(tok)
    return ' '.join(words)


@torch.no_grad()
def evaluate_bleu(enc, dec, ds, transform, vocab, device, max_len=MAX_LEN, limit=None, show_examples=0, use_amp=False):
    """
    Runs greedy decoding once per image, compares to all refs for that image,
    and returns BLEU-1..4 (corpus-level).
    """
    enc.eval(); dec.eval()

    img2refs = _build_refs_for_split(ds)
    img_names = list(img2refs.keys())
    if limit is not None:
        img_names = img_names[:limit]

    list_of_references = []   # shape: N x (#refs_i) x (tokens)
    hypotheses = []           # shape: N x (tokens)

    for img_name in img_names:
        img_path = os.path.join(ds.images_dir, img_name)
        hyp_text = generate_caption(enc, dec, img_path, transform, vocab, device, max_len=max_len, use_amp=use_amp)
        hyp_tok  = tokenize_caption(hyp_text)

        references = img2refs[img_name]  # already tokenized (multiple refs)
        list_of_references.append(references)
        hypotheses.append(hyp_tok)

    smooth = SmoothingFunction().method1

    # BLEU-1 .. BLEU-4 (corpus)
    bleu1 = corpus_bleu(list_of_references, hypotheses, weights=(1.0, 0.0, 0.0, 0.0), smoothing_function=smooth)
    bleu2 = corpus_bleu(list_of_references, hypotheses, weights=(0.5, 0.5, 0.0, 0.0), smoothing_function=smooth)
    bleu3 = corpus_bleu(list_of_references, hypotheses, weights=(1/3, 1/3, 1/3, 0.0), smoothing_function=smooth)
    bleu4 = corpus_bleu(list_of_references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smooth)

    # Print a few examples
    for i in range(min(show_examples, len(img_names))):
        img_path = os.path.join(ds.images_dir, img_names[i])

        # show image
        plt.figure(figsize=(5,5))
        plt.imshow(Image.open(img_path).convert("RGB"))
        plt.axis("off")
        plt.title(img_names[i])
        plt.show()

        # then print texts
        print("Hyp:", " ".join(hypotheses[i]))
        print("Ref 1:", " ".join(list_of_references[i][0]))
        if len(list_of_references[i]) > 1:
            print("Ref 2:", " ".join(list_of_references[i][1]))

    return {
        "BLEU-1": bleu1,
        "BLEU-2": bleu2,
        "BLEU-3": bleu3,
        "BLEU-4": bleu4,
        "num_images": len(img_names)
    }

## Model 1: CNN-RNN

Simple CNN->RNN image caption baseline model.

Encoder: CNN with transfer learning from ResNet18.

Decoder: Text RNN, no transfer learning.

In [None]:
# --------------- Encoder: ---------------

class CNNEncoder(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        resnet = resnet18(weights=ResNet18_Weights.DEFAULT)
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])
        for p in self.backbone.parameters():
            p.requires_grad = False
        self.fc = nn.Linear(512, embed_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        self.backbone.eval()
        with torch.no_grad():
            feat = self.backbone(x)
        feat = feat.view(feat.size(0), -1)
        feat = self.fc(feat)
        return feat

# --------------- Decoder: ---------------

class RNNDecoder(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1, dropout=0.3, pad_idx=0):
        super().__init__()
        self.embed  = nn.Embedding(vocab_size, embed_size, padding_idx=pad_idx)
        self.lstm   = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.init_h = nn.Linear(embed_size, hidden_size)
        self.init_c = nn.Linear(embed_size, hidden_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, features, captions):
        emb = self.dropout(self.embed(captions))
        h0 = self.init_h(features).unsqueeze(0)
        c0 = self.init_c(features).unsqueeze(0)
        out, _ = self.lstm(emb, (h0, c0))
        logits = self.linear(self.dropout(out))
        return logits

# --------------- Training: ---------------

clear_cache()

DROPOUT = 0.3
EMBED_SIZE = 512
HIDDEN_SIZE = 1024

BATCH_SIZE = 128
VAL_BATCH_SIZE = 64
DEC_LR = 1e-3
NUM_EPOCHS = 5

OUTPUT_DIR = "./models_cnn_rnn"
os.makedirs(OUTPUT_DIR, exist_ok=True)

train_loader_cnn_rnn = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=collate_base, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
val_loader_cnn_rnn   = DataLoader(val_ds,   batch_size=VAL_BATCH_SIZE, shuffle=False,
                          collate_fn=collate_base, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

CNN_enc = CNNEncoder(EMBED_SIZE)
RNN_dec = RNNDecoder(
    embed_size=EMBED_SIZE,
    hidden_size=HIDDEN_SIZE,
    vocab_size=len(vocab),
    dropout=DROPOUT,
    pad_idx=vocab.stoi['<pad>']
)

CNN_enc = CNN_enc.to(DEVICE).to(memory_format=torch.channels_last)
RNN_dec = RNN_dec.to(DEVICE)

dec_opt = optim.Adam(RNN_dec.parameters(), lr=DEC_LR)

In [None]:
cnn_rnn_history = fit_model(
    enc=CNN_enc, dec=RNN_dec,
    train_loader=train_loader_cnn_rnn, val_loader=val_loader_cnn_rnn,
    enc_opt=None, dec_opt=dec_opt,
    device=DEVICE,
    vocab=vocab,
    output_dir=OUTPUT_DIR,
    num_epochs=NUM_EPOCHS
)

## Model 2: ViT trained from scratch

Image caption model with transformers trained from scratch.

Encoder: Visual transformer (ViT), no transfer learning.

Decoder: Small text transformer, no transfer learning.

In [None]:
# --------------- Encoder: ---------------

class PatchEmbed(nn.Module):
    def __init__(self, in_ch=3, embed_dim=256, patch_size=16):
        super().__init__()
        self.proj = nn.Conv2d(in_ch, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x


class ViTEncoder(nn.Module):
    def __init__(self, embed_dim=256, patch_size=16, num_layers=4, num_heads=4, mlp_dim=512, dropout=0.2, img_size=224):
        super().__init__()
        self.patch_embed = PatchEmbed(in_ch=3, embed_dim=embed_dim, patch_size=patch_size)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        num_patches = (img_size // patch_size) ** 2
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, dim_feedforward=mlp_dim,
            dropout=dropout, batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.norm = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        x = self.patch_embed(x)                     # [B,S,E]
        B, S, E = x.shape
        cls_tokens = self.cls_token.expand(B, -1, -1)  # [B,1,E]
        x = torch.cat([cls_tokens, x], dim=1)       # [B,S+1,E]
        x = x + self.pos_embed[:, :S+1, :]          # handle safety if S differs
        x = self.drop(x)
        x = self.transformer_encoder(x)
        x = self.norm(x)
        return x[:, 0]                               # CLS

# --------------- Decoder: ---------------

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=MAX_LEN):
        super().__init__()
        self.d_model = d_model
        self.register_buffer('pe', self._build_pe(max_len, d_model))

    @staticmethod
    def _build_pe(max_len, d_model):
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(1)  # [max_len,1,d_model]

    def forward(self, x):
        seq_len = x.size(0)
        if seq_len > self.pe.size(0):
            new_pe = self._build_pe(seq_len, self.d_model).to(self.pe.device)
            self.register_buffer('pe', new_pe, persistent=False)
        return x + self.pe[:seq_len, :]


class TransformerDecoderAdapter(nn.Module):
    def __init__(self, embed_size, vocab_size, nhead=8, num_layers=3,  dim_feedforward=2048, dropout=0.2, max_len=MAX_LEN, pad_idx=0):
        super().__init__()

        self.pad_idx = pad_idx

        # --- Embedding ---
        self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=self.pad_idx)

        # --- Positional encoding ---
        self.pos_enc = PositionalEncoding(embed_size, max_len=max_len)

        # --- Transformer decoder layers ---
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_size,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=False
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        # --- Output projection ---
        self.linear_out = nn.Linear(embed_size, vocab_size, bias=False)

        # --- Weight tying ---
        self.linear_out.weight = self.embed.weight

        # --- Misc ---
        self.embed_size = embed_size
        self.max_len = max_len

    def _generate_square_subsequent_mask(self, sz, device):
        return torch.triu(torch.full((sz, sz), float('-inf'), device=device), diagonal=1)

    def forward(self, features, captions):
        device = features.device
        B, L = captions.size()
        memory = features.unsqueeze(0)                    # [1, B, E]
        tgt = self.embed(captions).permute(1, 0, 2)       # [L, B, E]
        tgt = self.pos_enc(tgt)                           # [L, B, E]
        tgt_mask = self._generate_square_subsequent_mask(L, device)   # [L, L]
        tgt_key_padding_mask = (captions == self.pad_idx)                   # [B, L]
        out = self.transformer_decoder(
            tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask
        )                                                 # [L, B, E]
        out = out.permute(1, 0, 2)                        # [B, L, E]
        logits = self.linear_out(out)                     # [B, L, V]
        return logits

    @torch.no_grad()
    def sample(
        self,
        features,
        start_id=None,
        end_id=None,
        pad_id=None,
        max_len=None,
        no_repeat_ngram_size=3,
        temperature=1.0,
        top_k=0,          # 0 = no top-k
    ):
        """
        Greedy-ish decoding with:
          - EOS-aware early stop
          - no-repeat n-gram blocking
          - optional temperature and top-k sampling
        Works for batch>1.
        """
        if max_len is None:
            max_len = self.max_len
        device = features.device
        B = features.size(0)
        memory = features.unsqueeze(0)                              # [1,B,E]

        if start_id is None:
            raise ValueError("start_id must be provided")
        if end_id is None:
            # fall back to very large id that will never match
            end_id = -1
        if pad_id is None:
            pad_id = end_id  # harmless default

        # [B,1]
        generated = torch.full((B, 1), start_id, dtype=torch.long, device=device)
        finished  = torch.zeros(B, dtype=torch.bool, device=device)

        def _ngram_blocking(next_token_ids, seq, n=no_repeat_ngram_size):
            if n <= 0 or seq.size(1) < n-1:
                return next_token_ids
            # build set of existing n-grams for each batch item
            blocked = next_token_ids.clone()
            for b in range(seq.size(0)):
                if finished[b]:
                    continue
                history = seq[b].tolist()
                tails = tuple(history[-(n-1):])  # (n-1)-gram context
                # collect all tokens that would repeat an existing n-gram
                bad = set()
                for i in range(len(history) - (n-1)):
                    if tuple(history[i:i+n-1]) == tails and i+n-1 < len(history):
                        bad.add(history[i+n-1])
                if bad:
                    # if our chosen token is in bad set, mark it invalid by setting to PAD
                    if int(blocked[b].item()) in bad:
                        blocked[b] = pad_id
            return blocked

        ids_collected = []
        for t in range(max_len):
            tgt = self.embed(generated).permute(1, 0, 2)           # [T,B,E]
            tgt = self.pos_enc(tgt)
            tgt_mask = self._generate_square_subsequent_mask(tgt.size(0), device)
            out = self.transformer_decoder(tgt, memory, tgt_mask=tgt_mask)  # [T,B,E]
            logits = self.linear_out(out[-1])                       # [B,V]

            # temperature
            if temperature != 1.0:
                logits = logits / max(temperature, 1e-6)

            # softmax once so we can do top-k if needed
            probs = torch.softmax(logits, dim=-1)

            if top_k and top_k > 0:
                topk_vals, topk_idx = torch.topk(probs, k=min(top_k, probs.size(-1)), dim=-1)
                # sample from top-k
                next_ids_rel = torch.multinomial(topk_vals, num_samples=1).squeeze(1)   # [B]
                next_ids = topk_idx.gather(1, next_ids_rel.unsqueeze(1)).squeeze(1)     # [B]
            else:
                # greedy
                next_ids = probs.argmax(dim=-1)  # [B]

            # n-gram blocking (on chosen ids)
            if no_repeat_ngram_size and no_repeat_ngram_size > 1:
                next_ids = _ngram_blocking(next_ids, generated, no_repeat_ngram_size)

            # force EOS if we picked an invalid (blocked) id
            next_ids = torch.where(next_ids == pad_id, torch.tensor(end_id, device=device), next_ids)

            ids_collected.append(next_ids)
            generated = torch.cat([generated, next_ids.unsqueeze(1)], dim=1)

            # update finished mask and early-stop if all done
            finished |= (next_ids == end_id)
            if torch.all(finished):
                break

        if not ids_collected:
            return torch.full((B, 1), end_id, dtype=torch.long, device=device)

        return torch.stack(ids_collected, dim=1)  # [B, T]


# --------------- Training: ---------------

clear_cache()

BATCH_SIZE = 64
VAL_BATCH_SIZE = 32
NUM_EPOCHS = 5

DROPOUT = 0.2
EMBED_SIZE = 768
ENC_NUM_LAYERS = 6
DEC_NUM_LAYERS = 3
NUM_HEADS = 8
DIM_FEEDFORWARD = 2048

OUTPUT_DIR = "./models_vit_no_tl"
os.makedirs(OUTPUT_DIR, exist_ok=True)

train_loader_vit = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=collate_base, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
val_loader_vit = DataLoader(val_ds,   batch_size=VAL_BATCH_SIZE, shuffle=False,
                          collate_fn=collate_base, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

vit_enc = ViTEncoder(
    embed_dim=EMBED_SIZE,
    patch_size=16,
    num_layers=ENC_NUM_LAYERS,
    num_heads=NUM_HEADS,
    mlp_dim=DIM_FEEDFORWARD,
    dropout=DROPOUT
)
vit_dec = TransformerDecoderAdapter(
    embed_size=EMBED_SIZE,
    vocab_size=len(vocab),
    nhead=NUM_HEADS,
    num_layers=DEC_NUM_LAYERS,
    dim_feedforward=DIM_FEEDFORWARD,
    dropout=DROPOUT,
    max_len=MAX_LEN,
    pad_idx=PAD_IDX
)

vit_enc = vit_enc.to(DEVICE).to(memory_format=torch.channels_last)
vit_dec = vit_dec.to(DEVICE)

enc_opt = optim.AdamW(vit_enc.parameters(), lr=5e-4, weight_decay=0.05)
dec_opt = optim.AdamW(vit_dec.parameters(), lr=1e-3, weight_decay=0.05)

CEL_label_smoothing = nn.CrossEntropyLoss(ignore_index=PAD_IDX, label_smoothing=0.1)

In [None]:
vit_history = fit_model(
    enc=vit_enc, dec=vit_dec,
    train_loader=train_loader_vit, val_loader=val_loader_vit,
    enc_opt=enc_opt, dec_opt=dec_opt,
    device=DEVICE,
    vocab=vocab, criterion=CEL_label_smoothing,
    output_dir=OUTPUT_DIR,
    num_epochs=NUM_EPOCHS
)

## Model 3: Pre-trained ViT as encoder, same text transformer decoder:

Image caption model with transformers, pre-trained encoder.

Encoder: ViT with transfer learning from google/vit-base-patch16-224.

Decoder: Small text transformer, no transfer learning.

In [None]:
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

vit_tl_train = transforms.Compose([
    transforms.Resize(processor.size["height"]),
    transforms.RandomCrop(processor.size["height"]),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
])

vit_tl_val = transforms.Compose([
    transforms.Resize(processor.size["height"]),
    transforms.CenterCrop(processor.size["height"]),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
])

from transformers import AutoModel

class ViTEncoderTL(nn.Module):
    """
    ViT encoder with pretrained weights (transfer learning).
    Uses CLS token -> projects to `embed_size`.
    Fine-tunes the last `trainable_layers` transformer blocks (0 = frozen).
    """
    def __init__(self, model_name="google/vit-base-patch16-224-in21k",
                 embed_size=768, trainable_layers=0, dropout=0.0):
        super().__init__()
        self.vit = AutoModel.from_pretrained(model_name, add_pooling_layer=False)
        hidden = self.vit.config.hidden_size

        # freeze all
        for p in self.vit.parameters():
            p.requires_grad = False

        # unfreeze last N encoder blocks if requested
        if trainable_layers > 0:
            blocks = self.vit.encoder.layer
            for b in blocks[-trainable_layers:]:
                for p in b.parameters():
                    p.requires_grad = True

        # optional LN + projection to your decoder size
        self.head = nn.Sequential(
            nn.LayerNorm(hidden),
            nn.Dropout(dropout),
            nn.Linear(hidden, embed_size)
        )

    def forward(self, x):
        # x: [B,3,224,224] already resized/normalized by vit_tl_* transforms
        out = self.vit(pixel_values=x, output_hidden_states=False)
        cls = out.last_hidden_state[:, 0]      # [B, hidden]
        feat = self.head(cls)                  # [B, embed_size]
        return feat


BATCH_SIZE = 64
VAL_BATCH_SIZE = 32

# Use TL transforms here:
train_ds_tl = Flickr30kDataset(IMAGES_DIR, CAPTIONS_FILE, vocab=None, transform=vit_tl_train, split='train')
vocab_tl = train_ds_tl.vocab
PAD_IDX_TL = vocab_tl.stoi['<pad>']
collate_tl = make_collate_fn(PAD_IDX_TL)

val_ds_tl  = Flickr30kDataset(IMAGES_DIR, CAPTIONS_FILE, vocab=vocab_tl, transform=vit_tl_val, split='val')
test_ds_tl = Flickr30kDataset(IMAGES_DIR, CAPTIONS_FILE, vocab=vocab_tl, transform=vit_tl_val, split='test')

train_loader_tl = DataLoader(train_ds_tl, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=collate_tl, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
val_loader_tl = DataLoader(val_ds_tl,  batch_size=VAL_BATCH_SIZE, shuffle=False,
                          collate_fn=collate_tl, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
EMBED_SIZE = 768         # match decoder
vit_tl_enc = ViTEncoderTL(
    model_name="google/vit-base-patch16-224-in21k",
    embed_size=EMBED_SIZE,
    trainable_layers=2,     # try 0 (frozen), 2, or 4
    dropout=0.1
).to(DEVICE)

vit_tl_dec = TransformerDecoderAdapter(
    embed_size=EMBED_SIZE, vocab_size=len(vocab_tl),
    nhead=8, num_layers=3, dim_feedforward=2048,
    dropout=0.2, max_len=MAX_LEN, pad_idx=PAD_IDX_TL
).to(DEVICE)

# Optimizers: smaller LR for (partially) unfrozen encoder, larger for decoder
enc_params = [p for p in vit_tl_enc.parameters() if p.requires_grad]
enc_opt = optim.AdamW(enc_params, lr=1e-5, weight_decay=0.01) if enc_params else None
dec_opt = optim.AdamW(vit_tl_dec.parameters(), lr=1e-3, weight_decay=0.05)

NUM_EPOCHS = 5
OUTPUT_DIR = "./models_vit_tl"

CEL_tl = nn.CrossEntropyLoss(ignore_index=PAD_IDX_TL)

In [None]:
vit_tl_history = fit_model(
    enc=vit_tl_enc, dec=vit_tl_dec,
    train_loader=train_loader_tl, val_loader=val_loader_tl,
    enc_opt=enc_opt, dec_opt=dec_opt,
    device=DEVICE, vocab=vocab_tl, criterion=CEL_tl,
    output_dir=OUTPUT_DIR,
    num_epochs=NUM_EPOCHS
)

## Eval:

In [None]:
cnn_rnn_scores = evaluate_bleu(
    enc=CNN_enc, dec=RNN_dec,
    ds=test_ds, transform=val_transform,
    vocab=vocab, device=DEVICE,
    max_len=EVAL_MAX_LEN , limit=None,
    show_examples=3, use_amp=False
)
print("CNN→RNN BLEU:", cnn_rnn_scores)

In [None]:
vit_scores = evaluate_bleu(
    enc=vit_enc, dec=vit_dec,
    ds=test_ds, transform=val_transform,
    vocab=vocab, device=DEVICE,
    max_len=EVAL_MAX_LEN , limit=None,
    show_examples=3, use_amp=True
)
print("ViT→Transformer BLEU:", vit_scores)

In [None]:
vit_tl_scores = evaluate_bleu(
    enc=vit_tl_enc, dec=vit_tl_dec,
    ds=test_ds_tl, transform=vit_tl_val,
    vocab=vocab_tl, device=DEVICE,
    max_len=EVAL_MAX_LEN , limit=None,
    show_examples=3, use_amp=True
)
print("TL ViT→Transformer BLEU:", vit_tl_scores)