In [1]:
# ==========================================
# 1. SETUP & IMPORTS
# ==========================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.amp import autocast, GradScaler
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
import torchvision.models as models

import pandas as pd
import numpy as np
import random
import math
import re
import shutil
from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

from PIL import Image
import torchvision.transforms as transforms
from collections import defaultdict

In [2]:
# BEST PRACTICE: Set seeds for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.backends.cudnn.deterministic = True

# Device configuration (GPU if available, else CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [3]:
file_path = '/kaggle/input/flickr8k/captions.txt'

# read file
img_caption_pairs = []

with open(file_path, 'r', encoding='utf-8') as f:
    lines = [line.strip() for line in f if line.strip()]

# Remove header
lines = lines[1:]

for line in lines:
    img, caption = line.split(',', 1)
    img_caption_pairs.append((img, caption.lower()))

print("First (image, caption) pair:")
print(img_caption_pairs[0])

# save only captions for tokenizer
captions_file = '/kaggle/working/captions_clean.txt'

with open(captions_file, 'w', encoding='utf-8') as f:
    for _, caption in img_caption_pairs:
        f.write(caption + '\n')

def tokenize(caption):
    return re.findall(r"\w+", caption.lower())

SPECIAL_TOKENS = {
    "<pad>": 0,
    "<unk>": 1,
    "<bos>": 2,
    "<eos>": 3
}

def build_vocab(captions, min_freq=2):
    counter = Counter()
    for c in captions:
        counter.update(tokenize(c))

    vocab = dict(SPECIAL_TOKENS)
    idx = len(vocab)

    for word, freq in counter.items():
        if freq >= min_freq:
            vocab[word] = idx
            idx += 1

    return vocab

def load_glove_embeddings(glove_path, vocab, embed_dim=300):
    embeddings = np.random.normal(
        scale=0.6,
        size=(len(vocab), embed_dim)
    )

    found = 0
    with open(glove_path, "r", encoding="utf-8") as f:
        for line in f:
            values = line.split()
            word = values[0]
            if word in vocab:
                vector = np.asarray(values[1:], dtype="float32")
                embeddings[vocab[word]] = vector
                found += 1

    print(f"Loaded GloVe vectors for {found}/{len(vocab)} words")
    return torch.tensor(embeddings, dtype=torch.float32)

all_captions = [c for _, c in img_caption_pairs]
vocab = build_vocab(all_captions, min_freq=2)

print("Vocab size:", len(vocab))

glove_path = "/kaggle/input/glove-embeddings/glove.6B.300d.txt"
glove_embeddings = load_glove_embeddings(glove_path, vocab)
torch.save(glove_embeddings, '/kaggle/working/glove_embeddings.pt')

First (image, caption) pair:
('1000268201_693b08cb0e.jpg', 'a child in a pink dress is climbing up a set of stairs in an entry way .')
Vocab size: 5156
Loaded GloVe vectors for 5089/5156 words


In [4]:
from PIL import Image
import torchvision.transforms as transforms

class ImageCaptionDataset(Dataset):
    def __init__(self, img_caption_pairs, vocab, image_root, transform=None):
        self.data = img_caption_pairs
        self.vocab = vocab
        self.image_root = image_root
        self.transform = transform

        self.pad_id = vocab["<pad>"]
        self.unk_id = vocab["<unk>"]
        self.bos_id = vocab["<bos>"]
        self.eos_id = vocab["<eos>"]

    def encode_caption(self, caption):
        tokens = tokenize(caption)
        ids = [self.vocab.get(t, self.unk_id) for t in tokens]
        return [self.bos_id] + ids + [self.eos_id]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name, caption = self.data[idx]

        image = Image.open(f"{self.image_root}/{img_name}").convert("RGB")
        if self.transform:
            image = self.transform(image)

        caption_ids = self.encode_caption(caption)
        caption_tensor = torch.tensor(caption_ids, dtype=torch.long)

        return image, caption_tensor, len(caption_tensor)

image_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

In [5]:
def collate_fn(batch):
    images, captions, lengths = zip(*batch)

    images = torch.stack(images)
    captions = pad_sequence(
        captions,
        batch_first=True,
        padding_value=vocab["<pad>"]
    )

    return images, captions, torch.tensor(lengths)

In [6]:
# train/val split

img_to_captions = defaultdict(list)

for img, caption in img_caption_pairs:
  img_to_captions[img].append(caption)

all_images = list(img_to_captions.keys())

train_images, val_images = train_test_split(
    all_images,
    test_size=0.2,
    random_state=SEED
)

train_pairs = []
val_pairs = []

for img in train_images:
  for caption in img_to_captions[img]:
    train_pairs.append((img, caption))

for img in val_images:
  for caption in img_to_captions[img]:
    val_pairs.append((img, caption))

In [7]:
image_root = '/kaggle/input/flickr8k/Images'

train_dataset = ImageCaptionDataset(train_pairs, vocab, image_root, transform=image_transform)
val_dataset   = ImageCaptionDataset(val_pairs, vocab, image_root, transform=image_transform)

train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    prefetch_factor=4,
    collate_fn=collate_fn,
    persistent_workers=True   # keeps workers alive between epochs
)

val_loader = DataLoader(
    val_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn,
    persistent_workers=True
)

images, captions, lengths = next(iter(train_loader))

print("Images:", images.shape)        # (B, 3, 256, 256)
print("Captions:", captions.shape)    # (B, max_len)
print("Lengths:", lengths)

Images: torch.Size([64, 3, 256, 256])
Captions: torch.Size([64, 21])
Lengths: tensor([16,  8, 12,  9, 16, 16, 13,  9, 11, 15,  8, 11, 11, 11, 19, 15, 10, 18,
        15, 21, 10,  9, 16, 19, 11, 10, 16, 14, 13,  9, 11, 14, 21, 10, 11, 14,
        15, 16,  7, 18, 18,  9, 11, 13, 14, 13, 10, 18, 11,  8, 14,  9, 14, 14,
        15,  8, 12, 11, 20, 17, 17, 12, 13, 11])


In [8]:
class ImgToCaptionModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=512, max_seq_len=50, pad_token_id=0):
        super().__init__()

        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.max_seq_len = max_seq_len
        self.pad_token_id = pad_token_id

        # 2D positional embeddings for image features
        self.row_embed = nn.Parameter(torch.randn(8, embed_dim) * 0.02)
        self.col_embed = nn.Parameter(torch.randn(8, embed_dim) * 0.02)
        
        # CNN ENCODER
        resnet = models.resnet50(pretrained=True)
        self.cnn = nn.Sequential(*list(resnet.children())[:-2])

        # Project visual features
        self.prep = nn.Sequential(
            nn.Linear(2048, 300),
            nn.LayerNorm(300),
            nn.ReLU(),
            nn.Dropout(0.1)
        )

        # TEXT EMBEDDING
        self.embedding = nn.Embedding.from_pretrained(
            glove_embeddings,
            freeze=True,
            padding_idx=pad_token_id
        )
        
        self.embed_dropout = nn.Dropout(0.1)

        # Positional encoding for captions 
        self.pos_encoding = nn.Parameter(
            torch.randn(max_seq_len, embed_dim) * 0.02
        )

        # TRANSFORMER DECODER
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=300,
            nhead=6,              # IMPORTANT: 300 % 6 == 0
            dim_feedforward=1024,
            batch_first=True
        )

        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layer,
            num_layers=3
        )

        self.fc_out = nn.Linear(embed_dim, vocab_size)

    def forward(self, images, captions):
        # IMAGE ENCODING
        img_features = self.cnn(images)               # (B, 2048, 8, 8)
        img_features = img_features.flatten(2)        # (B, 2048, 64)
        img_features = img_features.transpose(1, 2)   # (B, 64, 2048)
        img_features = self.prep(img_features)        # (B, 64, 512)                                
    
        # Add 2D positional encoding to image features
        pos = self.row_embed[:, None, :] + self.col_embed[None, :, :]
        pos = pos.reshape(64, -1)
        img_features = img_features + pos.unsqueeze(0)
        
        # TEXT EMBEDDING
        seq_len = captions.size(1)
        caption_embeds = self.embedding(captions)
        caption_embeds = self.embed_dropout(caption_embeds)
        caption_embeds = caption_embeds + self.pos_encoding[:seq_len]

        # MASKS
        tgt_mask = torch.triu(
            torch.ones(seq_len, seq_len, device=captions.device, dtype=torch.bool),
            diagonal=1
        )

        tgt_key_padding_mask = (captions == self.pad_token_id)
        
        # DECODER
        output = self.transformer_decoder(
            tgt=caption_embeds,
            memory=img_features,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )

        output = self.fc_out(output)

        return output

# Initialize the model
model = ImgToCaptionModel(
    vocab_size=len(vocab),
    embed_dim=300,
    max_seq_len=50,
    pad_token_id=vocab["<pad>"]
).to(device)

model = model.to(device)

# number of total params
total_params = sum(p.numel() for p in model.parameters())
print(f"Total Parameters: {total_params:,}")



Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 97.8M/97.8M [00:00<00:00, 145MB/s]


Total Parameters: 31,261,660


In [9]:
# perplexity metric calculation (clearer indicator than just raw loss function)
def perplexity_from_loss(loss):
    return math.exp(loss) if loss < 20 else float("inf")

In [10]:
num_epochs = 15                   
learning_rate = 1e-4  

pad_token_id = 0
start_token_id = 2  
end_token_id = 3

criterion = nn.CrossEntropyLoss(ignore_index=pad_token_id, label_smoothing=0.05)

# Separate CNN and other parameters for different learning rates
cnn_params = list(model.cnn.parameters())
other_params = [p for n, p in model.named_parameters() if not n.startswith("cnn")]

optimizer = optim.AdamW([
    {'params': cnn_params, 'lr': 1e-5},        # slow LR for pretrained CNN
    {'params': other_params, 'lr': 1e-4}       # higher LR for transformer & prep layers
], weight_decay=1e-4)

# Learning rate scheduler 
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=3,
    min_lr=1e-6
)

# Mixed precision scaler
scaler = GradScaler()

# For saving best model
best_val_loss = float('inf')

# Training loop
for epoch in range(num_epochs):
    # ========== TRAIN ==========
    model.train()
    train_loss = 0.0

    for images, captions, lengths in train_loader:
        images = images.to(device)
        captions = captions.to(device)

        inputs = captions[:, :-1]
        targets = captions[:, 1:]

        optimizer.zero_grad()

        with autocast(device_type='cuda'):
            outputs = model(images, inputs)
            loss = criterion(
                outputs.reshape(-1, outputs.size(-1)),
                targets.reshape(-1)
            )

        # Backprop with AMP
        scaler.scale(loss).backward()

        # Unscale before clipping
        scaler.unscale_(optimizer)
        clip_grad_norm_(model.parameters(), max_norm=1.0)

        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)
    train_ppl = perplexity_from_loss(avg_train_loss)

    # ========== VALIDATION ==========
    model.eval()
    val_loss = 0.0
    bleu_scores = []

    with torch.no_grad():
        for images, captions, lengths in val_loader:
            images = images.to(device)
            captions = captions.to(device)

            inputs = captions[:, :-1]
            targets = captions[:, 1:]

            with autocast(device_type='cuda'):
                outputs = model(images, inputs)
                loss = criterion(
                    outputs.reshape(-1, outputs.size(-1)),
                    targets.reshape(-1)
                )

            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    val_ppl = perplexity_from_loss(avg_val_loss)

    scheduler.step(avg_val_loss)
    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), '/kaggle/working/image_caption_model_cnnfrozen_best.pth')
        print("âœ… Best model saved!")

    print(f"\nðŸ“Š Epoch [{epoch+1}/{num_epochs}]")
    print(f"   Train Loss: {avg_train_loss:.4f}")
    print(f"   Train PPL:  {train_ppl:.2f}")
    print(f"   Val Loss:   {avg_val_loss:.4f}")
    print(f"   Val PPL:    {val_ppl:.2f}")
    print(f"   LR (CNN):      {optimizer.param_groups[0]['lr']:.6f}")
    print(f"   LR (Transformer): {optimizer.param_groups[1]['lr']:.6f}")

print("\nâœ… Training complete!")

âœ… Best model saved!

ðŸ“Š Epoch [1/15]
   Train Loss: 4.7184
   Train PPL:  111.99
   Val Loss:   4.0327
   Val PPL:    56.41
   LR (CNN):      0.000010
   LR (Transformer): 0.000100
âœ… Best model saved!

ðŸ“Š Epoch [2/15]
   Train Loss: 3.8730
   Train PPL:  48.08
   Val Loss:   3.7273
   Val PPL:    41.56
   LR (CNN):      0.000010
   LR (Transformer): 0.000100
âœ… Best model saved!

ðŸ“Š Epoch [3/15]
   Train Loss: 3.6124
   Train PPL:  37.06
   Val Loss:   3.5654
   Val PPL:    35.36
   LR (CNN):      0.000010
   LR (Transformer): 0.000100
âœ… Best model saved!

ðŸ“Š Epoch [4/15]
   Train Loss: 3.4461
   Train PPL:  31.38
   Val Loss:   3.4749
   Val PPL:    32.29
   LR (CNN):      0.000010
   LR (Transformer): 0.000100
âœ… Best model saved!

ðŸ“Š Epoch [5/15]
   Train Loss: 3.3214
   Train PPL:  27.70
   Val Loss:   3.3995
   Val PPL:    29.95
   LR (CNN):      0.000010
   LR (Transformer): 0.000100
âœ… Best model saved!

ðŸ“Š Epoch [6/15]
   Train Loss: 3.2215
   Train PPL:  2

In [11]:
import shutil

# Save model
torch.save(model.state_dict(), '/kaggle/working/image_caption_model_cnnfrozen_last.pth')

In [12]:
import pickle

with open('/kaggle/working/vocab.pkl', 'wb') as f:
    pickle.dump(vocab, f)
