# Image Captioning with Transformers (Encoder–Decoder)

An image-captioning model using a **pretrained ResNet50** encoder (image features) and a **custom Transformer decoder** for caption generation.

What's included and *why*:
- **Pretrained ResNet50**: gives strong visual features without training a vision model from scratch. Faster convergence and better captions.
- **Patch/region features**: we convert CNN feature maps into a sequence the decoder can attend to.
- **Transformer decoder**: autoregressive text generation with self-attention + cross-attention to image features.
- **Beam search**: improved inference over greedy decoding; explores top-k candidate sequences.
- **Evaluation**: BLEU (via `nltk`) and CIDEr (via `pycocoevalcap`)—CIDEr correlates better with human judgement for captioning.


# Setup & Install

In [None]:

!pip install --upgrade pip
!pip install transformers torchvision pycocotools tqdm nltk pycocoevalcap
print("Install complete")


# Imports & Device

In [None]:

import os, math, random, json, collections
import random
from pathlib import Path
from tqdm.notebook import tqdm, trange

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CocoCaptions
import torchvision.models as models

from transformers import AutoTokenizer

os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


# Hyperparameters

In [None]:

image_size = 128
batch_size = 64
learning_rate = 1e-4
num_epochs = 50
patch_size = 16
hidden_size = 192
num_layers = (4, 4)
num_heads = 8

sample_small_dataset = True
max_train_samples = 2000
max_val_samples = 500

use_pretrained_resnet_encoder = True


# Download COCO Captions

The cell below supports a full download (set FULL_DOWNLOAD = True) or a demo path that downloads val2014 + annotations and samples a small subset.

In [None]:

ROOT = Path("/content/coco_captions"); ROOT.mkdir(parents=True, exist_ok=True)
FULL_DOWNLOAD = False
urls = {
    "train2014": "http://images.cocodataset.org/zips/train2014.zip",
    "val2014": "http://images.cocodataset.org/zips/val2014.zip",
    "annotations": "http://images.cocodataset.org/annotations/annotations_trainval2014.zip"
}
if FULL_DOWNLOAD:
    for k, url in urls.items():
        out_zip = ROOT / f"{k}.zip"
        if not out_zip.exists():
            print(f"Downloading {k} ...")
            !wget -q -c {url} -O "{out_zip}"
    for k in urls:
        out_zip = ROOT / f"{k}.zip"
        if out_zip.exists():
            !unzip -q -o "{out_zip}" -d "{ROOT}"
else:
    if not (ROOT/"val2014.zip").exists():
        print("Downloading val2014.zip (≈6GB)...")
        !wget -q -c {urls['val2014']} -O "{ROOT/'val2014.zip'}"
    !unzip -q -o "{ROOT/'val2014.zip'}" -d "{ROOT}"
    if not (ROOT/"annotations_trainval2014.zip").exists():
        print("Downloading annotations (≈241MB)...")
        !wget -q -c {urls['annotations']} -O "{ROOT/'annotations_trainval2014.zip'}"
    !unzip -q -o "{ROOT/'annotations_trainval2014.zip'}" -d "{ROOT}"
print("Data in:", ROOT)


# Data Transforms, Dataset & Dataloaders

- Resize to `IMAGE_SIZE` because ResNet expects ~224×224.
- `RandomCrop` + `flips` on training improves generalization.
- Using `SampleCaption` to randomly choose one of the 5 COCO captions per image — this provides diversity without duplicating images in the loader.

In [None]:

class SampleCaption(object):
    def __call__(self, captions): return random.choice(captions)

train_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.RandomCrop(image_size, padding=4, pad_if_needed=True),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])
val_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

ROOT = Path("/content/coco_captions")
train_images = ROOT / "train2014"
val_images = ROOT / "val2014"
ann_dir = ROOT / "annotations"
train_ann = ann_dir / "captions_train2014.json"
val_ann = ann_dir / "captions_val2014.json"

if train_ann.exists() and train_images.exists():
    train_ds = CocoCaptions(root=str(train_images), annFile=str(train_ann), transform=train_transform, target_transform=SampleCaption())
else:
    print("Train set not found. Using val set as training for demo.")
    train_ds = CocoCaptions(root=str(val_images), annFile=str(val_ann), transform=train_transform, target_transform=SampleCaption())

val_ds = CocoCaptions(root=str(val_images), annFile=str(val_ann), transform=val_transform, target_transform=SampleCaption())

if sample_small_dataset:
    rng = torch.Generator().manual_seed(42)
    train_n = min(len(train_ds), max_train_samples)
    val_n = min(len(val_ds), max_val_samples)
    train_ds, _ = torch.utils.data.random_split(train_ds, [train_n, max(0, len(train_ds)-train_n)], generator=rng)
    val_ds, _ = torch.utils.data.random_split(val_ds, [val_n, max(0, len(val_ds)-val_n)], generator=rng)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

print("Train samples:", len(train_ds), "Val samples:", len(val_ds))

imgs, caps = next(iter(val_loader))
plt.figure(figsize=(3,3))
out = torchvision.utils.make_grid(imgs[0:1], 1, normalize=True)
plt.imshow(out.numpy().transpose(1,2,0)); plt.axis('off')
print("Sample caption:", caps[0])


# Tokenizer + TokenDrop

Using `distilbert-base-uncased` tokenizer to convert captions to token ids. The decoder's final logits will predict over this tokenizer's vocabulary.

Tokenize captions on the fly to avoid storing a full tokenized dataset. `TokenDrop` randomly replaces tokens with pad token to regularize decoder reliance on previous ground-truth tokens (reduces teacher-forcing memorization).

In [None]:

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
print("Vocab size:", tokenizer.vocab_size)

class TokenDrop(object):
    def __init__(self, prob=0.5, blank_token=None, eos_token=None):
        self.prob = prob
        self.blank_token = blank_token if blank_token is not None else tokenizer.pad_token_id or tokenizer.mask_token_id
        self.eos_token = eos_token if eos_token is not None else tokenizer.sep_token_id or tokenizer.eos_token_id
        self.cls_token = tokenizer.cls_token_id
    def __call__(self, input_ids):
        mask = torch.bernoulli(self.prob * torch.ones_like(input_ids)).bool()
        if self.eos_token is not None:
            mask &= (input_ids != self.eos_token)
        if self.cls_token is not None:
            mask &= (input_ids != self.cls_token)
        out = input_ids.clone()
        out[mask] = self.blank_token
        return out

td = TokenDrop(prob=0.5)


# Model – Encoder–Decoder Transformer

- **Why use ResNet50?** Pretrained ResNet50 yields high-quality spatial feature maps. Remove the final classification head and extract an intermediate feature map (e.g., after layer4) which has spatial dimensions (H', W'). Flatten spatial positions into a sequence the decoder can attend to.

- Decoder receives **token embeddings** + **sinusoidal positional embeddings**.
- Each block does self-attention (causal mask) then cross-attention to image features, then MLP.
- Output logits map to the tokenizer vocabulary.

This approach is simple and effective: the decoder learns to attend to different spatial regions when generating each token.

In [None]:

def extract_patches(image_tensor, patch_size=16):
    bs, c, h, w = image_tensor.size()
    unfold = torch.nn.Unfold(kernel_size=patch_size, stride=patch_size)
    unfolded = unfold(image_tensor).transpose(1,2).reshape(bs, -1, c*patch_size*patch_size)
    return unfolded

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim): super().__init__(); self.dim = dim
    def forward(self, x):
        device = x.device; half = self.dim // 2
        emb = math.log(10000)/(half-1); emb = torch.exp(torch.arange(half, device=device)*-emb)
        emb = x[:,None]*emb[None,:]
        return torch.cat((emb.sin(), emb.cos()), dim=-1)

class AttentionBlock(nn.Module):
    def __init__(self, hidden_size=128, num_heads=4, masking=True):
        super().__init__(); self.masking = masking
        self.mha = nn.MultiheadAttention(hidden_size, num_heads=num_heads, batch_first=True, dropout=0.0)
    def forward(self, q_in, kv_in, key_mask=None):
        attn_mask = None
        if self.masking:
            L = q_in.shape[1]
            attn_mask = torch.triu(torch.ones(L, L, device=q_in.device), 1).bool()
        return self.mha(q_in, kv_in, kv_in, attn_mask=attn_mask, key_padding_mask=key_mask)[0]

class TransformerBlock(nn.Module):
    def __init__(self, hidden_size=128, num_heads=4, decoder=False, masking=True):
        super().__init__(); self.decoder = decoder
        self.norm1 = nn.LayerNorm(hidden_size)
        self.attn1 = AttentionBlock(hidden_size, num_heads, masking=masking)
        if self.decoder:
            self.norm2 = nn.LayerNorm(hidden_size)
            self.attn2 = AttentionBlock(hidden_size, num_heads, masking=False)
        self.norm_mlp = nn.LayerNorm(hidden_size)
        self.mlp = nn.Sequential(nn.Linear(hidden_size, hidden_size*4), nn.ELU(), nn.Linear(hidden_size*4, hidden_size))
    def forward(self, x, input_key_mask=None, cross_key_mask=None, kv_cross=None):
        x = self.attn1(x, x, key_mask=input_key_mask) + x; x = self.norm1(x)
        if self.decoder:
            x = self.attn2(x, kv_cross, key_mask=cross_key_mask) + x; x = self.norm2(x)
        x = self.mlp(x) + x; return self.norm_mlp(x)

class Decoder(nn.Module):
    def __init__(self, num_emb, hidden_size=128, num_layers=3, num_heads=4, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(num_emb, hidden_size)
        with torch.no_grad():
            self.embedding.weight.mul_(0.001)
        self.pos_emb = SinusoidalPosEmb(hidden_size)
        self.blocks = nn.ModuleList([
            TransformerBlock(hidden_size, num_heads, decoder=True) for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout)
        self.fc_out = nn.Linear(hidden_size, num_emb)

    def forward(self, input_seq, encoder_output, input_padding_mask=None, encoder_padding_mask=None):
        x = self.embedding(input_seq)
        B, L, H = x.shape
        pos = self.pos_emb(torch.arange(L, device=input_seq.device)).view(1, L, H).expand(B, L, H)
        x = self.dropout(x + pos)
        for blk in self.blocks:
            x = blk(x, input_key_mask=input_padding_mask, cross_key_mask=encoder_padding_mask, kv_cross=encoder_output)
        return self.fc_out(self.dropout(x))


class VisionEncoder(nn.Module):
    def __init__(self, image_size, channels_in, patch_size=16, hidden_size=128, num_layers=3, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.fc_in = nn.Linear(channels_in * patch_size * patch_size, hidden_size)
        seq_length = (image_size // patch_size) ** 2
        self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_size).normal_(std=0.02))
        self.blocks = nn.ModuleList([TransformerBlock(hidden_size, num_heads, decoder=False, masking=False) for _ in range(num_layers)])
    def forward(self, image):
        x = self.fc_in(extract_patches(image, self.patch_size)) + self.pos_embedding
        for blk in self.blocks: x = blk(x)
        return x

class ResNetEncoder(nn.Module):
    def __init__(self, hidden_size=128):
        super().__init__()
        base = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        modules = list(base.children())[:-2]
        self.backbone = nn.Sequential(*modules)
        self._out_channels = 2048
        self.proj = nn.Linear(self._out_channels, hidden_size)
    def forward(self, image):
        feat = self.backbone(image)   # [B,2048,Hf,Wf]
        B,C,Hf,Wf = feat.shape
        seq = feat.permute(0,2,3,1).reshape(B, Hf*Wf, C)
        return self.proj(seq)

class VisionEncoderDecoder(nn.Module):
    def __init__(self, image_size, channels_in, num_emb, patch_size=16, hidden_size=128, num_layers=(3,3), num_heads=4, use_pretrained=False):
        super().__init__()
        if use_pretrained:
            self.encoder = ResNetEncoder(hidden_size=hidden_size)
        else:
            self.encoder = VisionEncoder(image_size=image_size, channels_in=channels_in, patch_size=patch_size, hidden_size=hidden_size, num_layers=num_layers[0], num_heads=num_heads)
        self.decoder = Decoder(num_emb=num_emb, hidden_size=hidden_size, num_layers=num_layers[1], num_heads=num_heads)
    def forward(self, input_image, target_seq, padding_mask):
        bool_mask = padding_mask == 0
        enc = self.encoder(input_image)
        dec = self.decoder(input_seq=target_seq, encoder_output=enc, input_padding_mask=bool_mask)
        return dec


# Initialize Model & Optimizer

In [None]:

try:
    sample_images, _ = next(iter(train_loader))
    channels_in = sample_images.shape[1]
except Exception as e:
    channels_in = 3

caption_model = VisionEncoderDecoder(
    image_size=image_size,
    channels_in=channels_in,
    num_emb=tokenizer.vocab_size,
    patch_size=patch_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
    num_heads=num_heads,
    use_pretrained=use_pretrained_resnet_encoder
).to(device)

optimizer = optim.Adam(caption_model.parameters(), lr=learning_rate)
scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None
loss_fn = nn.CrossEntropyLoss(reduction='none', label_smoothing=0.1)

num_params = sum(p.numel() for p in caption_model.parameters())
print(f"Params: {num_params:,} (~{num_params/1e6:.2f}M)")


# Training Loop

In [None]:
def tokenize_batch(caps):
    t = tokenizer(caps, padding=True, truncation=True, return_tensors='pt')
    return t['input_ids'], t['attention_mask']

training_loss_logger = []

def train_one_epoch(model, loader, optimizer, device, loss_fn, scaler=None, ss_ratio=0.1, max_steps=None):
    model.train(); running=0.0; steps=0
    for images, captions in tqdm(loader, leave=False):
        images = images.to(device)
        input_ids, attn = tokenize_batch(captions)
        input_ids, attn = input_ids.to(device), attn.to(device)
        bs, L = input_ids.shape

        # Scheduled Sampling: build input sequence token by token
        target_ids = torch.cat(
            (input_ids[:,1:], torch.zeros(bs,1, dtype=torch.long, device=device)), dim=1
        )
        tokens_in = torch.full_like(input_ids, tokenizer.pad_token_id)

        # Start with [CLS] or SOS token
        tokens_in[:,0] = input_ids[:,0]

        for t in range(1, L):
            use_model_pred = (random.random() < ss_ratio)
            if use_model_pred and t > 1:
                # Predict next token given past tokens
                with torch.no_grad():
                    preds = model(images, tokens_in[:,:t], padding_mask=attn[:,:t])
                    next_token = preds[:,-1,:].argmax(-1)
                tokens_in[:,t] = next_token
            else:
                # Use ground truth
                tokens_in[:,t] = input_ids[:,t]

        # Forward + Loss
        if scaler is not None:
            with torch.cuda.amp.autocast():
                preds = model(images, tokens_in, padding_mask=attn)
                loss = (loss_fn(preds.transpose(1,2), target_ids) * attn).sum() / attn.sum().clamp_min(1.0)
            optimizer.zero_grad(); scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
        else:
            preds = model(images, tokens_in, padding_mask=attn)
            loss = (loss_fn(preds.transpose(1,2), target_ids) * attn).sum() / attn.sum().clamp_min(1.0)
            optimizer.zero_grad(); loss.backward(); optimizer.step()

        running += loss.item(); training_loss_logger.append(loss.item()); steps += 1
        if max_steps and steps >= max_steps: break

    return running/max(1,steps)


for epoch in range(num_epochs):
    loss = train_one_epoch(caption_model, train_loader, optimizer, device, loss_fn, scaler, max_steps=200)
    print(f"Epoch {epoch+1}/{num_epochs} - loss: {loss:.4f}")


# Plot Training Loss

In [None]:

if training_loss_logger:
    plt.figure(figsize=(10,4)); plt.plot(training_loss_logger); plt.title("Training Loss"); plt.xlabel("Step"); plt.show()
else:
    print("No logs yet.")


# Inference – Greedy & Beam Search

Beam search keeps the top `beam_size` sequences at each step (by cumulative log-probability). It's a middle ground between greedy (k=1) and full search (exponential).

In [None]:

def greedy_decode(model, image_tensor, tokenizer, max_len=30, device='cpu'):
    model.eval()
    with torch.no_grad():
        img = image_tensor.unsqueeze(0).to(device)
        start_id = tokenizer.cls_token_id or tokenizer.pad_token_id or 0
        cur = torch.full((1,1), start_id, dtype=torch.long, device=device)
        attn = torch.ones_like(cur)
        for _ in range(max_len):
            logits = model(img, cur, padding_mask=attn)[:,-1,:]
            nxt = logits.argmax(-1, keepdim=True)
            cur = torch.cat([cur, nxt], dim=1); attn = torch.ones_like(cur)
            if tokenizer.sep_token_id is not None and nxt.item()==tokenizer.sep_token_id: break
        return tokenizer.decode(cur.squeeze().tolist(), skip_special_tokens=True)

def beam_search_decode(model, image_tensor, tokenizer, max_len=30, beam_size=3, device='cpu'):
    model.eval()
    with torch.no_grad():
        img = image_tensor.unsqueeze(0).to(device)
        start_id = tokenizer.cls_token_id or tokenizer.pad_token_id or 0
        beams = [(torch.tensor([[start_id]], device=device, dtype=torch.long), 0.0)]
        for _ in range(max_len):
            new_beams = []
            for seq, score in beams:
                logits = model(img, seq, padding_mask=torch.ones_like(seq))[:,-1,:]
                log_probs = torch.nn.functional.log_softmax(logits, dim=-1).squeeze(0)
                topk = torch.topk(log_probs, beam_size)
                for idx, lp in zip(topk.indices.tolist(), topk.values.tolist()):
                    nxt = torch.tensor([[idx]], device=device, dtype=torch.long)
                    new_beams.append((torch.cat([seq, nxt], dim=1), score + lp))
            beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]
            if tokenizer.sep_token_id is not None and all(b[0][0,-1].item()==tokenizer.sep_token_id for b in beams):
                break
        best_seq = beams[0][0].squeeze().tolist()
        return tokenizer.decode(best_seq, skip_special_tokens=True)

# Demo
try:
    img_batch, cap_batch = next(iter(val_loader))
    img0 = img_batch[0]
    plt.figure(figsize=(3,3)); out = torchvision.utils.make_grid(img0,1,normalize=True)
    plt.imshow(out.numpy().transpose(1,2,0)); plt.axis('off')
    print("Reference:", cap_batch[0])
    print("Greedy:", greedy_decode(caption_model, img0, tokenizer, device=device))
    print("Beam (k=3):", beam_search_decode(caption_model, img0, tokenizer, device=device))
except Exception as e:
    print("Inference demo failed:", e)


In [None]:
# Demo
try:
    img_batch, cap_batch = next(iter(train_loader))
    img0 = img_batch[0]
    plt.figure(figsize=(3,3)); out = torchvision.utils.make_grid(img0,1,normalize=True)
    plt.imshow(out.numpy().transpose(1,2,0)); plt.axis('off')
    print("Reference:", cap_batch[0])
    print("Greedy:", greedy_decode(caption_model, img0, tokenizer, device=device))
    print("Beam (k=3):", beam_search_decode(caption_model, img0, tokenizer, device=device))
except Exception as e:
    print("Inference demo failed:", e)

In [None]:
!pip install ultralytics

In [None]:
from ultralytics import YOLO

# Use pretrained YOLOv8s (small) or YOLOv8m/l if you want stronger
yolo = YOLO("yolov8s.pt")

In [None]:
# ---------------------
# Dependencies (install in Colab if missing)
# ---------------------
# !pip install -q transformers
# ---------------------

import torch, torchvision
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from torchvision import transforms
import math, itertools

# Standard COCO categories for torchvision detection models
COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
    'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
    'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana',
    'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table',
    'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
    'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]


# ---------------------
# 1) Utility: unnormalize an image tensor (the detector expects 0..1 not normalized)
# ---------------------
MEAN = torch.tensor([0.485, 0.456, 0.406])
STD  = torch.tensor([0.229, 0.224, 0.225])

def unnormalize_image_tensor(img_norm):
    """
    img_norm: tensor [C,H,W] normalized with MEAN/STD
    Returns: tensor [C,H,W] in 0-1 range suitable for torchvision detection models
    """
    img = img_norm.detach().cpu().clone()
    img = img * STD[:, None, None] + MEAN[:, None, None]
    img = img.clamp(0., 1.)
    return img

# ---------------------
# 2) Object detection helper (Faster-RCNN)
# ---------------------
def load_detection_model(device='cpu'):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    model.to(device).eval()
    return model


def detect_objects(img_norm_tensor, detection_model, device='cpu', score_thresh=0.35, topk=15):
    detection_model.eval()
    with torch.no_grad():
        out = detection_model([img_norm_tensor.to(device)])[0]
    labels = out['labels'].cpu().tolist()
    scores = out['scores'].cpu().tolist()

    pairs = []
    for label, score in zip(labels, scores):
        if label < len(COCO_INSTANCE_CATEGORY_NAMES):  # ✅ prevent IndexError
            pairs.append((COCO_INSTANCE_CATEGORY_NAMES[label], score))

    # filter and keep top-k
    filtered = {}
    for cls, s in pairs:
        if s >= score_thresh:
            filtered[cls] = max(filtered.get(cls, 0.0), s)

    return sorted(filtered.items(), key=lambda x: x[1], reverse=True)[:topk]

def detect_objects_yolo(img_tensor, model, device='cpu', conf=0.25):
    """
    img_tensor: torch.Tensor [C,H,W] in range [0,1] (normalized)
    model: YOLO model
    """
    import numpy as np
    img = img_tensor.mul(255).byte().cpu().numpy().transpose(1,2,0)  # [H,W,C] uint8

    results = model.predict(img, conf=conf, device=device, verbose=False)

    detected = []
    if len(results) > 0 and len(results[0].boxes) > 0:
        for cls_id, score in zip(results[0].boxes.cls.tolist(),
                                 results[0].boxes.conf.tolist()):
            label = model.names[int(cls_id)]
            detected.append((label, score))

    return detected

# ---------------------
# 3) Candidate generation (beam search returning multiple candidates + log-prob scores)
# ---------------------
def beam_search_candidates(model, image_tensor, tokenizer, beam_size=5, max_len=30, device='cpu'):
    """
    Returns a list of (decoded_text, token_ids, score) sorted by score (desc).
    Score is sum log-probs (higher is better).
    """
    model.eval()
    with torch.no_grad():
        img = image_tensor.unsqueeze(0).to(device)
        start_id = tokenizer.cls_token_id or tokenizer.pad_token_id or 0
        beams = [(torch.tensor([[start_id]], device=device, dtype=torch.long), 0.0)]  # (seq, score)
        completed = []
        for _ in range(max_len):
            new_beams = []
            for seq, score in beams:
                logits = model(img, seq, padding_mask=torch.ones_like(seq).to(device))[:, -1, :]  # [1, vocab]
                log_probs = F.log_softmax(logits, dim=-1).squeeze(0)  # [vocab]
                topk = torch.topk(log_probs, k=beam_size)
                for idx, lp in zip(topk.indices.tolist(), topk.values.tolist()):
                    nxt = torch.tensor([[idx]], device=device, dtype=torch.long)
                    new_seq = torch.cat([seq, nxt], dim=1)
                    new_score = score + float(lp)
                    # if EOS token reached, add to completed
                    if tokenizer.sep_token_id is not None and idx == tokenizer.sep_token_id:
                        completed.append((new_seq, new_score))
                    else:
                        new_beams.append((new_seq, new_score))
            # keep top beam_size ongoing beams
            beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]
            if len(beams) == 0:
                break
        all_candidates = completed + beams
        # decode and return
        out = []
        for seq, score in sorted(all_candidates, key=lambda x: x[1], reverse=True):
            tokens = seq.squeeze().tolist()
            txt = tokenizer.decode(tokens, skip_special_tokens=True)
            out.append((txt, tokens, score))
        return out

# ---------------------
# 4) LM scoring using a small causal model (distilgpt2) — used to prefer grammatical captions
# ---------------------
def load_lm(device='cpu', model_name='distilgpt2'):
    lm_tokenizer = AutoTokenizer.from_pretrained(model_name)
    lm_model = AutoModelForCausalLM.from_pretrained(model_name).to(device).eval()
    return lm_tokenizer, lm_model

def lm_score_text(lm_tokenizer, lm_model, text, device='cpu'):
    """
    Returns a log-likelihood-style score (higher = better).
    We compute negative loss * token_count so bigger = more probable.
    """
    enc = lm_tokenizer(text, return_tensors='pt')
    input_ids = enc['input_ids'].to(device)
    with torch.no_grad():
        # model returns loss if labels provided
        out = lm_model(input_ids, labels=input_ids)
        # out.loss is mean token loss -> to make comparable multiply by n_tokens
        n_tokens = input_ids.shape[1]
        neg_nll = -float(out.loss.item()) * n_tokens
    return neg_nll

# ---------------------
# 5) object-match scoring and hallucination detection (substring matching)
# ---------------------
def caption_object_overlap_score(caption, detected_items):
    """
    caption: string
    detected_items: list of (label, score)
    returns: fraction_of_detected_labels_mentioned, list_of_matched_labels, list_of_mentioned_labels
    Strategy: for each detected label (like 'tennis racket'), check if any word of that label appears in caption.
    Also collect all COCO label tokens that appear in caption (mentioned_labels).
    """
    cap = caption.lower()
    detected_names = [name for name, _ in detected_items]
    matched = []
    for name in detected_names:
        # simple substring match (robust enough for two-word labels like 'tennis racket')
        if name in cap:
            matched.append(name)
        else:
            # also try word-by-word match (e.g., 'racket' present)
            for w in name.split():
                if f' {w} ' in f' {cap} ':
                    matched.append(name); break
    # find any COCO labels mentioned in caption
    mentioned = []
    for coco_name in COCO_INSTANCE_CATEGORY_NAMES:
        if coco_name == '__background__': continue
        if coco_name in cap:
            mentioned.append(coco_name)
        else:
            for w in coco_name.split():
                if f' {w} ' in f' {cap} ':
                    mentioned.append(coco_name); break
    frac = len(matched) / max(1, len(detected_names))
    return frac, matched, list(set(mentioned))

# ---------------------
# 6) n-gram repeat removal (deduplicate repeated chunks)
# ---------------------
def remove_repeated_ngrams(tokens, max_ngram=4):
    """
    tokens: list of strings
    Removes consecutive repeated n-grams (e.g. "side of a side of a side" -> "side of a").
    """
    changed = True
    while changed:
        changed = False
        for n in range(max_ngram, 0, -1):
            if len(tokens) < 2*n:
                continue
            i = 0
            while i + 2*n <= len(tokens):
                if tokens[i:i+n] == tokens[i+n:i+2*n]:
                    # remove the second repeated ngram
                    del tokens[i+n:i+2*n]
                    changed = True
                    # do NOT advance i to recheck for triple repeats
                else:
                    i += 1
            if changed:
                break
    return tokens

def dedupe_and_postprocess_caption(raw_caption):
    toks = raw_caption.strip().split()
    toks = remove_repeated_ngrams(toks, max_ngram=4)
    # collapse exact token duplicates e.g. 'a a a' -> 'a'
    cleaned = []
    for t in toks:
        if len(cleaned) > 0 and cleaned[-1] == t:
            continue
        cleaned.append(t)
    return " ".join(cleaned)

# ---------------------
# 7) High-level verified caption generator
# ---------------------
def generate_verified_caption(image_tensor,
                              caption_model,
                              caption_tokenizer,
                              detection_model=None,
                              lm_tokenizer=None,
                              lm_model=None,
                              device='cpu',
                              beam_k=5,
                              alpha_lm=1.0,
                              beta_overlap=2.0,
                              gamma_halluc=1.0):
    """
    image_tensor: single normalized image [C,H,W] (same preprocessing used for caption_model)
    caption_model: your encoder-decoder caption model (callable that accepts (img, seq, padding_mask))
    caption_tokenizer: tokenizer used by your caption decoder (for decode)
    detection_model: torchvision detection model (optional but recommended)
    lm_tokenizer, lm_model: small causal LM for grammar scoring (optional)
    Returns: best_caption, diagnostics dict
    """
    # 1) detect objects
    detected = []
    if detection_model is not None:
        # detected = detect_objects(image_tensor, detection_model, device=device, score_thresh=0.35, topk=15)
        detected = detect_objects_yolo(image_tensor, yolo, device=device)
    # 2) generate beam candidates
    candidates = beam_search_candidates(caption_model, image_tensor, caption_tokenizer, beam_size=beam_k, max_len=28, device=device)
    scored = []
    for txt, token_ids, score in candidates:
        # LM score (higher better)
        if lm_model is not None:
            try:
                lm_s = lm_score_text(lm_tokenizer, lm_model, txt, device=device)
            except Exception as e:
                lm_s = 0.0
        else:
            lm_s = 0.0
        # overlap/hallucination
        frac_overlap, matched, mentioned = caption_object_overlap_score(txt, detected)
        hallucinated_labels = [m for m in mentioned if m not in matched]
        # combined score: weighted sum (tune alpha/beta/gamma)
        combined = alpha_lm * lm_s + beta_overlap * (frac_overlap) - gamma_halluc * len(hallucinated_labels)
        scored.append({
            'text': txt,
            'tokens': token_ids,
            'beam_score': score,
            'lm_score': lm_s,
            'frac_overlap': frac_overlap,
            'matched': matched,
            'mentioned': mentioned,
            'hallucinated': hallucinated_labels,
            'combined': combined
        })
    # pick best by combined score
    if len(scored) == 0:
        return "", {'error': 'no candidates'}
    best = sorted(scored, key=lambda x: x['combined'], reverse=True)[0]
    best_text = dedupe_and_postprocess_caption(best['text'])
    best['final_text'] = best_text
    best['detected'] = detected
    return best_text, best

# ---------------------
# Example usage (after you have caption_model, tokenizer, device)
# ---------------------
# 1) load detector and LM once
detection_model = load_detection_model(device=device)   # recommended
lm_tokenizer, lm_model = load_lm(device=device)         # optional (distilgpt2)
#
# 2) run on an image (img0 is normalized tensor [C,H,W] as in your val_loader)
caption, diag = generate_verified_caption(img0, caption_model, tokenizer,
                                          detection_model=detection_model,
                                          lm_tokenizer=lm_tokenizer, lm_model=lm_model,
                                          device=device, beam_k=5)
print("Detected objects:", diag['detected'])
print("Matched labels in caption:", diag['matched'])
print("Hallucinated labels:", diag['hallucinated'])
print("Returned caption:", caption)


In [None]:
# Demo
try:
    img_batch, cap_batch = next(iter(train_loader))
    img0 = img_batch[0]
    plt.figure(figsize=(3,3)); out = torchvision.utils.make_grid(img0,1,normalize=True)
    plt.imshow(out.numpy().transpose(1,2,0)); plt.axis('off')
    print("Reference:", cap_batch[0])
    print("Greedy:", greedy_decode(caption_model, img0, tokenizer, device=device))
    print("Beam (k=3):", beam_search_decode(caption_model, img0, tokenizer, device=device))
except Exception as e:
    print("Inference demo failed:", e)

# Evaluation – BLEU & CIDEr (optional)

In [None]:
import nltk
nltk.download("punkt")
nltk.download("punkt_tab")

In [None]:

from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from pycocotools.coco import COCO
from pycocoevalcap.cider.cider import Cider

def evaluate_bleu(loader, model, tokenizer, max_samples=100):
    model.eval()
    smoothie = SmoothingFunction().method4
    scores = []; n = 0
    for imgs, caps in tqdm(loader, leave=False):
        for i in range(len(imgs)):
            pred = greedy_decode(model, imgs[i], tokenizer, device=device)
            ref = caps[i]
            ref_tokens = nltk.word_tokenize(ref.lower())
            hyp_tokens = nltk.word_tokenize(pred.lower())
            scores.append(sentence_bleu([ref_tokens], hyp_tokens, smoothing_function=smoothie))
            n += 1
            if n >= max_samples: break
        if n >= max_samples: break
    return float(np.mean(scores)) if scores else 0.0

def evaluate_cider(val_ann_path, ids_to_caps_pred, max_samples=200):
    coco = COCO(str(val_ann_path))
    gts = {}
    res = {}
    count = 0

    for img_id, pred in ids_to_caps_pred.items():
        ann_ids = coco.getAnnIds(imgIds=img_id)
        anns = coco.loadAnns(ann_ids)

        # keep only caption strings
        gts[img_id] = [a["caption"] for a in anns]
        res[img_id] = [pred]

        count += 1
        if count >= max_samples:
            break

    cider_scorer = Cider()
    score, _ = cider_scorer.compute_score(gts, res)
    return float(score)



# BLEU quick check
try:
    bleu = evaluate_bleu(val_loader, caption_model, tokenizer, max_samples=100)
    print(f"BLEU (greedy, 100 samples): {bleu:.3f}")
except Exception as e:
    print("BLEU eval failed:", e)

# CIDEr subset on first 200 images
try:
    coco_val = COCO(str(val_ann))
    img_ids = coco_val.getImgIds()[:200]
    ids_to_pred = {}
    from PIL import Image
    for img_id in tqdm(img_ids, leave=False):
        info = coco_val.loadImgs([img_id])[0]
        img_path = (val_images / info['file_name']).as_posix()
        img = Image.open(img_path).convert('RGB')
        img_t = val_transform(img)
        pred = greedy_decode(caption_model, img_t, tokenizer, device=device)
        ids_to_pred[img_id] = pred

    cider = evaluate_cider(val_ann, ids_to_pred, max_samples=200)
    print(f"CIDEr (subset): {cider:.3f}")
except Exception as e:
    print("CIDEr eval failed:", e)




# Save/Load Helpers

In [None]:

def save_model(model, path='/content/caption_model.pth'):
    torch.save(model.state_dict(), path); print("Saved:", path)
def load_model(model, path='/content/caption_model.pth', map_location=None):
    model.load_state_dict(torch.load(path, map_location=map_location)); print("Loaded:", path)


In [None]:
save_model(caption_model)