In [15]:

# @title Imports & Globals
import os, json, math, random
from collections import Counter
import os
import json
import h5py
import torch

import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 1337
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

SPECIALS = {
    "<PAD>": 0,
    "<BOS>": 1,
    "<EOS>": 2,
    "<UNK>": 3,
}


In [16]:

# @title Vocabulary
class Vocab:
    def __init__(self, min_count=3):
        self.min_count = min_count
        self.stoi = dict(SPECIALS)
        self.itos = [None] * len(SPECIALS)
        for tok, idx in SPECIALS.items():
            self.itos[idx] = tok

    def build(self, captions_json_path):
        with open(captions_json_path, 'r') as f:
            items = json.load(f)
        cnt = Counter()
        for it in items:
            for c in it["caption"]:
                tokens = self.tokenize(c)
                cnt.update(tokens)
        for w, c in cnt.items():
            if c >= self.min_count and w not in self.stoi:
                self.stoi[w] = len(self.itos)
                self.itos.append(w)
        return self

    def tokenize(self, s):
        return s.lower().strip().replace('\n', ' ').split()

    def encode(self, s, add_bos_eos=True):
        toks = self.tokenize(s)
        ids = [self.stoi.get(tok, SPECIALS["<UNK>"]) for tok in toks]
        if add_bos_eos:
            ids = [SPECIALS["<BOS>"]] + ids + [SPECIALS["<EOS>"]]
        return torch.tensor(ids, dtype=torch.long)

    def decode(self, ids):
        words = []
        for i in ids:
            if i == SPECIALS["<EOS>"]:
                break
            if i in (SPECIALS["<BOS>"], SPECIALS["<PAD>"]):
                continue
            words.append(self.itos[i])
        return ' '.join(words)

    def __len__(self):
        return len(self.itos)


In [17]:

# @title Dataset & Collate
class VideoCaptionDataset(Dataset):
    def __init__(self, data_dir, labels_json, vocab: Vocab, max_frames=80):
        self.data_dir = data_dir
        self.vocab = vocab
        with open(labels_json, 'r') as f:
            self.items = json.load(f)
        self.id2caps = {it['id']: it['caption'] for it in self.items}
        self.ids = list(self.id2caps.keys())

        id_txt = os.path.join(data_dir, 'id.txt')
        if os.path.isfile(id_txt):
            with open(id_txt, 'r') as f:
                ids_from_file = [line.strip() for line in f if line.strip()]
            self.ids = [i for i in ids_from_file if i in self.id2caps]

        self.feat_dir = os.path.join(data_dir, 'feat')
        self.max_frames = max_frames

    def __len__(self):
        return len(self.ids)

    def _load_feat(self, vid):
        base = vid
        path = os.path.join(self.feat_dir, f"{base}.npy")
        if not os.path.isfile(path):
            base = os.path.splitext(vid)[0]
            path = os.path.join(self.feat_dir, f"{base}.npy")
        arr = np.load(path)
        if arr.ndim == 1:
            arr = arr[None, :]
        elif arr.ndim == 3:
            T = arr.shape[0]
            arr = arr.reshape(T, -1)
        T, D = arr.shape[0], arr.shape[-1]
        if T > self.max_frames:
            arr = arr[:self.max_frames]
        elif T < self.max_frames:
            pad = np.zeros((self.max_frames - T, D), dtype=arr.dtype)
            arr = np.concatenate([arr, pad], axis=0)
        return torch.tensor(arr, dtype=torch.float32)

    def __getitem__(self, idx):
        vid = self.ids[idx]
        x = self._load_feat(vid)            # (T, D)
        refs = self.id2caps[vid]
        tgt = random.choice(refs)
        y = self.vocab.encode(tgt, add_bos_eos=True)
        return vid, x, y

def collate_batch(batch):
    vids, xs, ys = zip(*batch)
    xs = torch.stack(xs, dim=0)          # (B, T, D)
    lengths = torch.tensor([ys[i].numel() for i in range(len(ys))], dtype=torch.long)
    ys = pad_sequence(ys, batch_first=True, padding_value=SPECIALS["<PAD>"])
    return vids, xs, ys, lengths


In [18]:

# @title Model (Encoder, Attention, Decoder)
class Encoder(nn.Module):
    def __init__(self, feat_dim, hidden_size, num_layers=1, dropout=0.1):
        super().__init__()
        self.rnn = nn.GRU(feat_dim, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers>1 else 0)

    def forward(self, x):
        outputs, h = self.rnn(x)   # outputs: (B, T, H)
        return outputs, h          # h: (num_layers, B, H)

class Attention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.scale = 1.0 / math.sqrt(hidden_size)

    def forward(self, dec_h, enc_out):
        scores = torch.bmm(enc_out, dec_h.unsqueeze(2)).squeeze(2) * self.scale  # (B, T)
        weights = torch.softmax(scores, dim=1)                                   # (B, T)
        ctx = torch.bmm(weights.unsqueeze(1), enc_out).squeeze(1)                # (B, H)
        return ctx, weights

class Decoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, emb_size=256, num_layers=1, dropout=0.1, use_attention=True):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_size, padding_idx=SPECIALS["<PAD>"])
        self.rnn = nn.GRU(emb_size + (hidden_size if use_attention else 0), hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers>1 else 0)
        self.use_attention = use_attention
        self.attn = Attention(hidden_size) if use_attention else None
        self.out = nn.Linear(hidden_size, vocab_size)

    def forward(self, enc_outputs, h0, tgt, teacher_forcing_ratio=1.0):
        B, T, H = enc_outputs.size()
        L = tgt.size(1)
        inputs = tgt[:, 0]  # <BOS>
        logits = []
        h = h0
        ctx = torch.zeros(B, H, device=enc_outputs.device)
        for t in range(1, L):
            emb = self.emb(inputs)
            if self.use_attention:
                ctx, _ = self.attn(h[-1], enc_outputs)
                rnn_in = torch.cat([emb, ctx], dim=-1).unsqueeze(1)
            else:
                rnn_in = emb.unsqueeze(1)
            out, h = self.rnn(rnn_in, h)
            step_logits = self.out(out.squeeze(1))
            logits.append(step_logits)
            # scheduled sampling
            if random.random() < teacher_forcing_ratio:
                inputs = tgt[:, t]
            else:
                inputs = step_logits.argmax(dim=-1)
        logits = torch.stack(logits, dim=1)  # (B, L-1, V)
        return logits

class S2VT(nn.Module):
    def __init__(self, feat_dim, vocab_size, hidden=256, emb=256, enc_layers=1, dec_layers=1, dropout=0.1, use_attention=True):
        super().__init__()
        self.encoder = Encoder(feat_dim, hidden, enc_layers, dropout)
        self.decoder = Decoder(vocab_size, hidden, emb, dec_layers, dropout, use_attention)

    def forward(self, x, tgt, teacher_forcing_ratio=1.0):
        enc_out, h = self.encoder(x)
        logits = self.decoder(enc_out, h, tgt, teacher_forcing_ratio)
        return logits

    @torch.no_grad()
    def greedy_decode(self, x, max_len, bos_id, eos_id):
        enc_out, h = self.encoder(x)
        B, T, H = enc_out.size()
        inputs = torch.full((B,), bos_id, dtype=torch.long, device=x.device)
        outputs = []
        for _ in range(max_len):
            emb = self.decoder.emb(inputs)
            if self.decoder.use_attention:
                ctx, _ = self.decoder.attn(h[-1], enc_out)
                rnn_in = torch.cat([emb, ctx], dim=-1).unsqueeze(1)
            else:
                rnn_in = emb.unsqueeze(1)
            out, h = self.decoder.rnn(rnn_in, h)
            logits = self.decoder.out(out.squeeze(1))
            next_ids = torch.argmax(logits, dim=-1)
            outputs.append(next_ids)
            inputs = next_ids
        outs = torch.stack(outputs, dim=1)  # (B, L)
        sents = []
        for b in range(B):
            seq = outs[b].tolist()
            if eos_id in seq:
                idx = seq.index(eos_id)
                seq = seq[:idx]
            sents.append(seq)
        return sents

    @torch.no_grad()
    def beam_search(self, x, max_len, bos_id, eos_id, beam_size=3):
        enc_out, h0 = self.encoder(x)  # x: (1,T,D)
        beams = [(0.0, [bos_id], h0)]  # (logprob, seq, h)
        for _ in range(max_len):
            new_beams = []
            for logp, seq, h in beams:
                last = torch.tensor([seq[-1]], device=x.device)
                if seq[-1] == eos_id:
                    new_beams.append((logp, seq, h))
                    continue
                emb = self.decoder.emb(last)
                if self.decoder.use_attention:
                    ctx, _ = self.decoder.attn(h[-1], enc_out)
                    rnn_in = torch.cat([emb, ctx], dim=-1).unsqueeze(1)
                else:
                    rnn_in = emb.unsqueeze(1)
                out, h2 = self.decoder.rnn(rnn_in, h)
                logits = self.decoder.out(out.squeeze(1))
                logprobs = torch.log_softmax(logits, dim=-1).squeeze(0)
                topk = torch.topk(logprobs, beam_size)
                for k in range(beam_size):
                    new_beams.append((logp + topk.values[k].item(), seq + [topk.indices[k].item()], h2))
            beams = sorted(new_beams, key=lambda x: x[0], reverse=True)[:beam_size]
        best = max(beams, key=lambda x: x[0])[1]
        if eos_id in best:
            best = best[1:best.index(eos_id)]
        else:
            best = best[1:]
        return best


In [19]:

# @title BLEU@1 (assignment-compatible)
def bleu1_precision(candidate_tokens, reference_tokens_list):
    if not candidate_tokens:
        return 0.0
    cand_counts = Counter(candidate_tokens)
    max_ref_counts = Counter()
    for r in reference_tokens_list:
        rc = Counter(r)
        for w in rc:
            max_ref_counts[w] = max(max_ref_counts[w], rc[w])
    clipped = sum(min(cand_counts[w], max_ref_counts[w]) for w in cand_counts)
    precision = clipped / max(1, len(candidate_tokens))
    c = len(candidate_tokens)
    ref_lens = [len(r) for r in reference_tokens_list]
    r = min(ref_lens, key=lambda L: (abs(L - c), L))
    bp = 1.0 if c > r else math.exp(1 - r / max(c, 1))
    return bp * precision


In [20]:

# @title Train Utilities
def build_vocab(training_label_json, min_count=3):
    return Vocab(min_count=min_count).build(training_label_json)

def infer_feat_dim(training_data_dir):
    # find first .npy
    feat_dir = os.path.join(training_data_dir, "feat")
    for fname in os.listdir(feat_dir):
        if fname.endswith(".npy"):
            arr = np.load(os.path.join(feat_dir, fname))
            if arr.ndim == 1:
                return arr.shape[0]
            if arr.ndim == 2:
                return arr.shape[1]
            if arr.ndim == 3:
                T = arr.shape[0]
                return arr.reshape(T, -1).shape[1]
    raise RuntimeError("No .npy features found in training_data/feat")


In [22]:

# @title Train Loop
def train_model(training_data='training_data', training_label='training_label.json', max_frames=80,
               hidden=256, emb=256, enc_layers=1, dec_layers=1, dropout=0.1, use_attention=True,
               epochs=20, batch_size=32, lr=1e-3, min_count=3, schedule_sampling=True, ss_epochs=10,
               ckpt_path='your_seq2seq_model/ckpt.pt'):
    os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)

    vocab = build_vocab(training_label, min_count=min_count)
    ds = VideoCaptionDataset(training_data, training_label, vocab, max_frames=max_frames)
    _, x0, _ = ds[0]
    feat_dim = x0.size(-1)

    model = S2VT(feat_dim, len(vocab), hidden=hidden, emb=emb, enc_layers=enc_layers,
                 dec_layers=dec_layers, dropout=dropout, use_attention=use_attention).to(DEVICE)
    criterion = nn.CrossEntropyLoss(ignore_index=SPECIALS["<PAD>"])
    optim = torch.optim.Adam(model.parameters(), lr=lr)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=True, collate_fn=collate_batch, num_workers=0)

    for epoch in range(1, epochs + 1):
        model.train()
        epoch_loss = 0.0
        tf_ratio = max(0.0, 1.0 - (epoch - 1) / max(1, ss_epochs)) if schedule_sampling else 1.0
        for vids, xs, ys, lengths in dl:
            xs = xs.to(DEVICE); ys = ys.to(DEVICE)
            logits = model(xs, ys, teacher_forcing_ratio=tf_ratio)
            tgt = ys[:, 1:]
            loss = criterion(logits.reshape(-1, logits.size(-1)), tgt.reshape(-1))
            optim.zero_grad(); loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optim.step()
            epoch_loss += loss.item()
        print(f"Epoch {epoch}/{epochs} - loss={epoch_loss/len(dl):.4f} - tf={tf_ratio:.2f}")

    torch.save({
        'model': model.state_dict(),
        'stoi': vocab.stoi,
        'itos': vocab.itos,
        'feat_dim': feat_dim,
        'cfg': {
            'hidden': hidden, 'emb': emb, 'enc_layers': enc_layers, 'dec_layers': dec_layers,
            'dropout': dropout, 'use_attention': use_attention
        }
    }, ckpt_path)

    # Also save to HDF5 and a JSON sidecar (friend-style, once at the end)
    out_dir = os.path.dirname(ckpt_path) if os.path.dirname(ckpt_path) else "."
    h5_path   = os.path.join(out_dir, "model_weights.h5")
    meta_path = os.path.join(out_dir, "model_meta.json")
    os.makedirs(out_dir, exist_ok=True)

    state = model.state_dict()
    with h5py.File(h5_path, "w") as hf:
        for k, v in state.items():
            hf.create_dataset(k, data=v.detach().cpu().numpy())
    print(f"Saved model weights to {h5_path} ({len(state)} tensors)")

    meta = {
        "itos": vocab.itos,
        "stoi": vocab.stoi,
        "feat_dim": feat_dim,
        "cfg": {
            "hidden": hidden,
            "emb": emb,
            "enc_layers": enc_layers,
            "dec_layers": dec_layers,
            "dropout": dropout,
            "use_attention": use_attention,
        },
    }
    with open(meta_path, "w") as f:
        json.dump(meta, f)
    print(f"Saved metadata to {meta_path}")

    print(f"Saved checkpoint to {ckpt_path}")


In [23]:

# @title Inference
@torch.no_grad()

def run_inference(testing_data='testing_data', output='testset_output.txt', ckpt_path='your_seq2seq_model/ckpt.pt',
                  training_label='training_label.json', testing_label='testing_label.json',
                  beam_size=3, max_decode_len=25, skip_eval=False, max_frames=80):
    ckpt = torch.load(ckpt_path, map_location=DEVICE)
    stoi, itos = ckpt['stoi'], ckpt['itos']
    vocab = Vocab(min_count=3)
    vocab.stoi = stoi; vocab.itos = itos

    feat_dim = ckpt['feat_dim']
    cfg = ckpt['cfg']
    model = S2VT(feat_dim, len(vocab), hidden=cfg['hidden'], emb=cfg['emb'], enc_layers=cfg['enc_layers'],
                 dec_layers=cfg['dec_layers'], dropout=cfg['dropout'], use_attention=cfg['use_attention']).to(DEVICE)
    model.load_state_dict(ckpt['model'], strict=False)
    model.eval()

    labels_json = testing_label if os.path.isfile(testing_label) else training_label
    ds = VideoCaptionDataset(testing_data, labels_json, vocab, max_frames=max_frames)

    results = []
    for vid in ds.ids:
        x = ds._load_feat(vid).unsqueeze(0).to(DEVICE)
        if beam_size > 1:
            seq = model.beam_search(x, max_len=max_decode_len, bos_id=SPECIALS["<BOS>"], eos_id=SPECIALS["<EOS>"], beam_size=beam_size)
        else:
            seq = model.greedy_decode(x, max_len=max_decode_len, bos_id=SPECIALS["<BOS>"], eos_id=SPECIALS["<EOS>"])[0]
        caption = vocab.decode(seq)
        results.append((vid, caption))

    with open(output, 'w') as f:
        for vid, cap in results:
            f.write(f"{vid},{cap}\n")
    print(f"Wrote predictions to {output}")

    if os.path.isfile(testing_label) and not skip_eval:
        with open(testing_label, 'r') as f:
            test_items = json.load(f)
        id2refs = {it['id']: [vocab.tokenize(c) for c in it['caption']] for it in test_items}
        bleu_scores = []
        for vid, cap in results:
            cand = vocab.tokenize(cap)
            refs = id2refs.get(vid, [])
            if refs:
                bleu_scores.append(bleu1_precision(cand, refs))
        if bleu_scores:
            print(f"Average BLEU@1: {sum(bleu_scores)/len(bleu_scores):.4f}")



## Quick Start

1. **Train** (or skip if you already have a checkpoint at `your_seq2seq_model/ckpt.pt`):


In [26]:

# @title Train (baseline settings)
# Adjust epochs/batch_size for speed/quality trade-off
os.makedirs('your_seq2seq_model', exist_ok=True)
train_model(
    training_data='training_data',
    training_label='training_label.json',
    max_frames=80,
    hidden=256, emb=256, enc_layers=1, dec_layers=1, dropout=0.1, use_attention=True,
    epochs=200, batch_size=32, lr=1e-3, min_count=4, schedule_sampling=True, ss_epochs=10,
    ckpt_path='your_seq2seq_model/ckpt.pt'
)


Epoch 1/200 - loss=5.0095 - tf=1.00
Epoch 2/200 - loss=4.1372 - tf=0.90
Epoch 3/200 - loss=3.9894 - tf=0.80
Epoch 4/200 - loss=3.9298 - tf=0.70
Epoch 5/200 - loss=3.8540 - tf=0.60
Epoch 6/200 - loss=3.8591 - tf=0.50
Epoch 7/200 - loss=3.8083 - tf=0.40
Epoch 8/200 - loss=3.8441 - tf=0.30
Epoch 9/200 - loss=3.8602 - tf=0.20
Epoch 10/200 - loss=3.9152 - tf=0.10
Epoch 11/200 - loss=3.8923 - tf=0.00
Epoch 12/200 - loss=3.9038 - tf=0.00
Epoch 13/200 - loss=3.8669 - tf=0.00
Epoch 14/200 - loss=3.8087 - tf=0.00
Epoch 15/200 - loss=3.7837 - tf=0.00
Epoch 16/200 - loss=3.6954 - tf=0.00
Epoch 17/200 - loss=3.7015 - tf=0.00
Epoch 18/200 - loss=3.6656 - tf=0.00
Epoch 19/200 - loss=3.6164 - tf=0.00
Epoch 20/200 - loss=3.6643 - tf=0.00
Epoch 21/200 - loss=3.5776 - tf=0.00
Epoch 22/200 - loss=3.5604 - tf=0.00
Epoch 23/200 - loss=3.5247 - tf=0.00
Epoch 24/200 - loss=3.4735 - tf=0.00
Epoch 25/200 - loss=3.5019 - tf=0.00
Epoch 26/200 - loss=3.4320 - tf=0.00
Epoch 27/200 - loss=3.4193 - tf=0.00
Epoch 28/2


2. **Inference** (creates `testset_output.txt` with `id,caption` rows):


In [27]:

# @title Inference
run_inference(
    testing_data='testing_data',
    output='testset_output.txt',
    ckpt_path='your_seq2seq_model/ckpt.pt',
    training_label='training_label.json',
    testing_label='testing_label.json',
    beam_size=3, max_decode_len=25, skip_eval=False, max_frames=80
)


Wrote predictions to testset_output.txt
Average BLEU@1: 0.6698
