# Image Captioning Training Notebook
## CNN (ResNet152) + LSTM + Attention
### Dataset: Flickr30k via Kaggle

This notebook trains a custom image captioning model.

## 1. Setup & Download Flickr30k from Kaggle

In [1]:
import os
import zipfile

# Download Flickr30k from Kaggle (adityajn105/flickr30k)
DATASET_DIR = "flickr30k_data"

if not os.path.exists(DATASET_DIR):
    print("Downloading Flickr30k from Kaggle...")
    !kaggle datasets download -d adityajn105/flickr30k -p {DATASET_DIR}
    
    # Unzip
    print("Extracting...")
    for f in os.listdir(DATASET_DIR):
        if f.endswith('.zip'):
            with zipfile.ZipFile(os.path.join(DATASET_DIR, f), 'r') as z:
                z.extractall(DATASET_DIR)
            os.remove(os.path.join(DATASET_DIR, f))
    print("Done!")
else:
    print(f"Dataset already exists at {DATASET_DIR}")

# Find images and captions
IMAGES_DIR = os.path.join(DATASET_DIR, "flickr30k_images", "flickr30k_images")
CAPTIONS_FILE = os.path.join(DATASET_DIR, "flickr30k_images", "results.csv")

print(f"Images: {IMAGES_DIR}")
print(f"Captions: {CAPTIONS_FILE}")

Dataset already exists at flickr30k_data
Images: flickr30k_data/flickr30k_images/flickr30k_images
Captions: flickr30k_data/flickr30k_images/results.csv


## 2. Imports & Configuration

In [2]:
import csv
import json
import random
import re
from collections import Counter, defaultdict
from tqdm.notebook import tqdm

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Paths
BASE_DIR = os.getcwd()
FEATURES_DIR = os.path.join(BASE_DIR, "features_flickr30k")
MODEL_SAVE_PATH = os.path.join(BASE_DIR, "backend", "custom_caption_model.pth")
VOCAB_SAVE_PATH = os.path.join(BASE_DIR, "backend", "vocab.json")

# Hyperparameters
EMBED_DIM = 256
ATTENTION_DIM = 256
DECODER_DIM = 512
ENCODER_DIM = 2048
DROPOUT = 0.5
LEARNING_RATE = 4e-4
BATCH_SIZE = 32
NUM_EPOCHS = 20
MAX_CAPTION_LEN = 50
MIN_WORD_FREQ = 5
PATIENCE = 5

Using device: cuda


## 3. Load Captions

In [3]:
def clean_caption(text):
    text = text.lower()
    text = re.sub(r"[^a-z ]+", " ", text)
    return text.strip()

print("Loading captions...")
raw_captions = defaultdict(list)

with open(CAPTIONS_FILE, "r", encoding="utf-8") as f:
    reader = csv.reader(f, delimiter='|')
    next(reader)  # Skip header
    for row in reader:
        if len(row) >= 3:
            img_id = row[0].strip()
            caption = row[2].strip() if len(row) > 2 else row[1].strip()
            raw_captions[img_id].append(caption)

print(f"Loaded captions for {len(raw_captions)} images")

Loading captions...
Loaded captions for 31783 images


In [4]:
# Build vocabulary
print("Building vocabulary...")
counter = Counter()
all_captions = {}

for img_id, caps in tqdm(raw_captions.items(), desc="Processing"):
    img_path = os.path.join(IMAGES_DIR, img_id)
    if not os.path.exists(img_path):
        continue
    
    processed = []
    for cap in caps:
        cleaned = clean_caption(cap)
        if cleaned:
            final = f"<start> {cleaned} <end>"
            processed.append(final)
            counter.update(final.split())
    
    if processed:
        all_captions[img_id] = processed

print(f"Valid images: {len(all_captions)}")
print(f"Unique words: {len(counter)}")

Building vocabulary...


Processing:   0%|          | 0/31783 [00:00<?, ?it/s]

Valid images: 31783
Unique words: 18081


In [5]:
# Create vocabulary
words = [w for w, cnt in counter.items() if cnt >= MIN_WORD_FREQ and w not in {"<pad>", "<start>", "<end>", "<unk>"}]
words = ["<pad>", "<start>", "<end>", "<unk>"] + words

word2idx = {w: i for i, w in enumerate(words)}
idx2word = {i: w for w, i in word2idx.items()}
VOCAB_SIZE = len(word2idx)

print(f"Vocabulary size: {VOCAB_SIZE}")

# Save
os.makedirs(os.path.dirname(VOCAB_SAVE_PATH), exist_ok=True)
with open(VOCAB_SAVE_PATH, 'w') as f:
    json.dump({'word2idx': word2idx, 'idx2word': {str(k): v for k, v in idx2word.items()}}, f)
print(f"Saved to {VOCAB_SAVE_PATH}")

Vocabulary size: 7611
Saved to /home/sebastian/Desktop/UNI/Master Anul 1/Sem 1/IBD/ImageCaption/image_caption/backend/vocab.json


## 4. Train/Val Split

In [6]:
all_ids = list(all_captions.keys())
random.seed(42)
random.shuffle(all_ids)

train_ids = all_ids[:int(0.85 * len(all_ids))]
val_ids = all_ids[int(0.85 * len(all_ids)):]

print(f"Train: {len(train_ids)}, Val: {len(val_ids)}")

Train: 27015, Val: 4768


## 5. Feature Extraction (ResNet152)

In [7]:
print("Loading ResNet152...")
resnet = models.resnet152(weights=models.ResNet152_Weights.IMAGENET1K_V2)
modules = list(resnet.children())[:-2]
encoder = nn.Sequential(*modules).to(DEVICE).eval()

img_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def extract_features(img_path):
    img = Image.open(img_path).convert("RGB")
    img = img_transform(img).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        return encoder(img).squeeze().cpu().numpy()

print("Encoder ready!")

Loading ResNet152...
Downloading: "https://download.pytorch.org/models/resnet152-f82ba261.pth" to /home/sebastian/.cache/torch/hub/checkpoints/resnet152-f82ba261.pth


100%|██████████| 230M/230M [00:23<00:00, 10.1MB/s] 


Encoder ready!


In [8]:
os.makedirs(FEATURES_DIR, exist_ok=True)
print(f"Extracting features to {FEATURES_DIR}...")

for img_id in tqdm(all_captions.keys(), desc="Features"):
    save_path = os.path.join(FEATURES_DIR, img_id.replace(".jpg", ".npy"))
    if not os.path.exists(save_path):
        try:
            feat = extract_features(os.path.join(IMAGES_DIR, img_id))
            np.save(save_path, feat)
        except Exception as e:
            print(f"Error {img_id}: {e}")

print("Done!")

Extracting features to /home/sebastian/Desktop/UNI/Master Anul 1/Sem 1/IBD/ImageCaption/image_caption/features_flickr30k...


Features:   0%|          | 0/31783 [00:00<?, ?it/s]

Done!


## 6. Dataset & DataLoaders

In [9]:
class CaptionDataset(Dataset):
    def __init__(self, ids, captions, features_dir, word2idx, max_len=MAX_CAPTION_LEN):
        self.data = []
        for img_id in ids:
            fp = os.path.join(features_dir, img_id.replace(".jpg", ".npy"))
            if os.path.exists(fp):
                for cap in captions[img_id]:
                    self.data.append((fp, cap))
        self.word2idx = word2idx
        self.max_len = max_len
    
    def __len__(self): return len(self.data)
    
    def __getitem__(self, idx):
        fp, cap = self.data[idx]
        feat = torch.tensor(np.load(fp), dtype=torch.float32)
        
        indices = [self.word2idx.get(w, self.word2idx["<unk>"]) for w in cap.split()]
        indices = indices[:self.max_len] + [self.word2idx["<pad>"]] * max(0, self.max_len - len(indices))
        
        return feat, torch.tensor(indices[:self.max_len], dtype=torch.long)

train_loader = DataLoader(CaptionDataset(train_ids, all_captions, FEATURES_DIR, word2idx), batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(CaptionDataset(val_ids, all_captions, FEATURES_DIR, word2idx), batch_size=BATCH_SIZE)

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

Train batches: 4222, Val batches: 745


## 7. Model

In [10]:
class Attention(nn.Module):
    def __init__(self, enc_dim, dec_dim, att_dim):
        super().__init__()
        self.enc_att = nn.Linear(enc_dim, att_dim)
        self.dec_att = nn.Linear(dec_dim, att_dim)
        self.full_att = nn.Linear(att_dim, 1)
    
    def forward(self, enc, dec):
        att = self.full_att(torch.relu(self.enc_att(enc) + self.dec_att(dec).unsqueeze(1)))
        alpha = torch.softmax(att, dim=1)
        return (enc * alpha).sum(dim=1), alpha

class Decoder(nn.Module):
    def __init__(self, att_dim, emb_dim, dec_dim, vocab_size, enc_dim=2048, drop=0.5):
        super().__init__()
        self.enc_dim = enc_dim
        self.vocab_size = vocab_size
        self.attention = Attention(enc_dim, dec_dim, att_dim)
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.dropout = nn.Dropout(drop)
        self.lstm = nn.LSTMCell(emb_dim + enc_dim, dec_dim)
        self.init_h = nn.Linear(enc_dim, dec_dim)
        self.init_c = nn.Linear(enc_dim, dec_dim)
        self.fc = nn.Linear(dec_dim, vocab_size)
    
    def forward(self, enc_out, caps):
        B = enc_out.size(0)
        enc_out = enc_out.permute(0,2,3,1).view(B, -1, self.enc_dim)
        
        emb = self.embedding(caps)
        mean = enc_out.mean(1)
        h, c = self.init_h(mean), self.init_c(mean)
        
        preds = torch.zeros(B, caps.size(1)-1, self.vocab_size).to(DEVICE)
        alphas = torch.zeros(B, caps.size(1)-1, 49).to(DEVICE)
        
        for t in range(caps.size(1)-1):
            ctx, a = self.attention(enc_out, h)
            h, c = self.lstm(torch.cat([emb[:,t], ctx], 1), (h, c))
            preds[:,t] = self.fc(self.dropout(h))
            alphas[:,t] = a.squeeze(2)
        
        return preds, alphas

## 8. Training

In [11]:
model = Decoder(ATTENTION_DIM, EMBED_DIM, DECODER_DIM, VOCAB_SIZE).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2)
criterion = nn.CrossEntropyLoss(ignore_index=word2idx["<pad>"])

print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

Parameters: 14,378,428


In [12]:
best_loss = float('inf')
patience_cnt = 0

for epoch in range(NUM_EPOCHS):
    # Train
    model.train()
    train_loss = 0
    for imgs, caps in tqdm(train_loader, leave=False):
        imgs, caps = imgs.to(DEVICE), caps.to(DEVICE)
        out, alphas = model(imgs, caps)
        loss = criterion(out.view(-1, VOCAB_SIZE), caps[:,1:].reshape(-1))
        loss += ((1 - alphas.sum(1))**2).mean()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    
    # Val
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for imgs, caps in val_loader:
            imgs, caps = imgs.to(DEVICE), caps.to(DEVICE)
            out, _ = model(imgs, caps)
            val_loss += criterion(out.view(-1, VOCAB_SIZE), caps[:,1:].reshape(-1)).item()
    
    train_loss /= len(train_loader)
    val_loss /= len(val_loader)
    scheduler.step(val_loss)
    
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Train: {train_loss:.4f} | Val: {val_loss:.4f}")
    
    if val_loss < best_loss:
        best_loss = val_loss
        patience_cnt = 0
        torch.save({'model_state_dict': model.state_dict(), 'vocab_size': VOCAB_SIZE,
                    'embed_dim': EMBED_DIM, 'attention_dim': ATTENTION_DIM,
                    'decoder_dim': DECODER_DIM, 'encoder_dim': ENCODER_DIM}, MODEL_SAVE_PATH)
        print("  Saved!")
    else:
        patience_cnt += 1
        if patience_cnt >= PATIENCE:
            print("Early stopping!")
            break

print(f"\nDone! Best loss: {best_loss:.4f}")
print(f"Model: {MODEL_SAVE_PATH}")

  0%|          | 0/4222 [00:00<?, ?it/s]

Epoch 1/20 | Train: 3.8142 | Val: 3.2720
  Saved!


  0%|          | 0/4222 [00:00<?, ?it/s]

Epoch 2/20 | Train: 3.1915 | Val: 3.0651
  Saved!


  0%|          | 0/4222 [00:00<?, ?it/s]

Epoch 3/20 | Train: 2.9765 | Val: 2.9862
  Saved!


  0%|          | 0/4222 [00:00<?, ?it/s]

Epoch 4/20 | Train: 2.8321 | Val: 2.9488
  Saved!


  0%|          | 0/4222 [00:00<?, ?it/s]

Epoch 5/20 | Train: 2.7216 | Val: 2.9297
  Saved!


  0%|          | 0/4222 [00:00<?, ?it/s]

Epoch 6/20 | Train: 2.6289 | Val: 2.9256
  Saved!


  0%|          | 0/4222 [00:00<?, ?it/s]

Epoch 7/20 | Train: 2.5498 | Val: 2.9299


  0%|          | 0/4222 [00:00<?, ?it/s]

Epoch 8/20 | Train: 2.4780 | Val: 2.9423


  0%|          | 0/4222 [00:00<?, ?it/s]

Epoch 9/20 | Train: 2.4153 | Val: 2.9596


  0%|          | 0/4222 [00:00<?, ?it/s]

Epoch 10/20 | Train: 2.2800 | Val: 2.9698


  0%|          | 0/4222 [00:00<?, ?it/s]

Epoch 11/20 | Train: 2.2206 | Val: 2.9903
Early stopping!

Done! Best loss: 2.9256
Model: /home/sebastian/Desktop/UNI/Master Anul 1/Sem 1/IBD/ImageCaption/image_caption/backend/custom_caption_model.pth
