In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
import re

# Load the captions
def load_captions(caption_file):
    captions = {}
    with open(caption_file, 'r') as f:
        next(f)  # Skip header
        for line in f:
            line = line.strip()
            if line:
                # Split only on first comma to handle commas in captions
                parts = line.split(',', 1)
                if len(parts) == 2:
                    image_name = parts[0]
                    caption = parts[1].strip()
                    
                    if image_name not in captions:
                        captions[image_name] = []
                    captions[image_name].append(caption)
    return captions

# Load your captions
captions_dict = load_captions('/kaggle/input/flickr8k/captions.txt')

# Basic exploration
print(f"Total images: {len(captions_dict)}")
print(f"Total captions: {sum(len(caps) for caps in captions_dict.values())}")

# Show some examples
sample_images = list(captions_dict.keys())[:3]
for img in sample_images:
    print(f"\n{img}:")
    for i, cap in enumerate(captions_dict[img]):
        print(f"  {i+1}. {cap}")

Total images: 8091
Total captions: 40455

1000268201_693b08cb0e.jpg:
  1. A child in a pink dress is climbing up a set of stairs in an entry way .
  2. A girl going into a wooden building .
  3. A little girl climbing into a wooden playhouse .
  4. A little girl climbing the stairs to her playhouse .
  5. A little girl in a pink dress going into a wooden cabin .

1001773457_577c3a7d70.jpg:
  1. A black dog and a spotted dog are fighting
  2. A black dog and a tri-colored dog playing with each other on the road .
  3. A black dog and a white dog with brown spots are staring at each other in the street .
  4. Two dogs of different breeds looking at each other on the road .
  5. Two dogs on pavement moving toward each other .

1002674143_1b742ab4b8.jpg:
  1. A little girl covered in paint sits in front of a painted rainbow with her hands in a bowl .
  2. A little girl is sitting in front of a large painted rainbow .
  3. A small girl in the grass plays with fingerpaints in front of a whit

In [2]:
import string

def preprocess_caption(caption):
    """Clean and preprocess a single caption"""
    # Convert to lowercase
    caption = caption.lower()
    
    # Remove extra whitespace
    caption = caption.strip()
    
    # Remove punctuation except periods (we'll handle them specially)
    caption = re.sub(r'[^\w\s.]', '', caption)
    
    # Remove multiple spaces
    caption = re.sub(r'\s+', ' ', caption)
    
    # Remove trailing period if present
    caption = caption.rstrip('.')
    
    # Add start and end tokens
    caption = '<start> ' + caption + ' <end>'
    
    return caption

# Preprocess all captions
processed_captions = {}
all_words = []

for image_name, caption_list in captions_dict.items():
    processed_captions[image_name] = []
    for caption in caption_list:
        processed_cap = preprocess_caption(caption)
        processed_captions[image_name].append(processed_cap)
        # Collect all words for vocabulary
        all_words.extend(processed_cap.split())

print("Sample processed captions:")
sample_img = list(processed_captions.keys())[0]
for i, cap in enumerate(processed_captions[sample_img][:3]):
    print(f"{i+1}. {cap}")

# Build vocabulary
word_counts = Counter(all_words)
print(f"\nTotal unique words: {len(word_counts)}")
print(f"Most common words: {word_counts.most_common(10)}")

# Create vocabulary (words appearing at least 5 times)
min_word_count = 5
vocab = [word for word, count in word_counts.items() if count >= min_word_count]

# Add special tokens if not present
special_tokens = ['<start>', '<end>', '<unk>', '<pad>']
for token in special_tokens:
    if token not in vocab:
        vocab.append(token)

print(f"\nVocabulary size (min_count={min_word_count}): {len(vocab)}")
print(f"Special tokens added: {special_tokens}")

# Create word-to-index and index-to-word mappings
word_to_idx = {word: idx for idx, word in enumerate(vocab)}
idx_to_word = {idx: word for word, idx in word_to_idx.items()}

print(f"\nSample word mappings:")
for word in ['<start>', 'a', 'dog', 'running', '<end>']:
    if word in word_to_idx:
        print(f"'{word}' -> {word_to_idx[word]}")

Sample processed captions:
1. <start> a child in a pink dress is climbing up a set of stairs in an entry way  <end>
2. <start> a girl going into a wooden building  <end>
3. <start> a little girl climbing into a wooden playhouse  <end>

Total unique words: 8836
Most common words: [('a', 62986), ('<start>', 40455), ('<end>', 40455), ('in', 18974), ('the', 18418), ('on', 10743), ('is', 9345), ('and', 8851), ('dog', 8136), ('with', 7765)]

Vocabulary size (min_count=5): 2996
Special tokens added: ['<start>', '<end>', '<unk>', '<pad>']

Sample word mappings:
'<start>' -> 0
'a' -> 1
'dog' -> 26
'running' -> 110
'<end>' -> 14


In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

class Flickr8kDataset(Dataset):
    def __init__(self, image_dir, captions_dict, word_to_idx, transform=None):
        self.image_dir = image_dir
        self.word_to_idx = word_to_idx
        self.transform = transform
        
        # Create list of (image_path, caption) pairs
        self.data = []
        for image_name, caption_list in captions_dict.items():
            for caption in caption_list:
                self.data.append((image_name, caption))
        
        print(f"Dataset created with {len(self.data)} image-caption pairs")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image_name, caption = self.data[idx]
        
        # Load and transform image
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        # Convert caption to tensor
        caption_tokens = caption.split()
        caption_indices = []
        
        for token in caption_tokens:
            if token in self.word_to_idx:
                caption_indices.append(self.word_to_idx[token])
            else:
                caption_indices.append(self.word_to_idx['<unk>'])
        
        caption_tensor = torch.tensor(caption_indices, dtype=torch.long)
        
        return image, caption_tensor

# Define image transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])  # ImageNet normalization
])

# Create dataset
dataset = Flickr8kDataset(
    image_dir='/kaggle/input/flickr8k/Images',  # adjust path
    captions_dict=processed_captions,
    word_to_idx=word_to_idx,
    transform=transform
)

# Test the dataset
print("Testing dataset...")
sample_image, sample_caption = dataset[0]
print(f"Image shape: {sample_image.shape}")
print(f"Caption tensor: {sample_caption}")
print(f"Caption length: {len(sample_caption)}")

# Convert back to words to verify
sample_words = [idx_to_word[idx.item()] for idx in sample_caption]
print(f"Caption as words: {' '.join(sample_words)}")

Dataset created with 40455 image-caption pairs
Testing dataset...
Image shape: torch.Size([3, 224, 224])
Caption tensor: tensor([   0,    1,    2,    3,    1,    4,    5,    6,    7,    8,    1,    9,
          10,   11,    3,   12, 2994,   13,   14])
Caption length: 19
Caption as words: <start> a child in a pink dress is climbing up a set of stairs in an <unk> way <end>


In [4]:
def collate_fn(batch):
    """Custom collate function to handle variable length captions"""
    images, captions = zip(*batch)
    
    # Stack images
    images = torch.stack(images)
    
    # Get caption lengths
    lengths = [len(cap) for cap in captions]
    max_length = max(lengths)
    
    # Pad captions to max length
    padded_captions = torch.zeros(len(captions), max_length, dtype=torch.long)
    
    for i, caption in enumerate(captions):
        padded_captions[i, :len(caption)] = caption
        # Pad with <pad> token
        if len(caption) < max_length:
            padded_captions[i, len(caption):] = word_to_idx['<pad>']
    
    return images, padded_captions, torch.tensor(lengths)

# Test collate function
test_loader = DataLoader(dataset, batch_size=3, shuffle=False, collate_fn=collate_fn)
test_batch = next(iter(test_loader))
test_images, test_captions, test_lengths = test_batch

print(f"\nBatch shapes:")
print(f"Images: {test_images.shape}")
print(f"Captions: {test_captions.shape}")
print(f"Lengths: {test_lengths}")

print(f"\nFirst caption in batch:")
first_cap_words = [idx_to_word[idx.item()] for idx in test_captions[0]]
print(' '.join(first_cap_words))


Batch shapes:
Images: torch.Size([3, 3, 224, 224])
Captions: torch.Size([3, 19])
Lengths: tensor([19,  9, 10])

First caption in batch:
<start> a child in a pink dress is climbing up a set of stairs in an <unk> way <end>


In [5]:
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Subset

# Split image names (not individual captions) to avoid data leakage
image_names = list(processed_captions.keys())
print(f"Total images: {len(image_names)}")

# 80% train, 20% validation split
train_images, val_images = train_test_split(
    image_names, 
    test_size=0.2, 
    random_state=97,
    shuffle=True
)

print(f"Train images: {len(train_images)}")
print(f"Validation images: {len(val_images)}")

# Create separate caption dictionaries for train and val
train_captions = {img: processed_captions[img] for img in train_images}
val_captions = {img: processed_captions[img] for img in val_images}

# Verify split
train_caption_count = sum(len(caps) for caps in train_captions.values())
val_caption_count = sum(len(caps) for caps in val_captions.values())

print(f"Train captions: {train_caption_count}")
print(f"Val captions: {val_caption_count}")
print(f"Total: {train_caption_count + val_caption_count}")

Total images: 8091
Train images: 6472
Validation images: 1619
Train captions: 32360
Val captions: 8095
Total: 40455


In [6]:
# Create train and validation datasets
train_dataset = Flickr8kDataset(
    image_dir='/kaggle/input/flickr8k/Images',
    captions_dict=train_captions,
    word_to_idx=word_to_idx,
    transform=transform
)

val_dataset = Flickr8kDataset(
    image_dir='/kaggle/input/flickr8k/Images',
    captions_dict=val_captions,
    word_to_idx=word_to_idx,
    transform=transform
)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Val dataset size: {len(val_dataset)}")

# Create data loaders
batch_size = 16  # Power of 2, adjust based on GPU memory

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=2,  # For faster data loading
    pin_memory=True  # For GPU efficiency
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,  # No need to shuffle validation
    collate_fn=collate_fn,
    num_workers=2,
    pin_memory=True
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

# Test the loaders
print("\nTesting train loader...")
train_batch = next(iter(train_loader))
train_images, train_captions, train_lengths = train_batch

print(f"Train batch - Images: {train_images.shape}")
print(f"Train batch - Captions: {train_captions.shape}")
print(f"Train batch - Lengths: {train_lengths}")

print(f"Max caption length in batch: {train_captions.shape[1]}")
print(f"Caption lengths: {train_lengths.tolist()}")

Dataset created with 32360 image-caption pairs
Dataset created with 8095 image-caption pairs
Train dataset size: 32360
Val dataset size: 8095
Train batches: 2023
Val batches: 506

Testing train loader...
Train batch - Images: torch.Size([16, 3, 224, 224])
Train batch - Captions: torch.Size([16, 25])
Train batch - Lengths: tensor([15, 13, 13, 13,  9, 15, 25,  9, 14, 14, 12, 14, 12, 13, 10, 13])
Max caption length in batch: 25
Caption lengths: [15, 13, 13, 13, 9, 15, 25, 9, 14, 14, 12, 14, 12, 13, 10, 13]


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import math

class ImageEncoder(nn.Module):
    """EfficientNet-B0 encoder (frozen)"""
    def __init__(self, embed_dim=256):
        super().__init__()
        # Load pretrained EfficientNet-B0
        self.efficientnet = models.efficientnet_b0(pretrained=True)
        
        # Freeze all parameters
        for param in self.efficientnet.parameters():
            param.requires_grad = False
        
        # Get the feature dimension before the classifier
        # EfficientNet-B0 has 1280 features before classifier
        self.feature_dim = self.efficientnet.classifier[1].in_features
        
        # Remove the classifier
        self.efficientnet.classifier = nn.Identity()
        
        # Add projection layer to embed_dim
        self.projection = nn.Linear(self.feature_dim, embed_dim)
        
        print(f"ImageEncoder: EfficientNet-B0 -> {embed_dim} dims")
        print(f"Frozen parameters: {sum(1 for p in self.efficientnet.parameters())}")
    
    def forward(self, images):
        # Extract features
        features = self.efficientnet(images)  # [batch, 1280]
        
        # Project to embedding dimension
        embedded = self.projection(features)  # [batch, embed_dim]
        
        return embedded

# Test the encoder
print("Testing ImageEncoder...")
encoder = ImageEncoder(embed_dim=256)

# Test with a batch from train_loader
test_images, _, _ = next(iter(train_loader))
print(f"Input images shape: {test_images.shape}")

with torch.no_grad():
    encoded_features = encoder(test_images)
    print(f"Encoded features shape: {encoded_features.shape}")

# Check trainable parameters
trainable_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in encoder.parameters())
print(f"Trainable params: {trainable_params:,}")
print(f"Total params: {total_params:,}")

Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth


Testing ImageEncoder...


100%|██████████| 20.5M/20.5M [00:00<00:00, 161MB/s]

ImageEncoder: EfficientNet-B0 -> 256 dims
Frozen parameters: 211





Input images shape: torch.Size([16, 3, 224, 224])
Encoded features shape: torch.Size([16, 256])
Trainable params: 327,936
Total params: 4,335,484


In [8]:
class PositionalEncoding(nn.Module):
    """Positional encoding for transformer"""
    def __init__(self, d_model, max_len=100):
        super().__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):
        # x shape: [batch, seq_len, d_model]
        return x + self.pe[:, :x.size(1)]

# Test positional encoding
pos_encoding = PositionalEncoding(d_model=256, max_len=50)
test_input = torch.randn(2, 10, 256)  # [batch, seq_len, embed_dim]
pos_output = pos_encoding(test_input)
print(f"Positional encoding test: {test_input.shape} -> {pos_output.shape}")

Positional encoding test: torch.Size([2, 10, 256]) -> torch.Size([2, 10, 256])


In [9]:
class TransformerDecoder(nn.Module):
    """2-layer Transformer decoder for caption generation"""
    def __init__(self, vocab_size, embed_dim=256, num_heads=4, num_layers=2, 
                 max_seq_len=50, dropout=0.1):
        super().__init__()
        
        self.embed_dim = embed_dim
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        
        # Word embedding layer
        self.word_embedding = nn.Embedding(vocab_size, embed_dim)
        
        # Positional encoding
        self.pos_encoding = PositionalEncoding(embed_dim, max_seq_len)
        
        # Transformer decoder layers
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * 4,  # Standard is 4x
            dropout=dropout,
            batch_first=True  # Important: batch dimension first
        )
        
        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layer, 
            num_layers=num_layers
        )
        
        # Output projection to vocabulary
        self.output_projection = nn.Linear(embed_dim, vocab_size)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
        print(f"TransformerDecoder: {num_layers} layers, {num_heads} heads")
        print(f"Vocabulary size: {vocab_size}")
    
    def create_padding_mask(self, sequences, lengths):
        """Create mask for padded tokens"""
        batch_size, max_len = sequences.shape
        device = sequences.device  # Get device from input tensor
        
        mask = torch.arange(max_len, device=device).expand(batch_size, max_len) >= lengths.unsqueeze(1)
        return mask
    
    def create_causal_mask(self, seq_len):
        """Create causal (look-ahead) mask for decoder"""
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
        return mask.bool()
    
    def forward(self, image_features, captions, caption_lengths, training=True):
        batch_size, seq_len = captions.shape
        
        # Word embeddings
        word_embeds = self.word_embedding(captions)  # [batch, seq_len, embed_dim]
        
        # Add positional encoding
        word_embeds = self.pos_encoding(word_embeds)
        word_embeds = self.dropout(word_embeds)
        
        # Prepare image features as memory (encoder output)
        # Transformer expects [batch, src_seq_len, embed_dim]
        # Our image features are [batch, embed_dim], so we add sequence dimension
        memory = image_features.unsqueeze(1)  # [batch, 1, embed_dim]
        
        # Create masks
        if training:
            # Causal mask prevents looking ahead
            causal_mask = self.create_causal_mask(seq_len).to(captions.device)
        else:
            causal_mask = None
        
        # Padding mask for target sequence
        tgt_key_padding_mask = self.create_padding_mask(captions, caption_lengths).to(captions.device)
        
        # Transformer decoder
        decoder_output = self.transformer_decoder(
            tgt=word_embeds,
            memory=memory,
            tgt_mask=causal_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )
        
        # Project to vocabulary
        logits = self.output_projection(decoder_output)  # [batch, seq_len, vocab_size]
        
        return logits

# Test the decoder
print("Testing TransformerDecoder...")
decoder = TransformerDecoder(
    vocab_size=len(vocab),
    embed_dim=256,
    num_heads=4,
    num_layers=2
)

# Test with encoder output and captions
test_images, test_captions, test_lengths = next(iter(train_loader))
test_encoded = encoder(test_images)

print(f"Image features shape: {test_encoded.shape}")
print(f"Input captions shape: {test_captions.shape}")
print(f"Caption lengths: {test_lengths[:5]}")  # Show first 5

# Forward pass
with torch.no_grad():
    logits = decoder(test_encoded, test_captions, test_lengths, training=True)
    print(f"Output logits shape: {logits.shape}")
    print(f"Expected: [batch_size, seq_len, vocab_size] = [{test_captions.shape[0]}, {test_captions.shape[1]}, {len(vocab)}]")

# Check decoder parameters
decoder_trainable = sum(p.numel() for p in decoder.parameters() if p.requires_grad)
print(f"Decoder trainable params: {decoder_trainable:,}")

Testing TransformerDecoder...
TransformerDecoder: 2 layers, 4 heads
Vocabulary size: 2996
Image features shape: torch.Size([16, 256])
Input captions shape: torch.Size([16, 19])
Caption lengths: tensor([12, 19, 10, 10, 10])
Output logits shape: torch.Size([16, 19, 2996])
Expected: [batch_size, seq_len, vocab_size] = [16, 19, 2996]
Decoder trainable params: 3,643,828


In [10]:
class ImageCaptionGenerator(nn.Module):
    """Complete Image Caption Generator: EfficientNet-B0 + Transformer Decoder"""
    def __init__(self, vocab_size, embed_dim=256, num_heads=4, num_layers=2, 
                 max_seq_len=50, dropout=0.1):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.max_seq_len = max_seq_len
        
        # Image encoder
        self.encoder = ImageEncoder(embed_dim=embed_dim)
        
        # Text decoder
        self.decoder = TransformerDecoder(
            vocab_size=vocab_size,
            embed_dim=embed_dim,
            num_heads=num_heads,
            num_layers=num_layers,
            max_seq_len=max_seq_len,
            dropout=dropout
        )
        
        # Store special token indices
        self.start_token = word_to_idx['<start>']
        self.end_token = word_to_idx['<end>']
        self.pad_token = word_to_idx['<pad>']
        
        print(f"Complete model created!")
        self.print_model_info()
    
    def generate_beam_search(self, images, beam_size=3, max_length=20):
        """Generate captions using beam search for better diversity"""
        self.eval()
        
        # Encode images
        image_features = self.encoder(images)
        batch_size = image_features.shape[0]
        device = images.device
        
        # Process each image separately (beam search is typically done per image)
        all_results = []
        
        for batch_idx in range(batch_size):
            single_image_features = image_features[batch_idx:batch_idx+1]  # [1, embed_dim]
            
            # Initialize beam with start token
            # Each beam item: (sequence, cumulative_log_prob)
            beam = [([self.start_token], 0.0)]
            
            for step in range(max_length):
                candidates = []
                
                for sequence, score in beam:
                    # Skip if sequence already ended
                    if sequence[-1] == self.end_token:
                        candidates.append((sequence, score))
                        continue
                    
                    # Convert sequence to tensor
                    seq_tensor = torch.tensor([sequence], dtype=torch.long, device=device)
                    length_tensor = torch.tensor([len(sequence)], device=device)
                    
                    # Get next token probabilities
                    with torch.no_grad():
                        logits = self.decoder(single_image_features, seq_tensor, length_tensor, training=False)
                        next_token_logits = logits[0, -1, :]  # [vocab_size]
                        log_probs = F.log_softmax(next_token_logits, dim=0)
                    
                    # Get top beam_size candidates
                    top_log_probs, top_indices = torch.topk(log_probs, beam_size)
                    
                    # Add candidates
                    for log_prob, token_idx in zip(top_log_probs, top_indices):
                        new_sequence = sequence + [token_idx.item()]
                        new_score = score + log_prob.item()
                        candidates.append((new_sequence, new_score))
                
                # Keep top beam_size candidates
                candidates.sort(key=lambda x: x[1], reverse=True)  # Sort by score (higher better)
                beam = candidates[:beam_size]
                
                # Stop if all beams have ended
                if all(seq[-1] == self.end_token for seq, _ in beam):
                    break
            
            # Get best sequence for this image
            best_sequence = beam[0][0]  # Highest scoring sequence
            best_tensor = torch.tensor(best_sequence, dtype=torch.long, device=device)
            all_results.append(best_tensor)
        
        # Pad sequences to same length for batch return
        max_len = max(len(seq) for seq in all_results)
        
        batched_results = torch.full((batch_size, max_len), self.pad_token, 
                                    dtype=torch.long, device=device)
        
        for i, seq in enumerate(all_results):
            batched_results[i, :len(seq)] = seq
        
        return batched_results
    
    def print_model_info(self):
        """Print model parameter information"""
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {trainable_params:,}")
        print(f"Frozen parameters: {total_params - trainable_params:,}")
        
        # Check if under 4M limit
        if trainable_params < 4_000_000:
            print(f"✅ Under 4M parameter limit!")
        else:
            print(f"❌ Exceeds 4M parameter limit by {trainable_params - 4_000_000:,}")
    
    def forward(self, images, captions=None, caption_lengths=None):
        """Forward pass for training"""
        # Encode images
        image_features = self.encoder(images)
        
        if captions is not None:
            # Training mode: use teacher forcing
            logits = self.decoder(image_features, captions, caption_lengths, training=True)
            return logits
        else:
            # Inference mode: generate captions
            return self.generate_caption(image_features)
    
    def generate_caption(self, image_features, max_length=20, temperature=1.0):
        """Generate caption using greedy decoding"""
        batch_size = image_features.shape[0]
        device = image_features.device
        
        # Start with <start> token
        generated = torch.full((batch_size, 1), self.start_token, 
                              dtype=torch.long, device=device)
        
        for _ in range(max_length):
            # Get current sequence length
            seq_len = generated.shape[1]
            lengths = torch.full((batch_size,), seq_len, device=device)
            
            # Forward pass
            with torch.no_grad():
                logits = self.decoder(image_features, generated, lengths, training=False)
                
                # Get next token probabilities (last position)
                next_token_logits = logits[:, -1, :] / temperature
                next_token_probs = F.softmax(next_token_logits, dim=-1)
                
                # Greedy sampling (take most probable token)
                next_token = torch.argmax(next_token_probs, dim=-1).unsqueeze(1)
            
            # Append to generated sequence
            generated = torch.cat([generated, next_token], dim=1)
            
            # Check if all sequences have generated <end> token
            if torch.all(next_token.squeeze() == self.end_token):
                break
        
        return generated


In [11]:
# Create the complete model
model = ImageCaptionGenerator(
    vocab_size=len(vocab),
    embed_dim=256,
    num_heads=4,
    num_layers=2,
    max_seq_len=50,
    dropout=0.1
)


## Load best model and evaluate
print("Loading best model for evaluation...")
checkpoint = torch.load('/kaggle/input/imagecaption-after8/model_checkpoints_after_8/best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])

# Test the complete model
print("\nTesting complete model...")
test_images, test_captions, test_lengths = next(iter(train_loader))

# Training forward pass
print("Training mode:")
with torch.no_grad():
    training_logits = model(test_images, test_captions, test_lengths)
    print(f"Training output shape: {training_logits.shape}")

# Inference mode
print("\nInference mode:")
with torch.no_grad():
    generated_captions = model(test_images[:2])  # Generate for first 2 images
    print(f"Generated captions shape: {generated_captions.shape}")
    
    # Convert to words
    for i in range(2):
        caption_indices = generated_captions[i].cpu().tolist()
        caption_words = [idx_to_word[idx] for idx in caption_indices]
        print(f"Generated caption {i+1}: {' '.join(caption_words)}")

ImageEncoder: EfficientNet-B0 -> 256 dims
Frozen parameters: 211
TransformerDecoder: 2 layers, 4 heads
Vocabulary size: 2996
Complete model created!
Total parameters: 7,979,312
Trainable parameters: 3,971,764
Frozen parameters: 4,007,548
✅ Under 4M parameter limit!
Loading best model for evaluation...

Testing complete model...
Training mode:
Training output shape: torch.Size([16, 20, 2996])

Inference mode:
Generated captions shape: torch.Size([2, 15])
Generated caption 1: <start> a dog stands in the water <end> <end> <end> <end> <end> <end> <end> <end>
Generated caption 2: <start> a man in a blue jacket and a backpack and a <unk> <unk> <end>


In [13]:
import os
import json
from datetime import datetime
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

class TrainingManager:
    def __init__(self, model, train_loader, val_loader, word_to_idx, idx_to_word, 
                 checkpoint_dir='/kaggle/working/model_checkpoints', device='cuda'):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.word_to_idx = word_to_idx
        self.idx_to_word = idx_to_word
        self.device = device
        self.checkpoint_dir = checkpoint_dir
        
        # Create checkpoint directory
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        # Training components
        self.criterion = nn.CrossEntropyLoss(ignore_index=word_to_idx['<pad>'])
        
        # Only train decoder + projection parameters
        trainable_params = [p for p in model.parameters() if p.requires_grad]
        self.optimizer = optim.Adam(trainable_params, lr=1e-3, weight_decay=1e-4)
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode='min', patience=2, 
                                          factor=0.5)
        
        # Training state
        self.start_epoch = 0
        self.best_val_loss = float('inf')
        self.train_losses = []
        self.val_losses = []
        
        print(f"Training setup complete!")
        print(f"Device: {device}")
        print(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
    
    def save_checkpoint(self, epoch, train_loss, val_loss, is_best=False):
        """Save training checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'best_val_loss': self.best_val_loss,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'word_to_idx': self.word_to_idx,
            'idx_to_word': self.idx_to_word,
            'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        }
        
        # Save last checkpoint
        last_path = os.path.join(self.checkpoint_dir, 'last_checkpoint.pth')
        torch.save(checkpoint, last_path)
        print(f"Saved checkpoint: {last_path}")
        
        # Save best model
        if is_best:
            best_path = os.path.join(self.checkpoint_dir, 'best_model.pth')
            torch.save(checkpoint, best_path)
            print(f"🎉 New best model saved: {best_path} (val_loss: {val_loss:.4f})")
    
    def load_checkpoint(self, checkpoint_path=r"/kaggle/input/model"):
        """Load checkpoint to resume training"""
        if os.path.exists(checkpoint_path):
            print(f"Loading checkpoint: {checkpoint_path}")
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
            
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            
            self.start_epoch = checkpoint['epoch'] + 1
            self.train_losses = checkpoint['train_losses']
            self.val_losses = checkpoint['val_losses']
            self.best_val_loss = checkpoint['best_val_loss']
            
            print(f"Resumed from epoch {self.start_epoch}")
            print(f"Best validation loss so far: {self.best_val_loss:.4f}")
            return True
        else:
            print(f"No checkpoint found at {checkpoint_path}")
            return False
    
    def train_epoch(self):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        num_batches = len(self.train_loader)
        
        for batch_idx, (images, captions, lengths) in enumerate(tqdm.tqdm(self.train_loader)):
            images = images.to(self.device)
            captions = captions.to(self.device)
            lengths = lengths.to(self.device)
            
            # Forward pass
            # Input: captions without last token, Target: captions without first token
            input_captions = captions[:, :-1]  # Remove <end> token
            target_captions = captions[:, 1:]   # Remove <start> token
            input_lengths = lengths - 1
            
            logits = self.model(images, input_captions, input_lengths)
            
            # Reshape for loss calculation
            logits = logits.reshape(-1, logits.size(-1))  # [batch*seq, vocab]
            targets = target_captions.reshape(-1)         # [batch*seq]
            
            loss = self.criterion(logits, targets)
            
            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=5.0)
            
            self.optimizer.step()
            
            total_loss += loss.item()
        
        return total_loss / num_batches
    
    def validate(self):
        """Validate the model"""
        self.model.eval()
        total_loss = 0
        
        with torch.no_grad():
            for images, captions, lengths in tqdm.tqdm(self.val_loader):
                images = images.to(self.device)
                captions = captions.to(self.device)
                lengths = lengths.to(self.device)
                
                # Same input/target split as training
                input_captions = captions[:, :-1]
                target_captions = captions[:, 1:]
                input_lengths = lengths - 1
                
                logits = self.model(images, input_captions, input_lengths)
                
                logits = logits.reshape(-1, logits.size(-1))
                targets = target_captions.reshape(-1)
                
                loss = self.criterion(logits, targets)
                total_loss += loss.item()
        
        return total_loss / len(self.val_loader)
    
    def generate_sample_captions(self, num_samples=3):
        """Generate sample captions for monitoring progress - FIXED"""
        self.model.eval()
        
        # Get a few validation samples
        val_batch = next(iter(self.val_loader))
        images, true_captions, _ = val_batch
        
        images = images[:num_samples].to(self.device)
        
        with torch.no_grad():
            # Pass raw images, let the model handle encoding
            generated = self.model(images)  # This will call generate_caption internally
        
        print(f"\nSample Generated Captions:")
        for i in range(num_samples):
            # True caption
            true_caption_tokens = [self.idx_to_word[idx.item()] for idx in true_captions[i] 
                                  if idx.item() != self.word_to_idx['<pad>']]
            
            # Generated caption
            gen_caption_tokens = [self.idx_to_word[idx.item()] for idx in generated[i]]
            
            print(f"True:      {' '.join(true_caption_tokens)}")
            print(f"Generated: {' '.join(gen_caption_tokens)}")
            print("-" * 50)

In [22]:
import tqdm

def train_model(model, train_loader, val_loader, word_to_idx, idx_to_word,
                num_epochs=10, device='cuda', resume_from_checkpoint=True):
    """Main training function"""
    
    # Move model to device
    model = model.to(device)
    
    # Initialize training manager
    trainer = TrainingManager(model, train_loader, val_loader, 
                            word_to_idx, idx_to_word, device=device)
    
    # Try to resume from checkpoint
    if resume_from_checkpoint:
        trainer.load_checkpoint('/kaggle/input/model/last_checkpoint.pth')
    
    print(f"\n🚀 Starting training for {num_epochs} epochs...")
    print(f"Starting from epoch: {trainer.start_epoch}")
    
    for epoch in range(trainer.start_epoch, num_epochs):
        print(f"\n{'='*60}")
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"{'='*60}")
        
        # Training
        print("Training...")
        train_loss = trainer.train_epoch()
        trainer.train_losses.append(train_loss)
        
        # Validation  
        print("Validating...")
        val_loss = trainer.validate()
        trainer.val_losses.append(val_loss)
        
        # Learning rate scheduling
        trainer.scheduler.step(val_loss)
        current_lr = trainer.optimizer.param_groups[0]['lr']
        
        # Check if best model
        is_best = val_loss < trainer.best_val_loss
        if is_best:
            trainer.best_val_loss = val_loss
        
        # Print epoch results
        print(f"\nEpoch {epoch+1} Results:")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f} {'🎉 (Best!)' if is_best else ''}")
        print(f"Learning Rate: {current_lr:.6f}")
        
        # Generate sample captions to monitor progress
        trainer.generate_sample_captions(num_samples=2)
        
        # Save checkpoint
        trainer.save_checkpoint(epoch, train_loss, val_loss, is_best)
        
        print(f"Epoch {epoch+1} completed in {'-'*20}")
    
    print(f"\n🎉 Training completed!")
    print(f"Best validation loss: {trainer.best_val_loss:.4f}")
    
    return trainer

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Start training
trainer = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    word_to_idx=word_to_idx,
    idx_to_word=idx_to_word,
    num_epochs=8,  # Start with 8 epochs
    device=device,
    resume_from_checkpoint=False  # Set True to resume
)

Using device: cuda
Training setup complete!
Device: cuda
Trainable parameters: 3,971,764

🚀 Starting training for 8 epochs...
Starting from epoch: 0

Epoch 1/8
Training...


 10%|█         | 203/2023 [00:12<01:50, 16.53it/s]

Batch [200/2023], Loss: 3.2570


 20%|█▉        | 402/2023 [00:23<01:35, 17.04it/s]

Batch [400/2023], Loss: 3.2423


 30%|██▉       | 602/2023 [00:34<01:18, 18.15it/s]

Batch [600/2023], Loss: 3.2301


 40%|███▉      | 801/2023 [00:46<01:13, 16.56it/s]

Batch [800/2023], Loss: 3.2105


 50%|████▉     | 1003/2023 [00:58<00:58, 17.57it/s]

Batch [1000/2023], Loss: 3.1921


 59%|█████▉    | 1203/2023 [01:09<00:46, 17.68it/s]

Batch [1200/2023], Loss: 3.1776


 69%|██████▉   | 1402/2023 [01:21<00:35, 17.55it/s]

Batch [1400/2023], Loss: 3.1635


 79%|███████▉  | 1601/2023 [01:32<00:25, 16.65it/s]

Batch [1600/2023], Loss: 3.1527


 89%|████████▉ | 1803/2023 [01:44<00:12, 17.00it/s]

Batch [1800/2023], Loss: 3.1469


 99%|█████████▉| 2002/2023 [01:56<00:01, 17.16it/s]

Batch [2000/2023], Loss: 3.1394


100%|██████████| 2023/2023 [01:57<00:00, 17.26it/s]


Validating...


100%|██████████| 506/506 [00:29<00:00, 17.27it/s]


Epoch 1 Results:
Train Loss: 3.1374
Val Loss: 3.0422 🎉 (Best!)
Learning Rate: 0.001000






Sample Generated Captions:
True:      <start> a blurred image of a black and white dog running through the grass <end>
Generated: <start> a brown dog is running through the grass <end>
--------------------------------------------------
True:      <start> a dog is running through a field <end>
Generated: <start> a brown dog is running through the grass <end>
--------------------------------------------------
Saved checkpoint: /kaggle/working/model_checkpoints/last_checkpoint.pth
🎉 New best model saved: /kaggle/working/model_checkpoints/best_model.pth (val_loss: 3.0422)
Epoch 1 completed in --------------------

Epoch 2/8
Training...


 10%|█         | 203/2023 [00:12<01:43, 17.62it/s]

Batch [200/2023], Loss: 2.9554


 20%|█▉        | 401/2023 [00:23<01:56, 13.93it/s]

Batch [400/2023], Loss: 2.9558


 30%|██▉       | 601/2023 [00:35<01:19, 17.78it/s]

Batch [600/2023], Loss: 2.9644


 40%|███▉      | 803/2023 [00:47<01:37, 12.49it/s]

Batch [800/2023], Loss: 2.9644


 50%|████▉     | 1002/2023 [00:59<01:01, 16.63it/s]

Batch [1000/2023], Loss: 2.9607


 59%|█████▉    | 1201/2023 [01:11<00:47, 17.37it/s]

Batch [1200/2023], Loss: 2.9611


 69%|██████▉   | 1403/2023 [01:23<00:33, 18.45it/s]

Batch [1400/2023], Loss: 2.9604


 79%|███████▉  | 1603/2023 [01:34<00:23, 17.90it/s]

Batch [1600/2023], Loss: 2.9576


 89%|████████▉ | 1803/2023 [01:46<00:12, 17.59it/s]

Batch [1800/2023], Loss: 2.9558


 99%|█████████▉| 2003/2023 [01:57<00:01, 16.95it/s]

Batch [2000/2023], Loss: 2.9547


100%|██████████| 2023/2023 [01:59<00:00, 16.95it/s]


Validating...


100%|██████████| 506/506 [00:28<00:00, 17.76it/s]


Epoch 2 Results:
Train Loss: 2.9551
Val Loss: 2.9986 🎉 (Best!)
Learning Rate: 0.001000






Sample Generated Captions:
True:      <start> a blurred image of a black and white dog running through the grass <end>
Generated: <start> a dog running through a field <end>
--------------------------------------------------
True:      <start> a dog is running through a field <end>
Generated: <start> a dog running through a field <end>
--------------------------------------------------
Saved checkpoint: /kaggle/working/model_checkpoints/last_checkpoint.pth
🎉 New best model saved: /kaggle/working/model_checkpoints/best_model.pth (val_loss: 2.9986)
Epoch 2 completed in --------------------

Epoch 3/8
Training...


 10%|█         | 203/2023 [00:12<01:43, 17.63it/s]

Batch [200/2023], Loss: 2.8158


 20%|█▉        | 403/2023 [00:24<01:29, 18.13it/s]

Batch [400/2023], Loss: 2.8328


 30%|██▉       | 603/2023 [00:36<01:19, 17.87it/s]

Batch [600/2023], Loss: 2.8431


 40%|███▉      | 802/2023 [00:48<01:10, 17.39it/s]

Batch [800/2023], Loss: 2.8493


 50%|████▉     | 1002/2023 [00:59<00:58, 17.53it/s]

Batch [1000/2023], Loss: 2.8525


 59%|█████▉    | 1201/2023 [01:11<00:52, 15.62it/s]

Batch [1200/2023], Loss: 2.8579


 69%|██████▉   | 1401/2023 [01:23<00:39, 15.66it/s]

Batch [1400/2023], Loss: 2.8618


 79%|███████▉  | 1602/2023 [01:35<00:24, 17.29it/s]

Batch [1600/2023], Loss: 2.8637


 89%|████████▉ | 1801/2023 [01:47<00:13, 16.96it/s]

Batch [1800/2023], Loss: 2.8672


 99%|█████████▉| 2001/2023 [01:59<00:01, 15.06it/s]

Batch [2000/2023], Loss: 2.8681


100%|██████████| 2023/2023 [02:00<00:00, 16.74it/s]


Validating...


100%|██████████| 506/506 [00:33<00:00, 15.27it/s]


Epoch 3 Results:
Train Loss: 2.8685
Val Loss: 2.9651 🎉 (Best!)
Learning Rate: 0.001000






Sample Generated Captions:
True:      <start> a blurred image of a black and white dog running through the grass <end>
Generated: <start> a brown dog is running through a grassy field <end>
--------------------------------------------------
True:      <start> a dog is running through a field <end>
Generated: <start> a brown dog is running through a grassy field <end>
--------------------------------------------------
Saved checkpoint: /kaggle/working/model_checkpoints/last_checkpoint.pth
🎉 New best model saved: /kaggle/working/model_checkpoints/best_model.pth (val_loss: 2.9651)
Epoch 3 completed in --------------------

Epoch 4/8
Training...


 10%|█         | 203/2023 [00:14<01:58, 15.38it/s]

Batch [200/2023], Loss: 2.7709


 20%|█▉        | 402/2023 [00:26<01:33, 17.25it/s]

Batch [400/2023], Loss: 2.7720


 30%|██▉       | 602/2023 [00:39<01:31, 15.52it/s]

Batch [600/2023], Loss: 2.7800


 40%|███▉      | 801/2023 [00:51<01:09, 17.60it/s]

Batch [800/2023], Loss: 2.7812


 50%|████▉     | 1003/2023 [01:03<00:57, 17.80it/s]

Batch [1000/2023], Loss: 2.7851


 59%|█████▉    | 1202/2023 [01:15<00:46, 17.65it/s]

Batch [1200/2023], Loss: 2.7938


 69%|██████▉   | 1403/2023 [01:26<00:35, 17.28it/s]

Batch [1400/2023], Loss: 2.7985


 79%|███████▉  | 1600/2023 [01:41<00:33, 12.56it/s]

Batch [1600/2023], Loss: 2.8055


 89%|████████▉ | 1803/2023 [01:53<00:12, 17.96it/s]

Batch [1800/2023], Loss: 2.8066


 99%|█████████▉| 2002/2023 [02:05<00:01, 16.74it/s]

Batch [2000/2023], Loss: 2.8098


100%|██████████| 2023/2023 [02:06<00:00, 15.96it/s]


Validating...


100%|██████████| 506/506 [00:28<00:00, 17.74it/s]


Epoch 4 Results:
Train Loss: 2.8099
Val Loss: 2.9119 🎉 (Best!)
Learning Rate: 0.001000






Sample Generated Captions:
True:      <start> a blurred image of a black and white dog running through the grass <end>
Generated: <start> a dog runs through the grass <end>
--------------------------------------------------
True:      <start> a dog is running through a field <end>
Generated: <start> a dog runs through the grass <end>
--------------------------------------------------
Saved checkpoint: /kaggle/working/model_checkpoints/last_checkpoint.pth
🎉 New best model saved: /kaggle/working/model_checkpoints/best_model.pth (val_loss: 2.9119)
Epoch 4 completed in --------------------

Epoch 5/8
Training...


 10%|█         | 203/2023 [00:12<01:47, 16.99it/s]

Batch [200/2023], Loss: 2.6934


 20%|█▉        | 403/2023 [00:24<01:34, 17.07it/s]

Batch [400/2023], Loss: 2.7232


 30%|██▉       | 603/2023 [00:35<01:21, 17.51it/s]

Batch [600/2023], Loss: 2.7251


 40%|███▉      | 802/2023 [00:47<01:08, 17.84it/s]

Batch [800/2023], Loss: 2.7443


 50%|████▉     | 1003/2023 [00:58<00:52, 19.60it/s]

Batch [1000/2023], Loss: 2.7473


 59%|█████▉    | 1202/2023 [01:10<00:49, 16.66it/s]

Batch [1200/2023], Loss: 2.7526


 69%|██████▉   | 1403/2023 [01:22<00:32, 19.13it/s]

Batch [1400/2023], Loss: 2.7528


 79%|███████▉  | 1603/2023 [01:33<00:25, 16.46it/s]

Batch [1600/2023], Loss: 2.7586


 89%|████████▉ | 1803/2023 [01:44<00:12, 17.55it/s]

Batch [1800/2023], Loss: 2.7643


 99%|█████████▉| 2001/2023 [01:56<00:01, 17.28it/s]

Batch [2000/2023], Loss: 2.7704


100%|██████████| 2023/2023 [01:57<00:00, 17.16it/s]


Validating...


100%|██████████| 506/506 [00:29<00:00, 17.29it/s]


Epoch 5 Results:
Train Loss: 2.7706
Val Loss: 2.9211 
Learning Rate: 0.001000






Sample Generated Captions:
True:      <start> a blurred image of a black and white dog running through the grass <end>
Generated: <start> a dog running across a field of flowers <end>
--------------------------------------------------
True:      <start> a dog is running through a field <end>
Generated: <start> a dog running across a field of flowers <end>
--------------------------------------------------
Saved checkpoint: /kaggle/working/model_checkpoints/last_checkpoint.pth
Epoch 5 completed in --------------------

Epoch 6/8
Training...


 10%|▉         | 202/2023 [00:12<01:45, 17.23it/s]

Batch [200/2023], Loss: 2.6702


 20%|█▉        | 403/2023 [00:24<01:30, 17.96it/s]

Batch [400/2023], Loss: 2.6884


 30%|██▉       | 602/2023 [00:35<01:20, 17.73it/s]

Batch [600/2023], Loss: 2.6991


 40%|███▉      | 803/2023 [00:47<01:06, 18.39it/s]

Batch [800/2023], Loss: 2.7090


 50%|████▉     | 1002/2023 [00:59<01:00, 16.90it/s]

Batch [1000/2023], Loss: 2.7108


 59%|█████▉    | 1202/2023 [01:10<00:46, 17.59it/s]

Batch [1200/2023], Loss: 2.7194


 69%|██████▉   | 1402/2023 [01:22<00:37, 16.55it/s]

Batch [1400/2023], Loss: 2.7202


 79%|███████▉  | 1603/2023 [01:34<00:24, 17.02it/s]

Batch [1600/2023], Loss: 2.7245


 89%|████████▉ | 1802/2023 [01:45<00:12, 17.69it/s]

Batch [1800/2023], Loss: 2.7308


 99%|█████████▉| 2003/2023 [01:57<00:01, 15.67it/s]

Batch [2000/2023], Loss: 2.7340


100%|██████████| 2023/2023 [01:58<00:00, 17.01it/s]


Validating...


100%|██████████| 506/506 [00:28<00:00, 17.78it/s]


Epoch 6 Results:
Train Loss: 2.7352
Val Loss: 2.8847 🎉 (Best!)
Learning Rate: 0.001000






Sample Generated Captions:
True:      <start> a blurred image of a black and white dog running through the grass <end>
Generated: <start> a dog running through the grass <end>
--------------------------------------------------
True:      <start> a dog is running through a field <end>
Generated: <start> a dog running through the grass <end>
--------------------------------------------------
Saved checkpoint: /kaggle/working/model_checkpoints/last_checkpoint.pth
🎉 New best model saved: /kaggle/working/model_checkpoints/best_model.pth (val_loss: 2.8847)
Epoch 6 completed in --------------------

Epoch 7/8
Training...


 10%|▉         | 202/2023 [00:12<01:44, 17.43it/s]

Batch [200/2023], Loss: 2.6385


 20%|█▉        | 403/2023 [00:24<01:34, 17.17it/s]

Batch [400/2023], Loss: 2.6533


 30%|██▉       | 603/2023 [00:35<01:23, 16.91it/s]

Batch [600/2023], Loss: 2.6691


 40%|███▉      | 803/2023 [00:47<01:13, 16.66it/s]

Batch [800/2023], Loss: 2.6755


 50%|████▉     | 1002/2023 [00:58<00:58, 17.48it/s]

Batch [1000/2023], Loss: 2.6821


 59%|█████▉    | 1203/2023 [01:10<00:46, 17.54it/s]

Batch [1200/2023], Loss: 2.6875


 69%|██████▉   | 1402/2023 [01:21<00:35, 17.31it/s]

Batch [1400/2023], Loss: 2.6925


 79%|███████▉  | 1603/2023 [01:33<00:26, 16.10it/s]

Batch [1600/2023], Loss: 2.6997


 89%|████████▉ | 1802/2023 [01:44<00:12, 17.73it/s]

Batch [1800/2023], Loss: 2.7021


 99%|█████████▉| 2002/2023 [01:56<00:01, 17.32it/s]

Batch [2000/2023], Loss: 2.7081


100%|██████████| 2023/2023 [01:57<00:00, 17.21it/s]


Validating...


100%|██████████| 506/506 [00:28<00:00, 17.70it/s]


Epoch 7 Results:
Train Loss: 2.7086
Val Loss: 2.8808 🎉 (Best!)
Learning Rate: 0.001000






Sample Generated Captions:
True:      <start> a blurred image of a black and white dog running through the grass <end>
Generated: <start> a dog running through the grass <end>
--------------------------------------------------
True:      <start> a dog is running through a field <end>
Generated: <start> a dog running through the grass <end>
--------------------------------------------------
Saved checkpoint: /kaggle/working/model_checkpoints/last_checkpoint.pth
🎉 New best model saved: /kaggle/working/model_checkpoints/best_model.pth (val_loss: 2.8808)
Epoch 7 completed in --------------------

Epoch 8/8
Training...


 10%|▉         | 201/2023 [00:12<01:52, 16.25it/s]

Batch [200/2023], Loss: 2.6097


 20%|█▉        | 403/2023 [00:23<01:30, 17.86it/s]

Batch [400/2023], Loss: 2.6292


 30%|██▉       | 602/2023 [00:35<01:21, 17.38it/s]

Batch [600/2023], Loss: 2.6436


 40%|███▉      | 802/2023 [00:47<01:21, 14.93it/s]

Batch [800/2023], Loss: 2.6514


 50%|████▉     | 1002/2023 [00:58<01:00, 16.79it/s]

Batch [1000/2023], Loss: 2.6632


 59%|█████▉    | 1202/2023 [01:10<00:43, 18.67it/s]

Batch [1200/2023], Loss: 2.6670


 69%|██████▉   | 1403/2023 [01:22<00:35, 17.25it/s]

Batch [1400/2023], Loss: 2.6711


 79%|███████▉  | 1603/2023 [01:33<00:23, 17.73it/s]

Batch [1600/2023], Loss: 2.6770


 89%|████████▉ | 1803/2023 [01:44<00:12, 17.21it/s]

Batch [1800/2023], Loss: 2.6834


 99%|█████████▉| 2001/2023 [01:56<00:01, 16.11it/s]

Batch [2000/2023], Loss: 2.6842


100%|██████████| 2023/2023 [01:58<00:00, 17.11it/s]


Validating...


100%|██████████| 506/506 [00:29<00:00, 17.36it/s]


Epoch 8 Results:
Train Loss: 2.6850
Val Loss: 2.8777 🎉 (Best!)
Learning Rate: 0.001000






Sample Generated Captions:
True:      <start> a blurred image of a black and white dog running through the grass <end>
Generated: <start> a dog running through the grass <end>
--------------------------------------------------
True:      <start> a dog is running through a field <end>
Generated: <start> a dog running through the grass <end>
--------------------------------------------------
Saved checkpoint: /kaggle/working/model_checkpoints/last_checkpoint.pth
🎉 New best model saved: /kaggle/working/model_checkpoints/best_model.pth (val_loss: 2.8777)
Epoch 8 completed in --------------------

🎉 Training completed!
Best validation loss: 2.8777


In [14]:
import matplotlib.pyplot as plt

def plot_training_progress(trainer):
    """Plot training and validation losses"""
    plt.figure(figsize=(12, 4))
    
    # Loss plot
    plt.subplot(1, 2, 1)
    epochs = range(1, len(trainer.train_losses) + 1)
    plt.plot(epochs, trainer.train_losses, 'b-', label='Training Loss')
    plt.plot(epochs, trainer.val_losses, 'r-', label='Validation Loss')
    plt.title('Training Progress')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    # Learning rate plot (if available)
    plt.subplot(1, 2, 2)
    # This will be useful to see LR scheduling
    plt.title('Model Performance')
    plt.text(0.5, 0.7, f'Best Val Loss: {trainer.best_val_loss:.4f}', 
             transform=plt.gca().transAxes, ha='center', fontsize=12)
    plt.text(0.5, 0.5, f'Total Epochs: {len(trainer.train_losses)}', 
             transform=plt.gca().transAxes, ha='center', fontsize=12)
    plt.text(0.5, 0.3, f'Final Train Loss: {trainer.train_losses[-1]:.4f}', 
             transform=plt.gca().transAxes, ha='center', fontsize=12)
    plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('training_progress.png', dpi=300, bbox_inches='tight')
    plt.show()

# After training completes, call this to visualize progress

In [12]:
checkpoint.keys()

dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'scheduler_state_dict', 'train_losses', 'val_losses', 'best_val_loss', 'train_loss', 'val_loss', 'word_to_idx', 'idx_to_word', 'timestamp'])

In [13]:
checkpoint['val_losses']

[3.0421794080451545,
 2.998595917413357,
 2.9650969974137107,
 2.911877240823663,
 2.9211416072525056,
 2.8846528888219902,
 2.880760518929704,
 2.8776617766369004]

In [25]:
!zip model_checkpoints /kaggle/working/model_checkpoints/*

  adding: kaggle/working/model_checkpoints/best_model.pth (deflated 10%)
  adding: kaggle/working/model_checkpoints/last_checkpoint.pth (deflated 10%)


In [14]:
# Create test split from remaining validation data
# Since we did 80/20 split, let's take 50% of val for test
val_images_list = list(val_images)
val_test_images, val_val_images = train_test_split(
    val_images_list, 
    test_size=0.5,  # Split validation in half
    random_state=42
)

print(f"New validation images: {len(val_val_images)}")
print(f"Test images: {len(val_test_images)}")

# Create test caption dictionary
test_captions = {img: processed_captions[img] for img in val_test_images}
updated_val_captions = {img: processed_captions[img] for img in val_val_images}

# Create test dataset
test_dataset = Flickr8kDataset(
    image_dir='/kaggle/input/flickr8k/Images',
    captions_dict=test_captions,
    word_to_idx=word_to_idx,
    transform=transform
)

# Create test loader
test_loader = DataLoader(
    test_dataset,
    batch_size=16,
    shuffle=False,  # No shuffling for test set
    collate_fn=collate_fn,
    num_workers=2,
    pin_memory=True
)

print(f"Test dataset size: {len(test_dataset)}")
print(f"Test batches: {len(test_loader)}")

New validation images: 810
Test images: 809
Dataset created with 4045 image-caption pairs
Test dataset size: 4045
Test batches: 253


In [16]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [20]:
model = model.to(device)

In [18]:
import tqdm

In [21]:
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import numpy as np

def evaluate_model(model, test_loader, idx_to_word, device):
    """Comprehensive model evaluation with BLEU scores"""
    model.eval()
    
    all_bleu1_scores = []
    all_bleu2_scores = []
    all_bleu3_scores = []
    all_bleu4_scores = []
    
    total_samples = 0
    smoothing = SmoothingFunction().method1
    
    print("Evaluating model on test set...")
    print("=" * 50)
    
    with torch.no_grad():
        for batch_idx, (images, true_captions, lengths) in enumerate(tqdm.tqdm(test_loader)):
            images = images.to(device)
            batch_size = images.shape[0]
            
            # Generate captions
            generated_captions = model(images)
            
            # Process each sample in batch
            for i in range(batch_size):
                # Convert true caption to words (remove padding)
                true_tokens = []
                for idx in true_captions[i]:
                    if idx.item() == word_to_idx['<pad>']:
                        break
                    token = idx_to_word[idx.item()]
                    if token not in ['<start>', '<end>']:  # Remove special tokens
                        true_tokens.append(token)
                
                # Convert generated caption to words
                gen_tokens = []
                for idx in generated_captions[i]:
                    token = idx_to_word[idx.item()]
                    if token == '<end>':
                        break
                    if token not in ['<start>', '<pad>']:  # Remove special tokens
                        gen_tokens.append(token)
                
                # Calculate BLEU scores
                if len(gen_tokens) > 0 and len(true_tokens) > 0:
                    # BLEU expects reference as list of lists
                    reference = [true_tokens]
                    candidate = gen_tokens
                    
                    # Calculate different BLEU scores
                    bleu1 = sentence_bleu(reference, candidate, weights=(1, 0, 0, 0), smoothing_function=smoothing)
                    bleu2 = sentence_bleu(reference, candidate, weights=(0.5, 0.5, 0, 0), smoothing_function=smoothing)
                    bleu3 = sentence_bleu(reference, candidate, weights=(0.33, 0.33, 0.33, 0), smoothing_function=smoothing)
                    bleu4 = sentence_bleu(reference, candidate, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothing)
                    
                    all_bleu1_scores.append(bleu1)
                    all_bleu2_scores.append(bleu2)
                    all_bleu3_scores.append(bleu3)
                    all_bleu4_scores.append(bleu4)
                    
                    total_samples += 1
    
    # Calculate average scores
    avg_bleu1 = np.mean(all_bleu1_scores)
    avg_bleu2 = np.mean(all_bleu2_scores)
    avg_bleu3 = np.mean(all_bleu3_scores)
    avg_bleu4 = np.mean(all_bleu4_scores)
    
    # Print results
    print(f"\n🎯 EVALUATION RESULTS on {total_samples} samples:")
    print("=" * 50)
    print(f"BLEU-1: {avg_bleu1:.4f}")
    print(f"BLEU-2: {avg_bleu2:.4f}") 
    print(f"BLEU-3: {avg_bleu3:.4f}")
    print(f"BLEU-4: {avg_bleu4:.4f} ⭐ (Primary Metric)")
    
    # Interpretation
    print(f"\n📊 INTERPRETATION:")
    if avg_bleu4 > 0.25:
        print("🎉 Excellent performance! Research-level quality")
    elif avg_bleu4 > 0.15:
        print("✅ Good performance! Acceptable for applications")
    elif avg_bleu4 > 0.10:
        print("⚠️  Fair performance. Room for improvement")
    else:
        print("❌ Poor performance. Needs significant improvement")
    
    return {
        'bleu1': avg_bleu1,
        'bleu2': avg_bleu2, 
        'bleu3': avg_bleu3,
        'bleu4': avg_bleu4,
        'total_samples': total_samples
    }

# Run evaluation
results = evaluate_model(model, test_loader, idx_to_word, device)
print(results)

Evaluating model on test set...


100%|██████████| 253/253 [00:21<00:00, 12.03it/s]


🎯 EVALUATION RESULTS on 4045 samples:
BLEU-1: 0.2226
BLEU-2: 0.1083
BLEU-3: 0.0630
BLEU-4: 0.0427 ⭐ (Primary Metric)

📊 INTERPRETATION:
❌ Poor performance. Needs significant improvement
{'bleu1': 0.22264526954273994, 'bleu2': 0.10828822645810378, 'bleu3': 0.06303495889097521, 'bleu4': 0.04267422823700421, 'total_samples': 4045}





In [22]:
# Test beam search
def test_beam_search(model, test_loader, idx_to_word, device, num_samples=5):
    """Test beam search vs greedy decoding"""
    model.eval()
    
    images, true_captions, _ = next(iter(test_loader))
    images = images[:num_samples].to(device)
    
    print("🔍 BEAM SEARCH vs GREEDY COMPARISON:")
    print("=" * 70)
    
    with torch.no_grad():
        # Greedy decoding
        greedy_generated = model(images)
        
        # Beam search
        beam_generated = model.generate_beam_search(images, beam_size=3)
    
    for i in range(num_samples):
        # True caption
        true_tokens = [idx_to_word[idx.item()] for idx in true_captions[i] 
                      if idx.item() != word_to_idx['<pad>']]
        
        # Greedy caption
        greedy_tokens = [idx_to_word[idx.item()] for idx in greedy_generated[i]]
        if '<end>' in greedy_tokens:
            end_idx = greedy_tokens.index('<end>') + 1
            greedy_tokens = greedy_tokens[:end_idx]
        
        # Beam search caption
        beam_tokens = [idx_to_word[idx.item()] for idx in beam_generated[i]]
        if '<end>' in beam_tokens:
            end_idx = beam_tokens.index('<end>') + 1
            beam_tokens = beam_tokens[:end_idx]
        
        print(f"Sample {i+1}:")
        print(f"True:   {' '.join(true_tokens)}")
        print(f"Greedy: {' '.join(greedy_tokens)}")
        print(f"Beam:   {' '.join(beam_tokens)}")
        print("-" * 50)

# Test it
test_beam_search(model, test_loader, idx_to_word, device)

🔍 BEAM SEARCH vs GREEDY COMPARISON:
Sample 1:
True:   <start> a boy smiles by the pool <end>
Greedy: <start> a woman and a child are sitting on a bench <end>
Beam:   <start> a man and a woman are sitting on a bench <end>
--------------------------------------------------
Sample 2:
True:   <start> a laughing young boy is near a swimming pool <end>
Greedy: <start> a woman and a child are sitting on a bench <end>
Beam:   <start> a man and a woman are sitting on a bench <end>
--------------------------------------------------
Sample 3:
True:   <start> a young boy outside by the pool smiling <end>
Greedy: <start> a woman and a child are sitting on a bench <end>
Beam:   <start> a man and a woman are sitting on a bench <end>
--------------------------------------------------
Sample 4:
True:   <start> the child is laughing by a pool on the other side of the bars <end>
Greedy: <start> a woman and a child are sitting on a bench <end>
Beam:   <start> a man and a woman are sitting on a bench <end>

In [23]:
# Better evaluation using all 5 captions per image
def evaluate_with_multiple_references(model, test_captions_dict, idx_to_word, device, num_samples=500):
    """Evaluate using all available captions for each image"""
    model.eval()
    
    # Randomly sample some test images
    test_images_list = list(test_captions_dict.keys())
    sampled_images = np.random.choice(test_images_list, min(num_samples, len(test_images_list)), replace=False)
    
    all_bleu4_scores = []
    
    for img_name in sampled_images:
        # Load and preprocess image
        img_path = f'/kaggle/input/flickr8k/Images/{img_name}'
        image = Image.open(img_path).convert('RGB')
        image_tensor = transform(image).unsqueeze(0).to(device)
        
        # Generate caption
        with torch.no_grad():
            generated = model(image_tensor)
        
        # Convert generated to words
        gen_tokens = []
        for idx in generated[0]:
            token = idx_to_word[idx.item()]
            if token == '<end>':
                break
            if token not in ['<start>', '<pad>']:
                gen_tokens.append(token)
        
        # Get all reference captions for this image
        references = []
        for caption in test_captions_dict[img_name]:
            ref_tokens = []
            for token in caption.split():
                if token not in ['<start>', '<end>']:
                    ref_tokens.append(token)
            references.append(ref_tokens)
        
        # Calculate BLEU-4 with multiple references
        if len(gen_tokens) > 0:
            bleu4 = sentence_bleu(references, gen_tokens, weights=(0.25, 0.25, 0.25, 0.25), 
                                smoothing_function=SmoothingFunction().method1)
            all_bleu4_scores.append(bleu4)
    
    avg_bleu4 = np.mean(all_bleu4_scores)
    print(f"BLEU-4 with multiple references: {avg_bleu4:.4f}")
    return avg_bleu4

# Test with multiple references
multi_ref_bleu4 = evaluate_with_multiple_references(model, test_captions, idx_to_word, device)

BLEU-4 with multiple references: 0.1275
