# 🖼️ Image Captioning 

This notebook demonstrates the end-to-end workflow for the CNN + Attention LSTM image captioning project contained in this repository.

It will: 
1. Inspect environment & dependencies
2. (Optionally) locate / download dataset (Kaggle mini COCO) or fall back to a tiny synthetic example
3. Build / load vocabulary
4. Construct the encoder-decoder model and load an existing checkpoint if available
5. (Optional) Run a super-light illustrative training micro-step (on 1–2 batches)
6. Run inference to generate captions for sample images
7. Provide next steps & improvement ideas

> Designed to be resilient: if the full dataset isn't present or Kaggle credentials aren't configured, the notebook still runs using a synthetic mini dataset so you can exercise the pipeline quickly.

## 1. Environment & Imports

In [None]:
import os, sys, json, random, math, time, pathlib
from pathlib import Path

import torch
import torchvision
from PIL import Image

print('Python:', sys.version)
print('Torch version:', torch.__version__)
print('Torchvision version:', torchvision.__version__)
print('CUDA available:', torch.cuda.is_available())
print('MPS available:', torch.backends.mps.is_available() if hasattr(torch.backends, 'mps') else False)

# Add project root to path (in case notebook run from a different working dir)
PROJECT_ROOT = Path(__file__).resolve().parent if '__file__' in globals() else Path.cwd()
if (PROJECT_ROOT / 'image_captioning').exists():
    sys.path.append(str(PROJECT_ROOT))
print('Project root:', PROJECT_ROOT)

## 2.  Download Dataset
This project expects a COCO-style (or simplified list) captions JSON and an `images/` directory.

If you already trained via `scripts/train.py`, artifacts should be present under `artifacts/`.
**Kaggle Download (Optional):** In a standard environment you would run the CLI script:
```bash
python scripts/train.py --epochs 1 --no-pretrained --batch-size 8
```
For portability (e.g. in hosted notebook environments without Kaggle credentials), we skip automatic download here and attempt to reuse existing local data.

In [3]:
if 'PROJECT_ROOT' not in globals():
    from pathlib import Path
    PROJECT_ROOT = Path.cwd()

DATA_ROOT_CANDIDATES = [
    PROJECT_ROOT / 'data',
    PROJECT_ROOT / 'dataset',
    PROJECT_ROOT / 'mini_coco',
    PROJECT_ROOT / 'kaggle',
] + [p for p in (PROJECT_ROOT).glob('*') if p.is_dir() and 'coco' in p.name.lower()]

captions_file = None
images_dir = None
for root in DATA_ROOT_CANDIDATES:
    if not root.exists():
        continue
    cand_caps = list(root.glob('**/captions*.json'))
    cand_imgs = [p for p in root.glob('**/images') if p.is_dir()]
    if cand_caps and cand_imgs:
        captions_file = cand_caps[0]
        images_dir = cand_imgs[0]
        break

if captions_file and images_dir:
    print('Found dataset:')
    print('  Captions JSON:', captions_file)
    print('  Images dir  :', images_dir)
else:
    print(' dataset found; will synthesize a tiny in-memory dataset for demonstration.')

 dataset found; will synthesize a tiny in-memory dataset for demonstration.


## 3. Vocabulary Build / Load
If a vocabulary already exists in `artifacts/vocab.json` we load it. Otherwise we build one from either the located captions file or a synthetic fallback set.

In [None]:
from image_captioning.utils.tokenizer import Vocabulary, tokenize

artifacts_dir = PROJECT_ROOT / 'artifacts'
vocab_path = artifacts_dir / 'vocab.json'
vocab = None

def build_vocab_from_captions(captions_list, min_freq=1, max_size=10000):
    freqs = {}
    for c in captions_list:
        for tok in tokenize(c):
            freqs[tok] = freqs.get(tok, 0) + 1
    # Sort by frequency desc then alpha
    tokens = [t for t,_ in sorted(freqs.items(), key=lambda x: (-x[1], x[0])) if freqs[t] >= min_freq]
    if max_size: tokens = tokens[:max_size]
    return Vocabulary.from_tokens(tokens)

if vocab_path.exists():
    vocab = Vocabulary.load(vocab_path)
    print('Loaded existing vocabulary with size:', len(vocab))
else:
    captions_list = []
    if 'captions_file' in globals() and captions_file and captions_file.exists():
        try:
            payload = json.loads(captions_file.read_text())
            if isinstance(payload, dict) and 'annotations' in payload:
                annotations = payload['annotations']
            elif isinstance(payload, list):
                annotations = payload
            else:
                annotations = []
            for ann in annotations[:2000]:  # sample subset
                cap = ann.get('caption') or ann.get('text') or ''
                if cap:
                    captions_list.append(cap)
        except Exception as e:
            print('Failed to parse real captions, fallback to synthetic. Error:', e)
    if not captions_list:
        captions_list = [
            'a dog playing with a ball',
            'a cat sitting on the mat',
            'a child riding a red bicycle',
            'a group of people hiking a mountain trail',
            'a plate of fresh colorful fruit'
        ]
    vocab = build_vocab_from_captions(captions_list)
    artifacts_dir.mkdir(exist_ok=True, parents=True)
    vocab.save(vocab_path)
    print('Built vocabulary size:', len(vocab))

# Quick inspection
print('Sample tokens:', list(vocab.token_to_idx.keys())[:20])

## 4. Create / Load Model

In [None]:
from image_captioning.config import ModelConfig
from image_captioning.models.encoder_decoder import EncoderDecoder
device = torch.device('cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
model_cfg = ModelConfig(vocab_size=len(vocab), embed_dim=256, hidden_dim=512, attention_dim=256, num_layers=1, dropout=0.1, use_pretrained=False)
model = EncoderDecoder(model_cfg, vocab).to(device)

checkpoint_dir = artifacts_dir / 'checkpoints'
loaded_checkpoint = None
if checkpoint_dir.exists():
    # Pick latest epoch numerically
    checkpoints = sorted(checkpoint_dir.glob('model_epoch_*.pt'))
    if checkpoints:
        latest = checkpoints[-1]
        try:
            state = torch.load(latest, map_location=device)
            if isinstance(state, dict) and 'model_state' in state:
                model.load_state_dict(state['model_state'])
            else:
                model.load_state_dict(state)
            loaded_checkpoint = latest
            print(f'Loaded checkpoint: {latest.name}')
        except Exception as e:
            print('Could not load checkpoint:', e)
else:
    print('No checkpoints directory found; using fresh model.')

print('Model parameters:', sum(p.numel() for p in model.parameters())//1000, 'K')

## 5. Micro Training (Illustrative)
Skips entirely if no real dataset is present. Trains on 1–2 batches only to demonstrate the API (won't produce quality captions).

In [None]:
from torch import nn
from torch.utils.data import DataLoader
import torchvision.transforms as T
from image_captioning.data.dataset import CocoCaptionsDataset, CaptionExample, collate_fn

def build_demo_dataloader(max_items=32):
    if captions_file and images_dir and captions_file.exists() and images_dir.exists():
        try:
            ds = CocoCaptionsDataset(captions_file=str(captions_file), images_root=str(images_dir), vocab=vocab, transform=T.Compose([T.Resize((224,224)), T.ToTensor(), T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])]))
            if len(ds) > max_items:
                # Subsample deterministically
                indices = list(range(len(ds)))[:max_items]
                ds.examples = [ds.examples[i] for i in indices]
            return DataLoader(ds, batch_size=4, shuffle=True, collate_fn=collate_fn)
        except Exception as e:
            print('Failed constructing real dataloader:', e)
    # Fallback synthetic dataset (solid color images in-memory)
    examples = []
    colors = [(255,0,0), (0,255,0), (0,0,255), (200,200,0)]
    captions = ['a red square', 'a green square', 'a blue square', 'a yellow square']
    for i,(c,cap) in enumerate(zip(colors, captions)):
        img_path = artifacts_dir / f'synthetic_{i}.png'
        if not img_path.exists():
            im = Image.new('RGB', (224,224), c)
            im.save(img_path)
        examples.append(CaptionExample(str(img_path), cap))
    ds = CocoCaptionsDataset(captions_file=None, images_root=str(artifacts_dir), vocab=vocab, transform=T.Compose([T.Resize((224,224)), T.ToTensor(), T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])]))
    ds.examples = examples
    return DataLoader(ds, batch_size=2, shuffle=True, collate_fn=collate_fn)

micro_loader = build_demo_dataloader()
criterion = nn.CrossEntropyLoss(ignore_index=vocab.token_to_idx.get('<pad>', 0))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

model.train()
max_steps = 2
step = 0
losses = []
for batch in micro_loader:
    images = batch['images'].to(device)
    captions = batch['captions'].to(device)
    targets = captions[:,1:]  # shift
    optimizer.zero_grad()
    outputs = model(images, captions[:,:-1])  # predict next token
    outputs = outputs.reshape(-1, outputs.size(-1))
    loss = criterion(outputs, targets.reshape(-1))
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
    optimizer.step()
    losses.append(loss.item())
    step += 1
    print(f'Step {step} loss: {loss.item():.3f}')
    if step >= max_steps: break
print('Micro training complete. Avg loss:', sum(losses)/len(losses))
model.eval()

## 6. Inference / Caption Generation

In [None]:
from image_captioning.inference.service import CaptionGenerator
import matplotlib.pyplot as plt

generator = CaptionGenerator(model=model, vocab=vocab, device=device)

# Pick an image: prefer real dataset image, else synthetic
candidate_images = []
if 'images_dir' in globals() and images_dir and images_dir.exists():
    candidate_images = list(images_dir.glob('*.jpg')) + list(images_dir.glob('*.png'))
if not candidate_images:
    candidate_images = list(artifacts_dir.glob('synthetic_*.png'))
assert candidate_images, 'No images found to caption.'
sample_image = candidate_images[0]
print('Using image:', sample_image)
caption = generator.caption_image(str(sample_image), max_len=20)
print('Generated caption:', caption)

# Display image & caption
img_display = Image.open(sample_image).convert('RGB')
plt.figure(figsize=(4,4))
plt.imshow(img_display)
plt.axis('off')
plt.title(caption)
plt.show()

## 7. Next Steps & Improvements
**Potential Enhancements:**
- Implement beam search decoding for higher-quality captions.
- Add evaluation metrics (BLEU, CIDEr, ROUGE, METEOR).
- Increase training epochs & optionally enable pretrained encoder weights.
- Add scheduled sampling or label smoothing to improve robustness.
- Visualize attention maps over image regions per generated token.
- Experiment with transformer-based decoder for comparison.


```
