# Image Captioning: Data Exploration and Training

This notebook explores the Flickr8k dataset and trains the image captioning model.

## 1. Setup

In [None]:
import sys
from pathlib import Path

# Add project root to path
project_root = Path().resolve().parent
sys.path.insert(0, str(project_root))

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from PIL import Image
from collections import Counter

from src.utils import load_config, get_device, parse_captions_file, set_seed
from src.vocab import Vocabulary
from src.dataset import build_vocab_from_dataloader, get_dataloaders, get_transforms
from src.model import ImageCaptioningModel
from src.train import train

%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')

In [None]:
# Load configuration
config = load_config("../configs/config.yaml")
print("Configuration loaded successfully!")
print(f"Batch size: {config['data']['batch_size']}")
print(f"Vocab frequency threshold: {config['data']['freq_threshold']}")
print(f"Max caption length: {config['data']['max_caption_length']}")

In [None]:
# Set device and seed
device = get_device()
set_seed(42)
print(f"Using device: {device}")

## 2. Data Exploration

In [None]:
# Load captions
captions_df = parse_captions_file(config["data"]["captions_file"])
print(f"Loaded {len(captions_df)} captions")
captions_df.head(10)

In [None]:
# Dataset statistics
num_unique_images = captions_df['image'].nunique()
total_captions = len(captions_df)
captions_per_image = captions_df.groupby('image').size()

print("=" * 50)
print("Dataset Statistics")
print("=" * 50)
print(f"Number of unique images: {num_unique_images}")
print(f"Total number of captions: {total_captions}")
print(f"\nCaptions per image distribution:")
print(captions_per_image.value_counts().sort_index())
print(f"\nMean captions per image: {captions_per_image.mean():.2f}")
print(f"Median captions per image: {captions_per_image.median():.0f}")

In [None]:
# Caption length distribution
captions_df['caption_length'] = captions_df['caption'].apply(lambda x: len(x.split()))

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.hist(captions_df['caption_length'], bins=30, edgecolor='black', alpha=0.7)
plt.xlabel('Caption Length (words)')
plt.ylabel('Frequency')
plt.title('Distribution of Caption Lengths')
plt.axvline(captions_df['caption_length'].mean(), color='red', linestyle='--', 
            label=f"Mean: {captions_df['caption_length'].mean():.1f}")
plt.legend()

plt.subplot(1, 2, 2)
plt.boxplot(captions_df['caption_length'])
plt.ylabel('Caption Length (words)')
plt.title('Caption Length Box Plot')

plt.tight_layout()
plt.show()

print(f"\nCaption length statistics:")
print(f"  Mean: {captions_df['caption_length'].mean():.2f}")
print(f"  Median: {captions_df['caption_length'].median():.0f}")
print(f"  Min: {captions_df['caption_length'].min()}")
print(f"  Max: {captions_df['caption_length'].max()}")

In [None]:
# Word frequency distribution
all_words = []
for caption in captions_df['caption']:
    words = caption.lower().split()
    all_words.extend(words)

word_counts = Counter(all_words)
most_common_words = word_counts.most_common(50)

words, counts = zip(*most_common_words)

plt.figure(figsize=(15, 6))
plt.bar(range(len(words)), counts, alpha=0.7)
plt.xticks(range(len(words)), words, rotation=45, ha='right')
plt.xlabel('Words')
plt.ylabel('Frequency')
plt.title('Top 50 Most Frequent Words')
plt.tight_layout()
plt.show()

print(f"\nTotal unique words: {len(word_counts)}")
print(f"Total word occurrences: {len(all_words)}")
print(f"\nTop 10 words:")
for word, count in most_common_words[:10]:
    print(f"  {word}: {count}")

In [None]:
# Display random images with their captions
image_dir = Path(config["data"]["image_dir"])

# Sample 5 random images
random_images = captions_df['image'].unique()[:5]

fig, axes = plt.subplots(1, 5, figsize=(20, 5))

for idx, img_name in enumerate(random_images):
    # Load image
    img_path = image_dir / img_name
    if img_path.exists():
        img = Image.open(img_path).convert('RGB')
        axes[idx].imshow(img)
        axes[idx].axis('off')
        
        # Get captions for this image
        img_captions = captions_df[captions_df['image'] == img_name]['caption'].tolist()
        
        # Create title with captions
        title = f"{img_name}\n"
        for i, cap in enumerate(img_captions[:3], 1):  # Show first 3 captions
            title += f"{i}. {cap[:30]}...\n"
        axes[idx].set_title(title, fontsize=8)

plt.tight_layout()
plt.show()

# Print all captions for one image
print(f"\nAll captions for {random_images[0]}:")
for i, cap in enumerate(captions_df[captions_df['image'] == random_images[0]]['caption'], 1):
    print(f"  {i}. {cap}")

## 3. Vocabulary Analysis

In [None]:
# Build vocabulary
print("Building vocabulary...")
vocab = build_vocab_from_dataloader(captions_df, config["data"]["freq_threshold"])

print(f"\nVocabulary size: {len(vocab)}")
print(f"Frequency threshold: {vocab.freq_threshold}")
print(f"\nSpecial tokens:")
print(f"  <PAD>: {vocab.stoi['<PAD>']}")
print(f"  <SOS>: {vocab.stoi['<SOS>']}")
print(f"  <EOS>: {vocab.stoi['<EOS>']}")
print(f"  <UNK>: {vocab.stoi['<UNK>']}")

In [None]:
# Coverage statistics
# Count how many words are mapped to <UNK>
unk_count = 0
total_tokens = 0

for caption in captions_df['caption']:
    tokens = caption.lower().split()
    total_tokens += len(tokens)
    for token in tokens:
        if token not in vocab.stoi:
            unk_count += 1

coverage = (1 - unk_count / total_tokens) * 100

print(f"Vocabulary Coverage:")
print(f"  Total tokens: {total_tokens}")
print(f"  UNK tokens: {unk_count}")
print(f"  Coverage: {coverage:.2f}%")

In [None]:
# Example numericalize/denumericalize
test_captions = [
    "a dog is running in the park",
    "two people walking on the beach",
    "a cat sitting on a couch",
]

print("Numericalize/Denumericalize Examples:")
print("=" * 80)

for caption in test_captions:
    # Numericalize
    indices = vocab.numericalize(caption)
    
    # Denumericalize
    reconstructed = vocab.denumericalize(indices)
    
    print(f"\nOriginal:      {caption}")
    print(f"Indices:       {indices}")
    print(f"Reconstructed: {reconstructed}")
    print(f"Match: {caption == reconstructed}")

## 4. Dataset Verification

In [None]:
# Create dataloaders
print("Creating dataloaders...")
train_loader, val_loader, test_loader = get_dataloaders(config, vocab)

print(f"\nDataLoader Statistics:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")
print(f"  Batch size: {config['data']['batch_size']}")

In [None]:
# Get one batch
train_iter = iter(train_loader)
images, captions, image_names = next(train_iter)

print(f"Batch shapes:")
print(f"  Images: {images.shape}  # (batch, channels, height, width)")
print(f"  Captions: {captions.shape}  # (seq_len, batch) - time-first format")
print(f"  Image names: {len(image_names)}")

print(f"\nImage statistics:")
print(f"  Mean: {images.mean():.3f}")
print(f"  Std: {images.std():.3f}")
print(f"  Min: {images.min():.3f}")
print(f"  Max: {images.max():.3f}")
print(f"  (Images are normalized with ImageNet stats)")

In [None]:
# Visualize batch
# Denormalize images for visualization
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

def denormalize(img_tensor):
    """Denormalize image tensor for visualization."""
    img = img_tensor * std + mean
    img = torch.clamp(img, 0, 1)
    return img

# Display 4 images with captions
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
axes = axes.ravel()

for idx in range(4):
    # Denormalize and convert to numpy
    img = denormalize(images[idx]).permute(1, 2, 0).numpy()
    
    # Get caption (transpose to batch-first)
    caption_indices = captions[:, idx].tolist()
    caption_text = vocab.denumericalize(caption_indices)
    
    # Display
    axes[idx].imshow(img)
    axes[idx].axis('off')
    axes[idx].set_title(f"{image_names[idx]}\n{caption_text}", fontsize=10)

plt.tight_layout()
plt.show()

## 5. Model Summary

In [None]:
# Create model
print("Creating model...")
model = ImageCaptioningModel.create_from_config(config, vocab_size=len(vocab))
model = model.to(device)

print("\nModel created successfully!")
print(f"Device: {device}")

In [None]:
# Model architecture summary
print("=" * 80)
print("Model Architecture")
print("=" * 80)

print("\nEncoder (CNN):")
print(f"  Backbone: {config['model']['encoder_backbone']}")
print(f"  Output: (batch, 49, {config['model']['embed_size']})")
print(f"  Features from: 7x7 spatial grid")

print("\nDecoder (Transformer):")
print(f"  Embed size: {config['model']['embed_size']}")
print(f"  Num heads: {config['model']['num_heads']}")
print(f"  Num layers: {config['model']['num_layers']}")
print(f"  Dropout: {config['model']['dropout']}")
print(f"  Vocab size: {len(vocab)}")

In [None]:
# Parameter counts
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
frozen_params = total_params - trainable_params

encoder_params = sum(p.numel() for p in model.encoder.parameters())
decoder_params = sum(p.numel() for p in model.decoder.parameters())

print("=" * 80)
print("Parameter Counts")
print("=" * 80)

print(f"\nTotal parameters: {total_params:,}")
print(f"  Trainable: {trainable_params:,} ({trainable_params/total_params*100:.1f}%)")
print(f"  Frozen: {frozen_params:,} ({frozen_params/total_params*100:.1f}%)")

print(f"\nBy component:")
print(f"  Encoder: {encoder_params:,}")
print(f"  Decoder: {decoder_params:,}")

print(f"\nNote: Encoder is initially frozen and will be unfrozen at epoch {config['training']['unfreeze_encoder_epoch']}")

## 6. Training

**Note:** Training takes significant time. You can reduce `epochs` in config for testing.

In [None]:
# Train the model
# This will take a while depending on your hardware
print("Starting training...")
print("This will save checkpoints to: checkpoints/best_model.pt")
print()

best_model_path = train("../configs/config.yaml")

In [None]:
# Note: Loss curves are typically logged during training
# If you want to plot them, you can modify the train function to return loss history
# or use tensorboard/wandb for tracking

print("Training completed!")
print(f"Best model saved to: {best_model_path}")

## 7. Save Artifacts

In [None]:
# Check saved artifacts
checkpoint_dir = Path("../checkpoints")

print("=" * 80)
print("Saved Artifacts")
print("=" * 80)

# Check for saved files
best_model = checkpoint_dir / "best_model.pt"
vocab_file = checkpoint_dir / "vocab.pkl"

if best_model.exists():
    size_mb = best_model.stat().st_size / (1024 * 1024)
    print(f"✓ Best model: {best_model}")
    print(f"  Size: {size_mb:.2f} MB")
else:
    print(f"✗ Best model not found: {best_model}")

if vocab_file.exists():
    size_kb = vocab_file.stat().st_size / 1024
    print(f"\n✓ Vocabulary: {vocab_file}")
    print(f"  Size: {size_kb:.2f} KB")
else:
    print(f"\n✗ Vocabulary not found: {vocab_file}")

print("\n" + "=" * 80)
print("Next steps:")
print("  1. Run inference.ipynb to test the model")
print("  2. Run: python -m src.evaluate to get BLEU scores")
print("  3. Visualize results in the inference notebook")
print("=" * 80)