In [7]:
# Standard library imports
import gc
import json
import math
import os
import pickle
import re
from collections import Counter, OrderedDict

# Third-party library imports
import nltk
import pandas as pd
from PIL import Image
from tqdm import tqdm

# PyTorch and related imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms


In [8]:
#This code is needed for all the models
nltk.download('punkt_tab', quiet=True)
def tokenize_english(text):
    text = text.lower()
    text = re.sub(r"[\r\n\t]+", " ", text)
    text = re.sub(r"[^a-z0-9\.\,\!\?\:\;\’\'\-]", " ", text)
    tokens = nltk.word_tokenize(text)
    tokens = [t for t in tokens if any(c.isalnum() for c in t)]
    return tokens

with open('data/annotations/captions_train2017.json', 'r') as f:
    coco_train = json.load(f)

with open('data/annotations/captions_val2017.json', 'r') as f:
    coco_val = json.load(f)

img_id_to_filename = {img['id']: img['file_name'] for img in coco_train['images']}

annotations = []
for ann in coco_train['annotations']:
    fname = img_id_to_filename[ann['image_id']]
    tokens = tokenize_english(ann['caption'])
    seq = ['<start>'] + tokens + ['<end>']
    annotations.append((fname, seq))


min_freq = 5
counter = Counter(tok for _, seq in annotations for tok in seq)
words = [w for w, cnt in counter.items() if cnt >= min_freq]
specials = ['<pad>', '<start>', '<end>', '<unk>']
itos = specials + words
stoi = {w: i for i, w in enumerate(itos)}
vocab = {'itos': itos, 'stoi': stoi}

numerical = []
for fname, seq in annotations:
    ids = [stoi.get(tok, stoi['<unk>']) for tok in seq]
    numerical.append((fname, ids))


img_id_to_filename_val = {img['id']: img['file_name'] for img in coco_val['images']}

annotations_val = []
for ann in coco_val['annotations']:
    fname = img_id_to_filename_val[ann['image_id']]
    tokens = tokenize_english(ann['caption'])
    seq = ['<start>'] + tokens + ['<end>']
    annotations_val.append((fname, seq))

numerical_val = []
for fname, seq in annotations_val:
    ids = [stoi.get(tok, stoi['<unk>']) for tok in seq]
    numerical_val.append((fname, ids))


class CaptionImageDataset(Dataset):
    def __init__(self, img_dir, annotations, vocab, transform=None):
        self.img_dir = img_dir
        self.annotations = annotations
        self.vocab = vocab
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        fname, seq_ids = self.annotations[idx]
        img = Image.open(os.path.join(self.img_dir, fname)).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, torch.tensor(seq_ids, dtype=torch.long)

def collate_fn(batch):
    imgs, seqs = zip(*batch)
    imgs = torch.stack(imgs, 0)
    lengths = [len(s) for s in seqs]
    max_len = max(lengths)
    padded = torch.zeros(len(seqs), max_len, dtype=torch.long)
    for i, s in enumerate(seqs):
        padded[i, :lengths[i]] = s
    return imgs, padded, lengths

[nltk_data] Error loading punkt_tab: <urlopen error [Errno -3]
[nltk_data]     Temporary failure in name resolution>


# 1. NO PRETRAINED CNN

In [9]:
# Dataset that will only be used for computing mean and std
class SimpleImageDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_paths = [
            os.path.join(img_dir, fname)
            for fname in os.listdir(img_dir)
            if fname.lower().endswith(('.jpg', '.png'))
        ]
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img = Image.open(self.img_paths[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img

# Function to compute mean and std
def compute_mean_std(img_dir, batch_size=32, num_workers=8):
    basic_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.ToTensor(),
    ])

    dataset_mean = SimpleImageDataset(img_dir, transform=basic_transform)
    loader_mean = DataLoader(dataset_mean, batch_size=batch_size,
                        shuffle=False, num_workers=num_workers,
                        pin_memory=True)

    sum_rgb = torch.zeros(3)
    sum_squared = torch.zeros(3)
    num_pixels = 0

    for batch in loader_mean:
        b, c, h, w = batch.shape
        sum_rgb += batch.sum(dim=[0, 2, 3])
        sum_squared += (batch ** 2).sum(dim=[0, 2, 3])
        num_pixels += b * h * w

    mean = sum_rgb / num_pixels
    std = torch.sqrt((sum_squared / num_pixels) - mean ** 2)
    return mean, std

# Compute mean and std for the training images and create the datasets and loaders
img_directory = 'data/images/train2017'
if os.path.exists('image_stats.pkl'):
    with open('image_stats.pkl', 'rb') as f:
        stats = pickle.load(f)
        mean, std = stats['mean'], stats['std']
else:
    mean, std = compute_mean_std(img_directory)
    stats = {'mean': mean, 'std': std}
    with open('image_stats.pkl', 'wb') as f:
        pickle.dump(stats, f)

train_transform_no_pretrained = transforms.Compose([
    transforms.RandomResizedCrop(
        256, 
        scale=(0.9, 1.0), 
        ratio=(0.9, 1.1)
    ),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(
        brightness=0.1, 
        contrast=0.1, 
        saturation=0.1, 
        hue=0.05
    ),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])
dataset_no_pretrained = CaptionImageDataset(
    img_dir=img_directory,
    annotations=numerical,
    vocab=vocab,
    transform=train_transform_no_pretrained
)

batch_size = 32
train_loader_no_pretrained = DataLoader(
    dataset_no_pretrained,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    collate_fn=collate_fn
)

val_transform_no_pretrained = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])
val_img_directory = 'data/images/val2017'

val_dataset_no_pretrained = CaptionImageDataset(
    img_dir=val_img_directory,
    annotations=numerical_val,
    vocab=vocab,
    transform=val_transform_no_pretrained
)

val_loader_no_pretrained = DataLoader(
    val_dataset_no_pretrained,
    batch_size=batch_size,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    collate_fn=collate_fn
)

# Define the models
# Encoder
class SimpleEncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(SimpleEncoderCNN, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.block4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1,1))
        )
        self.embed = nn.Linear(512, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = x.view(x.size(0), -1)
        x = self.embed(x)
        x = self.bn(x)
        return x
    
# Positional Encoding for Transformer
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return x

# Decoder
class TransformerDecoder(nn.Module):
    def __init__(self, embed_size, num_heads, ff_dim, vocab_size, max_len=70, dropout=0.1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.pos_enc = PositionalEncoding(embed_size, max_len)
        self.self_attn = nn.MultiheadAttention(embed_size, num_heads, dropout=dropout)
        self.cross_attn = nn.MultiheadAttention(embed_size, num_heads, dropout=dropout)
        self.ffn = nn.Sequential(
            nn.Linear(embed_size, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_size),
        )
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.norm3 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(embed_size, vocab_size)

    def _generate_square_mask(self, sz, device):
        mask = torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1).to(device)
        return mask

    def forward(self, features, captions):
        tgt = captions[:, :-1]  
        x = self.embed(tgt) * math.sqrt(self.embed.embedding_dim)
        x = self.pos_enc(x)
        x = x.transpose(0, 1) 

        mask = self._generate_square_mask(x.size(0), x.device)
        attn_out, _ = self.self_attn(x, x, x, attn_mask=mask)
        x = self.norm1(x + self.dropout(attn_out))

        mem = features.unsqueeze(0)           
        cross_out, _ = self.cross_attn(x, mem, mem)
        x = self.norm2(x + self.dropout(cross_out))

        ff_out = self.ffn(x)
        x = self.norm3(x + self.dropout(ff_out))

        x = x.transpose(0, 1)      
        logits = self.out(x)                                
        return logits

In [11]:
#Hyperparameters
embed_size  = 512
num_heads   = 8
hidden_dim  = 2048
num_epochs = 4
lr = 1e-2
weight_decay = 1e-4
warmup_steps = 20000

vocab_size  = len(vocab['itos'])
#Load the model to the GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gc.collect()
torch.cuda.empty_cache()

#Model
encoder = SimpleEncoderCNN(embed_size).to(device)
decoder = TransformerDecoder(
    embed_size, num_heads, hidden_dim, vocab_size
).to(device)

#Loss function, optimizer and learning rate scheduler
criterion = nn.CrossEntropyLoss(ignore_index=vocab['stoi']['<pad>'])
optimizer = torch.optim.Adam(
    list(encoder.parameters()) + list(decoder.parameters()),
    lr=lr,
    weight_decay=weight_decay,
)
def get_transformer_scheduler(optimizer, d_model, warmup_steps=1000):
    def lr_lambda(step):
        step = max(step, 1)
        return (d_model ** -0.5) * min(step ** -0.5, step * (warmup_steps ** -1.5))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
scaler = torch.amp.GradScaler(
    'cuda',
    init_scale=512, 
    growth_interval=2000,
    growth_factor=2.0,
    backoff_factor=0.5
)
scheduler = get_transformer_scheduler(optimizer, embed_size, warmup_steps=warmup_steps)
best_val_loss = float('inf')

#Training loop
for epoch in range(1, num_epochs + 1):
    encoder.train()
    decoder.train()
    train_loss = 0.0

    pbar = tqdm(train_loader_no_pretrained, desc=f"[Epoch {epoch}/{num_epochs}] Train")
    for images, captions, lengths in pbar:
        images, captions = images.to(device), captions.to(device)

        optimizer.zero_grad()
        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            feats   = encoder(images)
            logits  = decoder(feats, captions)

        loss = criterion(
            logits.float().view(-1, vocab_size),
            captions[:, 1:].contiguous().view(-1)
        )
        if torch.isnan(loss) or torch.isinf(loss):
            continue
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(
            list(encoder.parameters()) + list(decoder.parameters()),
            max_norm=0.5
        )
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        train_loss += loss.item()
        lr = scheduler.get_last_lr()[0]
        pbar.set_postfix(lastLoss=f"{loss.item():.4f}", meanLoss=f"{train_loss/(pbar.n+1):.4f}", lr=f"{lr:.2e}")

    
    avg_train = train_loss / len(train_loader_no_pretrained)

    encoder.eval()
    decoder.eval()
    val_loss = 0.0
    # Validation
    with torch.no_grad():
        for images, captions, lengths in tqdm(val_loader_no_pretrained, desc=f"[Epoch {epoch}/{num_epochs}] Val  "):
            images, captions = images.to(device), captions.to(device)
            with torch.amp.autocast('cuda'):
                feats   = encoder(images)
                outputs = decoder(feats, captions)
                loss    = criterion(
                    outputs.view(-1, vocab_size),
                    captions[:, 1:].contiguous().view(-1)
                )
            val_loss += loss.item()

    avg_val = val_loss / len(val_loader_no_pretrained)
    print(f"Epoch {epoch}: Train Loss = {avg_train:.4f}, Val Loss = {avg_val:.4f}")

    if avg_val < best_val_loss:
        best_val_loss = avg_val
        ckpt = {
            'epoch': epoch,
            'encoder_state': encoder.state_dict(),
            'decoder_state': decoder.state_dict(),
            'optim_state': optimizer.state_dict(),
            'sched_state': scheduler.state_dict(),
            'val_loss': best_val_loss
        }
        torch.save(ckpt, 'best_model_no_pretrained.pth')
        print(f"best_model_no_pretrained updated (Val Loss: {best_val_loss:.4f})")

[Epoch 1/4] Train: 100%|██████████| 18024/18024 [40:44<00:00,  7.37it/s, lastLoss=3.9643, lr=2.82e-06, meanLoss=6.2744]
[Epoch 1/4] Val  : 100%|██████████| 782/782 [00:48<00:00, 16.13it/s]


Epoch 1: Train Loss = 6.2744, Val Loss = 4.5102
best_model_no_pretrained updated (Val Loss: 4.5102)


[Epoch 2/4] Train: 100%|██████████| 18024/18024 [40:59<00:00,  7.33it/s, lastLoss=2.8937, lr=2.33e-06, meanLoss=4.2410]
[Epoch 2/4] Val  : 100%|██████████| 782/782 [00:52<00:00, 14.98it/s]


Epoch 2: Train Loss = 4.2410, Val Loss = 4.0112
best_model_no_pretrained updated (Val Loss: 4.0112)


[Epoch 3/4] Train: 100%|██████████| 18024/18024 [40:57<00:00,  7.33it/s, lastLoss=3.8818, lr=1.90e-06, meanLoss=3.9599]
[Epoch 3/4] Val  : 100%|██████████| 782/782 [00:52<00:00, 14.79it/s]


Epoch 3: Train Loss = 3.9599, Val Loss = 3.8433
best_model_no_pretrained updated (Val Loss: 3.8433)


[Epoch 4/4] Train: 100%|██████████| 18024/18024 [41:05<00:00,  7.31it/s, lastLoss=3.3516, lr=1.65e-06, meanLoss=3.8348]
[Epoch 4/4] Val  : 100%|██████████| 782/782 [00:48<00:00, 16.10it/s]


Epoch 4: Train Loss = 3.8348, Val Loss = 3.7478
best_model_no_pretrained updated (Val Loss: 3.7478)


# 2. Pretrained CNN

In [6]:
IMNET_MEAN, IMNET_STD = [0.485,0.456,0.406], [0.229,0.224,0.225] # Imagenet mean and std, resnet is trained on imagenet

#Transforms for data
tr_tf = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.9,1.0), ratio=(0.9,1.1)),
    transforms.RandomHorizontalFlip(), transforms.RandomRotation(10),
    transforms.ColorJitter(0.1,0.1,0.1,0.05),
    transforms.ToTensor(), transforms.Normalize(IMNET_MEAN, IMNET_STD)])
va_tf = transforms.Compose([
    transforms.Resize(256), transforms.CenterCrop(224),
    transforms.ToTensor(),  transforms.Normalize(IMNET_MEAN, IMNET_STD)])

#Training dataset and dataloader for pretrained model
img_dir='data/images/train2017'
train_dataset_pretrained = CaptionImageDataset(
    img_dir=img_dir,
    annotations=numerical,
    vocab=vocab,
    transform=tr_tf
)
train_loader_pretrained = DataLoader(
    train_dataset_pretrained,
    batch_size=32,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    collate_fn=collate_fn
)

#Validation dataset and dataloader for pretrained model
val_img_dir='data/images/val2017'
val_dataset_pretrained = CaptionImageDataset(
    img_dir=val_img_dir,
    annotations=numerical_val,
    vocab=vocab,
    transform=va_tf
)
val_loader_pretrained = DataLoader(
    val_dataset_pretrained,
    batch_size=32,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    collate_fn=collate_fn
)

# Positional encoding for 2D images
class PE2D(nn.Module):
    def __init__(self, d, h=7, w=7):
        super().__init__()
        if d % 4: raise ValueError("d_model must be divisible by 4")
        pe = torch.zeros(d, h, w)
        d_half = d // 2
        div = torch.exp(torch.arange(0, d_half, 2) * (-math.log(10000.0) / d_half))

        pos_w = torch.arange(w).unsqueeze(1)
        pos_h = torch.arange(h).unsqueeze(1)

        sin_w = torch.sin(pos_w * div).T.unsqueeze(1).repeat(1, h, 1)
        cos_w = torch.cos(pos_w * div).T.unsqueeze(1).repeat(1, h, 1)
        pe[0:d_half:2]  = sin_w
        pe[1:d_half:2]  = cos_w

        sin_h = torch.sin(pos_h * div).T.unsqueeze(2).repeat(1, 1, w)
        cos_h = torch.cos(pos_h * div).T.unsqueeze(2).repeat(1, 1, w)
        pe[d_half::2]   = sin_h
        pe[d_half+1::2] = cos_h

        self.register_buffer('pe', pe.flatten(1).T.unsqueeze(1))
    def forward(self, x):
        return x + self.pe[:x.size(0)]

# Encoder for pretrained model
class ResNetEncoder(nn.Module):
    def __init__(self, d=512):
        super().__init__()
        rn = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        self.backbone = nn.Sequential(*list(rn.children())[:-2])
        for p in self.backbone.parameters(): p.requires_grad=False
        self.proj = nn.Conv2d(2048, d, 1)
        nn.init.xavier_uniform_(self.proj.weight)
        self.pe2d = PE2D(d, 7, 7)
    def forward(self, x):
        x = self.proj(self.backbone(x))
        x = x.flatten(2).permute(2,0,1)
        return self.pe2d(x)

# Decoder for pretrained model
class DecoderBlock(nn.Module):
    def __init__(self, d, heads, d_ff, p=0.1):
        super().__init__()
        self.self_attn  = nn.MultiheadAttention(d, heads, dropout=p, batch_first=True)
        self.cross_attn = nn.MultiheadAttention(d, heads, dropout=p, batch_first=True)
        self.ff = nn.Sequential(nn.Linear(d, d_ff), nn.ReLU(), nn.Dropout(p),
                                nn.Linear(d_ff, d))
        self.norm1 = nn.LayerNorm(d); self.drop1 = nn.Dropout(p)
        self.norm2 = nn.LayerNorm(d); self.drop2 = nn.Dropout(p)
        self.norm3 = nn.LayerNorm(d); self.drop3 = nn.Dropout(p)
    def forward(self, x, mem, attn_mask, pad_mask):
        sa,_ = self.self_attn(x,x,x, attn_mask=attn_mask, key_padding_mask=pad_mask)
        x = self.norm1(x + self.drop1(sa))
        ca,_ = self.cross_attn(x, mem, mem)
        x = self.norm2(x + self.drop2(ca))
        x = self.norm3(x + self.drop3(self.ff(x)))
        return x
class CaptionDecoder(nn.Module):
    def __init__(self, vocab, d=512, heads=8, d_ff=2048,
                 layers=6, dropout=0.1, max_len=70):
        super().__init__()
        self.tok = nn.Embedding(vocab, d)
        self.pos = nn.Embedding(max_len, d)
        self.blocks = nn.ModuleList(
            DecoderBlock(d, heads, d_ff, dropout) for _ in range(layers))
        self.ln = nn.LayerNorm(d)
        self.out = nn.Linear(d, vocab)
        self.d = d
    def forward(self, mem, caps, pad=0):
        B,L = caps.size()
        pos = torch.arange(L-1, device=caps.device).unsqueeze(0)
        x = self.tok(caps[:,:-1]) * math.sqrt(self.d) + self.pos(pos)
        tgt_mask = torch.triu(torch.ones((L-1, L-1),
                                 dtype=torch.bool,
                                 device=caps.device), 1)
        kpm = (caps[:,:-1] == pad)
        mem = mem.permute(1,0,2)
        for blk in self.blocks:
            x = blk(x, mem, tgt_mask, kpm)
        x = self.ln(x)
        return self.out(x)

In [None]:
# Hyperparameters
lr_encoder = 1e-3
lr_decoder = 4e-3
embed_size_pretrained = 512
warmup_steps = 50000
EPOCHS = 15

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the models
encoder = ResNetEncoder(embed_size_pretrained).to(device)
decoder = CaptionDecoder(len(vocab['itos']), embed_size_pretrained).to(device)

# Define the loss function (cross entropy)
criterion = nn.CrossEntropyLoss(ignore_index=stoi['<pad>'], label_smoothing=0.1)
optimizer = torch.optim.AdamW([
    {'params': decoder.parameters(),          'lr': lr_decoder},
    {'params': encoder.proj.parameters(),     'lr': lr_encoder},
], weight_decay=1e-4)
def transformer_sched(opt, d=embed_size_pretrained, warm=warmup_steps):
    f = lambda s: (d**-0.5) * min(s**-0.5, s*warm**-1.5) if s else 0.
    return torch.optim.lr_scheduler.LambdaLR(opt, f)
scheduler = transformer_sched(optimizer)
scaler = torch.amp.GradScaler(device='cuda')

# Training loop
best = float('inf')
for ep in range(1, EPOCHS+1):
    encoder.train(); decoder.train()
    tloss = 0
    samples_seen = 0
    pbar = tqdm(train_loader_pretrained, desc=f"[Epoch {ep}/{EPOCHS}] Train")
    for imgs,caps,_ in pbar:
        imgs,caps = imgs.to(device), caps.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.autocast(device.type, enabled=device.type=='cuda'):
            mem = encoder(imgs)
            logits = decoder(mem, caps)
            loss = criterion(logits.reshape(-1,len(vocab['itos'])), caps[:,1:].reshape(-1))
        if torch.isnan(loss) or torch.isinf(loss):
            continue
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(decoder.parameters(), 0.5)
        scaler.step(optimizer); scaler.update(); scheduler.step()
        
        batch_loss = loss.item()
        current_lr = scheduler.get_last_lr()[0]
        samples_seen += imgs.size(0)

        tloss += batch_loss*imgs.size(0)
        pbar.set_postfix(lastLoss=f"{batch_loss:.4f}", meanLoss=f"{tloss/(samples_seen):.4f}", lr=f"{current_lr:.2e}")

    tloss /= len(train_loader_pretrained.dataset)

    encoder.eval(); decoder.eval()
    vloss = 0
    with torch.no_grad():
        for imgs,caps,_ in tqdm(val_loader_pretrained, desc=f'ep{ep} val'):
            imgs,caps = imgs.to(device), caps.to(device)
            with torch.autocast(device.type, enabled=device.type=='cuda'):
                mem = encoder(imgs)
                logits = decoder(mem, caps)
                loss = criterion(logits.reshape(-1,len(vocab['itos'])), caps[:,1:].reshape(-1))
            vloss += loss.item()*imgs.size(0)
    vloss /= len(val_loader_pretrained.dataset)

    print(f'→ epoch {ep}: train {tloss:.4f} | val {vloss:.4f}')
    if vloss < best:
        best = vloss
        torch.save({'encoder': encoder.state_dict(),
            'decoder': decoder.state_dict(),
            'itos': itos,
            'stoi': stoi},
           'best_model_pretrained.pth')
        print(f"best_model_pretrained updated (Val Loss: {best:.4f})")
    
    torch.cuda.empty_cache()


[Epoch 1/15] Train: 100%|██████████| 18024/18024 [27:50<00:00, 10.79it/s, lastLoss=6.2569, lr=2.85e-07, meanLoss=7.4869]
ep1 val: 100%|██████████| 782/782 [00:42<00:00, 18.52it/s]


→ epoch 1: train 7.4869 | val 6.2537
best_model_pretrained updated (Val Loss: 6.2537)


[Epoch 2/15] Train: 100%|██████████| 18024/18024 [28:09<00:00, 10.67it/s, lastLoss=4.8921, lr=5.70e-07, meanLoss=5.6661]
ep2 val: 100%|██████████| 782/782 [00:41<00:00, 18.93it/s]


→ epoch 2: train 5.6661 | val 5.1725
best_model_pretrained updated (Val Loss: 5.1725)


[Epoch 3/15] Train: 100%|██████████| 18024/18024 [28:15<00:00, 10.63it/s, lastLoss=4.9420, lr=7.60e-07, meanLoss=4.9672]
ep3 val: 100%|██████████| 782/782 [00:41<00:00, 18.93it/s]


→ epoch 3: train 4.9672 | val 4.7133
best_model_pretrained updated (Val Loss: 4.7133)


[Epoch 4/15] Train: 100%|██████████| 18024/18024 [28:14<00:00, 10.64it/s, lastLoss=5.0846, lr=6.58e-07, meanLoss=4.6383]
ep4 val: 100%|██████████| 782/782 [00:41<00:00, 18.70it/s]


→ epoch 4: train 4.6383 | val 4.4884
best_model_pretrained updated (Val Loss: 4.4884)


[Epoch 5/15] Train: 100%|██████████| 18024/18024 [28:12<00:00, 10.65it/s, lastLoss=3.8941, lr=5.89e-07, meanLoss=4.4642]
ep5 val: 100%|██████████| 782/782 [00:44<00:00, 17.74it/s]


→ epoch 5: train 4.4642 | val 4.3503
best_model_pretrained updated (Val Loss: 4.3503)


[Epoch 6/15] Train: 100%|██████████| 18024/18024 [28:18<00:00, 10.61it/s, lastLoss=4.7821, lr=5.38e-07, meanLoss=4.3527]
ep6 val: 100%|██████████| 782/782 [00:41<00:00, 18.73it/s]


→ epoch 6: train 4.3527 | val 4.2603
best_model_pretrained updated (Val Loss: 4.2603)


[Epoch 7/15] Train: 100%|██████████| 18024/18024 [28:18<00:00, 10.61it/s, lastLoss=4.6283, lr=4.98e-07, meanLoss=4.2725]
ep7 val: 100%|██████████| 782/782 [00:41<00:00, 18.75it/s]


→ epoch 7: train 4.2725 | val 4.1905
best_model_pretrained updated (Val Loss: 4.1905)


[Epoch 8/15] Train: 100%|██████████| 18024/18024 [28:11<00:00, 10.65it/s, lastLoss=5.3288, lr=4.66e-07, meanLoss=4.2109]
ep8 val: 100%|██████████| 782/782 [00:42<00:00, 18.22it/s]


→ epoch 8: train 4.2109 | val 4.1399
best_model_pretrained updated (Val Loss: 4.1399)


[Epoch 9/15] Train: 100%|██████████| 18024/18024 [28:25<00:00, 10.57it/s, lastLoss=4.5949, lr=4.39e-07, meanLoss=4.1616]
ep9 val: 100%|██████████| 782/782 [00:41<00:00, 18.91it/s]


→ epoch 9: train 4.1616 | val 4.0950
best_model_pretrained updated (Val Loss: 4.0950)


[Epoch 10/15] Train: 100%|██████████| 18024/18024 [28:41<00:00, 10.47it/s, lastLoss=3.9350, lr=4.16e-07, meanLoss=4.1204]
ep10 val: 100%|██████████| 782/782 [00:41<00:00, 18.97it/s]


→ epoch 10: train 4.1204 | val 4.0560
best_model_pretrained updated (Val Loss: 4.0560)


[Epoch 11/15] Train:  28%|██▊       | 5106/18024 [08:10<21:06, 10.20it/s, lastLoss=3.8725, lr=4.11e-07, meanLoss=4.0963]

# EVALUATION OF THE MODELS

In [3]:
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.spice.spice import Spice
from collections import defaultdict

In [13]:
with open('data/annotations/captions_test2017.json', 'r') as f:
    coco_test = json.load(f)
# Function to generate captions for images
if (torch.cuda.is_available()):
    device = torch.device('cuda')
@torch.no_grad()
def generate_caption(img_path, encoder, decoder, infer_tf, table, max_len=70, top_k=1, temperature=1.0):
    img = Image.open(img_path).convert('RGB')
    img = infer_tf(img).unsqueeze(0).to(device)

    with torch.amp.autocast(device.type):
        feat = encoder(img)

    caption_idx = [stoi['<start>']]
    steps = []
    for _ in range(max_len - 1):
        seq = torch.tensor(caption_idx + [stoi['<pad>']],
                           device=device).unsqueeze(0)

        with torch.amp.autocast(device.type):
            logits = decoder(feat, seq)

        logits = logits[0, -1] / temperature
        probs  = torch.softmax(logits, dim=-1)
        top_p, top_i = torch.topk(probs, top_k)

        steps.append([(itos[i], p.item()) for i, p in zip(top_i, top_p)])
        next_idx = top_i[0].item()
        caption_idx.append(next_idx)

        if next_idx == stoi['<end>']:
            break
        
    if table:
        rows = []
        for pos, token_probs in enumerate(steps, start=1):
            row = OrderedDict(pos=pos)
            for r, (tok, p) in enumerate(token_probs, start=1):
                row[f'token{r}'] = tok
                row[f'prob{r}']  = f'{p:.4f}'
            rows.append(row)
        df = pd.DataFrame(rows).set_index('pos')
        print(df.to_string())

    words = [itos[i] for i in caption_idx[1:-1]]
    return ' '.join(words)


def evaluate_model(encoder, decoder, test_dataset, transform, itos, stoi, spice=True):
    references = {}
    hypotheses = {}
    
    # Get image paths and reference captions
    img_paths = []
    img_captions = defaultdict(list)
    
    for ann in test_dataset['annotations']:
        img_id = ann['image_id']
        for img in test_dataset['images']:
            if img['id'] == img_id:
                img_path = f"data/images/test2017/{img['file_name']}"
                img_paths.append((img_id, img_path))
                caption = ' '.join(tokenize_english(ann['caption']))
                img_captions[img_id].append(caption)
                break
    
    # Remove duplicates while preserving order
    unique_paths = []
    seen = set()
    for img_id, path in img_paths:
        if img_id not in seen:
            seen.add(img_id)
            unique_paths.append((img_id, path))
    
    # Generate captions for each image
    for i, (img_id, img_path) in enumerate(tqdm(unique_paths)):
        try:
            generated_caption = generate_caption(img_path, encoder, decoder, transform, table=False)
            hypotheses[str(i)] = [generated_caption]
            references[str(i)] = [cap for cap in img_captions[img_id]]
        except Exception as e:
            print(f"Error processing {img_path}: {e}")
    
    # Calculate metrics
    metrics = {
        'BLEU': Bleu(4),
        'ROUGE': Rouge(),
        'METEOR': Meteor(),
        'CIDEr': Cider(),
    }
    
    # Try to use SPICE if available
    if spice:
        metrics['SPICE'] = Spice()
    results = {}
    for metric_name, metric in metrics.items():
        if metric_name == 'BLEU':
            score, _ = metric.compute_score(references, hypotheses)
            for i, n in enumerate([1, 2, 3, 4]):
                results[f'BLEU-{n}'] = score[i]
        else:
            score, _ = metric.compute_score(references, hypotheses)
            results[metric_name] = score
    
    return results

## No pretrained CNN

In [14]:
embed_size_no_pretrained = 512
itos = vocab['itos']
stoi = vocab['stoi']
pad_idx   = stoi['<pad>']
start_idx = stoi['<start>']
end_idx   = stoi['<end>']
num_heads  = 8
hidden_dim = 2048
# Create the models
encoder = SimpleEncoderCNN(embed_size_no_pretrained).to(device)
decoder = TransformerDecoder(
    embed_size_no_pretrained, num_heads, hidden_dim, len(itos)
).to(device)

# Load best models
ckpt = torch.load('best_model_no_pretrained.pth', map_location=device)
encoder.load_state_dict(ckpt['encoder_state'])
decoder.load_state_dict(ckpt['decoder_state'])
encoder.eval(); decoder.eval()

infer_transform_no_pretrained = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])
# Evaluate the model
results = evaluate_model(encoder, decoder, coco_test, infer_transform_no_pretrained, itos, stoi, False)

# Print the results
print("Evaluation results of the model without pretraining:")
for metric, score in results.items():
    if isinstance(score, list):
        print(f"{metric}: {', '.join([f'{s:.4f}' for s in score])}")
    else:
        print(f"{metric}: {score:.4f}")

100%|██████████| 3000/3000 [01:52<00:00, 26.66it/s]


{'testlen': 30108, 'reflen': 26487, 'guess': [30108, 27108, 24108, 21108], 'correct': [12756, 4122, 1171, 300]}
ratio: 1.1367085740173997
Evaluation results of the model without pretraining:
BLEU-1: 0.4237
BLEU-2: 0.2538
BLEU-3: 0.1463
BLEU-4: 0.0817
ROUGE: 0.3699
METEOR: 0.1265
CIDEr: 0.2118


## Pretrained CNN

In [None]:
embed_size_pretrained = 512
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

# Create the models
encoder_pretrained = ResNetEncoder(embed_size_pretrained).to(device)
decoder_pretrained = CaptionDecoder(len(vocab['itos']), embed_size_pretrained).to(device)

# Load best models
ckpt = torch.load('best_model_pretrained.pth', map_location=device)
encoder_pretrained.load_state_dict(ckpt['encoder'])
decoder_pretrained.load_state_dict(ckpt['decoder'])
encoder_pretrained.eval(); decoder_pretrained.eval()


infer_tf_pretrained = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])
results = evaluate_model(
    encoder_pretrained,
    decoder_pretrained,
    coco_test,
    infer_tf_pretrained,
    itos,
    stoi
)
print("Evaluation results for the model that uses ResNet50:")
for metric, score in results.items():
    if isinstance(score, list):
        print(f"{metric}: {', '.join([f'{s:.4f}' for s in score])}")
    else:
        print(f"{metric}: {score:.4f}")

100%|██████████| 3000/3000 [07:31<00:00,  6.65it/s]


{'testlen': 28934, 'reflen': 28719, 'guess': [28934, 25934, 22934, 19934], 'correct': [18662, 8636, 3351, 1230]}
ratio: 1.0074863330895572


Parsing reference captions
Parsing test captions


SPICE evaluation took: 2.753 s
Evaluation results for the model that uses ResNet50:
BLEU-1: 0.6450
BLEU-2: 0.4634
BLEU-3: 0.3154
BLEU-4: 0.2098
ROUGE: 0.4716
METEOR: 0.2022
CIDEr: 0.6338
SPICE: 0.1335
