# Deep Machine Learning Project (SSY340)

Project Group 92

## Environment setup

In [None]:
import re, random, os, ast, gc, torch
import matplotlib.pyplot as plt
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
from collections import Counter, defaultdict
from PIL import Image
from tqdm.auto import tqdm

%pip install --upgrade pip
%pip install torch torchvision --index-url https://download.pytorch.org/whl/cu129

print("\nPyTorch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

if torch.cuda.is_available():
    print("CUDA version:", torch.version.cuda)
    print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)
IMAGES_DIR = "flickr30k-images"
CAPTIONS_FILE = "Flickr30k.token.txt"
SEED = 0
NUM_WORKERS = 0

random.seed(SEED)
torch.manual_seed(SEED)

Setup dataset and image-caption:

In [None]:
csv_path = "flickr_annotations_30k.csv"
out_path = "Flickr30k.token.txt"

df = pd.read_csv(csv_path, low_memory=False)
print("Columns in CSV:", list(df.columns))
display(df.head(3))

def _get_image_name(row):
    # common img filename columns; adjust or add if your file uses another name
    for col in ('file_name','filename','image','img','image_filename','img_name','image_name','img_id','image_id','path'):
        if col in df.columns:
            val = row.get(col)
            if pd.isna(val):
                continue
            return os.path.basename(str(val))
    # fallback: use index as filename
    return f"{row.name}.jpg"

def _get_captions(row):
    # common caption columns
    for col in ('raw','captions','sentences','sentence','caption','raw_captions','sentids'):
        if col in df.columns:
            val = row.get(col)
            if pd.isna(val):
                continue
            # if already a list (some readers may parse JSON)
            if isinstance(val, (list, tuple)):
                return [str(x).strip() for x in val if str(x).strip()]
            # try to parse python/JSON list stored as string
            if isinstance(val, str):
                # attempt ast literal_eval for list-like strings
                try:
                    parsed = ast.literal_eval(val)
                    if isinstance(parsed, (list, tuple)):
                        return [str(x).strip() for x in parsed if str(x).strip()]
                    if isinstance(parsed, dict) and 'raw' in parsed:
                        r = parsed['raw']
                        if isinstance(r, (list, tuple)):
                            return [str(x).strip() for x in r if str(x).strip()]
                except Exception:
                    pass
                # common separators
                for sep in ('|||', '||', '\n'):
                    if sep in val:
                        return [s.strip() for s in val.split(sep) if s.strip()]
                # otherwise treat as single caption string
                return [val.strip()]
    return []

# Write token file
with open(out_path, 'w', encoding='utf-8') as fout:
    for _, row in df.iterrows():
        img_name = _get_image_name(row)
        caps = _get_captions(row)
        if not caps:
            continue
        for i, c in enumerate(caps):
            fout.write(f"{img_name}#{i}\t{c}\n")

print("Wrote", out_path)

Dataset image-caption test

In [None]:
IMAGES_DIR = "flickr30k-images"
TOKENS_FILE = "Flickr30k.token.txt"
NUM_SAMPLES = 5

image_to_caps = defaultdict(list)
with open(TOKENS_FILE, 'r', encoding='utf-8') as f:
    for line in f:
        line = line.strip()
        if not line:
            continue
        img_token, caption = line.split('\t', 1)
        img_name = img_token.split('#')[0]
        image_to_caps[img_name].append(caption)

print(f"Loaded {len(image_to_caps)} images from token file.")

sampled_imgs = random.sample(list(image_to_caps.keys()), min(NUM_SAMPLES, len(image_to_caps)))

for img_name in sampled_imgs:
    img_path = os.path.join(IMAGES_DIR, img_name)
    caps = image_to_caps[img_name]

    # show image
    try:
        img = Image.open(img_path).convert('RGB')
    except FileNotFoundError:
        print(f"Image not found: {img_path}")
        continue

    plt.figure(figsize=(5, 5))
    plt.imshow(img)
    plt.axis('off')
    plt.title(img_name)
    plt.show()

    print(f"{img_name} — {len(caps)} captions:")
    for i, c in enumerate(caps):
        print(f"  {i+1}. {c}")
    print("-" * 60)

## CNN-RNN model:

Simple CNN->RNN baseline

In [None]:
# Clear cache and GPU memory
gc.collect()
torch.cuda.empty_cache()

# ----------------- Config -----------------

NUM_EPOCHS = 1
BATCH_SIZE = 64
EMBED_SIZE = 256
HIDDEN_SIZE = 512
MIN_FREQ = 5
LEARNING_RATE = 1e-3
FINE_TUNE = False
OUTPUT_DIR = "./models_baseline"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ----------------- tokenizer / vocab -----------------

def tokenize_caption(text):
    text = text.lower()
    text = re.sub(r"[^a-z0-9' ]+", " ", text)
    tokens = text.split()
    return tokens

class Vocab:
    def __init__(self, min_freq=5, reserved=None):
        if reserved is None:
            reserved = ['<pad>', '<start>', '<end>', '<unk>']
        self.min_freq = min_freq
        self.reserved = reserved
        self.freq = Counter()
        self.itos = []
        self.stoi = {}

    def build(self, token_lists):
        for t in token_lists:
            self.freq.update(t)
        self.itos = list(self.reserved)
        for tok, cnt in self.freq.most_common():
            if cnt < self.min_freq:
                continue
            if tok in self.reserved:
                continue
            self.itos.append(tok)
        self.stoi = {tok:i for i,tok in enumerate(self.itos)}

    def __len__(self):
        return len(self.itos)

    def numericalize(self, tokens):
        return [self.stoi.get(t, self.stoi['<unk>']) for t in tokens]

# ----------------- Dataset -----------------

class Flickr30kDataset(Dataset):
    def __init__(self, images_dir, captions_file, vocab=None, transform=None, split='train', train_ratio=0.8, seed=42):
        self.images_dir = str(images_dir)
        self.transform = transform
        image_to_captions = defaultdict(list)
        with open(captions_file, 'r', encoding='utf-8') as f:
            for line in f:
                line=line.strip()
                if not line:
                    continue
                parts = line.split('\t')
                if len(parts) != 2:
                    continue
                img_token, cap = parts
                img_name = img_token.split('#')[0]
                image_to_captions[img_name].append(cap)
        available = set(os.listdir(self.images_dir))
        self.entries = []
        for img, caps in image_to_captions.items():
            if img not in available:
                continue
            for c in caps:
                self.entries.append((img, c))
        # split images (by unique image) to avoid leakage
        images = sorted(list({e[0] for e in self.entries}))
        random.Random(seed).shuffle(images)
        n_train = int(len(images)*train_ratio)
        train_images = set(images[:n_train])
        val_images = set(images[n_train:])
        if split == 'train':
            self.entries = [e for e in self.entries if e[0] in train_images]
        elif split == 'val':
            self.entries = [e for e in self.entries if e[0] in val_images]
        else:
            raise ValueError("split must be 'train' or 'val'")
        if vocab is None and split=='train':
            token_lists = [tokenize_caption(c) for _,c in self.entries]
            self.vocab = Vocab(min_freq=MIN_FREQ)
            self.vocab.build(token_lists)
        elif vocab is not None:
            self.vocab = vocab
        else:
            raise ValueError("Provide vocab for val split")

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx):
        img_name, caption = self.entries[idx]
        img_path = os.path.join(self.images_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        tokens = tokenize_caption(caption)
        tokens = ['<start>'] + tokens + ['<end>']
        num = torch.tensor(self.vocab.numericalize(tokens), dtype=torch.long)
        return image, num

def collate_fn(batch):
    images, caps = zip(*batch)
    images = torch.stack(images, dim=0)
    lengths = [c.size(0) for c in caps]
    caps_padded = nn.utils.rnn.pad_sequence(caps, batch_first=True, padding_value=0)
    return images, caps_padded, lengths

# ----------------- Models -----------------

class Encoder(nn.Module):
    def __init__(self, embed_size, fine_tune=False):
        super().__init__()
        resnet = resnet18(weights=ResNet18_Weights.DEFAULT)
        modules = list(resnet.children())[:-1]  # remove fc
        self.backbone = nn.Sequential(*modules)
        self.fc = nn.Linear(512, embed_size)
        self.relu = nn.ReLU()
        self.fine_tune(fine_tune)

    def forward(self, x):
        feat = self.backbone(x)               # (B, 512, 1, 1)
        feat = feat.view(feat.size(0), -1)    # (B, 512)
        feat = self.fc(feat)                  # (B, embed)
        feat = self.relu(feat)
        return feat

    def fine_tune(self, fine):
        for p in self.backbone.parameters():
            p.requires_grad = fine


class Decoder(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1, dropout=0.5):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        # features: (B, embed), captions: (B, max_len)
        embeddings = self.embed(captions)            # (B, L, embed)
        feats = features.unsqueeze(1)                # (B,1,embed)
        inputs = torch.cat([feats, embeddings[:,:-1,:]], dim=1)  # shift right for teacher forcing
        outputs, _ = self.lstm(inputs)
        outputs = self.linear(outputs)               # (B, L, vocab)
        return outputs

    def sample(self, features, max_len=30):
        ids = []
        inputs = features.unsqueeze(1)               # (B,1,embed)
        states = None
        for _ in range(max_len):
            out, states = self.lstm(inputs, states)  # out: (B,1,hidden)
            logits = self.linear(out.squeeze(1))     # (B,vocab)
            pred = logits.argmax(dim=1)              # (B,)
            ids.append(pred)
            inputs = self.embed(pred).unsqueeze(1)   # (B,1,embed)
        ids = torch.stack(ids, dim=1)                # (B, max_len)
        return ids

# ----------------- Training utilities -----------------

def train_epoch(enc, dec, loader, criterion, enc_opt, dec_opt, device):
    enc.train(); dec.train()
    total=0; loss_acc=0.0
    for images, caps, _ in tqdm(loader, desc="Train", leave=False):
        images = images.to(device); caps=caps.to(device)
        feats = enc(images)
        outputs = dec(feats, caps)
        loss = criterion(outputs.view(-1, outputs.size(-1)), caps.view(-1))
        if enc_opt: enc_opt.zero_grad()
        dec_opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(dec.parameters(), 5.0)
        if enc_opt: nn.utils.clip_grad_norm_(enc.parameters(), 5.0)
        if enc_opt: enc_opt.step()
        dec_opt.step()
        loss_acc += loss.item()
        total += 1
    return loss_acc/total

@torch.no_grad()
def validate(enc, dec, loader, criterion, device):
    enc.eval(); dec.eval()
    total=0; loss_acc=0.0
    for images, caps, _ in tqdm(loader, desc="Val", leave=False):
        images = images.to(device); caps=caps.to(device)
        feats = enc(images)
        outputs = dec(feats, caps)
        loss = criterion(outputs.view(-1, outputs.size(-1)), caps.view(-1))
        loss_acc += loss.item()
        total += 1
    return loss_acc/total

@torch.no_grad()
def generate_caption(enc, dec, img_path, transform, vocab, device, max_len=30):
    enc.eval(); dec.eval()
    img = Image.open(img_path).convert('RGB')
    x = transform(img).unsqueeze(0).to(device)
    feats = enc(x)
    gen = dec.sample(feats, max_len=max_len)[0].cpu().tolist()
    words=[]
    for idx in gen:
        if idx < len(vocab.itos):
            tok = vocab.itos[idx]
        else:
            tok = '<unk>'
        if tok == '<end>': break
        if tok not in ('<pad>','<start>'):
            words.append(tok)
    return ' '.join(words)

# ----------------- Prepare data & models -----------------

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

print("Preparing datasets...")
train_ds = Flickr30kDataset(IMAGES_DIR, CAPTIONS_FILE, vocab=None, transform=transform, split='train', train_ratio=0.8, seed=SEED)
vocab = train_ds.vocab
val_ds   = Flickr30kDataset(IMAGES_DIR, CAPTIONS_FILE, vocab=vocab, transform=transform, split='val', train_ratio=0.8, seed=SEED)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=NUM_WORKERS)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=NUM_WORKERS)

print("Vocab size:", len(vocab))
enc = Encoder(EMBED_SIZE, fine_tune=FINE_TUNE).to(DEVICE)
dec = Decoder(EMBED_SIZE, HIDDEN_SIZE, vocab_size=len(vocab)).to(DEVICE)
dec_opt = optim.Adam(dec.parameters(), lr=LEARNING_RATE)
enc_opt = optim.Adam(enc.parameters(), lr=LEARNING_RATE*0.1) if FINE_TUNE else None
criterion = nn.CrossEntropyLoss(ignore_index=vocab.stoi['<pad>'])

# ----------------- Train loop -----------------

best_val = 1e9
for epoch in range(1, NUM_EPOCHS+1):
    tr_loss = train_epoch(enc, dec, train_loader, criterion, enc_opt, dec_opt, DEVICE)
    val_loss = validate(enc, dec, val_loader, criterion, DEVICE)
    print(f"Epoch {epoch}/{NUM_EPOCHS}  train_loss={tr_loss:.4f}  val_loss={val_loss:.4f}")
    ckpt = {
        'epoch': epoch,
        'encoder': enc.state_dict(),
        'decoder': dec.state_dict(),
        'vocab': vocab.itos
    }
    torch.save(ckpt, os.path.join(OUTPUT_DIR, f"ckpt_epoch_{epoch}.pth"))
    if val_loss < best_val:
        best_val = val_loss
        torch.save(ckpt, os.path.join(OUTPUT_DIR, "best.pth"))

print("Done. Models saved to", OUTPUT_DIR)