In [2]:
import json, re, torch, numpy as np
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
import nltk
from collections import Counter

import torch, torch.nn as nn
from torchvision import models

import argparse, os, pandas as pd
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt


PAD, BOS, EOS, UNK = "<pad>", "<bos>", "<eos>", "<unk>"


# Utils

In [3]:
class Vocabulary:
    def __init__(self, min_freq=3):
        self.min_freq = min_freq
        self.word2id = {PAD:0, BOS:1, EOS:2, UNK:3}
        self.id2word = {0:PAD, 1:BOS, 2:EOS, 3:UNK}

    def build(self, texts):
        tok = lambda s: nltk.word_tokenize(re.sub(r"[^A-Za-z0-9' ]"," ", s.lower()))
        cnt = Counter()
        for t in texts:
            cnt.update(tok(t))
        for w, c in cnt.items():
            if c >= self.min_freq and w not in self.word2id:
                idx = len(self.word2id)
                self.word2id[w] = idx
                self.id2word[idx] = w

    def encode(self, text, max_len=20):
        tok = nltk.word_tokenize(re.sub(r"[^A-Za-z0-9' ]"," ", text.lower()))
        ids = [self.word2id.get(w, self.word2id[UNK]) for w in tok][:max_len]
        return [self.word2id[BOS]] + ids + [self.word2id[EOS]]

    def decode(self, ids):
        words = []
        for i in ids:
            w = self.id2word.get(int(i), UNK)
            if w in [PAD, BOS]:
                continue
            if w == EOS:
                break
            words.append(w)
        return " ".join(words)

    def to_json(self, path):
        with open(path, "w", encoding="utf-8") as f:
            json.dump({"min_freq": self.min_freq, "word2id": self.word2id}, f)

    @classmethod
    def from_json(cls, path):
        with open(path, "r", encoding="utf-8") as f:
            obj = json.load(f)
        v = cls(min_freq=obj["min_freq"])
        v.word2id = obj["word2id"]
        v.id2word = {int(i):w for w,i in [(i, w) for w,i in v.word2id.items()]}
        return v

class CaptionDataset(Dataset):
    def __init__(self, df, images_root, vocab, split="train", max_len=20, image_size=224):
        self.df = df[df["split"] == split].reset_index(drop=True)
        self.images_root = images_root
        self.vocab = vocab
        self.max_len = max_len
        self.tf = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = f"{self.images_root}/{row['image_path']}"
        img = Image.open(img_path).convert("RGB")
        img = self.tf(img)
        ids = self.vocab.encode(row["caption"], max_len=self.max_len)
        return img, torch.tensor(ids, dtype=torch.long)

def pad_collate(batch):
    imgs, seqs = zip(*batch)
    imgs = torch.stack(imgs, dim=0)
    lengths = [len(s) for s in seqs]
    maxlen = max(lengths)
    PAD_ID = 0
    padded = torch.full((len(seqs), maxlen), PAD_ID, dtype=torch.long)
    for i, s in enumerate(seqs):
        padded[i, :len(s)] = s
    return imgs, padded, torch.tensor(lengths, dtype=torch.long)

def compute_bleu(gens, refs, n=4):
    import nltk
    weights_map = {
        1:(1.0, 0, 0, 0),
        2:(0.5, 0.5, 0, 0),
        3:(1/3, 1/3, 1/3, 0),
        4:(0.25, 0.25, 0.25, 0.25)
    }
    weights = weights_map.get(n, weights_map[4])
    refs_tok = [[nltk.word_tokenize(r.lower())] for r in refs]
    gens_tok = [nltk.word_tokenize(g.lower()) for g in gens]
    try:
        return nltk.translate.bleu_score.corpus_bleu(refs_tok, gens_tok, weights=weights)
    except ZeroDivisionError:
        return 0.0

# Models

In [4]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_dim=256):
        super().__init__()
        backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        for p in backbone.parameters():
            p.requires_grad = False
        modules = list(backbone.children())[:-1]
        self.cnn = nn.Sequential(*modules)
        self.fc = nn.Linear(backbone.fc.in_features, embed_dim)
        self.bn = nn.BatchNorm1d(embed_dim, momentum=0.01)

    def forward(self, images):
        with torch.no_grad():
            feats = self.cnn(images).squeeze()
            if feats.dim() == 1:
                feats = feats.unsqueeze(0)
        feats = self.fc(feats)
        feats = self.bn(feats)
        return torch.relu(feats)

class DecoderLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim=256, hidden_dim=512, num_layers=1, dropout=0.1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, dropout=dropout, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, features, captions):
        emb = self.embed(captions)
        feats = features.unsqueeze(1)
        x = torch.cat([feats, emb], dim=1)
        out, _ = self.lstm(x)
        logits = self.fc(out)
        return logits

    def sample(self, features, max_len=20, bos_id=1, eos_id=2):
        B = features.size(0)
        inputs = features.unsqueeze(1)
        states = None
        outputs = []
        for _ in range(max_len):
            out, states = self.lstm(inputs, states)
            logits = self.fc(out[:, -1, :])
            _, next_ids = torch.max(logits, dim=1)
            outputs.append(next_ids)
            emb = self.embed(next_ids).unsqueeze(1)
            inputs = emb
            if (next_ids == eos_id).all():
                break
        return torch.stack(outputs, dim=1)

# Train

In [None]:
def plot_curves(history, outpath):
    fig, ax = plt.subplots(figsize=(7,5))
    ax.plot(history["train_loss"], label="train_loss")
    ax.plot(history["val_loss"], label="val_loss")
    ax.set_xlabel("Epoch"); ax.set_ylabel("Loss (CE)"); ax.set_title("Training & Validation Loss")
    ax.legend(); fig.tight_layout(); fig.savefig(outpath, dpi=160); plt.close(fig)

def plot_bleu(bleus, outpath):
    fig, ax = plt.subplots(figsize=(7,5))
    ax.plot(bleus, label="BLEU-4")
    ax.set_xlabel("Epoch"); ax.set_ylabel("BLEU"); ax.set_title("Validation BLEU-4")
    ax.legend(); fig.tight_layout(); fig.savefig(outpath, dpi=160); plt.close(fig)

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--captions", required=True, help="CSV with columns: image_path, caption, split")
    ap.add_argument("--images-root", type=str, default="data")
    ap.add_argument("--outdir", type=str, default="outputs")
    ap.add_argument("--epochs", type=int, default=10)
    ap.add_argument("--batch-size", type=int, default=64)
    ap.add_argument("--embed-dim", type=int, default=256)
    ap.add_argument("--hidden-dim", type=int, default=512)
    ap.add_argument("--min-freq", type=int, default=3)
    ap.add_argument("--max-len", type=int, default=20)
    ap.add_argument("--lr", type=float, default=1e-3)
    ap.add_argument("--seed", type=int, default=42)
    args = ap.parse_args()

    os.makedirs(args.outdir, exist_ok=True)
    torch.manual_seed(args.seed); np.random.seed(args.seed)

    df = pd.read_csv(args.captions)
    vocab = Vocabulary(min_freq=args.min_freq)
    vocab.build(df[df["split"]=="train"]["caption"].tolist())
    vocab.to_json(os.path.join(args.outdir, "vocab.json"))

    train_ds = CaptionDataset(df, args.images_root, vocab, split="train", max_len=args.max_len)
    val_ds = CaptionDataset(df, args.images_root, vocab, split="val", max_len=args.max_len)
    train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=pad_collate, num_workers=2)
    val_dl = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=pad_collate, num_workers=2)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    enc = EncoderCNN(embed_dim=args.embed_dim).to(device)
    dec = DecoderLSTM(vocab_size=len(vocab.word2id), embed_dim=args.embed_dim, hidden_dim=args.hidden_dim).to(device)

    crit = nn.CrossEntropyLoss(ignore_index=0)
    params = list(dec.parameters()) + list(enc.fc.parameters()) + list(enc.bn.parameters())
    opt = torch.optim.Adam(params, lr=args.lr)

    history = {"train_loss": [], "val_loss": [], "bleu4": []}
    best_bleu = -1.0
    for ep in range(1, args.epochs+1):
        enc.train(); dec.train(); tr_loss = 0.0; n=0
        for imgs, tgt, lengths in tqdm(train_dl, desc=f"Epoch {ep}/{args.epochs} [train]"):
            imgs, tgt = imgs.to(device), tgt.to(device)
            opt.zero_grad()
            feats = enc(imgs)
            logits = dec(feats, tgt[:, :-1])
            loss = crit(logits.reshape(-1, logits.size(-1)), tgt.reshape(-1))
            loss.backward(); opt.step()
            tr_loss += loss.item() * imgs.size(0); n += imgs.size(0)
        tr = tr_loss / n

        enc.eval(); dec.eval(); va_loss=0.0; vn=0
        gens, refs = [], []
        with torch.no_grad():
            for imgs, tgt, lengths in tqdm(val_dl, desc=f"Epoch {ep}/{args.epochs} [val]"):
                imgs, tgt = imgs.to(device), tgt.to(device)
                feats = enc(imgs)
                logits = dec(feats, tgt[:, :-1])
                loss = crit(logits.reshape(-1, logits.size(-1)), tgt.reshape(-1))
                va_loss += loss.item() * imgs.size(0); vn += imgs.size(0)
                out_ids = dec.sample(feats, max_len=args.max_len)
                for i in range(out_ids.size(0)):
                    gens.append(vocab.decode(out_ids[i].cpu().numpy()))
                    refs.append(vocab.decode(tgt[i, 1:].cpu().numpy()))
        vl = va_loss / vn
        bleu4 = compute_bleu(gens, refs, n=4)

        history["train_loss"].append(tr); history["val_loss"].append(vl); history["bleu4"].append(bleu4)
        print(f"[epoch {ep}] train_loss={tr:.4f} val_loss={vl:.4f} bleu4={bleu4:.4f}")
        if bleu4 > best_bleu:
            best_bleu = bleu4
            torch.save({"encoder": enc.state_dict(), "decoder": dec.state_dict(),
                        "vocab_size": len(vocab.word2id), "embed_dim": args.embed_dim,
                        "hidden_dim": args.hidden_dim, "max_len": args.max_len},
                       os.path.join(args.outdir, "best_captioner.pt"))

        plot_curves(history, os.path.join(args.outdir, "training_curves.png"))
        plot_bleu(history["bleu4"], os.path.join(args.outdir, "bleu_scores.png"))

    with open(os.path.join(args.outdir, "metrics.json"), "w") as f:
        json.dump({"best_bleu4": best_bleu}, f, indent=2)
    print("[OK] Training done. Best BLEU-4:", best_bleu)

if __name__ == "__main__":
    main()

#Infer

In [None]:
def load_vocab(path):
    with open(path, "r", encoding="utf-8") as f:
        obj = json.load(f)
    v = Vocabulary(min_freq=obj.get("min_freq", 3))
    v.word2id = obj["word2id"]
    v.id2word = {int(i): w for w, i in v.word2id.items()}
    return v

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--checkpoint", required=True)
    ap.add_argument("--vocab", required=True)
    ap.add_argument("--image", required=True)
    ap.add_argument("--max-len", type=int, default=20)
    args = ap.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ckpt = torch.load(args.checkpoint, map_location=device)
    vocab = load_vocab(args.vocab)

    enc = EncoderCNN(embed_dim=ckpt["embed_dim"]).to(device)
    dec = DecoderLSTM(vocab_size=ckpt["vocab_size"], embed_dim=ckpt["embed_dim"], hidden_dim=ckpt["hidden_dim"]).to(device)

    enc.load_state_dict(ckpt["encoder"]); dec.load_state_dict(ckpt["decoder"])
    enc.eval(); dec.eval()

    tf = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ])
    img = tf(Image.open(args.image).convert("RGB")).unsqueeze(0).to(device)
    with torch.no_grad():
        feats = enc(img)
        ids = dec.sample(feats, max_len=args.max_len, bos_id=vocab.word2id["<bos>"], eos_id=vocab.word2id["<eos>"])
    caption = vocab.decode(ids[0].cpu().numpy())
    print(caption)

if __name__ == "__main__":
    main()