In [12]:
# Standard library imports
import gc
import json
import math
import os
import pickle
import re
from collections import Counter, OrderedDict

# Third-party library imports
import nltk
import pandas as pd
from PIL import Image
from tqdm import tqdm

# PyTorch and related imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms


In [13]:
#This code is needed for all the models
nltk.download('punkt_tab', quiet=True)
def tokenize_english(text):
    text = text.lower()
    text = re.sub(r"[\r\n\t]+", " ", text)
    text = re.sub(r"[^a-z0-9\.\,\!\?\:\;\’\'\-]", " ", text)
    tokens = nltk.word_tokenize(text)
    tokens = [t for t in tokens if any(c.isalnum() for c in t)]
    return tokens

with open('data/annotations/captions_train2017.json', 'r') as f:
    coco_train = json.load(f)

with open('data/annotations/captions_val2017.json', 'r') as f:
    coco_val = json.load(f)

img_id_to_filename = {img['id']: img['file_name'] for img in coco_train['images']}

annotations = []
for ann in coco_train['annotations']:
    fname = img_id_to_filename[ann['image_id']]
    tokens = tokenize_english(ann['caption'])
    seq = ['<start>'] + tokens + ['<end>']
    annotations.append((fname, seq))


min_freq = 5
counter = Counter(tok for _, seq in annotations for tok in seq)
words = [w for w, cnt in counter.items() if cnt >= min_freq]
specials = ['<pad>', '<start>', '<end>', '<unk>']
itos = specials + words
stoi = {w: i for i, w in enumerate(itos)}
vocab = {'itos': itos, 'stoi': stoi}

numerical = []
for fname, seq in annotations:
    ids = [stoi.get(tok, stoi['<unk>']) for tok in seq]
    numerical.append((fname, ids))


img_id_to_filename_val = {img['id']: img['file_name'] for img in coco_val['images']}

annotations_val = []
for ann in coco_val['annotations']:
    fname = img_id_to_filename_val[ann['image_id']]
    tokens = tokenize_english(ann['caption'])
    seq = ['<start>'] + tokens + ['<end>']
    annotations_val.append((fname, seq))

numerical_val = []
for fname, seq in annotations_val:
    ids = [stoi.get(tok, stoi['<unk>']) for tok in seq]
    numerical_val.append((fname, ids))


class CaptionImageDataset(Dataset):
    def __init__(self, img_dir, annotations, vocab, transform=None):
        self.img_dir = img_dir
        self.annotations = annotations
        self.vocab = vocab
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        fname, seq_ids = self.annotations[idx]
        img = Image.open(os.path.join(self.img_dir, fname)).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, torch.tensor(seq_ids, dtype=torch.long)

def collate_fn(batch):
    imgs, seqs = zip(*batch)
    imgs = torch.stack(imgs, 0)
    lengths = [len(s) for s in seqs]
    max_len = max(lengths)
    padded = torch.zeros(len(seqs), max_len, dtype=torch.long)
    for i, s in enumerate(seqs):
        padded[i, :lengths[i]] = s
    return imgs, padded, lengths

# 1. USING TRANSFORMERS

## 1.1 NO PRETRAINED CNN

In [14]:
# Dataset that will only be used for computing mean and std
class SimpleImageDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_paths = [
            os.path.join(img_dir, fname)
            for fname in os.listdir(img_dir)
            if fname.lower().endswith(('.jpg', '.png'))
        ]
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img = Image.open(self.img_paths[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img

# Function to compute mean and std
def compute_mean_std(img_dir, batch_size=64, num_workers=8):
    basic_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.ToTensor(),
    ])

    dataset_mean = SimpleImageDataset(img_dir, transform=basic_transform)
    loader_mean = DataLoader(dataset_mean, batch_size=batch_size,
                        shuffle=False, num_workers=num_workers,
                        pin_memory=True)

    sum_rgb = torch.zeros(3)
    sum_squared = torch.zeros(3)
    num_pixels = 0

    for batch in loader_mean:
        b, c, h, w = batch.shape
        sum_rgb += batch.sum(dim=[0, 2, 3])
        sum_squared += (batch ** 2).sum(dim=[0, 2, 3])
        num_pixels += b * h * w

    mean = sum_rgb / num_pixels
    std = torch.sqrt((sum_squared / num_pixels) - mean ** 2)
    return mean, std

# Compute mean and std for the training images and create the datasets and loaders
img_directory = 'data/images/train2017'
if os.path.exists('image_stats.pkl'):
    with open('image_stats.pkl', 'rb') as f:
        stats = pickle.load(f)
        mean, std = stats['mean'], stats['std']
else:
    mean, std = compute_mean_std(img_directory)
    stats = {'mean': mean, 'std': std}
    with open('image_stats.pkl', 'wb') as f:
        pickle.dump(stats, f)

train_transform_no_pretrained = transforms.Compose([
    transforms.RandomResizedCrop(
        256, 
        scale=(0.9, 1.0), 
        ratio=(0.9, 1.1)
    ),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(
        brightness=0.1, 
        contrast=0.1, 
        saturation=0.1, 
        hue=0.05
    ),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])
dataset_no_pretrained = CaptionImageDataset(
    img_dir=img_directory,
    annotations=numerical,
    vocab=vocab,
    transform=train_transform_no_pretrained
)

batch_size = 64
train_loader_no_pretrained = DataLoader(
    dataset_no_pretrained,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    collate_fn=collate_fn
)

val_transform_no_pretrained = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])
val_img_directory = 'data/images/val2017'

val_dataset_no_pretrained = CaptionImageDataset(
    img_dir=val_img_directory,
    annotations=numerical_val,
    vocab=vocab,
    transform=val_transform_no_pretrained
)

val_loader_no_pretrained = DataLoader(
    val_dataset_no_pretrained,
    batch_size=batch_size,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    collate_fn=collate_fn
)

# Define the models
# Encoder
class SimpleEncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(SimpleEncoderCNN, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)  # 32×32
        )
        self.block4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1,1))  # reduce a 1×1
        )
        self.embed = nn.Linear(512, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = x.view(x.size(0), -1)
        x = self.embed(x)
        x = self.bn(x)
        return x
    
# Positional Encoding for Transformer
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return x

# Decoder
class TransformerDecoder(nn.Module):
    def __init__(self, embed_size, num_heads, ff_dim, vocab_size, max_len=70, dropout=0.1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.pos_enc = PositionalEncoding(embed_size, max_len)
        self.self_attn = nn.MultiheadAttention(embed_size, num_heads, dropout=dropout)
        self.cross_attn = nn.MultiheadAttention(embed_size, num_heads, dropout=dropout)
        self.ffn = nn.Sequential(
            nn.Linear(embed_size, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_size),
        )
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.norm3 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(embed_size, vocab_size)

    def _generate_square_mask(self, sz, device):
        mask = torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1).to(device)
        return mask

    def forward(self, features, captions):
        tgt = captions[:, :-1]  
        x = self.embed(tgt) * math.sqrt(self.embed.embedding_dim)
        x = self.pos_enc(x)
        x = x.transpose(0, 1) 

        mask = self._generate_square_mask(x.size(0), x.device)
        attn_out, _ = self.self_attn(x, x, x, attn_mask=mask)
        x = self.norm1(x + self.dropout(attn_out))

        mem = features.unsqueeze(0)           
        cross_out, _ = self.cross_attn(x, mem, mem)
        x = self.norm2(x + self.dropout(cross_out))

        ff_out = self.ffn(x)
        x = self.norm3(x + self.dropout(ff_out))

        x = x.transpose(0, 1)      
        logits = self.out(x)                                
        return logits

In [15]:
#Hyperparameters
embed_size  = 512
num_heads   = 8
hidden_dim  = 2048
num_epochs = 12
lr = 7e-3
weight_decay = 1e-4
warmup_steps = 40000

vocab_size  = len(vocab['itos'])
#Load the model to the GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gc.collect()
torch.cuda.empty_cache()

#Model
encoder = SimpleEncoderCNN(embed_size).to(device)
decoder = TransformerDecoder(
    embed_size, num_heads, hidden_dim, vocab_size
).to(device)

#Loss function, optimizer and learning rate scheduler
criterion = nn.CrossEntropyLoss(ignore_index=vocab['stoi']['<pad>'])
optimizer = torch.optim.Adam(
    list(encoder.parameters()) + list(decoder.parameters()),
    lr=lr,
    weight_decay=weight_decay,
)
def get_transformer_scheduler(optimizer, d_model, warmup_steps=1000):
    def lr_lambda(step):
        step = max(step, 1)
        return (d_model ** -0.5) * min(step ** -0.5, step * (warmup_steps ** -1.5))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
scaler = torch.amp.GradScaler(
    'cuda',
    init_scale=512, 
    growth_interval=2000,
    growth_factor=2.0,
    backoff_factor=0.5
)
scheduler = get_transformer_scheduler(optimizer, embed_size, warmup_steps=warmup_steps)
best_val_loss = float('inf')

#Training loop
for epoch in range(1, num_epochs + 1):
    encoder.train()
    decoder.train()
    train_loss = 0.0

    pbar = tqdm(train_loader_no_pretrained, desc=f"[Epoch {epoch}/{num_epochs}] Train")
    for images, captions, lengths in pbar:
        images, captions = images.to(device), captions.to(device)

        optimizer.zero_grad()
        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            feats   = encoder(images)
            logits  = decoder(feats, captions)

        loss = criterion(
            logits.float().view(-1, vocab_size),
            captions[:, 1:].contiguous().view(-1)
        )

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(
            list(encoder.parameters()) + list(decoder.parameters()),
            max_norm=0.5
        )
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        train_loss += loss.item()
        lr = scheduler.get_last_lr()[0]
        pbar.set_postfix(lastLoss=f"{loss.item():.4f}", meanLoss=f"{train_loss/(pbar.n+1):.4f}", lr=f"{lr:.2e}")

    
    avg_train = train_loss / len(train_loader_no_pretrained)

    encoder.eval()
    decoder.eval()
    val_loss = 0.0
    # Validation
    with torch.no_grad():
        for images, captions, lengths in tqdm(val_loader_no_pretrained, desc=f"[Epoch {epoch}/{num_epochs}] Val  "):
            images, captions = images.to(device), captions.to(device)
            with torch.amp.autocast('cuda'):
                feats   = encoder(images)
                outputs = decoder(feats, captions)
                loss    = criterion(
                    outputs.view(-1, vocab_size),
                    captions[:, 1:].contiguous().view(-1)
                )
            val_loss += loss.item()

    avg_val = val_loss / len(val_loader_no_pretrained)
    print(f"Epoch {epoch}: Train Loss = {avg_train:.4f}, Val Loss = {avg_val:.4f}")

    if avg_val < best_val_loss:
        best_val_loss = avg_val
        ckpt = {
            'epoch': epoch,
            'encoder_state': encoder.state_dict(),
            'decoder_state': decoder.state_dict(),
            'optim_state': optimizer.state_dict(),
            'sched_state': scheduler.state_dict(),
            'val_loss': best_val_loss
        }
        torch.save(ckpt, 'best_model_no_pretrained.pth')
        print(f"best_model_no_pretrained updated (Val Loss: {best_val_loss:.4f})")

[Epoch 1/12] Train:   0%|          | 9/9012 [00:42<11:49:30,  4.73s/it, lastLoss=9.3566, lr=3.48e-10, meanLoss=9.3717]


KeyboardInterrupt: 

In [None]:
#Data for the reconstruction of the model
itos = vocab['itos']
stoi = vocab['stoi']
pad_idx   = stoi['<pad>']
start_idx = stoi['<start>']
end_idx   = stoi['<end>']
embed_size = 512
num_heads  = 8
hidden_dim = 2048

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#Reconstruction of the model
encoder = SimpleEncoderCNN(embed_size).to(device)
decoder = TransformerDecoder(
    embed_size, num_heads, hidden_dim, len(itos)
).to(device)
ckpt = torch.load('best_model_no_pretrained.pth', map_location=device)
encoder.load_state_dict(ckpt['encoder_state'])
decoder.load_state_dict(ckpt['decoder_state'])

encoder.eval()
decoder.eval()

infer_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

# Function to caption an image and print the results in a table
@torch.no_grad()
def caption_with_table(img_path, max_len=50, top_k=10, temperature=1.0):
    image = Image.open(img_path).convert('RGB')
    image = infer_transform(image).unsqueeze(0).to(device)

    with torch.amp.autocast(device.type, torch.float16):
        feat = encoder(image)  

    steps = [] 
    caption_idx = [start_idx]

    for _ in range(max_len - 1):
        seq_in = caption_idx + [pad_idx]
        tgt    = torch.tensor(seq_in, device=device).unsqueeze(0)

        with torch.amp.autocast(device.type, torch.float16):
            logits = decoder(feat, tgt)

        logits_step = logits[0, -1] / temperature
        probs       = torch.softmax(logits_step, dim=-1)

        top_p, top_i = torch.topk(probs, top_k)
        step_info = [(itos[idx], p.item()) for idx, p in zip(top_i, top_p)]
        steps.append(step_info)

        next_idx = top_i[0].item()
        caption_idx.append(next_idx)
        if next_idx == end_idx:
            break

    rows = []
    for pos, token_probs in enumerate(steps, start=1):
        row = OrderedDict(pos=pos)
        for r, (tok, p) in enumerate(token_probs, start=1):
            row[f'token{r}'] = tok
            row[f'prob{r}']  = f'{p:.4f}'
        rows.append(row)

    df = pd.DataFrame(rows).set_index('pos')
    print(df.to_string())

    words = [itos[i] for i in caption_idx[1:-1]]
    return ' '.join(words)

# Selection of the image to be captioned
test_img = 'data/images/val2017/000000579635.jpg'
print(caption_with_table(test_img))

        token1   prob1     token2   prob2 token3   prob3     token4   prob4      token5   prob5  token6   prob6      token7   prob7  token8   prob8  token9   prob9  token10  prob10
pos                                                                                                                                                                                 
1            a  0.7109        the  0.0495    two  0.0476         an  0.0287       there  0.0186  people  0.0163     several  0.0108   three  0.0102     man  0.0067     some  0.0063
2          man  0.2203     person  0.1730  group  0.0389     surfer  0.0347       large  0.0317   woman  0.0268       skier  0.0208   small  0.0199  couple  0.0148    plane  0.0137
3       riding  0.1493         in  0.1136     is  0.1118         on  0.0695      flying  0.0485    with  0.0350    standing  0.0319     and  0.0245    that  0.0240  wearing  0.0233
4            a  0.5615       skis  0.0991     on  0.0655        the  0.0308        down  0.0191

## 1.2. Pretrained CNN

In [16]:
IMNET_MEAN, IMNET_STD = [0.485,0.456,0.406], [0.229,0.224,0.225] # Imagenet mean and std, resnet is trained on imagenet

#Transforms for data
tr_tf = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.9,1.0), ratio=(0.9,1.1)),
    transforms.RandomHorizontalFlip(), transforms.RandomRotation(10),
    transforms.ColorJitter(0.1,0.1,0.1,0.05),
    transforms.ToTensor(), transforms.Normalize(IMNET_MEAN, IMNET_STD)])
va_tf = transforms.Compose([
    transforms.Resize(256), transforms.CenterCrop(224),
    transforms.ToTensor(),  transforms.Normalize(IMNET_MEAN, IMNET_STD)])

#Training dataset and dataloader for pretrained model
img_dir='data/images/train2017'
train_dataset_pretrained = CaptionImageDataset(
    img_dir=img_dir,
    annotations=numerical,
    vocab=vocab,
    transform=tr_tf
)
train_loader_pretrained = DataLoader(
    train_dataset_pretrained,
    batch_size=64,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    collate_fn=collate_fn
)

#Validation dataset and dataloader for pretrained model
val_img_dir='data/images/val2017'
val_dataset_pretrained = CaptionImageDataset(
    img_dir=val_img_dir,
    annotations=numerical_val,
    vocab=vocab,
    transform=va_tf
)
val_loader_pretrained = DataLoader(
    val_dataset_pretrained,
    batch_size=64,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    collate_fn=collate_fn
)

# Positional encoding for 2D images
class PE2D(nn.Module):
    def __init__(self, d, h=7, w=7):
        super().__init__()
        if d % 4: raise ValueError("d_model must be divisible by 4")
        pe = torch.zeros(d, h, w)
        d_half = d // 2
        div = torch.exp(torch.arange(0, d_half, 2) * (-math.log(10000.0) / d_half))

        pos_w = torch.arange(w).unsqueeze(1)
        pos_h = torch.arange(h).unsqueeze(1)

        sin_w = torch.sin(pos_w * div).T.unsqueeze(1).repeat(1, h, 1)
        cos_w = torch.cos(pos_w * div).T.unsqueeze(1).repeat(1, h, 1)
        pe[0:d_half:2]  = sin_w
        pe[1:d_half:2]  = cos_w

        sin_h = torch.sin(pos_h * div).T.unsqueeze(2).repeat(1, 1, w)
        cos_h = torch.cos(pos_h * div).T.unsqueeze(2).repeat(1, 1, w)
        pe[d_half::2]   = sin_h
        pe[d_half+1::2] = cos_h

        self.register_buffer('pe', pe.flatten(1).T.unsqueeze(1))
    def forward(self, x):
        return x + self.pe[:x.size(0)]

# Encoder for pretrained model
class ResNetEncoder(nn.Module):
    def __init__(self, d=512):
        super().__init__()
        rn = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        self.backbone = nn.Sequential(*list(rn.children())[:-2])
        for p in self.backbone.parameters(): p.requires_grad=False
        self.proj = nn.Conv2d(2048, d, 1)
        nn.init.xavier_uniform_(self.proj.weight)
        self.pe2d = PE2D(d, 7, 7)
    def forward(self, x):
        x = self.proj(self.backbone(x))
        x = x.flatten(2).permute(2,0,1)
        return self.pe2d(x)

# Decoder for pretrained model
class DecoderBlock(nn.Module):
    def __init__(self, d, heads, d_ff, p=0.1):
        super().__init__()
        self.self_attn  = nn.MultiheadAttention(d, heads, dropout=p, batch_first=True)
        self.cross_attn = nn.MultiheadAttention(d, heads, dropout=p, batch_first=True)
        self.ff = nn.Sequential(nn.Linear(d, d_ff), nn.ReLU(), nn.Dropout(p),
                                nn.Linear(d_ff, d))
        self.norm1 = nn.LayerNorm(d); self.drop1 = nn.Dropout(p)
        self.norm2 = nn.LayerNorm(d); self.drop2 = nn.Dropout(p)
        self.norm3 = nn.LayerNorm(d); self.drop3 = nn.Dropout(p)
    def forward(self, x, mem, attn_mask, pad_mask):
        sa,_ = self.self_attn(x,x,x, attn_mask=attn_mask, key_padding_mask=pad_mask)
        x = self.norm1(x + self.drop1(sa))
        ca,_ = self.cross_attn(x, mem, mem)
        x = self.norm2(x + self.drop2(ca))
        x = self.norm3(x + self.drop3(self.ff(x)))
        return x
class CaptionDecoder(nn.Module):
    def __init__(self, vocab, d=512, heads=8, d_ff=2048,
                 layers=6, dropout=0.1, max_len=70):
        super().__init__()
        self.tok = nn.Embedding(vocab, d)
        self.pos = nn.Embedding(max_len, d)
        self.blocks = nn.ModuleList(
            DecoderBlock(d, heads, d_ff, dropout) for _ in range(layers))
        self.ln = nn.LayerNorm(d)
        self.out = nn.Linear(d, vocab)
        self.d = d
    def forward(self, mem, caps, pad=0):
        B,L = caps.size()
        pos = torch.arange(L-1, device=caps.device).unsqueeze(0)
        x = self.tok(caps[:,:-1]) * math.sqrt(self.d) + self.pos(pos)
        tgt_mask = torch.triu(torch.ones((L-1, L-1),
                                 dtype=torch.bool,
                                 device=caps.device), 1)
        kpm = (caps[:,:-1] == pad)
        mem = mem.permute(1,0,2)
        for blk in self.blocks:
            x = blk(x, mem, tgt_mask, kpm)
        x = self.ln(x)
        return self.out(x)

In [17]:
# Hyperparameters
lr_encoder = 1e-3
lr_decoder = 4e-3
embed_size_pretrained = 512
warmup_steps = 100000
EPOCHS = 30

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the models
encoder = ResNetEncoder(embed_size_pretrained).to(device)
decoder = CaptionDecoder(len(vocab['itos']), embed_size_pretrained).to(device)

# Define the loss function (cross entropy)
criterion = nn.CrossEntropyLoss(ignore_index=stoi['<pad>'], label_smoothing=0.1)
optimizer = torch.optim.AdamW([
    {'params': decoder.parameters(),          'lr': lr_decoder},
    {'params': encoder.proj.parameters(),     'lr': lr_encoder},
], weight_decay=1e-4)
def transformer_sched(opt, d=embed_size_pretrained, warm=warmup_steps):
    f = lambda s: (d**-0.5) * min(s**-0.5, s*warm**-1.5) if s else 0.
    return torch.optim.lr_scheduler.LambdaLR(opt, f)
scheduler = transformer_sched(optimizer)
scaler = torch.amp.GradScaler('cuda')

# Training loop
best = float('inf')
for ep in range(1, EPOCHS+1):
    encoder.train(); decoder.train()
    tloss = 0
    pbar = tqdm(train_loader_pretrained, desc=f"[Epoch {ep}/{EPOCHS}] Train")
    for imgs,caps,_ in pbar:
        imgs,caps = imgs.to(device), caps.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.autocast(device.type, enabled=device.type=='cuda'):
            mem = encoder(imgs)
            logits = decoder(mem, caps)
            loss = criterion(logits.reshape(-1,len(vocab['itos'])), caps[:,1:].reshape(-1))
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(decoder.parameters(), 0.5)
        scaler.step(optimizer); scaler.update(); scheduler.step()
        
        batch_loss = loss.item()
        current_lr = scheduler.get_last_lr()[0]
        
        tloss += batch_loss*imgs.size(0)
        pbar.set_postfix(lastLoss=f"{batch_loss:.4f}", meanLoss=f"{tloss/((pbar.n+1)*imgs.size(0)):.4f}", lr=f"{current_lr:.2e}")

    tloss /= len(train_loader_pretrained.dataset)

    encoder.eval(); decoder.eval()
    vloss = 0
    with torch.no_grad():
        for imgs,caps,_ in tqdm(val_loader_pretrained, desc=f'ep{ep} val'):
            imgs,caps = imgs.to(device), caps.to(device)
            with torch.autocast(device.type, enabled=device.type=='cuda'):
                mem = encoder(imgs)
                logits = decoder(mem, caps)
                loss = criterion(logits.reshape(-1,len(vocab['itos'])), caps[:,1:].reshape(-1))
            vloss += loss.item()*imgs.size(0)
    vloss /= len(val_loader_pretrained.dataset)

    print(f'→ epoch {ep}: train {tloss:.4f} | val {vloss:.4f}')
    if vloss < best:
        best = vloss
        torch.save({'encoder': encoder.state_dict(),
            'decoder': decoder.state_dict(),
            'itos': itos,
            'stoi': stoi},
           'best_model_pretrained.pth')
        print(f"best_model_pretrained updated (Val Loss: {best:.4f})")


[Epoch 1/30] Train:   0%|          | 3/9012 [00:19<16:32:22,  6.61s/it, lastLoss=9.3437, lr=1.68e-11, meanLoss=9.3498]


KeyboardInterrupt: 

In [None]:
embed_size_pretrained = 512
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

# Create the models
encoder = ResNetEncoder(embed_size_pretrained).to(device)
decoder = CaptionDecoder(len(vocab['itos']), embed_size_pretrained).to(device)

# Load best models
ckpt = torch.load('best_model_pretrained.pth', map_location=device)
encoder.load_state_dict(ckpt['encoder'])
decoder.load_state_dict(ckpt['decoder'])
encoder.eval(); decoder.eval()


infer_tf = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

# Function to create a caption from an image
@torch.no_grad()
def caption_with_table(img_path, max_len=70, top_k=10, temperature=1.0):
    img = Image.open(img_path).convert('RGB')
    img = infer_tf(img).unsqueeze(0).to(device)

    with torch.amp.autocast(device.type):
        feat = encoder(img)

    caption_idx = [stoi['<start>']]
    steps = []
    for _ in range(max_len - 1):
        seq = torch.tensor(caption_idx + [stoi['<pad>']],
                           device=device).unsqueeze(0)

        with torch.amp.autocast(device.type):
            logits = decoder(feat, seq)

        logits = logits[0, -1] / temperature
        probs  = torch.softmax(logits, dim=-1)
        top_p, top_i = torch.topk(probs, top_k)

        steps.append([(itos[i], p.item()) for i, p in zip(top_i, top_p)])
        next_idx = top_i[0].item()
        caption_idx.append(next_idx)

        if next_idx == stoi['<end>']:
            break

    rows = []
    for pos, token_probs in enumerate(steps, start=1):
        row = OrderedDict(pos=pos)
        for r, (tok, p) in enumerate(token_probs, start=1):
            row[f'token{r}'] = tok
            row[f'prob{r}']  = f'{p:.4f}'
        rows.append(row)

    df = pd.DataFrame(rows).set_index('pos')
    print(df.to_string())

    words = [itos[i] for i in caption_idx[1:-1]]
    return ' '.join(words)

# Execution of the function to generate captions
test_img = 'data/images/val2017/000000149375.jpg'
sent = caption_with_table(test_img, top_k=5)
print('\nCaption:', sent)

         token1   prob1  token2   prob2  token3   prob3   token4   prob4    token5   prob5
pos                                                                                       
1             a  0.7212     two  0.0433     the  0.0344    there  0.0125        an  0.0114
2           man  0.2261  person  0.1628   young  0.0875      boy  0.0468     woman  0.0436
3            in  0.1052      is  0.0950  riding  0.0767  holding  0.0495  standing  0.0487
4             a  0.5020     the  0.1028   front  0.0388       an  0.0278       his  0.0272
5    skateboard  0.2074   skate  0.0372    park  0.0176    field  0.0167      suit  0.0138
6            on  0.1010      in  0.0948    with  0.0694       is  0.0522        at  0.0298
7             a  0.6152     the  0.0826     top  0.0612      his  0.0509        an  0.0089
8    skateboard  0.5967   skate  0.0232   field  0.0141     ramp  0.0136     trick  0.0104
9         <end>  0.2590      in  0.1114      on  0.0426     with  0.0383        at  0.0264

# 2. USING LSTM 

In [None]:
"""
COCO-Captioner 2025
CNN-Transformer con multi-head attention explícito
(Python 3.10, PyTorch ≥ 2.1)

Pasos previos:
  ├─ data/images/train2017/   (118k jpg)
  ├─ data/images/val2017/     (5k jpg)
  └─ data/annotations/
         captions_train2017.json
         captions_val2017.json
"""

# ---------- 0) Imports ----------
import os, re, json, math
from collections import Counter
from PIL import Image
from tqdm import tqdm

import nltk
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

nltk.download('punkt_tab', quiet=True)

# ---------- 1) Utilidades de texto ----------
def tokenize_english(text: str):
    text = re.sub(r"[\r\n\t]+", " ", text.lower())
    text = re.sub(r"[^a-z0-9\.\,\!\?\:\;\’\'\-]", " ", text)
    return [t for t in nltk.word_tokenize(text) if any(c.isalnum() for c in t)]

# ---------- 2) Cargar captions ----------
def load_coco(split):
    with open(f'data/annotations/captions_{split}2017.json') as f:
        coco = json.load(f)
    id2fname = {img['id']: img['file_name'] for img in coco['images']}
    anns = []
    for ann in coco['annotations']:
        seq = ['<start>'] + tokenize_english(ann['caption']) + ['<end>']
        anns.append((id2fname[ann['image_id']], seq))
    return anns

train_anns = load_coco('train')
val_anns   = load_coco('val')

# ---------- 3) Vocabulario ----------
min_freq = 5
counter   = Counter(tok for _, seq in train_anns for tok in seq)
special   = ['<pad>', '<start>', '<end>', '<unk>']
itos      = special + [w for w, c in counter.items() if c >= min_freq]
stoi      = {w: i for i, w in enumerate(itos)}
V         = len(itos)

def numericalise(anns):
    out = []
    for fname, seq in anns:
        out.append((fname, [stoi.get(t, stoi['<unk>']) for t in seq]))
    return out

train_nums, val_nums = map(numericalise, (train_anns, val_anns))

# ---------- 4) Dataset ----------
class CocoDS(Dataset):
    def __init__(self, root, pairs, tf):
        self.root, self.pairs, self.tf = root, pairs, tf
    def __len__(self): return len(self.pairs)
    def __getitem__(self, idx):
        fname, ids = self.pairs[idx]
        img = Image.open(os.path.join(self.root, fname)).convert('RGB')
        return self.tf(img), torch.tensor(ids, dtype=torch.long)

def collate(batch):
    imgs, seqs = zip(*batch)
    imgs = torch.stack(imgs)
    Ls   = [len(s) for s in seqs]
    Lmax = max(Ls)
    padded = torch.full((len(seqs), Lmax), stoi['<pad>'], dtype=torch.long)
    for i,s in enumerate(seqs):
        padded[i,:Ls[i]] = s
    return imgs, padded, Ls

# ---------- 5) Transforms & Loaders ----------
IMNET_MEAN, IMNET_STD = [0.485,0.456,0.406], [0.229,0.224,0.225]
tr_tf = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.9,1.0), ratio=(0.9,1.1)),
    transforms.RandomHorizontalFlip(), transforms.RandomRotation(10),
    transforms.ColorJitter(0.1,0.1,0.1,0.05),
    transforms.ToTensor(), transforms.Normalize(IMNET_MEAN, IMNET_STD)])
va_tf = transforms.Compose([
    transforms.Resize(256), transforms.CenterCrop(224),
    transforms.ToTensor(),  transforms.Normalize(IMNET_MEAN, IMNET_STD)])

train_loader = DataLoader(
    CocoDS('data/images/train2017', train_nums, tr_tf),
    batch_size=64, shuffle=True,  num_workers=8, pin_memory=True, collate_fn=collate)
val_loader = DataLoader(
    CocoDS('data/images/val2017', val_nums, va_tf),
    batch_size=64, shuffle=False, num_workers=8, pin_memory=True, collate_fn=collate)

# ---------- 6) Encoder (49 tokens + PosEnc 2-D) ----------
class PE2D(nn.Module):
    def __init__(self, d, h=7, w=7):
        super().__init__()
        if d % 4: raise ValueError("d_model debe ser múltiplo de 4")
        pe = torch.zeros(d, h, w)
        d_half = d // 2
        div = torch.exp(torch.arange(0, d_half, 2) * (-math.log(10000.0) / d_half))

        pos_w = torch.arange(w).unsqueeze(1)                 # [W,1]
        pos_h = torch.arange(h).unsqueeze(1)                 # [H,1]

        # Ancho
        sin_w = torch.sin(pos_w * div).T.unsqueeze(1).repeat(1, h, 1)  # [d/4,H,W]
        cos_w = torch.cos(pos_w * div).T.unsqueeze(1).repeat(1, h, 1)
        pe[0:d_half:2]  = sin_w
        pe[1:d_half:2]  = cos_w

        # Alto
        sin_h = torch.sin(pos_h * div).T.unsqueeze(2).repeat(1, 1, w)  # [d/4,H,W]
        cos_h = torch.cos(pos_h * div).T.unsqueeze(2).repeat(1, 1, w)
        pe[d_half::2]   = sin_h
        pe[d_half+1::2] = cos_h

        self.register_buffer('pe', pe.flatten(1).T.unsqueeze(1))   # [49,1,D]
    def forward(self, x):      # x [49,B,D]
        return x + self.pe[:x.size(0)]

class ResNetEncoder(nn.Module):
    def __init__(self, d=512):
        super().__init__()
        rn = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        self.backbone = nn.Sequential(*list(rn.children())[:-2])   # hasta conv5_x
        for p in self.backbone.parameters(): p.requires_grad=False
        self.proj = nn.Conv2d(2048, d, 1)
        nn.init.xavier_uniform_(self.proj.weight)
        self.pe2d = PE2D(d, 7, 7)
    def forward(self, x):                      # [B,3,224,224]
        x = self.proj(self.backbone(x))        # [B,D,7,7]
        x = x.flatten(2).permute(2,0,1)        # [49,B,D]
        return self.pe2d(x)

# ---------- 7) Decoder con nn.MultiheadAttention ----------
class CaptionDecoderLSTM(nn.Module):
    """
    *   Imagen → vector (media de los 49 parches)  
    *   Ese vector inicializa h0; c0 = 0  
    *   Entrada al LSTM: embeddings de palabras desplazados (sin la última)  
    *   Salida: logits sobre el vocabulario para cada paso
    """
    def __init__(self, vocab_size, d=512, hidden=512,
                 layers=2, dropout=0.1, max_len=70):
        super().__init__()
        self.tok = nn.Embedding(vocab_size, d)
        self.lstm = nn.LSTM(d, hidden, layers,
                            dropout=dropout if layers > 1 else 0,
                            batch_first=True)
        self.h_init = nn.Linear(d, hidden)   # imagen → h0
        self.out = nn.Linear(hidden, vocab_size)
        self.layers = layers
        self.hidden = hidden

    def forward(self, mem, caps, pad_idx=0):    # mem [49,B,D]   caps [B,L]
        B, L = caps.size()
        context = mem.mean(0)                   # [B,D]

        h0 = torch.tanh(self.h_init(context))   # [B,H]
        h0 = h0.unsqueeze(0).repeat(self.layers, 1, 1)     # [L,B,H]
        c0 = torch.zeros_like(h0)                              # igual que h0

        emb = self.tok(caps[:, :-1])            # [B,L-1,E]
        out, _ = self.lstm(emb, (h0, c0))       # [B,L-1,H]
        out = self.out(out)                     # [B,L-1,V]
        return out

# ---------- 8) Entrenamiento ----------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
D_MODEL = 512
encoder = ResNetEncoder(D_MODEL).to(device)
decoder = CaptionDecoderLSTM(V, D_MODEL).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=stoi['<pad>'], label_smoothing=0.1)
optimizer = torch.optim.AdamW([
    {'params': decoder.parameters(),          'lr': 1e-4},
    {'params': encoder.proj.parameters(),     'lr': 5e-5},
], weight_decay=1e-4)

def transformer_sched(opt, d=D_MODEL, warm=100000):
    f = lambda s: (d**-0.5) * min(s**-0.5, s*warm**-1.5) if s else 0.
    return torch.optim.lr_scheduler.LambdaLR(opt, f)
scheduler = transformer_sched(optimizer)

scaler = torch.amp.GradScaler('cuda')

EPOCHS = 20
best = float('inf')
for ep in range(1, EPOCHS+1):
    # ---- train ----
    encoder.train(); decoder.train()
    tloss = 0
    pbar = tqdm(train_loader, desc=f'ep{ep} train')
    running_loss = 0.0
    for i, (imgs,caps,_) in enumerate(pbar):
        imgs,caps = imgs.to(device), caps.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.autocast(device.type, enabled=device.type=='cuda'):
            mem = encoder(imgs)
            logits = decoder(mem, caps)
            loss = criterion(logits.reshape(-1,V), caps[:,1:].reshape(-1))
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(decoder.parameters(), 0.5)
        scaler.step(optimizer); scaler.update(); scheduler.step()
        
        batch_loss = loss.item()
        tloss += batch_loss


        pbar.set_postfix({'loss': f'{batch_loss:.4f}'})
    tloss /= len(train_loader.dataset)

    # ---- val ----
    encoder.eval(); decoder.eval()
    vloss = 0
    with torch.no_grad():
        for imgs,caps,_ in tqdm(val_loader, desc=f'ep{ep} val'):
            imgs,caps = imgs.to(device), caps.to(device)
            with torch.autocast(device.type, enabled=device.type=='cuda'):
                mem = encoder(imgs)
                logits = decoder(mem, caps)
                loss = criterion(logits.reshape(-1,V), caps[:,1:].reshape(-1))
            vloss += loss.item()*imgs.size(0)
    vloss /= len(val_loader.dataset)

    print(f'→ epoch {ep}: train {tloss:.4f} | val {vloss:.4f}')
    if vloss < best:
        best = vloss
        torch.save({'encoder': encoder.state_dict(),
            'decoder': decoder.state_dict(),
            'itos': itos,
            'stoi': stoi},
           'best_model_lstm.pth')
        print('   ✓ checkpoint guardado')


ep1 train: 100%|██████████| 9247/9247 [27:53<00:00,  5.53it/s, loss=9.1759]  
ep1 val: 100%|██████████| 391/391 [00:50<00:00,  7.71it/s]


→ epoch 1: train 0.1438 | val 9.1664
   ✓ checkpoint guardado


ep2 train: 100%|██████████| 9247/9247 [28:50<00:00,  5.34it/s, loss=9.0985]  
ep2 val: 100%|██████████| 391/391 [00:56<00:00,  6.97it/s]


→ epoch 2: train 0.1427 | val 9.1018
   ✓ checkpoint guardado


ep3 train: 100%|██████████| 9247/9247 [29:34<00:00,  5.21it/s, loss=9.0274]  
ep3 val: 100%|██████████| 391/391 [00:55<00:00,  7.07it/s]


→ epoch 3: train 0.1417 | val 9.0318
   ✓ checkpoint guardado


ep4 train: 100%|██████████| 9247/9247 [29:36<00:00,  5.20it/s, loss=8.9311]  
ep4 val: 100%|██████████| 391/391 [00:59<00:00,  6.61it/s]


→ epoch 4: train 0.1405 | val 8.9533
   ✓ checkpoint guardado


ep5 train:   2%|▏         | 149/9247 [00:33<33:37,  4.51it/s, loss=8.9647] 


KeyboardInterrupt: 

In [None]:
embed_size = 512
num_heads  = 8
hidden_dim = 2048
num_layers = 6
mean = [0.485, 0.456, 0.406]
std  = [0.229, 0.224, 0.225]
checkpoint_path = 'best_model.pth'
image_path      = 'data/images/val2017/000000000139.jpg'  # cambia a tu imagen

# 2) Carga vocabulario
ckpt = torch.load(checkpoint_path, map_location='cpu')
itos = ckpt['itos']
stoi = ckpt['stoi']
pad_idx   = stoi['<pad>']
start_idx = stoi['<start>']
end_idx   = stoi['<end>']

# 3) Reconstruye y carga modelos
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder = SimpleEncoderCNN(embed_size).to(device)
decoder = TransformerDecoder(embed_size, num_heads, hidden_dim, len(itos), num_layers).to(device)

state = torch.load(checkpoint_path, map_location=device)
encoder.load_state_dict(state['encoder_state'])
decoder.load_state_dict(state['decoder_state'])
encoder.eval()
decoder.eval()

# 4) Transform para inferencia
infer_tf = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

# 5) Generación greedy
@torch.no_grad()
def generate_caption(img_tensor, max_len=50):
    img_tensor = img_tensor.unsqueeze(0).to(device)            # [1,3,H,W]
    feat = encoder(img_tensor)                                 # [1, E]
    caption = [stoi['<start>']]

    for _ in range(max_len - 1):
        seq = torch.tensor([caption + [pad_idx]], device=device)
        logits = decoder(feat, seq)                           # [1, L, V]
        next_logits = logits[0, -1] / 1.0                     # temperatura=1
        prob = torch.softmax(next_logits, dim=-1)
        next_idx = prob.argmax().item()
        caption.append(next_idx)
        if next_idx == end_idx:
            break

    # convierte índices a tokens, quitando <start> y <end>
    tokens = [itos[i] for i in caption[1:-1]]
    return ' '.join(tokens)

# 6) Ejecución
if __name__ == '__main__':
    img = Image.open(image_path).convert('RGB')
    img = infer_tf(img)
    result = generate_caption(img)
    print("Generated caption:", result)