
# Imports and HyperParametres

In [None]:
import os
import re
import json

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt
from tqdm import tqdm

try:
    from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction
    import nltk
    BLEU_AVAILABLE = True
except ImportError:
    print("NLTK not installed.")
    BLEU_AVAILABLE = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

In [None]:
import os

IS_KAGGLE = os.path.exists('/kaggle/input')

if IS_KAGGLE:
    DATA_DIR = ('/kaggle/input/deeplearn/caption_data')
    IMAGES_DIR = f'{DATA_DIR}/Images'
    CAPTIONS_FILE = f'{DATA_DIR}/captions.txt'
    SAVE_DIR = '/kaggle/working'
else:
    DATA_DIR = './caption_data'
    IMAGES_DIR = f'{DATA_DIR}/Images'
    CAPTIONS_FILE = f'{DATA_DIR}/captions.txt'
    SAVE_DIR = '.'

EMBED_DIM = 256
HIDDEN_DIM = 768
ATTENTION_DIM = 512
ENCODER_DIM = 2048
DROPOUT = 0.5

BATCH_SIZE = 64
NUM_EPOCHS = 17
LEARNING_RATE = 3e-4
ENCODER_LR = 1e-4
FINE_TUNE_ENCODER = True
FINE_TUNE_AFTER = 5

MIN_WORD_FREQ = 5
MAX_CAPTION_LEN = 50

TRAIN_SIZE = 6000
VAL_SIZE = 1000
TEST_SIZE = 1000

print(f"Save directory: {SAVE_DIR}")
print(f"Data directory: {DATA_DIR}")

## Define Classes and Functions


In [3]:
def clean_caption(text):
    """Clean caption: lowercase, remove punctuation/numbers/single chars."""
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    text = re.sub(r'\d+', '', text)
    words = [w for w in text.split() if len(w) > 1 or w in ['a', 'i']]
    return ' '.join(words).strip()

class Vocabulary:
    """Word-to-index mappings with special tokens: PAD, SOS, EOS, UNK."""
    
    def __init__(self, min_freq=5):
        self.min_freq = min_freq
        self.word2idx = {}
        self.idx2word = {}
        self.word_freq = Counter()
        self.PAD_TOKEN = '<PAD>'
        self.SOS_TOKEN = '<SOS>'
        self.EOS_TOKEN = '<EOS>'
        self.UNK_TOKEN = '<UNK>'
        self._init_special_tokens()
    
    def _init_special_tokens(self):
        for idx, token in enumerate([self.PAD_TOKEN, self.SOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]):
            self.word2idx[token] = idx
            self.idx2word[idx] = token
    
    def build_vocabulary(self, captions):
        for caption in tqdm(captions, desc="Building vocab"):
            self.word_freq.update(caption.split())
        
        idx = len(self.word2idx)
        for word, freq in self.word_freq.items():
            if freq >= self.min_freq and word not in self.word2idx:
                self.word2idx[word] = idx
                self.idx2word[idx] = word
                idx += 1
        
        print(f"Vocabulary: {len(self.word2idx)} words (filtered {sum(1 for f in self.word_freq.values() if f < self.min_freq)} rare words)")
    
    def encode(self, caption):
        caption = clean_caption(caption)
        indices = [self.word2idx[self.SOS_TOKEN]]
        for token in caption.split():
            indices.append(self.word2idx.get(token, self.word2idx[self.UNK_TOKEN]))
        indices.append(self.word2idx[self.EOS_TOKEN])
        return indices
    
    def decode(self, indices):
        words = []
        for idx in indices:
            if isinstance(idx, torch.Tensor):
                idx = idx.item()
            word = self.idx2word.get(idx, self.UNK_TOKEN)
            if word == self.EOS_TOKEN:
                break
            if word not in [self.PAD_TOKEN, self.SOS_TOKEN]:
                words.append(word)
        return ' '.join(words)
    
    def __len__(self):
        return len(self.word2idx)
    
    def save(self, path):
        data = {'word2idx': self.word2idx, 'idx2word': {int(k): v for k, v in self.idx2word.items()}, 
                'min_freq': self.min_freq, 'word_freq': dict(self.word_freq)}
        with open(path, 'w') as f:
            json.dump(data, f)
    
    @classmethod
    def load(cls, path):
        with open(path, 'r') as f:
            data = json.load(f)
        vocab = cls(min_freq=data['min_freq'])
        vocab.word2idx = data['word2idx']
        vocab.idx2word = {int(k): v for k, v in data['idx2word'].items()}
        if 'word_freq' in data:
            vocab.word_freq = Counter(data['word_freq'])
        return vocab

In [4]:
def load_captions(captions_file):
    """Load and clean captions from file."""
    image_captions = {}
    
    with open(captions_file, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    
    start_idx = 1 if 'image' in lines[0].lower() and 'caption' in lines[0].lower() else 0
    
    for line in lines[start_idx:]:
        line = line.strip()
        if not line:
            continue
        
        parts = line.split(',', 1) if ',' in line else line.split('\t', 1)
        if len(parts) != 2:
            continue
        
        image_name, caption = parts[0].strip(), parts[1].strip()
        if '#' in image_name:
            image_name = image_name.split('#')[0]
        
        cleaned = clean_caption(caption)
        if cleaned:
            if image_name not in image_captions:
                image_captions[image_name] = []
            image_captions[image_name].append(cleaned)
    
    print(f"Loaded {len(image_captions)} images, {sum(len(c) for c in image_captions.values())} captions")
    return image_captions

def create_data_splits(image_captions, train_ratio=0.75, seed=42):
    """Split data into train/val/test sets."""
    image_names = list(image_captions.keys())
    np.random.seed(seed)
    np.random.shuffle(image_names)
    
    n_total = len(image_names)
    n_train = int(n_total * train_ratio)
    n_val = int(n_total * 0.1)
    
    train_captions = {img: image_captions[img] for img in image_names[:n_train]}
    val_captions = {img: image_captions[img] for img in image_names[n_train:n_train + n_val]}
    test_captions = {img: image_captions[img] for img in image_names[n_train + n_val:]}
    
    print(f"Split: {len(train_captions)} train, {len(val_captions)} val, {len(test_captions)} test")
    return train_captions, val_captions, test_captions

In [5]:
class CaptionDataset(Dataset):
    """Dataset for image-caption pairs."""
    
    def __init__(self, image_captions, images_dir, vocab, transform=None, max_len=50):
        self.images_dir = images_dir
        self.vocab = vocab
        self.transform = transform
        self.max_len = max_len
        self.samples = [(img, cap) for img, caps in image_captions.items() for cap in caps]
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_name, caption = self.samples[idx]
        image = Image.open(os.path.join(self.images_dir, img_name)).convert('RGB')
        if self.transform:
            image = self.transform(image)
        
        encoded = self.vocab.encode(caption)
        if len(encoded) > self.max_len:
            encoded = encoded[:self.max_len-1] + [self.vocab.word2idx[self.vocab.EOS_TOKEN]]
        
        return image, torch.tensor(encoded, dtype=torch.long)

def collate_fn(batch):
    """Pad captions to same length within batch."""
    images, captions = zip(*batch)
    images = torch.stack(images, dim=0)
    lengths = torch.tensor([len(c) for c in captions], dtype=torch.long)
    captions = pad_sequence(captions, batch_first=True, padding_value=0)
    return images, captions, lengths

train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

## Data Processing

Loading captions, building vocabulary, and splitting dataset.

In [None]:
if not os.path.exists(CAPTIONS_FILE):
    print(f"WARNING: Captions file not found at {CAPTIONS_FILE}")
else:
    image_captions = load_captions(CAPTIONS_FILE)
    
    train_captions, val_captions, test_captions = create_data_splits(image_captions, 0.75)
    
    all_train_captions = [cap for caps in train_captions.values() for cap in caps]
    vocab = Vocabulary(min_freq=MIN_WORD_FREQ)
    vocab.build_vocabulary(all_train_captions)
    
    vocab.save(os.path.join(SAVE_DIR, 'vocab.json'))

In [None]:
if os.path.exists(CAPTIONS_FILE):
    train_dataset = CaptionDataset(
        train_captions, IMAGES_DIR, vocab, 
        transform=train_transform, max_len=MAX_CAPTION_LEN
    )
    
    val_dataset = CaptionDataset(
        val_captions, IMAGES_DIR, vocab,
        transform=val_transform, max_len=MAX_CAPTION_LEN
    )
    
    test_dataset = CaptionDataset(
        test_captions, IMAGES_DIR, vocab,
        transform=val_transform, max_len=MAX_CAPTION_LEN
    )
    
    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE, shuffle=True,
        collate_fn=collate_fn, num_workers=4, pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=BATCH_SIZE, shuffle=False,
        collate_fn=collate_fn, num_workers=4, pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset, batch_size=BATCH_SIZE, shuffle=False,
        collate_fn=collate_fn, num_workers=4, pin_memory=True
    )
    
    print(f"Dataloaders: {len(train_loader)} train, {len(val_loader)} val, {len(test_loader)} test batches")

In [None]:
if os.path.exists(CAPTIONS_FILE):
    caption_lengths = [len(cap.split()) for cap in all_train_captions]
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    axes[0, 0].hist(caption_lengths, bins=30, edgecolor='black', alpha=0.7, color='steelblue')
    axes[0, 0].axvline(np.mean(caption_lengths), color='red', linestyle='--', 
                       label=f'Mean: {np.mean(caption_lengths):.1f}')
    axes[0, 0].axvline(np.median(caption_lengths), color='green', linestyle='--',
                       label=f'Median: {np.median(caption_lengths):.1f}')
    axes[0, 0].set_xlabel('Caption Length (words)')
    axes[0, 0].set_ylabel('Frequency')
    axes[0, 0].set_title('Caption Length Distribution')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    top_words = vocab.word_freq.most_common(30)
    words, freqs = zip(*top_words)
    axes[0, 1].barh(range(len(words)), freqs, color='coral')
    axes[0, 1].set_yticks(range(len(words)))
    axes[0, 1].set_yticklabels(words)
    axes[0, 1].invert_yaxis()
    axes[0, 1].set_xlabel('Frequency')
    axes[0, 1].set_title('Top 30 Most Common Words')
    axes[0, 1].grid(True, alpha=0.3, axis='x')
    
    thresholds = [1, 2, 3, 5, 10, 15, 20]
    vocab_sizes = []
    for thresh in thresholds:
        size = sum(1 for freq in vocab.word_freq.values() if freq >= thresh)
        vocab_sizes.append(size)
    
    axes[1, 0].plot(thresholds, vocab_sizes, marker='o', linewidth=2, color='purple')
    axes[1, 0].axvline(MIN_WORD_FREQ, color='red', linestyle='--', 
                       label=f'Selected threshold: {MIN_WORD_FREQ}')
    axes[1, 0].set_xlabel('Minimum Word Frequency')
    axes[1, 0].set_ylabel('Vocabulary Size')
    axes[1, 0].set_title('Vocabulary Size vs. Frequency Threshold')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    split_names = ['Train', 'Validation', 'Test']
    split_sizes = [len(train_captions), len(val_captions), len(test_captions)]
    colors = ['#2ecc71', '#3498db', '#e74c3c']
    
    axes[1, 1].pie(split_sizes, labels=split_names, autopct='%1.1f%%', 
                   colors=colors, explode=(0.02, 0.02, 0.02))
    axes[1, 1].set_title('Data Split Distribution')
    
    plt.tight_layout()
    plt.savefig(os.path.join(SAVE_DIR, 'data_exploration.png'), dpi=150)
    plt.show()
    
    print(f"Caption stats: min {min(caption_lengths)}, max {max(caption_lengths)}, mean {np.mean(caption_lengths):.1f}")

## Model Architecture

CNN encoder using ResNet-50 for feature extraction, attention mechanism for focusing on relevant image regions, and LSTM decoder for caption generation.

In [None]:
class EncoderCNN(nn.Module):
    def __init__(self, encoded_image_size=7, fine_tune=False):
        super(EncoderCNN, self).__init__()
        
        try:
            resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        except:
            resnet = models.resnet50(weights=None)
            print("Warning: ResNet loaded without pretrained weights")
        
        self.resnet = nn.Sequential(*list(resnet.children())[:-2])
        self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))
        
        self.set_fine_tuning(fine_tune)
    
    def set_fine_tuning(self, fine_tune):
        for param in self.resnet.parameters():
            param.requires_grad = fine_tune
    
    def forward(self, images):
        features = self.resnet(images)
        features = self.adaptive_pool(features)
        
        batch_size = features.size(0)
        features = features.permute(0, 2, 3, 1)
        features = features.view(batch_size, -1, 2048)
        
        return features

In [16]:
class Attention(nn.Module):
    """Soft attention mechanism."""
    
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)
        self.full_att = nn.Linear(attention_dim, 1)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, encoder_out, decoder_hidden):
        att1 = self.encoder_att(encoder_out)
        att2 = self.decoder_att(decoder_hidden)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)
        alpha = self.softmax(att)
        context = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
        return context, alpha

In [18]:
class DecoderLSTM(nn.Module):
    """
    LSTM Decoder with Attention for image captioning.
    
    Generates captions word by word, attending to different
    parts of the encoded image at each timestep.
    """
    
    def __init__(self, embed_dim, decoder_dim, attention_dim, vocab_size, 
                 encoder_dim=2048, dropout=0.5):
        """
        Args:
            embed_dim: Dimension of word embeddings
            decoder_dim: Dimension of LSTM hidden state
            attention_dim: Dimension of attention network
            vocab_size: Size of vocabulary
            encoder_dim: Dimension of encoded image features
            dropout: Dropout probability
        """
        super(DecoderLSTM, self).__init__()
        
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim
        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.vocab_size = vocab_size
        self.dropout = dropout
        
        # Attention network
        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)
        
        # Word embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # Dropout layer
        self.dropout_layer = nn.Dropout(dropout)
        
        # LSTM cell (not full LSTM - we need step-by-step control)
        # Input: concatenation of embedding and attention-weighted encoding
        self.lstm_cell = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim)
        
        # Linear layers to initialize LSTM states from encoder output
        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)
        
        # Linear layer to create a sigmoid-activated gate
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        self.sigmoid = nn.Sigmoid()
        
        # Linear layer to find scores over vocabulary
        self.fc = nn.Linear(decoder_dim, vocab_size)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize some weights with uniform distribution."""
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)
    
    def init_hidden_state(self, encoder_out):
        """
        Initialize LSTM hidden state from encoder output.
        
        Args:
            encoder_out: (batch_size, num_pixels, encoder_dim)
        
        Returns:
            h: (batch_size, decoder_dim)
            c: (batch_size, decoder_dim)
        """
        # Mean of encoder output across spatial dimensions
        mean_encoder_out = encoder_out.mean(dim=1)  # (batch, encoder_dim)
        
        h = self.init_h(mean_encoder_out)  # (batch, decoder_dim)
        c = self.init_c(mean_encoder_out)  # (batch, decoder_dim)
        
        return h, c
    
    def forward(self, encoder_out, captions, caption_lengths):
        """
        Forward pass for training with teacher forcing.
        
        Args:
            encoder_out: (batch_size, num_pixels, encoder_dim)
            captions: (batch_size, max_caption_length)
            caption_lengths: (batch_size,)
        
        Returns:
            predictions: (batch_size, max_caption_length, vocab_size)
            alphas: (batch_size, max_caption_length, num_pixels)
        """
        batch_size = encoder_out.size(0)
        num_pixels = encoder_out.size(1)
        
        # Sort by decreasing caption length for efficient packing
        caption_lengths, sort_idx = caption_lengths.sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_idx]
        captions = captions[sort_idx]
        
        # Embed captions
        embeddings = self.embedding(captions)  # (batch, max_len, embed_dim)
        
        # Initialize LSTM state
        h, c = self.init_hidden_state(encoder_out)
        
        # We won't decode at <EOS> position, since we've finished when we hit <EOS>
        # So decode_length = caption_length - 1
        decode_lengths = (caption_lengths - 1).tolist()
        max_decode_length = max(decode_lengths)
        
        # Create tensors to hold predictions and attention weights
        predictions = torch.zeros(batch_size, max_decode_length, self.vocab_size).to(encoder_out.device)
        alphas = torch.zeros(batch_size, max_decode_length, num_pixels).to(encoder_out.device)
        
        # For each timestep
        for t in range(max_decode_length):
            # Determine batch size at this timestep (some sequences may have ended)
            batch_size_t = sum([l > t for l in decode_lengths])
            
            # Attention
            attention_weighted_encoding, alpha = self.attention(
                encoder_out[:batch_size_t], 
                h[:batch_size_t]
            )
            
            # Gating scalar (for doubly stochastic attention)
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))
            attention_weighted_encoding = gate * attention_weighted_encoding
            
            # LSTM input: concatenate embedding and attention-weighted encoding
            lstm_input = torch.cat([
                embeddings[:batch_size_t, t, :], 
                attention_weighted_encoding
            ], dim=1)
            
            # LSTM step
            h_new, c_new = self.lstm_cell(lstm_input, (h[:batch_size_t], c[:batch_size_t]))
            
            # Update hidden states
            h = h.clone()
            c = c.clone()
            h[:batch_size_t] = h_new
            c[:batch_size_t] = c_new
            
            # Predict next word
            preds = self.fc(self.dropout_layer(h_new))  # (batch_size_t, vocab_size)
            
            # Store predictions and attention weights
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha
        
        return predictions, alphas, captions, decode_lengths, sort_idx
    
    def generate(self, encoder_out, vocab, max_len=50, beam_size=1):
        """
        Generate caption for a single image using greedy decoding or beam search.
        
        Args:
            encoder_out: (1, num_pixels, encoder_dim)
            vocab: Vocabulary object
            max_len: Maximum caption length
            beam_size: Beam size for beam search (1 = greedy)
        
        Returns:
            caption: Generated caption string
            attention_weights: List of attention weights
        """
        if beam_size == 1:
            return self._greedy_decode(encoder_out, vocab, max_len)
        else:
            return self._beam_search(encoder_out, vocab, max_len, beam_size)
    
    def _greedy_decode(self, encoder_out, vocab, max_len):
        """Greedy decoding - select most probable word at each step."""
        device = encoder_out.device
        
        # Initialize LSTM state
        h, c = self.init_hidden_state(encoder_out)
        
        # Start with <SOS> token
        word_idx = vocab.word2idx[vocab.SOS_TOKEN]
        
        caption = []
        attention_weights = []
        
        for _ in range(max_len):
            # Embed current word
            embedding = self.embedding(torch.tensor([word_idx]).to(device))  # (1, embed_dim)
            
            # Attention
            attention_weighted_encoding, alpha = self.attention(encoder_out, h)
            attention_weights.append(alpha.squeeze(0).cpu().detach().numpy())
            
            # Gating
            gate = self.sigmoid(self.f_beta(h))
            attention_weighted_encoding = gate * attention_weighted_encoding
            
            # LSTM step
            lstm_input = torch.cat([embedding, attention_weighted_encoding], dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
            
            # Predict
            scores = self.fc(h)
            word_idx = scores.argmax(dim=1).item()
            
            # Check for <EOS>
            if word_idx == vocab.word2idx[vocab.EOS_TOKEN]:
                break
            
            caption.append(vocab.idx2word[word_idx])
        
        return ' '.join(caption), attention_weights
    
    def _beam_search(self, encoder_out, vocab, max_len, beam_size):
        """Beam search decoding for better results."""
        device = encoder_out.device
        num_pixels = encoder_out.size(1)
        
        # Expand encoder output for beam search
        encoder_out = encoder_out.expand(beam_size, num_pixels, self.encoder_dim)
        
        # Initialize
        k_prev_words = torch.tensor([[vocab.word2idx[vocab.SOS_TOKEN]]] * beam_size).to(device)
        seqs = k_prev_words
        top_k_scores = torch.zeros(beam_size, 1).to(device)
        
        # Initialize LSTM states
        h, c = self.init_hidden_state(encoder_out)
        
        complete_seqs = []
        complete_seqs_scores = []
        
        for step in range(max_len):
            embeddings = self.embedding(k_prev_words).squeeze(1)
            attention_weighted_encoding, _ = self.attention(encoder_out, h)
            gate = self.sigmoid(self.f_beta(h))
            attention_weighted_encoding = gate * attention_weighted_encoding
            
            lstm_input = torch.cat([embeddings, attention_weighted_encoding], dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
            
            scores = self.fc(h)
            scores = torch.log_softmax(scores, dim=1)
            scores = top_k_scores.expand_as(scores) + scores
            
            if step == 0:
                top_k_scores, top_k_words = scores[0].topk(beam_size, 0, True, True)
            else:
                top_k_scores, top_k_words = scores.view(-1).topk(beam_size, 0, True, True)
            
            prev_word_inds = top_k_words // self.vocab_size
            next_word_inds = top_k_words % self.vocab_size
            
            seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)
            
            incomplete_inds = [ind for ind, word in enumerate(next_word_inds) 
                             if word != vocab.word2idx[vocab.EOS_TOKEN]]
            complete_inds = [ind for ind, word in enumerate(next_word_inds) 
                           if word == vocab.word2idx[vocab.EOS_TOKEN]]
            
            if len(complete_inds) > 0:
                complete_seqs.extend(seqs[complete_inds].tolist())
                complete_seqs_scores.extend(top_k_scores[complete_inds].tolist())
            
            beam_size = len(incomplete_inds)
            if beam_size == 0:
                break
            
            seqs = seqs[incomplete_inds]
            h = h[prev_word_inds[incomplete_inds]]
            c = c[prev_word_inds[incomplete_inds]]
            encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
            top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
            k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)
        
        if len(complete_seqs) == 0:
            complete_seqs = seqs.tolist()
            complete_seqs_scores = top_k_scores.squeeze(1).tolist()
        
        best_seq_idx = complete_seqs_scores.index(max(complete_seqs_scores))
        best_seq = complete_seqs[best_seq_idx]
        
        caption = [vocab.idx2word[idx] for idx in best_seq[1:] 
                  if idx not in [vocab.word2idx[vocab.SOS_TOKEN], 
                                vocab.word2idx[vocab.EOS_TOKEN],
                                vocab.word2idx[vocab.PAD_TOKEN]]]
        
        return ' '.join(caption), []

## Model Training

Training with cross-entropy loss and attention regularization. Two-stage training approach: encoder frozen initially, then fine-tuned after 5 epochs.

In [None]:
def train_one_epoch(encoder, decoder, train_loader, criterion, 
                    decoder_optimizer, encoder_optimizer, device, alpha_c=1.0):
    encoder.train()
    decoder.train()
    
    total_loss = 0
    total_ce_loss = 0
    total_att_loss = 0
    
    progress_bar = tqdm(train_loader, desc="Training")
    
    for images, captions, lengths in progress_bar:
        images = images.to(device)
        captions = captions.to(device)
        lengths = lengths.to(device)
        
        encoder_out = encoder(images)
        predictions, alphas, sorted_captions, decode_lengths, sort_idx = decoder(
            encoder_out, captions, lengths
        )
        
        targets = sorted_captions[:, 1:]
        
        predictions_packed = torch.cat([
            predictions[i, :decode_lengths[i], :] 
            for i in range(len(decode_lengths))
        ], dim=0)
        
        targets_packed = torch.cat([
            targets[i, :decode_lengths[i]] 
            for i in range(len(decode_lengths))
        ], dim=0)
        
        ce_loss = criterion(predictions_packed, targets_packed)
        
        alphas_packed = torch.cat([
            alphas[i, :decode_lengths[i], :] 
            for i in range(len(decode_lengths))
        ], dim=0)
        
        att_regularization = alpha_c * ((1 - alphas_packed.sum(dim=0)) ** 2).mean()
        loss = ce_loss + att_regularization
        
        decoder_optimizer.zero_grad()
        if encoder_optimizer is not None:
            encoder_optimizer.zero_grad()
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(decoder.parameters(), max_norm=5.0)
        if encoder_optimizer is not None:
            torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=5.0)
        
        decoder_optimizer.step()
        if encoder_optimizer is not None:
            encoder_optimizer.step()
        
        total_loss += loss.item()
        total_ce_loss += ce_loss.item()
        total_att_loss += att_regularization.item()
        progress_bar.set_postfix({'loss': f'{ce_loss.item():.4f}'})
    
    return total_ce_loss / len(train_loader), total_att_loss / len(train_loader)

def validate(encoder, decoder, val_loader, criterion, device):
    encoder.eval()
    decoder.eval()
    
    total_loss = 0
    
    with torch.no_grad():
        for images, captions, lengths in tqdm(val_loader, desc="Validating"):
            images = images.to(device)
            captions = captions.to(device)
            lengths = lengths.to(device)
            
            encoder_out = encoder(images)
            predictions, alphas, sorted_captions, decode_lengths, sort_idx = decoder(
                encoder_out, captions, lengths
            )
            
            targets = sorted_captions[:, 1:]
            
            predictions_packed = torch.cat([
                predictions[i, :decode_lengths[i], :] 
                for i in range(len(decode_lengths))
            ], dim=0)
            
            targets_packed = torch.cat([
                targets[i, :decode_lengths[i]] 
                for i in range(len(decode_lengths))
            ], dim=0)
            
            loss = criterion(predictions_packed, targets_packed)
            total_loss += loss.item()
    
    return total_loss / len(val_loader)

def save_checkpoint(epoch, encoder, decoder, encoder_optimizer, decoder_optimizer, 
                   train_loss, val_loss, vocab_size, checkpoint_path):
    checkpoint = {
        'epoch': epoch,
        'encoder_state_dict': encoder.state_dict(),
        'decoder_state_dict': decoder.state_dict(),
        'decoder_optimizer_state_dict': decoder_optimizer.state_dict(),
        'encoder_optimizer_state_dict': encoder_optimizer.state_dict() if encoder_optimizer else None,
        'train_loss': train_loss,
        'val_loss': val_loss,
        'vocab_size': vocab_size,
        'embed_dim': EMBED_DIM,
        'hidden_dim': HIDDEN_DIM,
        'attention_dim': ATTENTION_DIM,
        'encoder_dim': ENCODER_DIM,
        'dropout': DROPOUT
    }
    torch.save(checkpoint, checkpoint_path)

In [None]:
if os.path.exists(CAPTIONS_FILE):
    encoder = EncoderCNN(fine_tune=False).to(device)
    
    decoder = DecoderLSTM(
        embed_dim=EMBED_DIM,
        decoder_dim=HIDDEN_DIM,
        attention_dim=ATTENTION_DIM,
        vocab_size=len(vocab),
        encoder_dim=ENCODER_DIM,
        dropout=DROPOUT
    ).to(device)
    
    criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx[vocab.PAD_TOKEN])
    
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=LEARNING_RATE)
    encoder_optimizer = None
    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        decoder_optimizer, mode='min', factor=0.5, patience=3
    )
    
    history = {
        'train_loss': [],
        'val_loss': []
    }
    
    best_val_loss = float('inf')

    print(f"Starting training for {NUM_EPOCHS} epochs")
    print(f"Encoder params: {sum(p.numel() for p in encoder.parameters()):,}")
    print(f"Decoder params: {sum(p.numel() for p in decoder.parameters()):,}")
else:
    print("Dataset not found. Please extract caption_data.zip first.")

In [None]:
if os.path.exists(CAPTIONS_FILE):
    best_model_path = os.path.join(SAVE_DIR, 'best_model.pt')
    
    for epoch in range(NUM_EPOCHS):
        print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")
        
        if FINE_TUNE_ENCODER and epoch == FINE_TUNE_AFTER:
            print("Starting encoder fine-tuning")
            encoder.set_fine_tuning(True)
            encoder_optimizer = optim.Adam(
                filter(lambda p: p.requires_grad, encoder.parameters()),
                lr=ENCODER_LR
            )
        
        train_ce_loss, train_att_loss = train_one_epoch(
            encoder, decoder, train_loader, criterion,
            decoder_optimizer, encoder_optimizer, device
        )
        
        val_loss = validate(encoder, decoder, val_loader, criterion, device)
        
        scheduler.step(val_loss)
        
        history['train_loss'].append(train_ce_loss)
        history['val_loss'].append(val_loss)
        
        print(f"\nEpoch {epoch + 1} Summary:")
        print(f"  Train Loss (CE): {train_ce_loss:.4f}")
        print(f"  Train Loss (Att Reg): {train_att_loss:.4f}")
        print(f"  Val Loss: {val_loss:.4f}")
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint(
                epoch, encoder, decoder, encoder_optimizer, decoder_optimizer,
                train_ce_loss, val_loss, len(vocab),
                best_model_path
            )
            print(f"  Best model saved (val loss: {val_loss:.4f})")
        else:
            print(f"  Not improved (best: {best_val_loss:.4f})")
    
    print("\nTraining complete!")
    print(f"Best validation loss: {best_val_loss:.4f}")
    print(f"Model saved to: {best_model_path}")
else:
    print("Dataset not found. Skipping training.")

In [None]:
if os.path.exists(CAPTIONS_FILE) and len(history['train_loss']) > 0:
    plt.figure(figsize=(10, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss', marker='o')
    plt.plot(history['val_loss'], label='Val Loss', marker='s')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    
    if FINE_TUNE_ENCODER and FINE_TUNE_AFTER < len(history['train_loss']):
        plt.axvline(x=FINE_TUNE_AFTER, color='r', linestyle='--', 
                   label=f'Fine-tune start (epoch {FINE_TUNE_AFTER})')
        plt.legend()
    
    plt.subplot(1, 2, 2)
    improvements = [0] + [history['val_loss'][i-1] - history['val_loss'][i] 
                         for i in range(1, len(history['val_loss']))]
    colors = ['green' if imp > 0 else 'red' for imp in improvements]
    plt.bar(range(len(improvements)), improvements, color=colors)
    plt.xlabel('Epoch')
    plt.ylabel('Loss Improvement')
    plt.title('Validation Loss Improvement per Epoch')
    plt.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
    plt.grid(True, axis='y')
    
    plt.tight_layout()
    plt.savefig(os.path.join(SAVE_DIR, 'training_history.png'), dpi=150)
    plt.show()
    
    with open(os.path.join(SAVE_DIR, 'training_history.json'), 'w') as f:
        json.dump(history, f, indent=2)
    print("Training history saved")

## Sample Predictions

Visualize generated captions on validation images to verify model performance.

In [None]:
if os.path.exists(CAPTIONS_FILE):
    encoder.eval()
    decoder.eval()
    
    sample_images, sample_captions, sample_lengths = next(iter(val_loader))
    
    def denormalize(tensor):
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        return tensor * std + mean
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    with torch.no_grad():
        for i in range(min(6, len(sample_images))):
            img = sample_images[i:i+1].to(device)
            
            encoder_out = encoder(img)
            generated_caption, _ = decoder.generate(encoder_out, vocab, max_len=30, beam_size=3)
            
            gt_caption = vocab.decode(sample_captions[i].tolist())
            
            img_display = denormalize(sample_images[i]).permute(1, 2, 0).numpy()
            img_display = np.clip(img_display, 0, 1)
            axes[i].imshow(img_display)
            axes[i].axis('off')
            axes[i].set_title(f"Generated: {generated_caption[:50]}...\n"
                             f"Ground Truth: {gt_caption[:50]}...", 
                             fontsize=8, wrap=True)
    
    plt.tight_layout()
    plt.savefig(os.path.join(SAVE_DIR, 'sample_predictions.png'), dpi=150)
    plt.show()
    print("Sample predictions saved")

## Model Evaluation

BLEU score evaluation on test set. BLEU measures n-gram precision between generated and reference captions (BLEU-1 through BLEU-4).

In [25]:
# ============================================
# BLEU SCORE EVALUATION
# ============================================

def evaluate_bleu(encoder, decoder, test_captions, images_dir, vocab, device, 
                  transform, num_samples=None):
    """
    Evaluate model using BLEU scores on test set.
    
    Args:
        encoder: Trained CNN encoder
        decoder: Trained LSTM decoder
        test_captions: Dict of {image_name: [reference_captions]}
        images_dir: Path to images directory
        vocab: Vocabulary object
        device: torch device
        transform: Image transform
        num_samples: Number of samples to evaluate (None = all)
    
    Returns:
        dict: BLEU-1, BLEU-2, BLEU-3, BLEU-4 scores
    """
    if not BLEU_AVAILABLE:
        print("NLTK not available. Skipping BLEU evaluation.")
        return None
    
    encoder.eval()
    decoder.eval()
    
    references = []  # List of list of reference captions (tokenized)
    hypotheses = []  # List of generated captions (tokenized)
    
    image_names = list(test_captions.keys())
    if num_samples:
        image_names = image_names[:num_samples]
    
    print(f"Evaluating BLEU on {len(image_names)} images...")
    
    smoothing = SmoothingFunction().method1
    
    with torch.no_grad():
        for img_name in tqdm(image_names, desc="Generating captions"):
            # Load and preprocess image
            img_path = os.path.join(images_dir, img_name)
            if not os.path.exists(img_path):
                continue
                
            image = Image.open(img_path).convert('RGB')
            image_tensor = transform(image).unsqueeze(0).to(device)
            
            # Generate caption
            encoder_out = encoder(image_tensor)
            generated_caption, _ = decoder.generate(encoder_out, vocab, max_len=30, beam_size=3)
            
            # Tokenize generated caption
            hypothesis = generated_caption.split()
            
            # Get reference captions (all 5 per image)
            refs = [cap.split() for cap in test_captions[img_name]]
            
            references.append(refs)
            hypotheses.append(hypothesis)
    
    # Calculate BLEU scores
    bleu1 = corpus_bleu(references, hypotheses, weights=(1, 0, 0, 0))
    bleu2 = corpus_bleu(references, hypotheses, weights=(0.5, 0.5, 0, 0))
    bleu3 = corpus_bleu(references, hypotheses, weights=(0.33, 0.33, 0.33, 0))
    bleu4 = corpus_bleu(references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25))
    
    results = {
        'BLEU-1': bleu1,
        'BLEU-2': bleu2,
        'BLEU-3': bleu3,
        'BLEU-4': bleu4
    }
    
    print(f"\n{'='*50}")
    print("BLEU SCORE RESULTS")
    print(f"{'='*50}")
    print(f"BLEU-1: {bleu1:.4f}")
    print(f"BLEU-2: {bleu2:.4f}")
    print(f"BLEU-3: {bleu3:.4f}")
    print(f"BLEU-4: {bleu4:.4f}")
    print(f"{'='*50}")
    
    return results

In [None]:
if os.path.exists(CAPTIONS_FILE):
    bleu_results = evaluate_bleu(
        encoder, decoder, test_captions, IMAGES_DIR, vocab, device,
        val_transform, num_samples=100
    )
    
    if bleu_results:
        with open(os.path.join(SAVE_DIR, 'bleu_scores.json'), 'w') as f:
            json.dump(bleu_results, f, indent=2)
        
        plt.figure(figsize=(8, 5))
        bleu_names = list(bleu_results.keys())
        bleu_values = list(bleu_results.values())
        
        bars = plt.bar(bleu_names, bleu_values, color=['#3498db', '#2ecc71', '#f39c12', '#e74c3c'])
        plt.ylabel('Score')
        plt.title('BLEU Scores on Test Set')
        plt.ylim(0, 1)
        
        for bar, val in zip(bars, bleu_values):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
                    f'{val:.3f}', ha='center', va='bottom', fontsize=11)
        
        plt.grid(True, axis='y', alpha=0.3)
        plt.tight_layout()
        plt.savefig(os.path.join(SAVE_DIR, 'bleu_scores.png'), dpi=150)
        plt.show()
        print("BLEU scores saved")