In [12]:
# train_transformer_ctc_best.py
import os
import json
import math
import random
from collections import Counter
from typing import List

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm

# -----------------------
# Config (basic)
# -----------------------
class CFG:
    data_dir = "D:/Semester/Semester5/DPL302/Project/sentence_dataset"
    meta = "D:/Semester/Semester5/DPL302/Project/metadata.json"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    epochs = 40
    batch_size = 16
    max_frames = 600            # pad/truncate length (frames)
    subsample_factor = 4        # conv subsample factor
    d_model = 256
    nhead = 8
    num_layers = 4
    dropout = 0.1
    lr = 2e-4
    weight_decay = 1e-4
    clip_grad = 1.0
    temperature = 0.6           # <1 to sharpen logits
    blank_penalty = 0.0         # small positive to discourage blank collapse e.g. 0.01
    num_workers = 2
    seed = 42
    save_path = "best_transformer_ctc_best.pth"
    use_amp = True
    balanced_batch = True       # use balanced sampling by token frequency
    shuffle = True

# -----------------------
# Utilities
# -----------------------
def set_seed(s):
    random.seed(s); np.random.seed(s); torch.manual_seed(s)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(s)

In [3]:
# -----------------------
# Tokenizer (space-tokenized labels)
# -----------------------
class Tokenizer:
    def __init__(self, sequences: List[List[str]]):
        tokens = sorted({t for seq in sequences for t in seq})
        self.blank = "<BLANK>"
        self.idx_to_token = [self.blank] + tokens
        self.token_to_idx = {t:i for i,t in enumerate(self.idx_to_token)}
        self.blank_idx = 0

    def encode(self, seq: List[str]) -> List[int]:
        return [self.token_to_idx[t] for t in seq]

    def decode(self, indices: List[int]) -> str:
        out = []
        prev = None
        for i in indices:
            if i == prev: 
                continue
            if i == self.blank_idx:
                prev = i; continue
            out.append(self.idx_to_token[i]); prev = i
        return " ".join(out)

# -----------------------
# Dataset
# -----------------------
class SentenceDataset(Dataset):
    def __init__(self, data_dir, meta_path, tokenizer, max_frames):
        with open(meta_path, "r", encoding="utf-8") as f:
            self.meta = json.load(f)   # id -> [labels]
        self.ids = sorted(self.meta.keys())
        self.data_dir = data_dir
        self.tokenizer = tokenizer
        self.max_frames = max_frames

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        sid = self.ids[idx]
        arr = np.load(os.path.join(self.data_dir, f"{sid}.npy")).astype(np.float32)  # (T,D)
        T, D = arr.shape
        # already normalized per your comment
        # truncate/pad later in collate
        target = np.array(self.tokenizer.encode(self.meta[sid]), dtype=np.int32)
        return arr, target

    def collate_fn(batch, max_frames):
        seqs, targets = zip(*batch)
        B = len(seqs)
        D = seqs[0].shape[1]
        padded = np.zeros((B, max_frames, D), dtype=np.float32)
        input_lengths = []
        tgt_list = []
        tgt_lens = []
        for i, s in enumerate(seqs):
            L = min(s.shape[0], max_frames)
            padded[i, :L, :] = s[:L, :]
            input_lengths.append(L)
        for t in targets:
            tgt_list.append(t)
            tgt_lens.append(len(t))
        if len(tgt_list) > 0:
            targets_concat = np.concatenate(tgt_list).astype(np.int32)
        else:
            targets_concat = np.array([], dtype=np.int32)
        return (torch.from_numpy(padded), torch.tensor(input_lengths, dtype=torch.long),
                torch.from_numpy(targets_concat).long(), torch.tensor(tgt_lens, dtype=torch.long))

# -----------------------
# Balanced sampler by token frequency (batch balanced roughly by number of tokens in batch)
# -----------------------
class BalancedSampler(Sampler):
    def __init__(self, meta, batch_size):
        # meta: id->list labels
        # build index per token
        token_to_ids = {}
        for sid, seq in meta.items():
            for t in seq:
                token_to_ids.setdefault(t, []).append(sid)
        self.token_to_ids = token_to_ids
        self.ids = list(meta.keys())
        self.batch_size = batch_size
        self.meta = meta

    def __iter__(self):
        # build batches by selecting token types uniformly then choosing ids containing them
        ids_chosen = []
        token_list = list(self.token_to_ids.keys())
        # Round-robin across tokens to reach all ids roughly balanced
        idx = 0
        while len(ids_chosen) < len(self.ids):
            t = token_list[idx % len(token_list)]
            pool = self.token_to_ids[t]
            pick = random.choice(pool)
            if pick not in ids_chosen:
                ids_chosen.append(pick)
            idx += 1
        # pad to multiple of batch_size
        while len(ids_chosen) % self.batch_size != 0:
            ids_chosen.append(random.choice(self.ids))
        # yield batches
        for i in range(0, len(ids_chosen), self.batch_size):
            yield ids_chosen[i:i+self.batch_size]

    def __len__(self):
        return math.ceil(len(self.ids) / self.batch_size)

# -----------------------
# Model: Conv subsample + Transformer encoder + linear
# -----------------------
class ConvSubsample(nn.Module):
    def __init__(self, in_dim, d_model, dropout=0.1):
        super().__init__()
        self.conv1 = nn.Conv1d(in_dim, d_model, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv1d(d_model, d_model, kernel_size=3, stride=2, padding=1)
        self.act = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        # x: (B, T, D)
        x = x.transpose(1,2)   # (B, D, T)
        x = self.conv1(x); x = self.act(x)
        x = self.conv2(x); x = self.act(x)
        x = self.dropout(x)
        x = x.transpose(1,2)   # (B, T_sub, d_model)
        return x

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=2000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div)
        pe[:, 1::2] = torch.cos(position * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x):
        # x: (B, T, D)
        T = x.size(1)
        return x + self.pe[:, :T, :]

class TransformerCTC(nn.Module):
    def __init__(self, in_dim, num_classes, d_model=256, nhead=8, num_layers=4, dropout=0.1, use_subsample=True, temp=1.0):
        super().__init__()
        self.use_subsample = use_subsample
        self.sub = ConvSubsample(in_dim, d_model) if use_subsample else None
        self.input_proj = nn.Linear(in_dim, d_model) if not use_subsample else None
        self.pos = PositionalEncoding(d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=4*d_model, dropout=dropout)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.out = nn.Linear(d_model, num_classes)
        self.d_model = d_model
        self.temp = temp
    def forward(self, x, src_key_padding_mask=None):
        # x: (B, T, D)
        if self.sub is not None:
            x = self.sub(x)    # (B, T_sub, d_model)
        else:
            x = self.input_proj(x) * math.sqrt(self.d_model)
        x = self.pos(x)
        x = x.transpose(0,1)  # (T, B, D)
        enc = self.encoder(x, src_key_padding_mask=src_key_padding_mask)
        logits = self.out(enc) / (self.temp if self.temp>0 else 1.0)
        return F.log_softmax(logits, dim=-1)   # (T', B, C)

# -----------------------
# Helpers: masks, decode
# -----------------------
def make_src_key_padding_mask(lengths, max_len, subsample=4):
    # lengths: tensor (B,) original lengths
    B = lengths.size(0)
    out_len = (max_len + subsample - 1) // subsample
    mask = torch.zeros((B, out_len), dtype=torch.bool)
    for i, L in enumerate(lengths.tolist()):
        newL = (L + subsample - 1) // subsample
        if newL < out_len:
            mask[i, newL:] = True
    return mask

def greedy_decode(log_probs, tokenizer):
    # log_probs: (T, B, C)
    probs = log_probs.exp()
    arg = probs.argmax(dim=-1).cpu().numpy()  # (T, B)
    preds = []
    blank_probs = []
    for b in range(arg.shape[1]):
        seq = arg[:, b].tolist()
        prev = None
        out = []
        for idx in seq:
            if idx == prev:
                continue
            if idx == tokenizer.blank_idx:
                prev = idx; continue
            out.append(idx); prev = idx
        preds.append(tokenizer.decode(out))
        blank_probs.append(probs[:, b, tokenizer.blank_idx].mean().item())
    return preds, sum(blank_probs)/len(blank_probs)



In [6]:
# -----------------------
# Training / Validation
# -----------------------
def train_one_epoch(model, loader, optimizer, scaler, ctc_loss_fn, cfg, tokenizer):
    model.train()
    total_loss = 0.0
    total_blank = 0.0
    for X, input_lens, targets, target_lens in tqdm(loader, desc="Train"):
        X = X.to(cfg.device); input_lens = input_lens.to(cfg.device)
        targets = targets.to(cfg.device); target_lens = target_lens.to(cfg.device)
        mask = make_src_key_padding_mask(input_lens, cfg.max_frames, subsample=cfg.subsample_factor).to(cfg.device)
        optimizer.zero_grad()
        with autocast(enabled=cfg.use_amp):
            logp = model(X, src_key_padding_mask=mask)  # (T', B, C)
            input_lens_ctc = ((input_lens + cfg.subsample_factor - 1)//cfg.subsample_factor).clamp(min=1)
            loss = ctc_loss_fn(logp, targets, input_lens_ctc, target_lens)
            if cfg.blank_penalty > 0:
                blank_prob = logp.exp()[:, :, 0].mean()
                loss = loss + cfg.blank_penalty * blank_prob
        scaler.scale(loss).backward()
        if cfg.clip_grad > 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_grad)
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item() * X.size(0)
        # blank prob reporting
        with torch.no_grad():
            bprob = logp.exp()[:, :, 0].mean().item()
            total_blank += bprob * X.size(0)
    return total_loss / len(loader.dataset), total_blank / len(loader.dataset)

def validate(model, loader, ctc_loss_fn, cfg, tokenizer):
    model.eval()
    total = 0; correct = 0; blanks = 0.0; tot_loss = 0.0
    with torch.no_grad():
        for X, input_lens, targets, target_lens in tqdm(loader, desc="Val"):
            X = X.to(cfg.device); input_lens = input_lens.to(cfg.device)
            targets = targets.to(cfg.device); target_lens = target_lens.to(cfg.device)
            mask = make_src_key_padding_mask(input_lens, cfg.max_frames, subsample=cfg.subsample_factor).to(cfg.device)
            logp = model(X, src_key_padding_mask=mask)
            input_lens_ctc = ((input_lens + cfg.subsample_factor - 1)//cfg.subsample_factor).clamp(min=1)
            loss = ctc_loss_fn(logp, targets, input_lens_ctc, target_lens)
            tot_loss += loss.item() * X.size(0)
            preds, avg_blank = greedy_decode(logp, tokenizer)
            blanks += avg_blank * X.size(0)
            # reconstruct true strings
            ptr = 0
            for b in range(X.size(0)):
                L = target_lens[b].item()
                true_inds = targets[ptr:ptr+L].cpu().numpy().tolist()
                true_str = tokenizer.decode(true_inds)
                ptr += L
                if preds[b].strip() == true_str.strip():
                    correct += 1
                total += 1
    acc = correct / total if total>0 else 0.0
    return tot_loss/len(loader.dataset), acc, blanks/len(loader.dataset)

In [14]:
# -----------------------
# Main
# -----------------------
def main(cfg: CFG):
    set_seed(cfg.seed)
    # load meta
    with open(cfg.meta, "r", encoding="utf-8") as f:
        meta = json.load(f)  # id -> [labels]
    sequences = list(meta.values())
    tokenizer = Tokenizer(sequences)
    dataset = SentenceDataset(cfg.data_dir, cfg.meta, tokenizer, cfg.max_frames)

    # split train/val
    ids = dataset.ids
    random.shuffle(ids)
    n_train = int(0.8 * len(ids))
    train_ids = set(ids[:n_train])
    train_meta = {k:v for k,v in meta.items() if k in train_ids}
    val_meta = {k:v for k,v in meta.items() if k not in train_ids}

    train_ds = SentenceDataset(cfg.data_dir, cfg.meta, tokenizer, cfg.max_frames)
    val_ds = SentenceDataset(cfg.data_dir, cfg.meta, tokenizer, cfg.max_frames)

    # dataloaders
    if cfg.balanced_batch:
        sampler = BalancedSampler(meta, cfg.batch_size)
        train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, collate_fn=lambda b: collate_fn(b, cfg.max_frames),
                                  num_workers=cfg.num_workers, sampler=sampler)
    else:
        train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, collate_fn=lambda b: collate_fn(b, cfg.max_frames),
                                  num_workers=cfg.num_workers, shuffle=cfg.shuffle)
    val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, collate_fn=lambda b: collate_fn(b, cfg.max_frames),
                            num_workers=cfg.num_workers, shuffle=False)

    # model
    sample_arr, _ = dataset[0]
    in_dim = sample_arr.shape[1]
    model = TransformerCTC(in_dim, num_classes=len(tokenizer.idx_to_token),
                           d_model=cfg.d_model, nhead=cfg.nhead, num_layers=cfg.num_layers,
                           dropout=cfg.dropout, use_subsample=True, temp=cfg.temperature)
    model.to(cfg.device)

    ctc_loss_fn = nn.CTCLoss(blank=tokenizer.blank_idx, zero_infinity=True, reduction="mean")
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, verbose=True)
    scaler = GradScaler(enabled=cfg.use_amp)

    best_acc = 0.0
    for epoch in range(1, cfg.epochs+1):
        tr_loss, tr_blank = train_one_epoch(model, train_loader, optimizer, scaler, ctc_loss_fn, cfg, tokenizer)
        val_loss, val_acc, val_blank = validate(model, val_loader, ctc_loss_fn, cfg, tokenizer)
        scheduler.step(val_acc)
        print(f"[Epoch {epoch}] tr_loss={tr_loss:.4f} tr_blank={tr_blank:.3f} | val_loss={val_loss:.4f} val_acc={val_acc:.4f} val_blank={val_blank:.3f}")
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save({
                "model_state": model.state_dict(),
                "tokenizer": tokenizer.__dict__,
                "cfg": cfg.__dict__,
                "epoch": epoch,
                "val_acc": val_acc
                }, cfg.save_path)
            print("Saved best model:", cfg.save_path)

if __name__ == "__main__":
    cfg = CFG()
    main(cfg)

ValueError: too many values to unpack (expected 2)