In [2]:
import os, sys, math, random, time, zipfile
from collections import Counter, defaultdict
from pathlib import Path
from typing import List, Tuple

import numpy as np
from PIL import Image
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models

import nltk
nltk.download('punkt')
from nltk.tokenize import word_tokenize
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

import warnings
warnings.filterwarnings("ignore")


  from .autonotebook import tqdm as notebook_tqdm


Device: cuda


[nltk_data] Downloading package punkt to C:\Users\Parshuram
[nltk_data]     Singh\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
zip_path = "flicker-8k-image-caption.zip"
extract_to = "flickr8k"

# If you already have files in Colab, skip extraction
if os.path.exists(extract_to):
    print("Found extracted folder:", extract_to)
else:
    if os.path.exists(zip_path):
        with zipfile.ZipFile(zip_path, 'r') as z:
            z.extractall(extract_to)
        print("Extracted zip to", extract_to)
    else:
        print("No zip found at", zip_path, "-- please upload or mount Drive.")

# Typical expected structure:
images_dir = os.path.join(extract_to, "Flicker8k_Dataset")  # common name
# also possible: /content/flickr8k/flickr8k/images or /content/flickr8k/images
if not os.path.exists(images_dir):
    # try other common locations
    for candidate in [os.path.join(extract_to, "images"),
                      os.path.join(extract_to, "flickr8k", "images"),
                      os.path.join(extract_to, "flickr8k","Flickr8k_Dataset")]:
        if os.path.exists(candidate):
            images_dir = candidate
            break

print("Images dir:", images_dir)
# captions file
possible_caps = [
    os.path.join(extract_to, "captions.txt"),
    os.path.join(extract_to, "Flickr8k_text", "Flickr8k.token.txt"),
    os.path.join(extract_to, "flickr8k", "captions.txt"),
    os.path.join(extract_to, "Flickr8k.token.txt")
]
captions_file = None
for p in possible_caps:
    if p and os.path.exists(p):
        captions_file = p
        break
print("Captions file:", captions_file)
if captions_file is None:
    raise FileNotFoundError("Could not locate captions file; please adjust paths.")


Found extracted folder: flickr8k
Images dir: flickr8k\flickr8k\images
Captions file: flickr8k\flickr8k\captions.txt


In [4]:
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt to C:\Users\Parshuram
[nltk_data]     Singh\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to C:\Users\Parshuram
[nltk_data]     Singh\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [5]:
def clean_caption(s: str) -> str:
    """Clean caption text - NO special tokens added here"""
    s = s.lower().strip()
    allowed = set("abcdefghijklmnopqrstuvwxyz0123456789 '")
    s = ''.join(ch if ch in allowed else ' ' for ch in s)
    s = ' '.join(s.split())
    return s

# Parse captions - DON'T add special tokens yet
image2caps = defaultdict(list)
with open(captions_file, 'r', encoding='utf-8') as f:
    for line in f:
        line = line.strip()
        if not line:
            continue
        if '\t' in line:
            left, cap = line.split('\t', 1)
        else:
            parts = line.split()
            left = parts[0]
            cap = ' '.join(parts[1:])
        img = left.split('#')[0]
        cap = clean_caption(cap)
        #  Store clean caption WITHOUT special tokens
        image2caps[img].append(cap)

print(f"Loaded captions for {len(image2caps)} images")

# FIXED Vocabulary class
class Vocabulary:
    def __init__(self, min_freq=2, max_size=10000):  # Increased from 8000
        self.min_freq = min_freq
        self.max_size = max_size
        self.counter = Counter()
        self.itos = []
        self.stoi = {}
        self.specials = ["<pad>", "<start>", "<end>", "<unk>"]

    def build(self, captions: List[str]):
        for c in captions:
            toks = word_tokenize(c.lower())
            self.counter.update(toks)

        # Get most common words (excluding very rare ones)
        freq_items = [
            w for w, cnt in self.counter.most_common(self.max_size)
            if cnt >= self.min_freq
        ]

        self.itos = list(self.specials) + freq_items
        self.stoi = {w: i for i, w in enumerate(self.itos)}

        # Print vocabulary statistics
        print(f"\nüìä Vocabulary Statistics:")
        print(f"  Total unique words in dataset: {len(self.counter)}")
        print(f"  Words included in vocab: {len(self.itos)}")
        print(f"  Words with freq >= {self.min_freq}: {len(freq_items)}")
        print(f"  Most common words: {self.itos[4:14]}")  # Skip special tokens

        # Check for important words
        important_words = ['wearing', 'walking', 'standing', 'holding', 'uniform',
                          'climbing', 'running', 'sitting', 'playing', 'looking']
        missing = [w for w in important_words if w not in self.stoi]
        if missing:
            print(f"  ‚ö†Ô∏è Missing important words: {missing}")

    def __len__(self):
        return len(self.itos)

# Build vocab from clean captions (no special tokens in text)
all_caps = []
for caps in image2caps.values():
    all_caps.extend(caps)

vocab = Vocabulary(min_freq=2, max_size=12000)
vocab.build(all_caps)

Loaded captions for 21672 images

üìä Vocabulary Statistics:
  Total unique words in dataset: 8411
  Words included in vocab: 5117
  Words with freq >= 2: 5113
  Most common words: ['a', 'in', 'the', 'on', 'is', 'and', 'dog', 'with', 'man', 'of']


In [6]:
def diagnose_vocabulary(vocab, val_ds, n_samples=100):
    """Check which words are being mapped to <unk>"""
    print("\nüîç Vocabulary Diagnosis:\n")

    missing_words = Counter()
    total_words = 0

    # Sample captions and check coverage
    for i in range(min(n_samples, len(val_ds))):
        try:
            if isinstance(val_ds, torch.utils.data.Subset):
                idx = val_ds.indices[i]
                imgpath, cap = val_ds.dataset.samples[idx]
            elif hasattr(val_ds, 'samples'):
                imgpath, cap = val_ds.samples[i]
            else:
                imgpath, cap = val_ds.dataset.samples[i]

            if isinstance(cap, str):
                words = word_tokenize(cap.lower())
            else:
                continue

            for word in words:
                total_words += 1
                if word not in vocab.stoi:
                    missing_words[word] += 1
        except:
            continue

    coverage = (total_words - sum(missing_words.values())) / total_words * 100

    print(f"Vocabulary Coverage: {coverage:.2f}%")
    print(f"Total words checked: {total_words}")
    print(f"Unique missing words: {len(missing_words)}")

    if missing_words:
        print(f"\nTop 20 missing words:")
        for word, count in missing_words.most_common(20):
            print(f"  '{word}': {count} times")

    return coverage, missing_words

In [7]:
def caption_to_indices(caption: str, vocab: Vocabulary, max_len: int):
    """
    Convert caption to indices with proper special token handling
    Format: [<start>, word1, word2, ..., <end>, <pad>, <pad>, ...]
    """
    toks = word_tokenize(caption.lower())

    # Add start token, caption tokens, end token
    idxs = [vocab.stoi["<start>"]]
    for t in toks:
        idxs.append(vocab.stoi.get(t, vocab.stoi["<unk>"]))
    idxs.append(vocab.stoi["<end>"])

    # Pad or truncate
    if len(idxs) < max_len:
        idxs += [vocab.stoi["<pad>"]] * (max_len - len(idxs))
    else:
        idxs = idxs[:max_len-1] + [vocab.stoi["<end>"]]  # Ensure <end> token

    return idxs

# Determine max length
max_len = max(len(word_tokenize(c)) for c in all_caps) + 2  # +2 for <start> and <end>
MAX_LEN = min(max_len, 22)  # Cap at reasonable length
print(f"Using MAX_LEN = {MAX_LEN}")

Using MAX_LEN = 22


In [8]:

!ls /flickr8k
!ls /flickr8k/flickr8k/images | head -5
!ls /flickr8k/flickr8k/captions.txt


'ls' is not recognized as an internal or external command,
operable program or batch file.
'ls' is not recognized as an internal or external command,
operable program or batch file.
'ls' is not recognized as an internal or external command,
operable program or batch file.


In [9]:

import pandas as pd
import os

captions_file = "flickr8k/flickr8k/captions.txt"
images_dir = "flickr8k/flickr8k/images"

# Read CSV
df = pd.read_csv(captions_file)
print(df.head())

# Build dictionary {image_name: [caption1, caption2, ...]}
image2caps = {}
for _, row in df.iterrows():
    image_name = row["image"].strip()
    caption = row["caption"].strip()
    image2caps.setdefault(image_name, []).append(caption)

print(f" Loaded {len(image2caps)} unique images with captions.")

# Check a sample entry
sample_key = list(image2caps.keys())[0]
print(f"Sample image: {sample_key}")
print(f"Captions for it: {image2caps[sample_key]}")


                       image  \
0  1000268201_693b08cb0e.jpg   
1  1000268201_693b08cb0e.jpg   
2  1000268201_693b08cb0e.jpg   
3  1000268201_693b08cb0e.jpg   
4  1000268201_693b08cb0e.jpg   

                                             caption  
0  A child in a pink dress is climbing up a set o...  
1              A girl going into a wooden building .  
2   A little girl climbing into a wooden playhouse .  
3  A little girl climbing the stairs to her playh...  
4  A little girl in a pink dress going into a woo...  
 Loaded 8091 unique images with captions.
Sample image: 1000268201_693b08cb0e.jpg
Captions for it: ['A child in a pink dress is climbing up a set of stairs in an entry way .', 'A girl going into a wooden building .', 'A little girl climbing into a wooden playhouse .', 'A little girl climbing the stairs to her playhouse .', 'A little girl in a pink dress going into a wooden cabin .']


In [10]:
class CaptionDataset(Dataset):
    def __init__(self, image2caps, images_dir, vocab: Vocabulary, max_len=MAX_LEN, transform=None):
        self.vocab = vocab
        self.max_len = max_len
        self.transform = transform
        self.samples = []

        for img, caps in image2caps.items():
            imgpath = os.path.join(images_dir, img)
            if not os.path.exists(imgpath):
                continue
            # Store path and caption (will add special tokens during indexing)
            for c in caps:
                self.samples.append((imgpath, c))

        random.shuffle(self.samples)
        print(f"Dataset created with {len(self.samples)} samples")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        imgpath, caption = self.samples[idx]
        image = Image.open(imgpath).convert("RGB")
        if self.transform:
            image = self.transform(image)

        # Convert caption to indices (adds <start>, <end>, <pad>)
        caps_idx = caption_to_indices(caption, self.vocab, self.max_len)
        return image, torch.tensor(caps_idx, dtype=torch.long)

# transforms
train_transform = T.Compose([
    T.Resize((256,256)),
    T.RandomHorizontalFlip(),
    T.RandomCrop(224),
    T.ColorJitter(0.1,0.1,0.1),
    T.ToTensor(),
    T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
val_transform = T.Compose([
    T.Resize((256,256)),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# collate_fn to stack and keep lengths (we have fixed padded len, but keep mask)
def collate_fn(batch):
    images = torch.stack([b[0] for b in batch], dim=0)
    captions = torch.stack([b[1] for b in batch], dim=0)
    # create mask where pad token index = 0
    pad_idx = vocab.stoi["<pad>"]
    mask = (captions != pad_idx)
    return images, captions, mask

# create dataset and loaders
dataset = CaptionDataset(image2caps, images_dir, vocab, max_len=MAX_LEN, transform=train_transform)
# split
n = len(dataset)
train_n = int(0.8 * n)
train_ds, val_ds = torch.utils.data.random_split(dataset, [train_n, n-train_n])
# override val transform image preprocessing
val_ds.dataset.transform = val_transform

print("\n" + "="*60)
print("Running Vocabulary Diagnosis...")
print("="*60)

coverage, missing_words = diagnose_vocabulary(vocab, val_ds, n_samples=200)

if coverage < 95:
    print("\nWARNING: Vocabulary coverage is low!")
    print("Consider:")
    print("  - Increasing max_size to 12000")
    print("  - Decreasing min_freq to 1")
    print("  - Or accept that rare words will be <unk>")
else:
    print("\n Vocabulary coverage is good!")


BATCH_SIZE = 32
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=True)
print("Train samples:", len(train_ds), "Val samples:", len(val_ds))


Dataset created with 40455 samples

Running Vocabulary Diagnosis...

üîç Vocabulary Diagnosis:

Vocabulary Coverage: 90.10%
Total words checked: 2272
Unique missing words: 30

Top 20 missing words:
  '.': 180 times
  ',': 16 times
  '``': 2 times
  'clowds': 1 times
  'ankle-deep': 1 times
  'tanktops': 1 times
  'models': 1 times
  'megaphone': 1 times
  'african-american': 1 times
  'tundra': 1 times
  'crotch': 1 times
  'formula': 1 times
  'crooswalk': 1 times
  'multi-colored': 1 times
  'goofing': 1 times
  'democrat': 1 times
  'supporters': 1 times
  'fton': 1 times
  'abs': 1 times
  'gothically': 1 times

Consider:
  - Increasing max_size to 12000
  - Decreasing min_freq to 1
  - Or accept that rare words will be <unk>
Train samples: 32364 Val samples: 8091


In [11]:
class EncoderCNNSpatial(nn.Module):
    def __init__(self, embed_dim=512, pretrained=True, train_backbone=False):
        super().__init__()
        resnet = models.resnet101(pretrained=pretrained)
        # remove last fc & avgpool so we have spatial map after layer4
        modules = list(resnet.children())[:-2]  # up to conv5_x -> outputs (B, 2048, H/32, W/32)
        self.backbone = nn.Sequential(*modules)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(2048, embed_dim)
        self.bn = nn.BatchNorm1d(embed_dim, momentum=0.01)

        # freeze backbone optionally
        if not train_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

    def forward(self, images):
        # images: (B,3,224,224)
        feat_map = self.backbone(images)  # (B, 2048, Hf, Wf)
        b, c, h, w = feat_map.shape
        # pooled vector
        pooled = self.avgpool(feat_map).view(b, c)
        embed = self.fc(pooled)  # (B, embed_dim)
        embed = self.bn(embed)
        # flatten spatial features for attention: (B, H*W, C)
        spatial = feat_map.view(b, c, -1).permute(0,2,1)  # (B, num_patches, C)
        return spatial, embed  # spatial used by attention, embed as initial hidden


In [12]:
class BahdanauAttention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super().__init__()
        self.enc_att = nn.Linear(encoder_dim, attention_dim)  # transform encoder
        self.dec_att = nn.Linear(decoder_dim, attention_dim)  # transform decoder hidden
        self.full_att = nn.Linear(attention_dim, 1)

    def forward(self, encoder_feats, decoder_hidden, mask=None):
        # encoder_feats: (B, num_patches, encoder_dim)
        # decoder_hidden: (B, decoder_dim)
        enc_proj = self.enc_att(encoder_feats)  # (B, N, attn_dim)
        dec_proj = self.dec_att(decoder_hidden).unsqueeze(1)  # (B,1,attn_dim)
        e = torch.tanh(enc_proj + dec_proj)  # (B,N,attn_dim)
        scores = self.full_att(e).squeeze(-1)  # (B,N)
        if mask is not None:
            # mask shape for spatial features - usually None (images always full)
            scores = scores.masked_fill(~mask, -1e9)
        alpha = F.softmax(scores, dim=1)  # (B,N)
        context = (encoder_feats * alpha.unsqueeze(-1)).sum(dim=1)  # (B, encoder_dim)
        return context, alpha

class DecoderWithAttention(nn.Module):
    def __init__(self, embed_dim, decoder_dim, vocab_size, encoder_dim=512,
                 attention_dim=512, dropout=0.5, padding_idx=0):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size

        self.attention = BahdanauAttention(encoder_dim, decoder_dim, attention_dim)
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
        self.dropout = nn.Dropout(p=dropout)
        self.lstm = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim)

        # Initialize hidden state from encoder pooled features
        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)

        # Output layer
        self.fc = nn.Linear(decoder_dim, vocab_size)

        # Optional: attention gating
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        self.sigmoid = nn.Sigmoid()

    def init_hidden_state(self, encoder_pooled):
        """Initialize LSTM hidden state from encoder output"""
        h = torch.tanh(self.init_h(encoder_pooled))
        c = torch.tanh(self.init_c(encoder_pooled))
        return h, c

    def forward(self, encoder_feats, encoder_pooled, captions, teacher_forcing_ratio=1.0):
        """
        Forward pass with proper teacher forcing
        Args:
            encoder_feats: (B, num_pixels, encoder_dim)
            encoder_pooled: (B, encoder_dim)
            captions: (B, max_len) - includes <start> at position 0
            teacher_forcing_ratio: probability of using ground truth
        Returns:
            outputs: (B, max_len, vocab_size) - predictions for positions 1 to max_len
        """
        batch_size = encoder_feats.size(0)
        seq_length = captions.size(1)

        # Initialize hidden state
        h, c = self.init_hidden_state(encoder_pooled)

        # Prepare outputs tensor
        outputs = torch.zeros(batch_size, seq_length, self.vocab_size,
                             device=encoder_feats.device)

        # Get embeddings for all caption tokens
        embeddings = self.embedding(captions)  # (B, seq_len, embed_dim)

        # Start with <start> token
        current_input = embeddings[:, 0, :]  # (B, embed_dim)

        # Generate predictions for positions 1 to seq_length
        for t in range(1, seq_length):
            # Attention
            context, alpha = self.attention(encoder_feats, h)

            # Optional gating
            if hasattr(self, 'f_beta'):
                gate = self.sigmoid(self.f_beta(h))
                context = gate * context

            # LSTM step
            lstm_input = torch.cat([current_input, context], dim=1)
            h, c = self.lstm(lstm_input, (h, c))

            # Predict next word
            output = self.fc(self.dropout(h))  # (B, vocab_size)
            outputs[:, t, :] = output

            # Teacher forcing decision
            use_teacher = random.random() < teacher_forcing_ratio
            if use_teacher and t < seq_length - 1:
                # Use ground truth
                current_input = embeddings[:, t, :]
            else:
                # Use prediction
                predicted_token = output.argmax(dim=-1)
                current_input = self.embedding(predicted_token)

        return outputs

    def sample_beam(self, encoder_feats, encoder_pooled, sos_idx, eos_idx, beam_size=3, max_len=20):
        # Beam search implementation (batch size = 1 only for simplicity)
        assert encoder_feats.size(0) == 1, "Beam search currently supports batch_size=1"
        device = encoder_feats.device
        # Flatten input for convenience
        enc = encoder_feats  # (1,N,enc_dim)
        pooled = encoder_pooled  # (1,enc_dim)
        h, c = self.init_hidden_state(pooled)  # (1, dec_dim)
        # Each beam: (tokens, logprob, h, c)
        beams = [([sos_idx], 0.0, h, c)]
        completed = []
        for _ in range(max_len):
            new_beams = []
            for tokens, score, h_b, c_b in beams:
                if tokens[-1] == eos_idx:
                    completed.append((tokens, score))
                    continue
                # get last token embedding
                last_idx = torch.tensor([tokens[-1]], device=device)
                emb = self.embedding(last_idx)  # (1, emb)
                context, _ = self.attention(enc, h_b)
                lstm_input = torch.cat([emb.squeeze(0), context], dim=0).unsqueeze(0)  # (1, emb+enc)
                h_new, c_new = self.lstm(lstm_input, (h_b, c_b))
                out = F.log_softmax(self.fc(h_new), dim=-1).squeeze(0)  # (vocab,)
                topk_logprobs, topk_idx = torch.topk(out, beam_size)
                for k_logp, k_idx in zip(topk_logprobs.tolist(), topk_idx.tolist()):
                    new_tokens = tokens + [int(k_idx)]
                    new_score = score + float(k_logp)
                    new_beams.append((new_tokens, new_score, h_new, c_new))
            # keep top beams
            new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]
            beams = new_beams
            if len(beams) == 0:
                break
        completed.extend([(b[0], b[1]) for b in beams])
        completed = sorted(completed, key=lambda x: x[1], reverse=True)
        best_tokens = completed[0][0]
        return best_tokens


In [13]:
EMBED_DIM = 300   # word embed dim (we can load GloVe into this shape)
DECODER_DIM = 512
ATTN_DIM = 512
ENC_EMBED = 512  # embed size for pooled vector to decoder init

encoder = EncoderCNNSpatial(
    embed_dim=ENC_EMBED,
    pretrained=True,
    train_backbone=False
).to(device)

decoder = DecoderWithAttention(
    embed_dim=EMBED_DIM,
    decoder_dim=DECODER_DIM,
    vocab_size=len(vocab),
    encoder_dim=512,  #  Changed from 2048 to 512 (after projection)
    attention_dim=ATTN_DIM,
    padding_idx=vocab.stoi["<pad>"]
).to(device)

#  Projection layer: 2048 ‚Üí 512 (for encoder spatial features)
proj = nn.Linear(2048, 512).to(device)

# Optionally load pretrained GloVe into decoder.embedding here
# Loss: ignore pad
criterion = nn.CrossEntropyLoss(
    ignore_index=vocab.stoi["<pad>"],
    label_smoothing=0.1  #  Helps with overconfidence
)

#  Include all trainable parameters: decoder, encoder head, and projection
params = (
    list(decoder.parameters()) +
    list(encoder.fc.parameters()) +
    list(encoder.bn.parameters()) +
    list(proj.parameters())  #  Add projection layer
)

optimizer = torch.optim.AdamW(params, lr=3e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=2
)

total_params = sum(p.numel() for p in params if p.requires_grad)
print(f"Total trainable parameters: {total_params:,}")
print(f"  - Decoder: {sum(p.numel() for p in decoder.parameters() if p.requires_grad):,}")
print(f"  - Encoder head: {sum(p.numel() for p in encoder.fc.parameters() if p.requires_grad):,}")
print(f"  - Projection: {sum(p.numel() for p in proj.parameters() if p.requires_grad):,}")


Total trainable parameters: 10,288,762
  - Decoder: 8,189,562
  - Encoder head: 1,049,088
  - Projection: 1,049,088


In [None]:

from pathlib import Path

save_path = "best_caps.pt"
NUM_EPOCHS = 30
best_val_loss = float('inf')
patience = 5
patience_cnt = 0

def train_one_epoch(epoch, teacher_forcing_ratio):
    encoder.train()
    decoder.train()
    proj.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Train E{epoch}")

    for batch_idx, (images, captions, mask) in enumerate(pbar):
        images, captions = images.to(device), captions.to(device)
        optimizer.zero_grad()

        # Encode
        encoder_feats, encoder_pooled = encoder(images)
        encoder_feats = proj(encoder_feats)  # Project to 512

        # Decode - captions include <start> at position 0
        logits = decoder(
            encoder_feats, encoder_pooled, captions,
            teacher_forcing_ratio=teacher_forcing_ratio
        )

        #  FIXED: Predictions are for positions 1 onwards
        # logits[:, t, :] predicts token at position t (given 0 to t-1)
        # So logits[:, 1, :] predicts position 1 (given <start>)
        # We want to predict positions 1 to end (skip <start>, predict everything else)

        # Shift: compare predictions at t with targets at t
        predictions = logits[:, 1:, :]  # (B, seq_len-1, vocab_size)
        targets = captions[:, 1:]       # (B, seq_len-1) - actual tokens at positions 1+

        # Loss - ignore padding
        loss = criterion(
            predictions.reshape(-1, predictions.size(-1)),
            targets.reshape(-1)
        )

        loss.backward()
        torch.nn.utils.clip_grad_norm_(params, 5.0)
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})

    return running_loss / len(train_loader.dataset)


def validate(epoch):
    encoder.eval()
    decoder.eval()
    proj.eval()  #  Set proj to eval mode
    val_loss = 0.0

    with torch.no_grad():
        for images, captions, mask in tqdm(val_loader, desc=f"Val E{epoch}"):
            images, captions = images.to(device), captions.to(device)

            # ---- Encoder ----
            encoder_feats, encoder_pooled = encoder(images)

            #  Project spatial features: 2048 ‚Üí 512
            encoder_feats = proj(encoder_feats)

            # ---- Decoder ----
            logits = decoder(
                encoder_feats,
                encoder_pooled,
                captions,
                teacher_forcing_ratio=0.0  # No teacher forcing in validation
            )

            # ---- Shift for next-word prediction ----
            logits = logits[:, 1:, :]
            targets = captions[:, 1:]

            loss = criterion(
                logits.reshape(-1, logits.size(-1)),
                targets.reshape(-1)
            )
            val_loss += loss.item() * images.size(0)

    return val_loss / len(val_loader.dataset)



print("Starting training...")
for epoch in range(1, NUM_EPOCHS + 1):
    # Linearly decay teacher forcing from 0.5 to 0.2
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)
    tfr = max(0.5 * (1 - (epoch - 1) / NUM_EPOCHS), 0.2)

    train_loss = train_one_epoch(epoch, teacher_forcing_ratio=tfr)
    val_loss = validate(epoch)

    print(f"Epoch {epoch}/{NUM_EPOCHS} | "
          f"TrainLoss={train_loss:.4f} | "
          f"ValLoss={val_loss:.4f} | "
          f"TFR={tfr:.3f}")

    # Step scheduler
    scheduler.step(val_loss)

    # Save best model
    if val_loss < best_val_loss - 1e-4:
        best_val_loss = val_loss
        patience_cnt = 0
        torch.save({
            'encoder': encoder.state_dict(),
            'decoder': decoder.state_dict(),
            'proj': proj.state_dict(),  #  Save projection layer
            'vocab_itos': vocab.itos,
            'vocab_stoi': vocab.stoi,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'epoch': epoch,
            'best_val_loss': best_val_loss
        }, save_path)
        print(f" Saved best model (val_loss: {val_loss:.4f})")
    else:
        patience_cnt += 1
        print(f"‚è≥ No improvement for {patience_cnt}/{patience} epochs")
        if patience_cnt >= patience:
            print("‚èπÔ∏è Early stopping triggered.")
            break

print("\n" + "="*50)
print("Training completed!")
print(f"Best validation loss: {best_val_loss:.4f}")
print("="*50)


print("\nLoading best checkpoint...")
ck = torch.load(save_path, map_location=device)
encoder.load_state_dict(ck['encoder'])
decoder.load_state_dict(ck['decoder'])
proj.load_state_dict(ck['proj'])  #  Load projection layer
print(f" Loaded best checkpoint from epoch {ck['epoch']}")
print(f" Best validation loss: {ck['best_val_loss']:.4f}")

Starting training...
Device: cuda


Train E1:   0%|          | 0/1012 [00:00<?, ?it/s]

In [None]:

def decode_indices(idx_seq, itos):
    words = []
    for idx in idx_seq:
        token = itos[idx]
        if token == "<end>":
            break
        if token not in ("<start>", "<pad>"):
            words.append(token)
    return " ".join(words)


def generate_caption_greedy(image_pil, max_len=20):
    """Generate caption with <unk> filtering"""
    encoder.eval()
    decoder.eval()
    proj.eval()

    img_t = val_transform(image_pil).unsqueeze(0).to(device)

    with torch.no_grad():
        enc_feats, enc_pooled = encoder(img_t)
        enc_feats = proj(enc_feats)

        h, c = decoder.init_hidden_state(enc_pooled)
        input_idx = torch.tensor([vocab.stoi["<start>"]], device=device)
        generated = []
        unk_count = 0

        for step in range(max_len):
            emb = decoder.embedding(input_idx)
            context, _ = decoder.attention(enc_feats, h)

            if hasattr(decoder, 'f_beta'):
                gate = decoder.sigmoid(decoder.f_beta(h))
                context = gate * context

            lstm_input = torch.cat([emb, context], dim=1)
            h, c = decoder.lstm(lstm_input, (h, c))

            output = decoder.fc(h)

            #  Get top-5 predictions and avoid <unk> if possible
            topk_vals, topk_ids = output.topk(5, dim=-1)

            # Try to find first non-<unk> token
            pred_idx = topk_ids[0, 0].item()  # Default to top prediction
            for i in range(5):
                candidate = topk_ids[0, i].item()
                if vocab.itos[candidate] != "<unk>":
                    pred_idx = candidate
                    break

            # Stop if <end> or too many <unk>
            if pred_idx == vocab.stoi["<end>"]:
                break

            if vocab.itos[pred_idx] == "<unk>":
                unk_count += 1
                if unk_count >= 3:  # Stop if too many unknowns
                    break

            # Add to generated sequence (skip special tokens)
            if vocab.itos[pred_idx] not in ("<start>", "<pad>", "<unk>"):
                generated.append(vocab.itos[pred_idx])

            input_idx = torch.tensor([pred_idx], device=device)

        caption = " ".join(generated)
        #  Apply post-processing
        caption = post_process_caption(caption)
        return caption

def post_process_caption(caption: str) -> str:
    """Clean up generated caption"""
    # Remove multiple spaces
    caption = ' '.join(caption.split())

    # Remove <unk> tokens if they appear
    caption = caption.replace('<unk>', '').strip()
    caption = ' '.join(caption.split())  # Clean extra spaces again

    # Capitalize first letter
    if caption:
        caption = caption[0].upper() + caption[1:]

    # Add period if missing
    if caption and caption[-1] not in '.!?':
        caption += '.'

    return caption


In [None]:

from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.tokenize import word_tokenize
import random
import numpy as np

def evaluate_bleu(n_samples=200, beam_size=3, max_len=20):
    """
    Evaluate BLEU score on validation set samples
    """
    encoder.eval()
    decoder.eval()
    proj.eval()  #  Set projection to eval mode

    smooth = SmoothingFunction().method1
    scores = []

    # Sample random indices from val_ds
    total_samples = len(val_ds)
    idxs = random.sample(range(total_samples), min(n_samples, total_samples))

    for i in tqdm(idxs, desc="BLEU eval"):
        # Get image path and caption
        if isinstance(val_ds, torch.utils.data.Subset):
            imgpath, cap = val_ds.dataset.samples[val_ds.indices[i]]
        else:
            imgpath, cap = val_ds.samples[i] if hasattr(val_ds, 'samples') else val_ds.dataset.samples[i]

        # Get ground truth caption
        if isinstance(cap, str):
            gt = cap
        else:
            gt_tokens = [
                vocab.itos[idx]
                for idx in cap.tolist()
                if vocab.itos[idx] not in ("<pad>", "<start>", "<end>")
            ]
            gt = " ".join(gt_tokens)

        # Generate prediction
        img = Image.open(imgpath).convert("RGB")
        pred = generate_caption_greedy(img, beam_size=beam_size, max_len=max_len)

        # Calculate BLEU
        bleu = sentence_bleu(
            [word_tokenize(gt)],
            word_tokenize(pred),
            smoothing_function=smooth
        )
        scores.append(bleu)

    print(f"Avg BLEU: {np.mean(scores):.4f}")
    print(f"Median BLEU: {np.median(scores):.4f}")
    print(f"Std BLEU: {np.std(scores):.4f}")

    return scores


# ============================
#  Generate Caption with Beam Search (Fixed)
# ============================
def generate_caption_beam(encoder, decoder, image_path, vocab, beam_size=5, max_len=20):
    """Beam search with <unk> penalization"""
    encoder.eval()
    decoder.eval()
    proj.eval()

    image = Image.open(image_path).convert("RGB")
    transform = val_transform if 'val_transform' in globals() else train_transform
    image = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        encoder_out, encoder_pooled = encoder(image)
        encoder_out = proj(encoder_out)

        num_pixels = encoder_out.size(1)
        encoder_dim = encoder_out.size(2)

        encoder_out = encoder_out.expand(beam_size, num_pixels, encoder_dim)
        encoder_pooled = encoder_pooled.expand(beam_size, -1)

        k = beam_size
        vocab_size = len(vocab)
        start_idx = vocab.stoi["<start>"]
        end_idx = vocab.stoi["<end>"]
        unk_idx = vocab.stoi["<unk>"]

        seqs = torch.full((k, 1), start_idx, dtype=torch.long).to(device)
        top_k_scores = torch.zeros(k, 1).to(device)
        complete_seqs = []
        complete_seqs_scores = []

        h, c = decoder.init_hidden_state(encoder_pooled)

        for step in range(max_len):
            embeddings = decoder.embedding(seqs[:, -1])
            awe, _ = decoder.attention(encoder_out, h)

            if hasattr(decoder, 'f_beta'):
                gate = decoder.sigmoid(decoder.f_beta(h))
                awe = gate * awe

            lstm_input = torch.cat([embeddings, awe], dim=1)
            h, c = decoder.lstm(lstm_input, (h, c))

            scores = decoder.fc(h)
            scores = F.log_softmax(scores, dim=1)

            #  Penalize <unk> tokens
            scores[:, unk_idx] -= 2.0

            scores = top_k_scores.expand_as(scores) + scores

            if step == 0:
                top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)
            else:
                top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)

            prev_word_inds = top_k_words // vocab_size
            next_word_inds = top_k_words % vocab_size
            seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)

            incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds)
                              if next_word != end_idx]
            complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

            if len(complete_inds) > 0:
                complete_seqs.extend(seqs[complete_inds].tolist())
                lengths = torch.tensor([len(seqs[i]) for i in complete_inds], device=device)
                normalized_scores = top_k_scores[complete_inds] / lengths.float()
                complete_seqs_scores.extend(normalized_scores.tolist())

            k -= len(complete_inds)
            if k == 0:
                break

            seqs = seqs[incomplete_inds]
            h = h[prev_word_inds[incomplete_inds]]
            c = c[prev_word_inds[incomplete_inds]]
            encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
            top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)

        if len(complete_seqs_scores) == 0:
            complete_seqs = seqs.tolist()
            lengths = torch.tensor([len(seq) for seq in complete_seqs], device=device)
            complete_seqs_scores = (top_k_scores.squeeze(1) / lengths.float()).tolist()

        i = complete_seqs_scores.index(max(complete_seqs_scores))
        best_seq = complete_seqs[i]

        caption = [
            vocab.itos[idx]
            for idx in best_seq
            if vocab.itos[idx] not in ("<start>", "<end>", "<pad>", "<unk>")
        ]

        caption = " ".join(caption)
        #  Apply post-processing
        caption = post_process_caption(caption)
        return caption

In [None]:
import matplotlib.pyplot as plt


In [None]:

from IPython.display import display

def show_val_examples(n=5, beam_size=5):
    """Display with improved captions"""
    encoder.eval()
    decoder.eval()
    proj.eval()

    for i in range(n):
        if isinstance(val_ds, torch.utils.data.Subset):
            idx = val_ds.indices[i]
            imgpath, cap = val_ds.dataset.samples[idx]
        elif hasattr(val_ds, 'samples'):
            imgpath, cap = val_ds.samples[i]
        else:
            imgpath, cap = val_ds.dataset.samples[i]

        if isinstance(cap, str):
            gt = cap
        else:
            gt_tokens = [
                vocab.itos[idx]
                for idx in cap.tolist()
                if vocab.itos[idx] not in ("<pad>", "<start>", "<end>")
            ]
            gt = " ".join(gt_tokens)

        img = Image.open(imgpath).convert("RGB")
        plt.figure(figsize=(8, 6))
        plt.imshow(img)
        plt.axis("off")

        with torch.no_grad():
            caption = generate_caption_beam(
                encoder, decoder, imgpath, vocab, beam_size=beam_size
            )
            # Caption is already post-processed in generate_caption_beam

        plt.title(f"Predicted: {caption}\n\nGround Truth: {gt}",
                  fontsize=12)
        plt.tight_layout()
        plt.show()



print("Generating captions for validation examples...")
show_val_examples(6, beam_size=5)

# Optionally run BLEU evaluation (uncomment to run)
# print("\nEvaluating BLEU scores...")
# bleu_scores = evaluate_bleu(n_samples=100, beam_size=3, max_len=20)

In [None]:
def generate_caption_greedy_v2(image_pil, max_len=20, prevent_repetition=True):
    """Generate caption with repetition prevention and better stopping"""
    encoder.eval()
    decoder.eval()
    proj.eval()

    img_t = val_transform(image_pil).unsqueeze(0).to(device)

    with torch.no_grad():
        enc_feats, enc_pooled = encoder(img_t)
        enc_feats = proj(enc_feats)

        h, c = decoder.init_hidden_state(enc_pooled)
        input_idx = torch.tensor([vocab.stoi["<start>"]], device=device)
        generated = []
        generated_ids = []

        for step in range(max_len):
            emb = decoder.embedding(input_idx)
            context, _ = decoder.attention(enc_feats, h)

            if hasattr(decoder, 'f_beta'):
                gate = decoder.sigmoid(decoder.f_beta(h))
                context = gate * context

            lstm_input = torch.cat([emb, context], dim=1)
            h, c = decoder.lstm(lstm_input, (h, c))

            output = decoder.fc(h).squeeze(0)

            # Repetition prevention
            if prevent_repetition and len(generated_ids) >= 2:
                for prev_id in generated_ids[-3:]:
                    output[prev_id] -= 2.0
                if len(generated_ids) >= 1:
                    output[generated_ids[-1]] -= 5.0

            # Penalize articles at end
            if len(generated) >= 3:
                article_ids = [vocab.stoi.get(w, -1) for w in ['a', 'an', 'the']]
                for aid in article_ids:
                    if aid >= 0:
                        output[aid] -= 3.0

            # Boost <end> after reasonable length
            if len(generated) >= 5:
                output[vocab.stoi["<end>"]] += 1.0

            # Get top-k and avoid <unk>
            topk_vals, topk_ids = output.topk(10, dim=-1)
            pred_idx = topk_ids[0].item()
            for i in range(10):
                candidate = topk_ids[i].item()
                if vocab.itos[candidate] != "<unk>":
                    pred_idx = candidate
                    break

            if pred_idx == vocab.stoi["<end>"]:
                break

            current_word = vocab.itos[pred_idx]
            if len(generated) >= 3:
                if generated[-1] in ['a', 'an', 'the'] and current_word in ['a', 'an', 'the', '.']:
                    break

            if current_word not in ("<start>", "<pad>", "<unk>"):
                generated.append(current_word)
                generated_ids.append(pred_idx)

            input_idx = torch.tensor([pred_idx], device=device)

        caption = " ".join(generated).strip()
        while caption.endswith((' a', ' an', ' the')):
            caption = caption.rsplit(' ', 1)[0]

        if caption:
            caption = caption[0].upper() + caption[1:] if len(caption) > 1 else caption.upper()
            if caption[-1] not in '.!?':
                caption += '.'

        return caption


def generate_caption_nucleus(image_pil, max_len=20, p=0.9, temperature=1.0):
    """Generate caption using nucleus (top-p) sampling for diversity"""
    encoder.eval()
    decoder.eval()
    proj.eval()

    img_t = val_transform(image_pil).unsqueeze(0).to(device)

    with torch.no_grad():
        enc_feats, enc_pooled = encoder(img_t)
        enc_feats = proj(enc_feats)

        h, c = decoder.init_hidden_state(enc_pooled)
        input_idx = torch.tensor([vocab.stoi["<start>"]], device=device)
        generated = []
        generated_ids = []

        for step in range(max_len):
            emb = decoder.embedding(input_idx)
            context, _ = decoder.attention(enc_feats, h)

            if hasattr(decoder, 'f_beta'):
                gate = decoder.sigmoid(decoder.f_beta(h))
                context = gate * context

            lstm_input = torch.cat([emb, context], dim=1)
            h, c = decoder.lstm(lstm_input, (h, c))

            output = decoder.fc(h).squeeze(0) / temperature

            if len(generated_ids) >= 1:
                output[generated_ids[-1]] -= 5.0

            probs = F.softmax(output, dim=-1)
            sorted_probs, sorted_indices = torch.sort(probs, descending=True)
            cumsum_probs = torch.cumsum(sorted_probs, dim=-1)

            sorted_indices_to_remove = cumsum_probs > p
            sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
            sorted_indices_to_remove[0] = False

            probs[sorted_indices[sorted_indices_to_remove]] = 0
            probs = probs / probs.sum()

            pred_idx = torch.multinomial(probs, 1).item()

            if pred_idx == vocab.stoi["<end>"]:
                break

            current_word = vocab.itos[pred_idx]
            if current_word not in ("<start>", "<pad>", "<unk>"):
                generated.append(current_word)
                generated_ids.append(pred_idx)

            input_idx = torch.tensor([pred_idx], device=device)

        caption = " ".join(generated).strip()
        while caption.endswith((' a', ' an', ' the')):
            caption = caption.rsplit(' ', 1)[0]

        if caption:
            caption = caption[0].upper() + caption[1:]
            if caption[-1] not in '.!?':
                caption += '.'

        return caption


def generate_caption_beam_v2(encoder, decoder, image_path, vocab, beam_size=5,
                             max_len=20, length_penalty=0.7):
    """Improved beam search with length penalty and diversity"""
    encoder.eval()
    decoder.eval()
    proj.eval()

    image = Image.open(image_path).convert("RGB")
    transform = val_transform if 'val_transform' in globals() else train_transform
    image = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        encoder_out, encoder_pooled = encoder(image)
        encoder_out = proj(encoder_out)

        num_pixels = encoder_out.size(1)
        encoder_dim = encoder_out.size(2)

        encoder_out = encoder_out.expand(beam_size, num_pixels, encoder_dim)
        encoder_pooled = encoder_pooled.expand(beam_size, -1)

        k = beam_size
        vocab_size = len(vocab)
        start_idx = vocab.stoi["<start>"]
        end_idx = vocab.stoi["<end>"]
        unk_idx = vocab.stoi["<unk>"]

        seqs = torch.full((k, 1), start_idx, dtype=torch.long).to(device)
        top_k_scores = torch.zeros(k, 1).to(device)
        complete_seqs = []
        complete_seqs_scores = []

        h, c = decoder.init_hidden_state(encoder_pooled)

        for step in range(max_len):
            embeddings = decoder.embedding(seqs[:, -1])
            awe, _ = decoder.attention(encoder_out, h)

            if hasattr(decoder, 'f_beta'):
                gate = decoder.sigmoid(decoder.f_beta(h))
                awe = gate * awe

            lstm_input = torch.cat([embeddings, awe], dim=1)
            h, c = decoder.lstm(lstm_input, (h, c))

            scores = decoder.fc(h)
            scores = F.log_softmax(scores, dim=1)

            scores[:, unk_idx] -= 3.0

            for beam_idx in range(scores.size(0)):
                if seqs.size(1) >= 2:
                    prev_token = seqs[beam_idx, -1].item()
                    scores[beam_idx, prev_token] -= 2.0

                    if seqs.size(1) >= 3 and seqs[beam_idx, -1] == seqs[beam_idx, -2]:
                        scores[beam_idx, prev_token] -= 5.0

            scores = top_k_scores.expand_as(scores) + scores

            if step == 0:
                top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)
            else:
                top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)

            prev_word_inds = top_k_words // vocab_size
            next_word_inds = top_k_words % vocab_size
            seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)

            incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds)
                              if next_word != end_idx]
            complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

            if len(complete_inds) > 0:
                complete_seqs.extend(seqs[complete_inds].tolist())
                lengths = torch.tensor([len(seqs[i]) for i in complete_inds],
                                      device=device, dtype=torch.float)
                normalized_scores = top_k_scores[complete_inds].squeeze() / (lengths ** length_penalty)
                complete_seqs_scores.extend(normalized_scores.tolist())

            k -= len(complete_inds)
            if k == 0:
                break

            seqs = seqs[incomplete_inds]
            h = h[prev_word_inds[incomplete_inds]]
            c = c[prev_word_inds[incomplete_inds]]
            encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
            top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)

        if len(complete_seqs_scores) == 0:
            complete_seqs = seqs.tolist()
            lengths = torch.tensor([len(seq) for seq in complete_seqs],
                                  device=device, dtype=torch.float)
            complete_seqs_scores = (top_k_scores.squeeze() / (lengths ** length_penalty)).tolist()

        i = complete_seqs_scores.index(max(complete_seqs_scores))
        best_seq = complete_seqs[i]

        caption = [
            vocab.itos[idx]
            for idx in best_seq
            if vocab.itos[idx] not in ("<start>", "<end>", "<pad>", "<unk>")
        ]

        caption = " ".join(caption).strip()
        while caption.endswith((' a', ' an', ' the')):
            caption = caption.rsplit(' ', 1)[0]

        if caption:
            caption = caption[0].upper() + caption[1:]
            if caption[-1] not in '.!?':
                caption += '.'

        return caption


def compare_generation_methods(image_path, vocab):
    """Compare different generation methods on one image"""
    img = Image.open(image_path).convert("RGB")

    print("="*70)
    print(f"Image: {image_path.split('/')[-1]}")
    print("="*70)

    # Method 1: Greedy v2
    cap1 = generate_caption_greedy_v2(img, max_len=20)
    print(f"1. Greedy v2:    {cap1}")

    # Method 2: Nucleus sampling
    cap2 = generate_caption_nucleus(img, max_len=20, p=0.9)
    print(f"2. Nucleus:      {cap2}")

    # Method 3: Beam search v2
    cap3 = generate_caption_beam_v2(encoder, decoder, image_path, vocab, beam_size=5)
    print(f"3. Beam (5):     {cap3}")

    # Method 4: Beam search with higher beam
    cap4 = generate_caption_beam_v2(encoder, decoder, image_path, vocab, beam_size=10)
    print(f"4. Beam (10):    {cap4}")

    print("="*70)
    print()

    # Display image
    plt.figure(figsize=(10, 8))
    plt.imshow(img)
    plt.axis('off')
    plt.title(f"Best Caption (Beam-10): {cap4}", fontsize=14, wrap=True)
    plt.tight_layout()
    plt.show()


print(" All functions loaded successfully!")


In [None]:
print("üß™ Test 1: Comparing generation methods on one image\n")

# Get first validation image
if isinstance(val_ds, torch.utils.data.Subset):
    test_img_path = val_ds.dataset.samples[val_ds.indices[0]][0]
else:
    test_img_path = val_ds.samples[0][0] if hasattr(val_ds, 'samples') else val_ds.dataset.samples[0][0]

compare_generation_methods(test_img_path, vocab)

In [None]:

print("üß™ Test 2: Testing on previously problematic images\n")

# Test on 5 images
for i in range(5):
    if isinstance(val_ds, torch.utils.data.Subset):
        imgpath = val_ds.dataset.samples[val_ds.indices[i]][0]
        _, gt_cap = val_ds.dataset.samples[val_ds.indices[i]]
    else:
        imgpath = val_ds.samples[i][0] if hasattr(val_ds, 'samples') else val_ds.dataset.samples[i][0]
        _, gt_cap = val_ds.samples[i] if hasattr(val_ds, 'samples') else val_ds.dataset.samples[i]

    # Get ground truth
    if isinstance(gt_cap, str):
        gt = gt_cap
    else:
        gt = " ".join([vocab.itos[idx] for idx in gt_cap.tolist()
                      if vocab.itos[idx] not in ("<pad>", "<start>", "<end>")])

    print(f"\n{'='*70}")
    print(f"Image {i+1}")
    print(f"Ground Truth: {gt}")
    print(f"{'='*70}")

    img = Image.open(imgpath).convert("RGB")

    # Generate with each method
    pred1 = generate_caption_greedy_v2(img, max_len=20)
    pred2 = generate_caption_nucleus(img, max_len=20, p=0.9)
    pred3 = generate_caption_beam_v2(encoder, decoder, imgpath, vocab, beam_size=10)

    print(f"Greedy v2: {pred1}")
    print(f"Nucleus:   {pred2}")
    print(f"Beam-10:   {pred3}")

    # Show image with best caption
    plt.figure(figsize=(8, 6))
    plt.imshow(img)
    plt.axis('off')
    plt.title(f"Predicted: {pred3}\n\nGround Truth: {gt}", fontsize=11)
    plt.tight_layout()
    plt.show()

In [None]:
# Cell 13: Save final artifacts and quick predict function
torch.save({
    'encoder': encoder.state_dict(),
    'decoder': decoder.state_dict(),
    'vocab_itos': vocab.itos,
    'vocab_stoi': vocab.stoi,
    'MAX_LEN': MAX_LEN
}, "/content/flickr8k_caption_best.pt")
print("Saved to /flickr8k_caption_best.pt")

# Quick inference wrapper (for local use or to plug into Gradio)
def predict_from_pil(image_pil, beam_size=3):
    return generate_caption_greedy(image_pil.convert("RGB"), beam_size=beam_size, max_len=MAX_LEN)

# Example:
# from PIL import Image
# im = Image.open('/path/to/some/image.jpg')
# print(predict_from_pil(im, beam_size=5))


In [None]:
# def evaluate_accuracy(encoder, decoder, dataloader, vocab, proj):
#     """
#     Evaluate word-level accuracy on validation set
#     """
#     decoder.eval()
#     encoder.eval()
#     total, correct = 0, 0

#     with torch.no_grad():
#         for imgs, caps, caplens in tqdm(dataloader, desc="Evaluating Accuracy"):
#             imgs, caps = imgs.to(device), caps.to(device)

#             # Encode
#             enc_out, enc_pooled = encoder(imgs)         # (B, num_pixels, 2048), (B, 512)

#             #  Apply projection to match decoder input dim
#             enc_out = proj(enc_out)                     # (B, num_pixels, 512)
#             enc_pooled = proj(enc_pooled)               # (B, 512)

#             # Forward decode
#             scores, caps_sorted, decode_lengths, _, _ = decoder(enc_out, caps, caplens)
#             targets = caps_sorted[:, 1:]

#             # Align predictions
#             scores_copy = torch.zeros_like(scores)
#             for i, length in enumerate(decode_lengths):
#                 scores_copy[i, :length, :] = scores[i, :length, :]
#             scores = scores_copy

#             preds = scores.argmax(-1)
#             mask = (targets != vocab.stoi["<pad>"])
#             correct += ((preds == targets) & mask).sum().item()
#             total += mask.sum().item()

#     acc = 100 * correct / total
#     print(f" Word-level Accuracy: {acc:.2f}%")
#     return acc


In [None]:
# evaluate_accuracy(encoder, decoder, val_loader, vocab, proj)

In [None]:
import torchvision

In [None]:
from nltk.translate.bleu_score import corpus_bleu

def evaluate_bleu(encoder, decoder, data_loader, vocab, beam_size=3, max_len=20):
    encoder.eval()
    decoder.eval()

    refs, hyps = [], []

    with torch.no_grad():
        for images, caps, _ in tqdm(data_loader, desc="Evaluating BLEU"):
            for i in range(images.size(0)):
                img = images[i]
                tmp_path = "/tmp/tmp_eval.jpg"
                torchvision.utils.save_image(img.cpu(), tmp_path)
                hyp = generate_caption_beam(encoder, decoder, tmp_path, vocab, beam_size, max_len)
                ref = [ [vocab.itos[idx] for idx in caps[i].tolist() if vocab.itos[idx] not in ("<start>", "<end>", "<pad>")] ]
                hyps.append(hyp.split())
                refs.append(ref)

    bleu1 = corpus_bleu(refs, hyps, weights=(1, 0, 0, 0))
    bleu4 = corpus_bleu(refs, hyps, weights=(0.25, 0.25, 0.25, 0.25))
    print(f"BLEU-1: {bleu1*100:.2f}, BLEU-4: {bleu4*100:.2f}")
    return bleu1, bleu4


In [None]:
evaluate_bleu(encoder, decoder, val_loader, vocab)