In [1]:
# ==========================================
# 1. SETUP & IMPORTS
# ==========================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.amp import autocast, GradScaler
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
import torchvision.models as models

import pandas as pd
import numpy as np
import random
import math
import re
import shutil
import pickle
import json
from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

from PIL import Image
import torchvision.transforms as transforms
from collections import defaultdict

In [2]:
# BEST PRACTICE: Set seeds for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.backends.cudnn.deterministic = True

# Device configuration (GPU if available, else CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [3]:
file_path = '/kaggle/input/flickr8k/captions.txt'

img_caption_pairs = []
with open(file_path, 'r', encoding='utf-8') as f:
    lines = [line.strip() for line in f if line.strip()]

lines = lines[1:]  # skipped header

for line in lines:
    img, caption = line.split(',', 1)
    img_caption_pairs.append((img, caption.lower()))

print("First (image, caption) pair:")
print(img_caption_pairs[0])

# 2. tokenizer (word-level) for Glove embeddings
def tokenize(caption):
    return re.findall(r"[a-z0-9]+", caption.lower())

# Building world-level vocabulary from captions we have
all_words = []
for _, caption in img_caption_pairs:
    all_words.extend(tokenize(caption))

word_counts = Counter(all_words)

print(f"\nTotal word tokens: {len(all_words):,}")
print(f"Unique words: {len(word_counts):,}")

# Add special tokens
special_tokens = ['<pad>', '<unk>', '<s>', '</s>']
vocab_words = special_tokens + list(word_counts.keys())

word2idx = {w: i for i, w in enumerate(vocab_words)}
idx2word = {i: w for w, i in word2idx.items()}

print(f"Final vocabulary size (with special tokens): {len(word2idx):,}")

# Loading GloVe 300d pretrained embeddings
glove_path = '/kaggle/input/glove-embeddings/glove.6B.300d.txt'

glove_dict = {}
with open(glove_path, 'r', encoding='utf-8') as f:
    for line in f:
        values = line.split()
        word = values[0]
        vector = np.asarray(values[1:], dtype=np.float32)
        glove_dict[word] = vector

print(f"\nLoaded GloVe vectors: {len(glove_dict):,}")

# Computing GloVe coverage of our words we have in vocab
real_vocab = [w for w in vocab_words if w not in special_tokens]

covered_words = []
missing_words = []

for w in real_vocab:
    if w in glove_dict:
        covered_words.append(w)
    else: missing_words.append(w)

coverage_pct = len(covered_words) / len(real_vocab) * 100

print(f"\nWords in vocab (incl. specials): {len(vocab_words):,}")
print(f"Words checked (no specials): {len(real_vocab):,}")
print(f"Words found in GloVe: {len(covered_words):,}")
print(f"Words missing from GloVe: {len(missing_words):,}")
print(f"GloVe coverage: {coverage_pct:.2f}%")

print("\nExamples of missing words:")
print(missing_words[:20])



First (image, caption) pair:
('1000268201_693b08cb0e.jpg', 'a child in a pink dress is climbing up a set of stairs in an entry way .')

Total word tokens: 437,601
Unique words: 8,488
Final vocabulary size (with special tokens): 8,492

Loaded GloVe vectors: 400,000

Words in vocab (incl. specials): 8,492
Words checked (no specials): 8,488
Words found in GloVe: 7,828
Words missing from GloVe: 660
GloVe coverage: 92.22%

Examples of missing words:
['fingerpaints', 'aross', 'belays', 'frolicks', 'moutains', 'magizine', 'overshirt', 'anouther', 'jumphouse', 'rappels', 'rappeling', 'barrior', 'torwards', 'bloe', 'inground', 'litlle', 'colred', 'carying', 'wakeboarder', 'waterskier']


In [4]:
def create_embedding_matrix(word2idx, glove_dict, embed_dim=300):
    vocab_size = len(word2idx)
    embedding_matrix = np.random.randn(vocab_size, embed_dim).astype('float32') * 0.02

    found = 0
    for word, idx in word2idx.items():
        if word in glove_dict:
            embedding_matrix[idx] = glove_dict[word]
            found += 1

    # pad token â†’ zero vector
    embedding_matrix[word2idx['<pad>']] = np.zeros(embed_dim)

    real_tokens = [w for w in word2idx if w not in ['<pad>', '<unk>', '<s>', '</s>']]
    found = sum(1 for w in real_tokens if w in glove_dict)

    print(f"âœ… Mapped {found}/{len(real_tokens)} words "
          f"({100*found/len(real_tokens):.1f}% coverage)")

    return torch.tensor(embedding_matrix)

embedding_matrix = create_embedding_matrix(word2idx, glove_dict, embed_dim=300)


âœ… Mapped 7828/8488 words (92.2% coverage)


In [5]:
# Save word2idx
with open('/kaggle/working/word2idx.json', 'w', encoding='utf-8') as f:
    json.dump(word2idx, f, ensure_ascii=False)

# Save idx2word
with open('/kaggle/working/idx2word.json', 'w', encoding='utf-8') as f:
    json.dump(idx2word, f, ensure_ascii=False)

# Save special tokens indices
special_token_ids = {
    'pad': word2idx['<pad>'],
    'unk': word2idx['<unk>'],
    'start': word2idx['<s>'],
    'end': word2idx['</s>']
}
with open('/kaggle/working/special_tokens.json', 'w', encoding='utf-8') as f:
    json.dump(special_token_ids, f, ensure_ascii=False)

# Save GloVe embedding matrix for inference
np.save('/kaggle/working/embedding_matrix.npy', embedding_matrix.cpu().numpy())

In [6]:
def encode_caption(caption, word2idx, max_len=50):
    """Convert caption to token IDs"""
    tokens = tokenize(caption)
    token_ids = [word2idx['<s>']]  # Start token
    
    for token in tokens:
        token_ids.append(word2idx.get(token, word2idx['<unk>']))
    
    token_ids.append(word2idx['</s>'])  # End token
    
    # Truncate if too long
    if len(token_ids) > max_len:
        token_ids = token_ids[:max_len-1] + [word2idx['</s>']]
    
    return token_ids

def decode_caption(token_ids, idx2word):
    """Convert token IDs back to text"""
    words = []
    for idx in token_ids:
        word = idx2word.get(idx, '<unk>')
        if word in ['<pad>', '<s>', '</s>']:
            continue
        words.append(word)
    return ' '.join(words)

# Test  tokenization (just in case)
test_caption = img_caption_pairs[0][1]
print("Test caption:")
print(f"  Original: {test_caption}")

encoded = encode_caption(test_caption, word2idx)
print(f"  Encoded:  {encoded}")

decoded = decode_caption(encoded, idx2word)
print(f"  Decoded:  {decoded}")


Test caption:
  Original: a child in a pink dress is climbing up a set of stairs in an entry way .
  Encoded:  [2, 4, 5, 6, 4, 7, 8, 9, 10, 11, 4, 12, 13, 14, 6, 15, 16, 17, 3]
  Decoded:  a child in a pink dress is climbing up a set of stairs in an entry way


In [7]:
class ImageCaptionDataset(Dataset):
    def __init__(self, img_caption_pairs, word2idx, image_root, transform=None):
        self.data = img_caption_pairs
        self.word2idx = word2idx
        self.image_root = image_root
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name, caption = self.data[idx]

        # Load image
        img_path = f"{self.image_root}/{img_name}"
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        # Tokenize caption
        caption_ids = encode_caption(caption, self.word2idx, max_len=50)
        caption_tensor = torch.tensor(caption_ids, dtype=torch.long)

        return image, caption_tensor, len(caption_tensor)
        

In [8]:
def collate_fn(batch):
    images, captions, lengths = zip(*batch)

    images = torch.stack(images, dim=0)

    captions_padded = pad_sequence(
        captions,
        batch_first=True,
        padding_value=word2idx['<pad>']
    )

    lengths = torch.tensor(lengths)

    return images, captions_padded, lengths


In [9]:
# train/val split

img_to_captions = defaultdict(list)

for img, caption in img_caption_pairs:
  img_to_captions[img].append(caption)

all_images = list(img_to_captions.keys())

train_images, val_images = train_test_split(
    all_images,
    test_size=0.2,
    random_state=SEED
)

train_pairs = []
val_pairs = []

for img in train_images:
  for caption in img_to_captions[img]:
    train_pairs.append((img, caption))

for img in val_images:
  for caption in img_to_captions[img]:
    val_pairs.append((img, caption))

image_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])  # ImageNet stats for pretrained CNN
])


In [10]:
image_root = '/kaggle/input/flickr8k/Images'

train_dataset = ImageCaptionDataset(train_pairs, word2idx, image_root, transform=image_transform)
val_dataset = ImageCaptionDataset(val_pairs, word2idx, image_root, transform=image_transform)

train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    prefetch_factor=4,
    collate_fn=collate_fn,
    persistent_workers=True   # keeps workers alive between epochs
)

val_loader = DataLoader(
    val_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn,
    persistent_workers=True
)

images, captions, lengths = next(iter(train_loader))

print("Images:", images.shape)        # (B, 3, 256, 256)
print("Captions:", captions.shape)    # (B, max_len)
print("Lengths:", lengths)

Images: torch.Size([64, 3, 256, 256])
Captions: torch.Size([64, 21])
Lengths: tensor([16,  8, 12,  9, 16, 16, 13,  9, 11, 15,  8, 11, 11, 11, 19, 15, 10, 18,
        15, 21, 10,  9, 16, 19, 11, 10, 16, 14, 13,  9, 11, 14, 21, 10, 11, 14,
        15, 16,  7, 18, 18,  9, 11, 13, 14, 13, 10, 18, 11,  8, 14,  9, 14, 14,
        15,  8, 12, 11, 20, 17, 17, 12, 13, 11])


In [11]:
class ImgToCaptionModel(nn.Module):
    def __init__(self, embedding_matrix, embed_dim=300, hidden_dim=512, max_seq_len=50, pad_token_id=0):
        super().__init__()
        vocab_size, embed_dim = embedding_matrix.shape
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.max_seq_len = max_seq_len
        self.pad_token_id = pad_token_id
        
        # IMAGE ENCODER (Pretrained CNN)
        resnet = models.resnet50(pretrained=True)
        self.cnn = nn.Sequential(*list(resnet.children())[:-2])
        
        # FREEZE pretrained CNN
        for param in self.cnn.parameters():
            param.requires_grad = False
        
        # IMAGE PROJECTION
        self.img_proj = nn.Sequential(
            nn.Linear(2048, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        
        # 2D positional embeddings for 8x8 feature map
        self.img_pos_embed = nn.Parameter(torch.randn(1, 64, hidden_dim) * 0.02)
        
        # TEXT DECODER (GloVe Embeddings)
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.embedding.weight = nn.Parameter(embedding_matrix.clone())
        self.embedding.weight.requires_grad = False
        
        # TEXT PROJECTION
        self.word_proj = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(0.1)
        )
        
        # Learnable positional embeddings
        self.pos_encoding = nn.Parameter(torch.randn(max_seq_len, hidden_dim) * 0.02)
        
        # TRANSFORMER DECODER
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=8,
            dim_feedforward=2048,
            dropout=0.1,
            batch_first=True
        )
        
        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layer,
            num_layers=3
        )
        
        # Output projection
        self.fc_out = nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, images, captions):
        # IMAGE ENCODING (Frozen CNN)
        with torch.no_grad():
            img_features = self.cnn(images)  # (B, 2048, 8, 8)
        
        img_features = img_features.flatten(2).permute(0, 2, 1)  # (B, 64, 2048)
        img_features = self.img_proj(img_features)  # (B, 64, hidden_dim)
        img_features = img_features + self.img_pos_embed 
        
        # TEXT ENCODING (Frozen GloVe)
        seq_len = captions.size(1)
        caption_embeds = self.embedding(captions)  # (B, seq_len, 300)
        caption_embeds = self.word_proj(caption_embeds)  # (B, seq_len, hidden_dim)
        caption_embeds = caption_embeds + self.pos_encoding[:seq_len].unsqueeze(0)
        
        # MASKS
        tgt_mask = torch.triu(
            torch.ones(seq_len, seq_len, device=captions.device, dtype=torch.bool),
            diagonal=1
        )
        
        tgt_key_padding_mask = (captions == self.pad_token_id)
        
        # DECODER
        output = self.transformer_decoder(
            tgt=caption_embeds,
            memory=img_features,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )
        
        logits = self.fc_out(output)
        return logits


# initialize model
model = ImgToCaptionModel(
    embedding_matrix=embedding_matrix,  # GloVe matrix
    embed_dim=300,                      # GloVe dimension
    hidden_dim=512,                     # Transformer hidden size
    max_seq_len=50,
    pad_token_id=word2idx['<pad>']
)

model = model.to(device)

# counting total params
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
frozen_params = total_params - trainable_params

print(f"Total Parameters: {total_params:,}")
print(f"Trainable Parameters: {trainable_params:,}")
print(f"Frozen Parameters: {frozen_params:,}")




Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 97.8M/97.8M [00:00<00:00, 144MB/s]


Total Parameters: 44,287,740
Trainable Parameters: 18,232,108
Frozen Parameters: 26,055,632


In [12]:
# perplexity metric calculation (clearer indicator than just raw loss function)
def perplexity_from_loss(loss):
    return math.exp(min(loss, 100))
    

In [13]:
num_epochs = 15                   
learning_rate = 1e-4  

pad_token_id = word2idx['<pad>']
start_token_id = word2idx['<s>']
end_token_id = word2idx['</s>']

criterion = nn.CrossEntropyLoss(ignore_index=pad_token_id, label_smoothing=0.05)

# Filtering out frozen parameters
trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.AdamW(trainable_params, lr=learning_rate, weight_decay=1e-4)

# Learning rate scheduler 
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=3,
    min_lr=1e-6
)

# Mixed precision scaler
scaler = GradScaler()

# For saving best model
best_val_loss = float('inf')

# Training loop
for epoch in range(num_epochs):
    # ========== TRAIN ==========
    model.train()
    train_loss = 0.0

    for images, captions, lengths in train_loader:
        images = images.to(device)
        captions = captions.to(device)

        inputs = captions[:, :-1]
        targets = captions[:, 1:]

        optimizer.zero_grad()

        with autocast(device_type='cuda'):
            outputs = model(images, inputs)
            loss = criterion(
                outputs.reshape(-1, outputs.size(-1)),
                targets.reshape(-1)
            )

        # Backprop with AMP
        scaler.scale(loss).backward()

        # Unscale before clipping
        scaler.unscale_(optimizer)
        clip_grad_norm_(model.parameters(), max_norm=1.0)

        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)
    train_ppl = perplexity_from_loss(avg_train_loss)

    # ========== VALIDATION ==========
    model.eval()
    val_loss = 0.0
    bleu_scores = []

    with torch.no_grad():
        for images, captions, lengths in val_loader:
            images = images.to(device)
            captions = captions.to(device)

            inputs = captions[:, :-1]
            targets = captions[:, 1:]

            with autocast(device_type='cuda'):
                outputs = model(images, inputs)
                loss = criterion(
                    outputs.reshape(-1, outputs.size(-1)),
                    targets.reshape(-1)
                )

            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    val_ppl = perplexity_from_loss(avg_val_loss)

    scheduler.step(avg_val_loss)
    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), '/kaggle/working/image_caption_model_best.pth')
        print("âœ… Best model saved!")

    print(f"\nðŸ“Š Epoch [{epoch+1}/{num_epochs}]")
    print(f"   Train Loss: {avg_train_loss:.4f}")
    print(f"   Train PPL:  {train_ppl:.2f}")
    print(f"   Val Loss:   {avg_val_loss:.4f}")
    print(f"   Val PPL:    {val_ppl:.2f}")
    print(f"   LR:         {optimizer.param_groups[0]['lr']:.6f}")

print("\nâœ… Training complete!")


âœ… Best model saved!

ðŸ“Š Epoch [1/15]
   Train Loss: 4.4627
   Train PPL:  86.72
   Val Loss:   3.8631
   Val PPL:    47.61
   LR:         0.000100
âœ… Best model saved!

ðŸ“Š Epoch [2/15]
   Train Loss: 3.6586
   Train PPL:  38.81
   Val Loss:   3.5805
   Val PPL:    35.89
   LR:         0.000100
âœ… Best model saved!

ðŸ“Š Epoch [3/15]
   Train Loss: 3.3920
   Train PPL:  29.72
   Val Loss:   3.4538
   Val PPL:    31.62
   LR:         0.000100
âœ… Best model saved!

ðŸ“Š Epoch [4/15]
   Train Loss: 3.2099
   Train PPL:  24.78
   Val Loss:   3.3797
   Val PPL:    29.36
   LR:         0.000100
âœ… Best model saved!

ðŸ“Š Epoch [5/15]
   Train Loss: 3.0645
   Train PPL:  21.42
   Val Loss:   3.3256
   Val PPL:    27.82
   LR:         0.000100
âœ… Best model saved!

ðŸ“Š Epoch [6/15]
   Train Loss: 2.9402
   Train PPL:  18.92
   Val Loss:   3.3028
   Val PPL:    27.19
   LR:         0.000100
âœ… Best model saved!

ðŸ“Š Epoch [7/15]
   Train Loss: 2.8325
   Train PPL:  16.99
   Val Los

In [14]:
# save model after 15 epochs
torch.save(model.state_dict(), '/kaggle/working/image_caption_model_after.pth')