In [12]:
# ========================== 1) Imports & Config ==========================
import os, zipfile, tempfile, requests, time, random
from collections import Counter
from PIL import Image

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from torchvision.models.feature_extraction import create_feature_extractor
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import nltk
from nltk.tokenize import word_tokenize
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

nltk.download('punkt', quiet=True)

class Config:
    BATCH_SIZE = 32             # CPU friendly
    NUM_EPOCHS = 20
    EMBED_SIZE = 512
    HIDDEN_SIZE = 512
    DROPOUT = 0.3
    LR = 1e-4
    VOCAB_THRESHOLD = 5

    DEVICE = torch.device('cpu')  # force CPU

    # Will be created under a temp dir by prepare_data()
    IMAGES_DIR = None
    CAPTIONS_FILE = None
    CACHED_FEATURES_DIR = None

print("Using device:", Config.DEVICE)

Using device: cpu


In [2]:
# ========================== 2) Vocab ==========================
class Vocabulary:
    def __init__(self, threshold=5):
        self.threshold = threshold
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0
        for tok in ['<pad>', '<start>', '<end>', '<unk>']:
            self.add_word(tok)

    def add_word(self, w):
        if w not in self.word2idx:
            self.word2idx[w] = self.idx
            self.idx2word[self.idx] = w
            self.idx += 1

    def build(self, captions_list):
        counter = Counter()
        for cap in tqdm(captions_list, desc="Building Vocabulary"):
            counter.update(word_tokenize(cap.lower()))
        for w, c in counter.items():
            if c >= self.threshold:
                self.add_word(w)
        print(f"Vocabulary size: {len(self)}")

    def __call__(self, w):
        return self.word2idx.get(w, self.word2idx['<unk>'])

    def __len__(self):
        return len(self.word2idx)

In [3]:
# ========================== 3) Encoder (ResNet-50 -> 2048) ==========================
class EncoderCNN(nn.Module):
    def __init__(self):
        super().__init__()
        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        self.backbone = create_feature_extractor(resnet, return_nodes={'layer4':'feat'})
        for p in self.backbone.parameters():
            p.requires_grad = False
        self.pool = nn.AdaptiveAvgPool2d((1,1))

    def forward(self, images):
        x = self.backbone(images)['feat']          # (B, 2048, H, W)
        x = self.pool(x)                           # (B, 2048, 1, 1)
        return x.view(x.size(0), -1)               # (B, 2048)

In [4]:
# ========================== 4) Captions parser ==========================
def parse_captions(captions_file):
    ann = {}
    with open(captions_file, 'r') as f:
        for line in f:
            line = line.strip()
            if not line: continue
            parts = line.split('\t', 1)
            if len(parts) < 2:
                parts = line.split(' ', 1)
                if len(parts) < 2: continue
            image_caption_id, caption = parts[0].strip(), parts[1].strip()
            img_id = image_caption_id.split('#')[0]
            if not img_id.lower().endswith('.jpg'):
                continue
            ann.setdefault(img_id, []).append(caption)
    image_ids = list(ann.keys())
    return ann, image_ids

# ========================== 5) Simple feature caching (no DataLoader) ==========================
def cache_features_naive(image_ids, images_dir, out_dir, transform):
    os.makedirs(out_dir, exist_ok=True)
    enc = EncoderCNN().to('cpu').eval()
    for img_id in tqdm(image_ids, desc="Caching features (simple loop)"):
        out_path = os.path.join(out_dir, img_id.replace('.jpg', '.pt'))
        if os.path.exists(out_path):
            continue
        img_path = os.path.join(images_dir, img_id)
        try:
            with Image.open(img_path) as im:
                x = transform(im.convert('RGB')).unsqueeze(0)  # (1,3,224,224)
            with torch.inference_mode():
                feat = enc(x).squeeze(0).cpu()                 # (2048,)
            torch.save(feat, out_path)
        except Exception as e:
            # skip corrupt/missing image
            # print(f"Skip {img_id}: {e}")
            continue

# ========================== 6) Dataset that reads cached features ==========================
class Flickr8kCached(Dataset):
    def __init__(self, image_ids, annotations, vocab, cached_dir):
        self.image_ids = image_ids
        self.annotations = annotations
        self.vocab = vocab
        self.cached_dir = cached_dir

    def __len__(self):
        return len(self.image_ids) * 5  # 5 caps per image

    def __getitem__(self, idx):
        img_i = idx // 5
        cap_i = idx % 5
        img_id = self.image_ids[img_i]
        feat_path = os.path.join(self.cached_dir, img_id.replace('.jpg', '.pt'))
        feat = torch.load(feat_path)        # (2048,)
        # tokenise caption
        tokens = word_tokenize(self.annotations[img_id][cap_i].lower())
        seq = [self.vocab('<start>')] + [self.vocab(t) for t in tokens] + [self.vocab('<end>')]
        return feat, torch.tensor(seq, dtype=torch.long)

def collate_fn(batch):
    batch.sort(key=lambda x: len(x[1]), reverse=True)
    feats, caps = zip(*batch)              # feats: list of (2048,)
    feats = torch.stack(feats, 0)
    lengths = [len(c) for c in caps]
    targets = torch.zeros(len(caps), max(lengths), dtype=torch.long)
    for i, c in enumerate(caps):
        targets[i, :len(c)] = c
    return feats, targets, torch.tensor(lengths, dtype=torch.long)

In [5]:
# ========================== 7) GRU Decoder (train + step for inference) ==========================
class DecoderWithGRU(nn.Module):
    def __init__(self, embed_dim, hidden_dim, vocab_size, encoder_dim=2048, dropout=0.3):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.drop = nn.Dropout(dropout)
        self.init_h = nn.Linear(encoder_dim, hidden_dim)
        self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.reset_params()

    def reset_params(self):
        self.embed.weight.data.uniform_(-0.1,0.1)
        self.fc.bias.data.zero_()
        self.fc.weight.data.uniform_(-0.1,0.1)

    def init_hidden_state(self, enc_out):   # enc_out: (B,2048) or (B,1,2048)
        if enc_out.dim()==3: enc_out = enc_out.squeeze(1)
        h0 = self.init_h(enc_out).unsqueeze(0)  # (1,B,H)
        return h0

    # one-step for decoding
    def step(self, word_idx, h):
        # word_idx: (B,) Long
        emb = self.embed(word_idx).unsqueeze(1)    # (B,1,E)
        out, h_new = self.gru(emb, h)              # out: (B,1,H)
        logits = self.fc(out.squeeze(1))           # (B,V)
        return logits, h_new

    # training forward with packed sequences
    def forward(self, enc_feats, caps, caplens):
        if enc_feats.dim()==2:
            enc_feats = enc_feats.unsqueeze(1)
        # sort by length
        caplens = caplens.squeeze(-1) if caplens.dim()>1 else caplens
        caplens, sort_idx = caplens.sort(0, descending=True)
        enc_feats = enc_feats[sort_idx]
        caps = caps[sort_idx]

        h0 = self.init_hidden_state(enc_feats)     # (1,B,H)
        emb = self.drop(self.embed(caps))          # (B,T,E)
        packed = nn.utils.rnn.pack_padded_sequence(emb, caplens.cpu(), batch_first=True, enforce_sorted=True)
        packed_out, _ = self.gru(packed, h0)
        out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)  # (B,T,H)
        logits = self.fc(out)                      # (B,T,V)
        decode_lengths = (caplens - 1).tolist()
        return logits, caps, decode_lengths, sort_idx

In [6]:
# ========================== 8) Beam search & BLEU on cached features ==========================
def beam_search(decoder, img_feat, vocab, beam_size=5, max_len=20):
    k = beam_size
    start = vocab.word2idx['<start>']; end = vocab.word2idx['<end>']
    enc = img_feat.unsqueeze(1)                    # (1,1,2048)
    h = decoder.init_hidden_state(enc)             # (1,1,H)

    beams = [([], h, 0.0)]                         # (seq, h, logp)
    for _ in range(max_len):
        cand = []
        for seq, h_cur, score in beams:
            if seq and seq[-1]==end:
                cand.append((seq, h_cur, score)); continue
            last = torch.tensor([seq[-1] if seq else start], dtype=torch.long)
            logits, h_next = decoder.step(last, h_cur)
            logp = torch.log_softmax(logits, dim=-1)
            topk_lp, topk_idx = torch.topk(logp, k, dim=-1)
            for i in range(k):
                cand.append((seq+[topk_idx[0,i].item()], h_next, score+topk_lp[0,i].item()))
        beams = sorted(cand, key=lambda x:x[2], reverse=True)[:k]
        if all(s and s[-1]==end for s,_,_ in beams): break
    return beams[0][0]

def evaluate_bleu_features(decoder, val_dataset, vocab, beam_size=5, max_len=20):
    sf = SmoothingFunction()
    decoder.eval()
    scores=[]
    with torch.no_grad():
        for img_id in tqdm(val_dataset.image_ids, desc="BLEU eval (cached)"):
            feat = torch.load(os.path.join(val_dataset.cached_dir, img_id.replace('.jpg','.pt'))).unsqueeze(0)
            seq = beam_search(decoder, feat, vocab, beam_size=beam_size, max_len=max_len)
            pred = [vocab.idx2word[i] for i in seq if i not in {vocab.word2idx['<start>'],vocab.word2idx['<end>'],vocab.word2idx['<pad>']}]
            refs = [word_tokenize(c.lower()) for c in val_dataset.annotations[img_id]]
            scores.append(sentence_bleu(refs, pred, smoothing_function=sf.method4))
    return float(np.mean(scores)) if scores else 0.0

In [7]:
# ========================== 9) Download Flickr8k ==========================
def download_flickr8k():
    tmp = tempfile.mkdtemp(prefix="flickr8k_")
    cap_url = "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip"
    img_url = "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip"
    cap_zip = os.path.join(tmp, "caps.zip")
    img_zip = os.path.join(tmp, "imgs.zip")

    for url, out in [(cap_url, cap_zip), (img_url, img_zip)]:
        if not os.path.exists(out):
            r = requests.get(url, stream=True, timeout=300)
            r.raise_for_status()
            with open(out, "wb") as f:
                for chunk in r.iter_content(8192): f.write(chunk)

    with zipfile.ZipFile(cap_zip, 'r') as z: z.extractall(tmp)
    with zipfile.ZipFile(img_zip, 'r') as z: z.extractall(tmp)

    Config.CAPTIONS_FILE = os.path.join(tmp, "Flickr8k.token.txt")
    Config.IMAGES_DIR = os.path.join(tmp, "Flicker8k_Dataset")
    Config.CACHED_FEATURES_DIR = os.path.join(tmp, "cached_features")
    os.makedirs(Config.CACHED_FEATURES_DIR, exist_ok=True)
    print("Data at:", tmp)
    return tmp

In [8]:
# ========================== 10) Prepare data ==========================
def prepare_data():
    root = download_flickr8k()
    annotations, image_ids = parse_captions(Config.CAPTIONS_FILE)

    # Build vocab
    all_caps = [c for caps in annotations.values() for c in caps]
    vocab = Vocabulary(Config.VOCAB_THRESHOLD)
    vocab.build(all_caps)

    # Split first (so we only cache what we need if you want)
    train_ids, val_ids = train_test_split(image_ids, test_size=0.1, random_state=42)

    # Transform for encoder
    tx = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ])

    # Cache features (simple loop; CPU)
    # You can cache train+val together:
    ids_to_cache = train_ids + val_ids
    cache_features_naive(ids_to_cache, Config.IMAGES_DIR, Config.CACHED_FEATURES_DIR, tx)

    # Build datasets that read cached features only
    train_ann = {i: annotations[i] for i in train_ids}
    val_ann   = {i: annotations[i] for i in val_ids}
    train_ds = Flickr8kCached(train_ids, train_ann, vocab, Config.CACHED_FEATURES_DIR)
    val_ds   = Flickr8kCached(val_ids,   val_ann,   vocab, Config.CACHED_FEATURES_DIR)

    train_loader = DataLoader(train_ds, batch_size=Config.BATCH_SIZE, shuffle=True,
                              num_workers=0, collate_fn=collate_fn)
    val_loader   = DataLoader(val_ds, batch_size=Config.BATCH_SIZE, shuffle=False,
                              num_workers=0, collate_fn=collate_fn)

    # Model + optim
    decoder = DecoderWithGRU(Config.EMBED_SIZE, Config.HIDDEN_SIZE, len(vocab), dropout=Config.DROPOUT).to(Config.DEVICE)
    criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx['<pad>'])
    optimizer = optim.Adam(decoder.parameters(), lr=Config.LR)

    return decoder, optimizer, criterion, vocab, train_ds, val_ds, train_loader, val_loader, root


In [9]:
# ========================== 11) Train / Validate / BLEU ==========================
def train_one_epoch(decoder, optimizer, criterion, loader, device='cpu'):
    decoder.train()
    running = 0.0
    for feats, caps, lens in tqdm(loader, desc="Training"):
        feats, caps, lens = feats.to(device), caps.to(device), lens.to(device)
        logits, caps_sorted, decode_lengths, _ = decoder(feats, caps, lens)
        targets = caps_sorted[:,1:]
        loss = 0.0
        for i, L in enumerate(decode_lengths):
            loss += criterion(logits[i,:L,:], targets[i,:L])
        loss /= len(decode_lengths)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(decoder.parameters(), 5.0)
        optimizer.step()
        running += loss.item()
    return running/len(loader)

def validate(decoder, criterion, loader, device='cpu'):
    decoder.eval()
    running = 0.0
    with torch.no_grad():
        for feats, caps, lens in tqdm(loader, desc="Validation"):
            feats, caps, lens = feats.to(device), caps.to(device), lens.to(device)
            logits, caps_sorted, decode_lengths, _ = decoder(feats, caps, lens)
            targets = caps_sorted[:,1:]
            loss = 0.0
            for i, L in enumerate(decode_lengths):
                loss += criterion(logits[i,:L,:], targets[i,:L])
            loss /= len(decode_lengths)
            running += loss.item()
    return running/len(loader)

In [13]:
decoder, optimizer, criterion, vocab, train_ds, val_ds, train_loader, val_loader, _tmp = prepare_data()

Data at: C:\Users\jocel\AppData\Local\Temp\flickr8k_ejxzg7dy


Building Vocabulary: 100%|██████████| 40455/40455 [00:23<00:00, 1705.42it/s]


Vocabulary size: 3005


Caching features (simple loop): 100%|██████████| 8091/8091 [54:16<00:00,  2.48it/s]  


In [14]:
# ========================== 12) Run ==========================
print(f"Train images: {len(train_ds.image_ids)}, Val images: {len(val_ds.image_ids)}")

best_bleu = 0.0
for epoch in range(1, Config.NUM_EPOCHS+1):
    t0 = time.time()
    tr_loss = train_one_epoch(decoder, optimizer, criterion, train_loader)
    va_loss = validate(decoder, criterion, val_loader)
    # Full-val BLEU with beam search (k=5)
    bleu = evaluate_bleu_features(decoder, val_ds, vocab, beam_size=5, max_len=20)

    if bleu > best_bleu:
        best_bleu = bleu
        torch.save({'decoder': decoder.state_dict(),
                    'vocab_size': len(vocab)}, f'best_decoder_epoch{epoch}.pth')
    dt = time.time()-t0
    print(f"Epoch {epoch}/{Config.NUM_EPOCHS} | Train {tr_loss:.3f} | Val {va_loss:.3f} | BLEU {bleu:.4f} | {dt/60:.1f} min")


Train images: 7281, Val images: 810


Training: 100%|██████████| 1138/1138 [11:10<00:00,  1.70it/s]
Validation: 100%|██████████| 127/127 [00:42<00:00,  2.97it/s]
BLEU eval (cached): 100%|██████████| 810/810 [01:29<00:00,  9.06it/s]


Epoch 1/20 | Train 4.086 | Val 3.420 | BLEU 0.1775 | 13.4 min


Training: 100%|██████████| 1138/1138 [09:04<00:00,  2.09it/s]
Validation: 100%|██████████| 127/127 [00:15<00:00,  7.98it/s]
BLEU eval (cached): 100%|██████████| 810/810 [01:38<00:00,  8.19it/s]


Epoch 2/20 | Train 3.181 | Val 3.055 | BLEU 0.1857 | 11.0 min


Training: 100%|██████████| 1138/1138 [09:14<00:00,  2.05it/s]
Validation: 100%|██████████| 127/127 [00:12<00:00, 10.01it/s]
BLEU eval (cached): 100%|██████████| 810/810 [01:30<00:00,  8.99it/s]


Epoch 3/20 | Train 2.882 | Val 2.886 | BLEU 0.2119 | 11.0 min


Training: 100%|██████████| 1138/1138 [08:35<00:00,  2.21it/s]
Validation: 100%|██████████| 127/127 [00:18<00:00,  6.89it/s]
BLEU eval (cached): 100%|██████████| 810/810 [01:23<00:00,  9.76it/s]


Epoch 4/20 | Train 2.704 | Val 2.787 | BLEU 0.2100 | 10.3 min


Training: 100%|██████████| 1138/1138 [08:48<00:00,  2.15it/s]
Validation: 100%|██████████| 127/127 [00:17<00:00,  7.28it/s]
BLEU eval (cached): 100%|██████████| 810/810 [01:17<00:00, 10.39it/s]


Epoch 5/20 | Train 2.575 | Val 2.723 | BLEU 0.2191 | 10.4 min


Training: 100%|██████████| 1138/1138 [08:52<00:00,  2.14it/s]
Validation: 100%|██████████| 127/127 [00:15<00:00,  8.02it/s]
BLEU eval (cached): 100%|██████████| 810/810 [01:43<00:00,  7.80it/s]


Epoch 6/20 | Train 2.471 | Val 2.675 | BLEU 0.2181 | 10.9 min


Training: 100%|██████████| 1138/1138 [08:58<00:00,  2.11it/s]
Validation: 100%|██████████| 127/127 [00:16<00:00,  7.89it/s]
BLEU eval (cached): 100%|██████████| 810/810 [01:25<00:00,  9.46it/s]


Epoch 7/20 | Train 2.384 | Val 2.640 | BLEU 0.2143 | 10.7 min


Training: 100%|██████████| 1138/1138 [09:22<00:00,  2.02it/s]
Validation: 100%|██████████| 127/127 [00:17<00:00,  7.19it/s]
BLEU eval (cached): 100%|██████████| 810/810 [01:16<00:00, 10.60it/s]


Epoch 8/20 | Train 2.310 | Val 2.615 | BLEU 0.2201 | 10.9 min


Training: 100%|██████████| 1138/1138 [08:56<00:00,  2.12it/s]
Validation: 100%|██████████| 127/127 [00:18<00:00,  6.77it/s]
BLEU eval (cached): 100%|██████████| 810/810 [01:28<00:00,  9.14it/s]


Epoch 9/20 | Train 2.243 | Val 2.595 | BLEU 0.2253 | 10.7 min


Training: 100%|██████████| 1138/1138 [08:41<00:00,  2.18it/s]
Validation: 100%|██████████| 127/127 [00:18<00:00,  6.85it/s]
BLEU eval (cached): 100%|██████████| 810/810 [01:27<00:00,  9.31it/s]


Epoch 10/20 | Train 2.183 | Val 2.584 | BLEU 0.2243 | 10.5 min


Training: 100%|██████████| 1138/1138 [08:47<00:00,  2.16it/s]
Validation: 100%|██████████| 127/127 [00:18<00:00,  6.80it/s]
BLEU eval (cached): 100%|██████████| 810/810 [01:29<00:00,  9.06it/s]


Epoch 11/20 | Train 2.128 | Val 2.571 | BLEU 0.2218 | 10.6 min


Training: 100%|██████████| 1138/1138 [09:16<00:00,  2.05it/s]
Validation: 100%|██████████| 127/127 [00:16<00:00,  7.49it/s]
BLEU eval (cached): 100%|██████████| 810/810 [01:35<00:00,  8.53it/s]


Epoch 12/20 | Train 2.076 | Val 2.568 | BLEU 0.2152 | 11.1 min


Training: 100%|██████████| 1138/1138 [08:45<00:00,  2.17it/s]
Validation: 100%|██████████| 127/127 [00:13<00:00,  9.59it/s]
BLEU eval (cached): 100%|██████████| 810/810 [01:30<00:00,  8.92it/s]


Epoch 13/20 | Train 2.028 | Val 2.559 | BLEU 0.2194 | 10.5 min


Training: 100%|██████████| 1138/1138 [09:05<00:00,  2.09it/s]
Validation: 100%|██████████| 127/127 [00:16<00:00,  7.50it/s]
BLEU eval (cached): 100%|██████████| 810/810 [01:43<00:00,  7.82it/s]


Epoch 14/20 | Train 1.983 | Val 2.561 | BLEU 0.2266 | 11.1 min


Training: 100%|██████████| 1138/1138 [09:03<00:00,  2.09it/s]
Validation: 100%|██████████| 127/127 [00:15<00:00,  8.29it/s]
BLEU eval (cached): 100%|██████████| 810/810 [01:42<00:00,  7.92it/s]


Epoch 15/20 | Train 1.940 | Val 2.561 | BLEU 0.2184 | 11.0 min


Training: 100%|██████████| 1138/1138 [09:50<00:00,  1.93it/s]
Validation: 100%|██████████| 127/127 [00:15<00:00,  8.34it/s]
BLEU eval (cached): 100%|██████████| 810/810 [01:46<00:00,  7.58it/s]


Epoch 16/20 | Train 1.899 | Val 2.564 | BLEU 0.2146 | 11.9 min


Training: 100%|██████████| 1138/1138 [09:05<00:00,  2.09it/s]
Validation: 100%|██████████| 127/127 [00:15<00:00,  8.02it/s]
BLEU eval (cached): 100%|██████████| 810/810 [01:27<00:00,  9.27it/s]


Epoch 17/20 | Train 1.860 | Val 2.568 | BLEU 0.2089 | 10.8 min


Training: 100%|██████████| 1138/1138 [09:46<00:00,  1.94it/s]
Validation: 100%|██████████| 127/127 [00:19<00:00,  6.56it/s]
BLEU eval (cached): 100%|██████████| 810/810 [01:21<00:00,  9.97it/s]


Epoch 18/20 | Train 1.821 | Val 2.565 | BLEU 0.2152 | 11.4 min


Training: 100%|██████████| 1138/1138 [09:24<00:00,  2.02it/s]
Validation: 100%|██████████| 127/127 [00:20<00:00,  6.18it/s]
BLEU eval (cached): 100%|██████████| 810/810 [01:35<00:00,  8.49it/s]


Epoch 19/20 | Train 1.785 | Val 2.574 | BLEU 0.2133 | 11.3 min


Training: 100%|██████████| 1138/1138 [09:48<00:00,  1.93it/s]
Validation: 100%|██████████| 127/127 [00:15<00:00,  8.12it/s]
BLEU eval (cached): 100%|██████████| 810/810 [01:29<00:00,  9.10it/s]

Epoch 20/20 | Train 1.750 | Val 2.578 | BLEU 0.2042 | 11.6 min



