In [1]:
# ==========================================
# 1. SETUP & IMPORTS
# ==========================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.amp import autocast, GradScaler
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
import torchvision.models as models

import pandas as pd
import numpy as np
import random
import math
import re
from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

from PIL import Image
import torchvision.transforms as transforms
from collections import defaultdict

In [2]:
# BEST PRACTICE: Set seeds for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.backends.cudnn.deterministic = True

# Device configuration (GPU if available, else CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [3]:
import sentencepiece as spm

file_path = '/kaggle/input/flickr8k/captions.txt'

# read file
img_caption_pairs = []

with open(file_path, 'r', encoding='utf-8') as f:
    lines = [line.strip() for line in f if line.strip()]

# Remove header
lines = lines[1:]

for line in lines:
    img, caption = line.split(',', 1)
    img_caption_pairs.append((img, caption.lower()))

print("First (image, caption) pair:")
print(img_caption_pairs[0])

# save only captions for tokenizer
captions_file = '/kaggle/working/captions_clean.txt'

with open(captions_file, 'w', encoding='utf-8') as f:
    for _, caption in img_caption_pairs:
        f.write(caption + '\n')

# train tokenizer
spm.SentencePieceTrainer.train(
    input=captions_file,
    model_prefix='/kaggle/working/spm',
    vocab_size=8000,
    model_type='bpe',
    pad_id=0,
    unk_id=1,
    bos_id=2,
    eos_id=3
)

# load tokenizer
sp = spm.SentencePieceProcessor()
sp.load('/kaggle/working/spm.model')

# building vocabulary
vocab = {sp.id_to_piece(i): i for i in range(sp.get_piece_size())}

print("Vocabulary size:", len(vocab))
print("Special tokens:")
print({k: v for k, v in vocab.items() if k in ["<pad>", "<unk>", "<s>", "</s>"]})

# Example: subword tokenization of first caption
first_caption = img_caption_pairs[0][1]

subword_tokens = sp.encode(first_caption, out_type=str)
subword_ids = sp.encode(first_caption, out_type=int)

print("\nFirst caption:")
print(first_caption)

print("\nSubword tokens:")
print(subword_tokens)

print("\nSubword token IDs:")
print(subword_ids)


First (image, caption) pair:
('1000268201_693b08cb0e.jpg', 'a child in a pink dress is climbing up a set of stairs in an entry way .')
Vocabulary size: 8000
Special tokens:
{'<pad>': 0, '<unk>': 1, '<s>': 2, '</s>': 3}

First caption:
a child in a pink dress is climbing up a set of stairs in an entry way .

Subword tokens:
['‚ñÅa', '‚ñÅchild', '‚ñÅin', '‚ñÅa', '‚ñÅpink', '‚ñÅdress', '‚ñÅis', '‚ñÅclimbing', '‚ñÅup', '‚ñÅa', '‚ñÅset', '‚ñÅof', '‚ñÅstairs', '‚ñÅin', '‚ñÅan', '‚ñÅent', 'ry', '‚ñÅway', '‚ñÅ.']

Subword token IDs:
[4, 128, 15, 4, 325, 270, 40, 414, 207, 4, 719, 46, 1045, 15, 135, 1879, 715, 1603, 7]


sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : 
trainer_spec {
  input: /kaggle/working/captions_clean.txt
  input_format: 
  model_prefix: /kaggle/working/spm
  model_type: BPE
  vocab_size: 8000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  seed_sentencepieces_file: 
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 1
  bos_id: 2
  eos_id: 3
  pad_id: 0
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piece: <pad>
  unk_surface:  ‚Åá 
  enable_d

In [4]:
from PIL import Image
import torchvision.transforms as transforms

class ImageCaptionDataset(Dataset):
    def __init__(self, img_caption_pairs, sp, image_root, transform=None):
        self.data = img_caption_pairs
        self.sp = sp # tokenizer
        self.image_root = image_root # path where the images are
        self.transform = transform

        self.bos_id = sp.bos_id()
        self.eos_id = sp.eos_id()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name, caption = self.data[idx]

        # ---- Load image ----
        img_path = f"{self.image_root}/{img_name}"
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        # ---- Tokenize caption ----
        caption_ids = self.sp.encode(caption, out_type=int)

        # Add <bos> and <eos>
        caption_ids = [self.bos_id] + caption_ids + [self.eos_id]

        caption_tensor = torch.tensor(caption_ids, dtype=torch.long)

        return image, caption_tensor, len(caption_tensor)

image_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

In [5]:
def collate_fn(batch):
    images, captions, lengths = zip(*batch)

    images = torch.stack(images, dim=0)

    captions_padded = pad_sequence(
        captions,
        batch_first=True,
        padding_value=sp.pad_id()
    )

    lengths = torch.tensor(lengths)

    return images, captions_padded, lengths


In [6]:
# train/val split

img_to_captions = defaultdict(list)

for img, caption in img_caption_pairs:
  img_to_captions[img].append(caption)

all_images = list(img_to_captions.keys())

train_images, val_images = train_test_split(
    all_images,
    test_size=0.2,
    random_state=SEED
)

train_pairs = []
val_pairs = []

for img in train_images:
  for caption in img_to_captions[img]:
    train_pairs.append((img, caption))

for img in val_images:
  for caption in img_to_captions[img]:
    val_pairs.append((img, caption))

In [7]:
image_root = '/kaggle/input/flickr8k/Images'

train_dataset = ImageCaptionDataset(train_pairs, sp, image_root, transform=image_transform)
val_dataset   = ImageCaptionDataset(val_pairs, sp, image_root, transform=image_transform)

train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    prefetch_factor=4,
    collate_fn=collate_fn,
    persistent_workers=True   # keeps workers alive between epochs
)

val_loader = DataLoader(
    val_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn,
    persistent_workers=True
)

images, captions, lengths = next(iter(train_loader))

print("Images:", images.shape)        # (B, 3, 256, 256)
print("Captions:", captions.shape)    # (B, max_len)
print("Lengths:", lengths)

Images: torch.Size([64, 3, 256, 256])
Captions: torch.Size([64, 24])
Lengths: tensor([17,  9, 13, 10, 17, 17, 14,  9, 12, 16,  9, 12, 12, 12, 20, 16, 11, 20,
        16, 22, 11, 10, 19, 21, 12, 11, 20, 17, 14, 10, 12, 15, 24, 10, 12, 14,
        16, 17,  7, 21, 19, 10, 12, 14, 15, 14, 11, 20, 12,  9, 15,  9, 15, 15,
        19,  9, 13, 12, 21, 18, 18, 13, 14, 12])


In [8]:
class ImgToCaptionModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=512, max_seq_len=50, pad_token_id=0):
        super().__init__()

        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.max_seq_len = max_seq_len
        self.pad_token_id = pad_token_id
        
        # CNN ENCODER (pretrained ResNet-50)
        # architecture: conv1, layer1, layer2, layer3, layer3 => 
        # 64 -> 256 -> 512 -> 1024 -> 2048 output
        resnet = models.resnet50(pretrained=True)
        self.cnn = nn.Sequential(*list(resnet.children())[:-2])  # (B, 2048, 8, 8)

        # Freeze all
        for p in self.cnn.parameters():
            p.requires_grad = False

        # Unfreeze entire layer4
        for p in self.cnn[-2].parameters():
            p.requires_grad = True

        # Project visual features
        self.prep = nn.Sequential(
            nn.Linear(2048, embed_dim),
            nn.LayerNorm(embed_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        # TEXT EMBEDDING
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.embed_dropout = nn.Dropout(0.3)

        self.pos_encoding = nn.Parameter(
            torch.randn(max_seq_len, embed_dim) * 0.02
        )

        # TRANSFORMER DECODER
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_dim,
            nhead=8,
            dim_feedforward=1024,
            dropout=0.3,
            batch_first=True
        )

        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layer,
            num_layers=3
        )

        self.fc_out = nn.Linear(embed_dim, vocab_size)

    def forward(self, images, captions):
        # IMAGE ENCODING
        img_features = self.cnn(images)               # (B, 2048, 8, 8)
        img_features = img_features.flatten(2)        # (B, 2048, 64)
        img_features = img_features.transpose(1, 2)   # (B, 64, 2048)
        img_features = self.prep(img_features)        # (B, 64, 512)
        
        # TEXT EMBEDDING
        seq_len = captions.size(1)
        caption_embeds = self.embedding(captions)
        caption_embeds = self.embed_dropout(caption_embeds)
        caption_embeds = caption_embeds + self.pos_encoding[:seq_len]

        # MASKS
        tgt_mask = torch.triu(
            torch.ones(seq_len, seq_len, device=captions.device, dtype=torch.bool),
            diagonal=1
        )

        tgt_key_padding_mask = (captions == self.pad_token_id)

        # DECODER
        output = self.transformer_decoder(
            tgt=caption_embeds,
            memory=img_features,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )

        output = self.fc_out(output)

        return output

# Initialize the model
vocab_size = 8000
model = ImgToCaptionModel(vocab_size=vocab_size, embed_dim=512, max_seq_len=50, pad_token_id=0)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total Parameters: {total_params:,}")




Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 97.8M/97.8M [00:00<00:00, 138MB/s]


Total Parameters: 42,247,040


In [9]:
# perplexity metric calculation (better indicator than just raw loss function)

def perplexity_from_loss(loss):
    return math.exp(loss) if loss < 20 else float("inf")


In [10]:
num_epochs = 15                   
learning_rate = 1e-4              
pad_token_id = 0

criterion = nn.CrossEntropyLoss(ignore_index=pad_token_id, label_smoothing=0.05)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-3)

# Learning rate scheduler 
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=3,
    min_lr=1e-6
)

# Mixed precision scaler
scaler = GradScaler()

# Training loop
for epoch in range(num_epochs):

    # ========== TRAIN ==========
    model.train()
    train_loss = 0.0

    for images, captions, lengths in train_loader:
        images = images.to(device)
        captions = captions.to(device)

        inputs = captions[:, :-1]
        targets = captions[:, 1:]

        optimizer.zero_grad()

        with autocast(device_type='cuda'):
            outputs = model(images, inputs)
            loss = criterion(
                outputs.reshape(-1, outputs.size(-1)),
                targets.reshape(-1)
            )

        # Backprop with AMP
        scaler.scale(loss).backward()

        # Unscale before clipping
        scaler.unscale_(optimizer)
        clip_grad_norm_(model.parameters(), max_norm=1.0)

        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)
    train_ppl = perplexity_from_loss(avg_train_loss)

    # ========== VALIDATION ==========
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for images, captions, lengths in val_loader:
            images = images.to(device)
            captions = captions.to(device)

            inputs = captions[:, :-1]
            targets = captions[:, 1:]

            with autocast(device_type='cuda'):
                outputs = model(images, inputs)
                loss = criterion(
                    outputs.reshape(-1, outputs.size(-1)),
                    targets.reshape(-1)
                )

            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    val_ppl = perplexity_from_loss(avg_val_loss)

    scheduler.step(avg_val_loss)

    print(f"\nüìä Epoch [{epoch+1}/{num_epochs}]")
    print(f"   Train Loss: {avg_train_loss:.4f}")
    print(f"   Train PPL:  {train_ppl:.2f}")
    print(f"   Val Loss:   {avg_val_loss:.4f}")
    print(f"   Val PPL:    {val_ppl:.2f}")
    print(f"   LR:         {optimizer.param_groups[0]['lr']:.6f}")

print("\n‚úÖ Training complete!")



üìä Epoch [1/15]
   Train Loss: 4.6470
   Train PPL:  104.27
   Val Loss:   3.9736
   Val PPL:    53.18
   LR:         0.000100

üìä Epoch [2/15]
   Train Loss: 3.8152
   Train PPL:  45.39
   Val Loss:   3.7186
   Val PPL:    41.21
   LR:         0.000100

üìä Epoch [3/15]
   Train Loss: 3.5376
   Train PPL:  34.38
   Val Loss:   3.6022
   Val PPL:    36.68
   LR:         0.000100

üìä Epoch [4/15]
   Train Loss: 3.3500
   Train PPL:  28.50
   Val Loss:   3.5309
   Val PPL:    34.16
   LR:         0.000100

üìä Epoch [5/15]
   Train Loss: 3.2094
   Train PPL:  24.76
   Val Loss:   3.4883
   Val PPL:    32.73
   LR:         0.000100

üìä Epoch [6/15]
   Train Loss: 3.0965
   Train PPL:  22.12
   Val Loss:   3.4571
   Val PPL:    31.73
   LR:         0.000100

üìä Epoch [7/15]
   Train Loss: 2.9992
   Train PPL:  20.07
   Val Loss:   3.4433
   Val PPL:    31.29
   LR:         0.000100

üìä Epoch [8/15]
   Train Loss: 2.9175
   Train PPL:  18.49
   Val Loss:   3.4277
   Val PPL: 

In [11]:
import shutil

# Save model
torch.save(model.state_dict(), '/kaggle/working/image_caption_model.pth')