# Image Captioning with CNN Encoder and Transformer Decoder

This notebook implements an image captioning system using:
- **CNN Encoder**: Pretrained ResNet50
- **Transformer Decoder**: PyTorch's nn.TransformerDecoder
- **Dataset**: COCO 2017

Optimized for Google Colab with limited GPU memory.

**Run cells sequentially from top to bottom.**

## Cell 1: Environment Setup

Install required libraries and download NLTK data

In [None]:
!pip install -q pycocotools nltk tqdm

import nltk
nltk.download('punkt', quiet=True)

print("‚úÖ Environment setup complete!")
print("   - pycocotools installed")
print("   - nltk installed")
print("   - punkt tokenizer downloaded")

## Cell 2: Dataset Download

Download COCO 2017 validation dataset (~1GB)

In [None]:
import os

# Create directories
!mkdir -p /content/coco/images
!mkdir -p /content/coco/annotations

# Download validation images (~1GB)
print("Downloading COCO validation images (this may take a few minutes)...")
!wget -q --show-progress http://images.cocodataset.org/zips/val2017.zip
!unzip -q val2017.zip -d /content/coco/images/
!rm val2017.zip

# Download annotations
print("\nDownloading COCO annotations...")
!wget -q --show-progress http://images.cocodataset.org/annotations/annotations_trainval2017.zip
!unzip -q annotations_trainval2017.zip -d /content/coco/
!rm annotations_trainval2017.zip

# Verify download
print("\n‚úÖ Dataset downloaded!")
print("\nüìÅ Directory structure:")
!ls -lh /content/coco/images/val2017 | head -5
print(f"\nTotal images: {len(os.listdir('/content/coco/images/val2017'))}")

## Cell 3: Imports & Configuration

Import libraries and define hyperparameters

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms

from pycocotools.coco import COCO
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
import math
import random
from tqdm import tqdm

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# ============================================================================
# CONFIGURATION
# ============================================================================

# Paths
IMAGES_DIR = '/content/coco/images/val2017'
ANNOTATIONS_FILE = '/content/coco/annotations/captions_val2017.json'

# Dataset limits (for memory optimization)
MAX_SAMPLES = 500        # Limit to 500 images for quick training
VOCAB_THRESHOLD = 3      # Minimum word frequency
MAX_CAPTION_LENGTH = 50  # Maximum caption length

# Model hyperparameters
EMBED_DIM = 256          # Embedding dimension (small for limited GPU)
NUM_HEADS = 8            # Number of attention heads
NUM_LAYERS = 3           # Number of transformer layers
DIM_FEEDFORWARD = 1024   # Feedforward dimension
DROPOUT = 0.1            # Dropout rate

# Training hyperparameters
BATCH_SIZE = 8           # Batch size (small for limited GPU)
NUM_EPOCHS = 5           # Number of epochs
LEARNING_RATE = 1e-4     # Learning rate

# Image preprocessing
IMAGE_SIZE = 224
IMAGE_MEAN = [0.485, 0.456, 0.406]  # ImageNet mean
IMAGE_STD = [0.229, 0.224, 0.225]   # ImageNet std

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("="*70)
print("Configuration")
print("="*70)
print(f"Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Max samples: {MAX_SAMPLES}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Epochs: {NUM_EPOCHS}")
print(f"Embed dim: {EMBED_DIM}")
print("="*70)

## Cell 4: Vocabulary & Tokenization

Build vocabulary from COCO captions

In [None]:
class Vocabulary:
    """Vocabulary for word-to-index mapping."""
    
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0
        
        # Special tokens
        self.pad_token = '<pad>'
        self.start_token = '<start>'
        self.end_token = '<end>'
        self.unk_token = '<unk>'
        
        # Add special tokens
        for token in [self.pad_token, self.start_token, self.end_token, self.unk_token]:
            self.add_word(token)
    
    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1
    
    def __len__(self):
        return len(self.word2idx)
    
    def __call__(self, word):
        return self.word2idx.get(word, self.word2idx[self.unk_token])


def build_vocabulary(annotation_file, threshold=3, max_samples=None):
    """Build vocabulary from COCO annotations."""
    
    print("Building vocabulary...")
    coco = COCO(annotation_file)
    counter = Counter()
    ids = list(coco.anns.keys())
    
    # Limit samples
    if max_samples:
        ids = ids[:max_samples * 5]  # Use more captions for vocab
    
    # Count words
    for ann_id in tqdm(ids, desc="Tokenizing captions"):
        caption = str(coco.anns[ann_id]['caption']).lower()
        tokens = caption.split()
        counter.update(tokens)
    
    # Create vocabulary
    vocab = Vocabulary()
    for word, count in counter.items():
        if count >= threshold:
            vocab.add_word(word)
    
    print(f"‚úÖ Vocabulary size: {len(vocab)}")
    return vocab


# Build vocabulary
vocab = build_vocabulary(ANNOTATIONS_FILE, VOCAB_THRESHOLD, MAX_SAMPLES)

## Cell 5: COCO Dataset & DataLoader

Implement PyTorch Dataset and DataLoader for COCO

In [None]:
class COCODataset(Dataset):
    """COCO Dataset for image captioning."""
    
    def __init__(self, root, annotation_file, vocab, transform=None, max_samples=None):
        self.root = root
        self.coco = COCO(annotation_file)
        self.vocab = vocab
        self.transform = transform
        self.ids = list(self.coco.anns.keys())
        
        # Limit dataset
        if max_samples:
            self.ids = self.ids[:max_samples]
    
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, index):
        ann_id = self.ids[index]
        annotation = self.coco.anns[ann_id]
        
        # Load image
        img_id = annotation['image_id']
        path = self.coco.loadImgs(img_id)[0]['file_name']
        image = Image.open(os.path.join(self.root, path)).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        # Process caption
        caption = str(annotation['caption']).lower()
        tokens = caption.split()
        
        # Convert to indices: <start> + tokens + <end>
        caption_indices = [self.vocab(self.vocab.start_token)]
        caption_indices.extend([self.vocab(token) for token in tokens])
        caption_indices.append(self.vocab(self.vocab.end_token))
        
        # Pad to max length
        length = len(caption_indices)
        if length < MAX_CAPTION_LENGTH:
            caption_indices.extend([self.vocab(self.vocab.pad_token)] * (MAX_CAPTION_LENGTH - length))
        else:
            caption_indices = caption_indices[:MAX_CAPTION_LENGTH]
            length = MAX_CAPTION_LENGTH
        
        return image, torch.LongTensor(caption_indices), length


def collate_fn(batch):
    """Collate function for DataLoader."""
    batch.sort(key=lambda x: x[2], reverse=True)
    images, captions, lengths = zip(*batch)
    images = torch.stack(images, 0)
    captions = torch.stack(captions, 0)
    return images, captions, list(lengths)


# Image transforms
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGE_MEAN, std=IMAGE_STD)
])

# Create dataset and dataloader
dataset = COCODataset(IMAGES_DIR, ANNOTATIONS_FILE, vocab, transform, MAX_SAMPLES)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, 
                       num_workers=2, collate_fn=collate_fn)

print(f"‚úÖ Dataset created: {len(dataset)} samples")
print(f"‚úÖ DataLoader created: {len(dataloader)} batches")

## Cell 6: CNN Encoder

CNN Encoder using pretrained ResNet50

In [None]:
class Encoder(nn.Module):
    """ResNet50 CNN Encoder."""
    
    def __init__(self, embed_dim=256):
        super().__init__()
        
        # Load pretrained ResNet50
        resnet = models.resnet50(pretrained=True)
        
        # Remove final layers
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)
        
        # Freeze ResNet weights
        for param in self.resnet.parameters():
            param.requires_grad = False
        
        # Adaptive pooling and projection
        self.adaptive_pool = nn.AdaptiveAvgPool2d((7, 7))
        self.projection = nn.Linear(2048, embed_dim)
    
    def forward(self, images):
        """
        Args:
            images: (batch, 3, 224, 224)
        Returns:
            features: (batch, 49, embed_dim)
        """
        # Extract features: (batch, 2048, 7, 7)
        features = self.resnet(images)
        features = self.adaptive_pool(features)
        
        # Reshape: (batch, 49, 2048)
        batch_size = features.size(0)
        features = features.permute(0, 2, 3, 1).contiguous()
        features = features.view(batch_size, -1, 2048)
        
        # Project: (batch, 49, embed_dim)
        features = self.projection(features)
        
        return features


# Test encoder
encoder = Encoder(EMBED_DIM).to(DEVICE)
test_images = torch.randn(2, 3, 224, 224).to(DEVICE)
test_features = encoder(test_images)

print(f"‚úÖ Encoder created")
print(f"   Input shape: {test_images.shape}")
print(f"   Output shape: {test_features.shape}")
print(f"   Expected: (2, 49, {EMBED_DIM})")

## Cell 7: Transformer Decoder

Transformer Decoder with positional encoding

In [None]:
class PositionalEncoding(nn.Module):
    """Sinusoidal positional encoding."""
    
    def __init__(self, embed_dim, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        # Create positional encoding
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * 
                            (-math.log(10000.0) / embed_dim))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


class Decoder(nn.Module):
    """Transformer Decoder."""
    
    def __init__(self, vocab_size, embed_dim=256, num_heads=8, 
                 num_layers=3, dim_feedforward=1024, dropout=0.1):
        super().__init__()
        
        self.embed_dim = embed_dim
        self.vocab_size = vocab_size
        
        # Embedding and positional encoding
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_encoder = PositionalEncoding(embed_dim, dropout=dropout)
        
        # Transformer decoder
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        
        # Output projection
        self.fc_out = nn.Linear(embed_dim, vocab_size)
        
        # Initialize weights
        nn.init.xavier_uniform_(self.embedding.weight)
        nn.init.xavier_uniform_(self.fc_out.weight)
        nn.init.zeros_(self.fc_out.bias)
    
    def forward(self, captions, encoder_out):
        """
        Args:
            captions: (batch, seq_len)
            encoder_out: (batch, 49, embed_dim)
        Returns:
            predictions: (batch, seq_len, vocab_size)
        """
        # Embed and add positional encoding
        embeddings = self.embedding(captions)
        embeddings = self.pos_encoder(embeddings)
        
        # Create causal mask
        seq_len = captions.size(1)
        tgt_mask = self._generate_causal_mask(seq_len).to(captions.device)
        
        # Transformer decoder
        decoder_out = self.transformer_decoder(
            tgt=embeddings,
            memory=encoder_out,
            tgt_mask=tgt_mask
        )
        
        # Project to vocabulary
        predictions = self.fc_out(decoder_out)
        
        return predictions
    
    def _generate_causal_mask(self, sz):
        """Generate causal mask to prevent attending to future tokens."""
        mask = torch.triu(torch.ones(sz, sz), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask


# Test decoder
decoder = Decoder(len(vocab), EMBED_DIM, NUM_HEADS, NUM_LAYERS, 
                 DIM_FEEDFORWARD, DROPOUT).to(DEVICE)
test_captions = torch.randint(0, len(vocab), (2, 20)).to(DEVICE)
test_predictions = decoder(test_captions, test_features)

print(f"‚úÖ Decoder created")
print(f"   Input captions shape: {test_captions.shape}")
print(f"   Input features shape: {test_features.shape}")
print(f"   Output predictions shape: {test_predictions.shape}")
print(f"   Expected: (2, 20, {len(vocab)})")

## Cell 8: Training

Training loop with teacher forcing

In [None]:
class ImageCaptioningModel(nn.Module):
    """Complete image captioning model."""
    
    def __init__(self, vocab_size, embed_dim=256, num_heads=8, 
                 num_layers=3, dim_feedforward=1024, dropout=0.1):
        super().__init__()
        self.encoder = Encoder(embed_dim)
        self.decoder = Decoder(vocab_size, embed_dim, num_heads, 
                              num_layers, dim_feedforward, dropout)
    
    def forward(self, images, captions):
        encoder_out = self.encoder(images)
        predictions = self.decoder(captions, encoder_out)
        return predictions


# Create model
model = ImageCaptioningModel(len(vocab), EMBED_DIM, NUM_HEADS, 
                            NUM_LAYERS, DIM_FEEDFORWARD, DROPOUT).to(DEVICE)

# Loss and optimizer
pad_idx = vocab.word2idx[vocab.pad_token]
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print(f"‚úÖ Model created")
print(f"   Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"   Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Training loop
print("\n" + "="*70)
print("Training Started")
print("="*70)

train_losses = []

for epoch in range(1, NUM_EPOCHS + 1):
    model.train()
    epoch_loss = 0
    
    progress_bar = tqdm(dataloader, desc=f'Epoch {epoch}/{NUM_EPOCHS}')
    
    for images, captions, lengths in progress_bar:
        images = images.to(DEVICE)
        captions = captions.to(DEVICE)
        
        # Teacher forcing: input all except last, target all except first
        decoder_input = captions[:, :-1]
        targets = captions[:, 1:]
        
        # Forward pass
        predictions = model(images, decoder_input)
        
        # Calculate loss
        batch_size, seq_len, vocab_size = predictions.shape
        loss = criterion(predictions.reshape(-1, vocab_size), targets.reshape(-1))
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()
        
        epoch_loss += loss.item()
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    avg_loss = epoch_loss / len(dataloader)
    train_losses.append(avg_loss)
    
    print(f"Epoch {epoch}/{NUM_EPOCHS} - Average Loss: {avg_loss:.4f}")

print("\n‚úÖ Training complete!")

# Save model
torch.save({
    'model_state_dict': model.state_dict(),
    'vocab': vocab,
    'config': {
        'embed_dim': EMBED_DIM,
        'num_heads': NUM_HEADS,
        'num_layers': NUM_LAYERS,
        'dim_feedforward': DIM_FEEDFORWARD,
        'dropout': DROPOUT
    }
}, '/content/image_captioning_model.pth')

print("‚úÖ Model saved to /content/image_captioning_model.pth")

# Plot training loss
plt.figure(figsize=(10, 5))
plt.plot(range(1, NUM_EPOCHS + 1), train_losses, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True)
plt.show()

## Cell 9: Evaluation & Caption Generation

Generate captions and compute BLEU scores

In [None]:
def generate_caption(model, image, vocab, max_length=50, device='cuda'):
    """Generate caption using greedy decoding."""
    model.eval()
    
    with torch.no_grad():
        # Encode image
        encoder_out = model.encoder(image)
        
        # Start with <start> token
        caption = [vocab.word2idx[vocab.start_token]]
        
        for _ in range(max_length):
            caption_tensor = torch.LongTensor(caption).unsqueeze(0).to(device)
            predictions = model.decoder(caption_tensor, encoder_out)
            predicted_idx = predictions[0, -1, :].argmax().item()
            caption.append(predicted_idx)
            
            if predicted_idx == vocab.word2idx[vocab.end_token]:
                break
        
        # Convert to words
        caption_words = []
        for idx in caption[1:-1]:  # Skip <start> and <end>
            word = vocab.idx2word.get(idx, vocab.unk_token)
            if word != vocab.pad_token:
                caption_words.append(word)
    
    return caption_words


def calculate_bleu(reference, candidate):
    """Calculate BLEU-4 score."""
    from collections import Counter
    
    # Calculate n-gram precisions
    precisions = []
    for n in range(1, 5):
        ref_ngrams = Counter([tuple(reference[i:i+n]) for i in range(len(reference)-n+1)])
        cand_ngrams = Counter([tuple(candidate[i:i+n]) for i in range(len(candidate)-n+1)])
        
        if len(cand_ngrams) == 0:
            precisions.append(0)
            continue
        
        clipped = sum(min(cand_ngrams[ng], ref_ngrams[ng]) for ng in cand_ngrams)
        total = sum(cand_ngrams.values())
        precisions.append(clipped / total if total > 0 else 0)
    
    # Brevity penalty
    bp = math.exp(1 - len(reference) / len(candidate)) if len(candidate) < len(reference) else 1
    
    # BLEU score
    if min(precisions) > 0:
        bleu = bp * math.exp(sum(math.log(p) for p in precisions) / 4)
    else:
        bleu = 0
    
    return bleu


def denormalize_image(image_tensor):
    """Denormalize image for display."""
    mean = torch.tensor(IMAGE_MEAN).view(3, 1, 1)
    std = torch.tensor(IMAGE_STD).view(3, 1, 1)
    img = image_tensor * std + mean
    img = img.permute(1, 2, 0).numpy()
    return np.clip(img, 0, 1)


# Load model
checkpoint = torch.load('/content/image_captioning_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print("‚úÖ Model loaded for evaluation")

## Cell 10: Results

Display qualitative results with images, captions, and BLEU scores

In [None]:
print("="*70)
print("EVALUATION RESULTS")
print("="*70)

# Evaluate on 5 samples
num_samples = 5
bleu_scores = []

for i in range(num_samples):
    # Get sample
    image, caption, length = dataset[i]
    
    # Generate caption
    image_batch = image.unsqueeze(0).to(DEVICE)
    generated_words = generate_caption(model, image_batch, vocab, 50, DEVICE)
    
    # Get ground truth
    gt_words = []
    for idx in caption.tolist():
        word = vocab.idx2word[idx]
        if word == vocab.end_token:
            break
        if word not in [vocab.start_token, vocab.pad_token]:
            gt_words.append(word)
    
    # Calculate BLEU
    bleu = calculate_bleu(gt_words, generated_words)
    bleu_scores.append(bleu)
    
    # Display
    img = denormalize_image(image)
    
    plt.figure(figsize=(12, 8))
    plt.imshow(img)
    plt.axis('off')
    
    title = f"Sample {i+1}\n\n"
    title += f"Generated Caption:\n{' '.join(generated_words)}\n\n"
    title += f"Ground Truth Caption:\n{' '.join(gt_words)}\n\n"
    title += f"BLEU-4 Score: {bleu:.4f}"
    
    plt.title(title, fontsize=14, pad=20, wrap=True)
    plt.tight_layout()
    plt.show()
    
    print(f"\nSample {i+1}:")
    print(f"  Generated:    {' '.join(generated_words)}")
    print(f"  Ground Truth: {' '.join(gt_words)}")
    print(f"  BLEU-4:       {bleu:.4f}")
    print("-"*70)

# Summary statistics
avg_bleu = sum(bleu_scores) / len(bleu_scores)

print("\n" + "="*70)
print("FINAL RESULTS")
print("="*70)
print(f"Samples evaluated: {num_samples}")
print(f"Average BLEU-4 Score: {avg_bleu:.4f}")
print(f"Min BLEU-4: {min(bleu_scores):.4f}")
print(f"Max BLEU-4: {max(bleu_scores):.4f}")

# Interpretation
print("\nüìä Performance Interpretation:")
if avg_bleu > 0.30:
    print("   ‚≠ê‚≠ê‚≠ê Excellent - High quality captions!")
elif avg_bleu > 0.20:
    print("   ‚≠ê‚≠ê Good - Acceptable caption quality")
elif avg_bleu > 0.10:
    print("   ‚≠ê Fair - Needs more training")
else:
    print("   Needs improvement - Train longer or with more data")

print("\n‚úÖ Evaluation complete!")
print("\nüí° Tips for improvement:")
print("   - Increase MAX_SAMPLES for more training data")
print("   - Increase NUM_EPOCHS for longer training")
print("   - Use beam search instead of greedy decoding")
print("   - Fine-tune the encoder")

print("\n" + "="*70)
print("NOTEBOOK EXECUTION COMPLETE")
print("="*70)