In [1]:
import os, re, glob, random, math, time, tarfile, urllib.request
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
import torchaudio
import sentencepiece as spm

from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from jiwer import wer, cer

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))


Device: cuda
GPU: NVIDIA GeForce RTX 4060 Laptop GPU


In [2]:
DATA_DIR = "./datasets"
TOKENIZER_DIR = "./tokenizers"
CHECKPOINT_DIR = "./checkpoints_en"

os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(TOKENIZER_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# LibriSpeech splits (recommended for real evaluation)
DOWNLOAD_SPLITS = {
    "train-clean-100": True,
    "dev-clean": True,
    "test-clean": True,
}

VOCAB_SIZE = 8000

BATCH_SIZE = 16
EPOCHS = 25
LR = 3e-4
LOG_EVERY = 50
GRAD_CLIP = 1.0

# Data constraints
MAX_AUDIO_SEC = 20.0      # we will SKIP (not truncate) audio longer than this
MAX_TOK_LEN = 250         # optionally skip extreme transcripts

# Training tricks (optional)
USE_LABEL_SMOOTHING = True
LABEL_SMOOTHING = 0.05

USE_SPEC_AUG = False      # start False; enable later


In [3]:
OPENSLR_BASE = "https://www.openslr.org/resources/12"
SPLIT_URLS = {
    "train-clean-100": f"{OPENSLR_BASE}/train-clean-100.tar.gz",
    "dev-clean":       f"{OPENSLR_BASE}/dev-clean.tar.gz",
    "test-clean":      f"{OPENSLR_BASE}/test-clean.tar.gz",
}

def download_and_extract_librispeech(split_name: str):
    url = SPLIT_URLS[split_name]
    tar_path = os.path.join(DATA_DIR, f"{split_name}.tar.gz")
    extract_root = os.path.join(DATA_DIR, "LibriSpeech")

    os.makedirs(extract_root, exist_ok=True)

    # Check if already extracted
    expected = os.path.join(extract_root, split_name)
    if os.path.isdir(expected) and len(glob.glob(os.path.join(expected, "**", "*.flac"), recursive=True)) > 0:
        print(f"{split_name}: already present.")
        return expected

    # Download tar if missing
    if not os.path.exists(tar_path):
        print(f"Downloading {split_name} ...")
        urllib.request.urlretrieve(url, tar_path)

    # Extract
    print(f"Extracting {split_name} ...")
    with tarfile.open(tar_path, "r:gz") as tar:
        tar.extractall(extract_root)

    # LibriSpeech extracts to DATA_DIR/LibriSpeech/LibriSpeech/<split>
    # Some people end up with double folder; handle both.
    candidates = [
        os.path.join(extract_root, "LibriSpeech", split_name),
        os.path.join(extract_root, split_name),
    ]
    for c in candidates:
        if os.path.isdir(c) and len(glob.glob(os.path.join(c, "**", "*.flac"), recursive=True)) > 0:
            print(f"{split_name}: ready at {c}")
            return c

    raise FileNotFoundError(f"Could not find extracted split folder for {split_name}.")

paths = {}
for split, do in DOWNLOAD_SPLITS.items():
    if do:
        paths[split] = download_and_extract_librispeech(split)

print("Split paths:", paths)


Downloading train-clean-100 ...
Extracting train-clean-100 ...
train-clean-100: ready at ./datasets\LibriSpeech\LibriSpeech\train-clean-100
Downloading dev-clean ...
Extracting dev-clean ...
dev-clean: ready at ./datasets\LibriSpeech\LibriSpeech\dev-clean
Downloading test-clean ...
Extracting test-clean ...
test-clean: ready at ./datasets\LibriSpeech\LibriSpeech\test-clean
Split paths: {'train-clean-100': './datasets\\LibriSpeech\\LibriSpeech\\train-clean-100', 'dev-clean': './datasets\\LibriSpeech\\LibriSpeech\\dev-clean', 'test-clean': './datasets\\LibriSpeech\\LibriSpeech\\test-clean'}


In [4]:
def build_librispeech_pairs(split_root: str):
    trans_map = {}
    for trans_path in glob.glob(os.path.join(split_root, "**", "*.trans.txt"), recursive=True):
        with open(trans_path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                utt_id, text = line.split(" ", 1)
                trans_map[utt_id] = text

    pairs = []
    for flac in glob.glob(os.path.join(split_root, "**", "*.flac"), recursive=True):
        utt_id = os.path.basename(flac).replace(".flac", "")
        if utt_id in trans_map:
            pairs.append((flac, trans_map[utt_id]))
    return pairs

train_pairs = build_librispeech_pairs(paths["train-clean-100"])
val_pairs   = build_librispeech_pairs(paths["dev-clean"])
test_pairs  = build_librispeech_pairs(paths["test-clean"])

print("Train pairs:", len(train_pairs))
print("Val pairs:", len(val_pairs))
print("Test pairs:", len(test_pairs))
print("Sample:", train_pairs[0])


Train pairs: 28539
Val pairs: 2703
Test pairs: 2620
Sample: ('./datasets\\LibriSpeech\\LibriSpeech\\train-clean-100\\103\\1240\\103-1240-0000.flac', 'CHAPTER ONE MISSUS RACHEL LYNDE IS SURPRISED MISSUS RACHEL LYNDE LIVED JUST WHERE THE AVONLEA MAIN ROAD DIPPED DOWN INTO A LITTLE HOLLOW FRINGED WITH ALDERS AND LADIES EARDROPS AND TRAVERSED BY A BROOK')


In [5]:
SP_PREFIX = os.path.join(TOKENIZER_DIR, "en_sp")
SP_MODEL  = SP_PREFIX + ".model"
SP_VOCAB  = SP_PREFIX + ".vocab"

# If you changed settings, DELETE old model+vocab before running this cell.
if not os.path.exists(SP_MODEL):
    print("Training SentencePiece tokenizer ...")
    txt_path = os.path.join(TOKENIZER_DIR, "en_transcripts.txt")

    with open(txt_path, "w", encoding="utf-8") as f:
        for _, t in train_pairs:
            f.write(t.strip().lower() + "\n")

    spm.SentencePieceTrainer.train(
        input=txt_path,
        model_prefix=SP_PREFIX,
        vocab_size=VOCAB_SIZE,
        model_type="unigram",
        character_coverage=1.0,
        pad_id=0, unk_id=1, bos_id=2, eos_id=3,
        byte_fallback=True,   # ✅ important for robustness
    )
    print("Tokenizer trained:", SP_MODEL)
else:
    print("Tokenizer exists:", SP_MODEL)

sp = spm.SentencePieceProcessor(model_file=SP_MODEL)
PAD_ID = sp.pad_id()
UNK_ID = sp.unk_id()
BOS_ID = sp.bos_id()
EOS_ID = sp.eos_id()

print("Vocab:", sp.get_piece_size(), "| PAD/UNK/BOS/EOS:", PAD_ID, UNK_ID, BOS_ID, EOS_ID)
print("Sanity:", sp.decode(sp.encode("this is a test", out_type=int)))


Training SentencePiece tokenizer ...
Tokenizer trained: ./tokenizers\en_sp.model
Vocab: 8000 | PAD/UNK/BOS/EOS: 0 1 2 3
Sanity: this is a test


In [6]:
def load_audio_16k_mono(path: str):
    wav, sr = torchaudio.load(path)  # (C, T)
    if wav.size(0) > 1:
        wav = wav.mean(dim=0, keepdim=True)
    if sr != 16000:
        wav = torchaudio.functional.resample(wav, sr, 16000)
    return wav.squeeze(0), 16000  # (T,), 16000


In [7]:
class LibriTokenDataset(Dataset):
    def __init__(self, pairs, sp, max_audio_sec=20.0, max_tok_len=250):
        self.pairs = pairs
        self.sp = sp
        self.max_audio_samples = int(max_audio_sec * 16000)
        self.max_tok_len = max_tok_len

        # ✅ build transform once (not inside __getitem__)
        self.mel_fn = torchaudio.transforms.MelSpectrogram(
            sample_rate=16000, n_fft=1024, hop_length=256, n_mels=80
        )

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        path, text = self.pairs[idx]
        wav, _ = load_audio_16k_mono(path)

        # ✅ skip (do not truncate) to avoid audio/text mismatch
        if wav.numel() > self.max_audio_samples:
            return self.__getitem__((idx + 1) % len(self.pairs))

        text = text.strip().lower()
        ids = self.sp.encode(text, out_type=int)
        if len(ids) > self.max_tok_len:
            return self.__getitem__((idx + 1) % len(self.pairs))

        ids = [BOS_ID] + ids + [EOS_ID]
        tok = torch.tensor(ids, dtype=torch.long)

        mel = self.mel_fn(wav.unsqueeze(0))              # (1, 80, frames)
        feat = torch.log(mel + 1e-9).squeeze(0).T        # (T, 80)
        return feat, tok


In [8]:
def collate_pad(batch):
    feats, toks = zip(*batch)
    feat_lens = torch.tensor([f.size(0) for f in feats], dtype=torch.long)
    tok_lens  = torch.tensor([t.size(0) for t in toks], dtype=torch.long)

    maxT = int(feat_lens.max())
    maxL = int(tok_lens.max())

    feat_pad = torch.zeros(len(batch), maxT, feats[0].size(1), dtype=torch.float32)
    tok_pad  = torch.full((len(batch), maxL), PAD_ID, dtype=torch.long)

    for i, (f, t) in enumerate(zip(feats, toks)):
        feat_pad[i, : f.size(0)] = f
        tok_pad[i, : t.size(0)]  = t

    return feat_pad, tok_pad, feat_lens, tok_lens

train_ds = LibriTokenDataset(train_pairs, sp, max_audio_sec=MAX_AUDIO_SEC, max_tok_len=MAX_TOK_LEN)
val_ds   = LibriTokenDataset(val_pairs,   sp, max_audio_sec=MAX_AUDIO_SEC, max_tok_len=MAX_TOK_LEN)
test_ds  = LibriTokenDataset(test_pairs,  sp, max_audio_sec=MAX_AUDIO_SEC, max_tok_len=MAX_TOK_LEN)

# ✅ num_workers=0 for Windows notebook stability
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  collate_fn=collate_pad,
                          num_workers=0, pin_memory=torch.cuda.is_available())
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_pad,
                          num_workers=0, pin_memory=torch.cuda.is_available())
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_pad,
                          num_workers=0, pin_memory=torch.cuda.is_available())

print("Loaders ready:", len(train_loader), len(val_loader), len(test_loader))


Loaders ready: 1784 169 164


In [9]:
def make_key_padding_mask(lengths, max_len):
    idx = torch.arange(max_len, device=lengths.device).unsqueeze(0)
    return idx >= lengths.unsqueeze(1)   # (B, T) True where padded

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))  # (1, max_len, d_model)

    def forward(self, x):
        T = x.size(1)
        return x + self.pe[:, :T]

def make_key_padding_mask(lengths, max_len):
    idx = torch.arange(max_len, device=lengths.device).unsqueeze(0)
    return idx >= lengths.unsqueeze(1)   # bool


In [10]:
class ConvSubsample(nn.Module):
    """
    2-layer Conv2D subsampling: T -> ~T/4
    Input:  (B, T, F)
    Output: (B, T', d_model)
    """
    def __init__(self, in_feats=80, d_model=256, channels=32):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, channels, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )
        f_out = (in_feats + 1) // 2
        f_out = (f_out + 1) // 2
        self.out = nn.Linear(channels * f_out, d_model)

    def forward(self, x, lengths):
        x = x.unsqueeze(1)             # (B,1,T,F)
        x = self.conv(x)               # (B,C,T',F')
        B, C, T2, F2 = x.shape
        x = x.permute(0, 2, 1, 3).contiguous().view(B, T2, C * F2)
        x = self.out(x)

        l1 = (lengths + 1) // 2
        l2 = (l1 + 1) // 2
        return x, l2


In [11]:
class ASRTransformer(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        pad_id: int,
        n_mels: int = 80,
        d_model: int = 256,
        nhead: int = 4,
        num_encoder_layers: int = 6,
        num_decoder_layers: int = 6,
        dim_feedforward: int = 512,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.d_model = d_model
        self.pad_id = pad_id

        self.subsample = ConvSubsample(in_feats=n_mels, d_model=d_model, channels=32)
        self.pos_enc   = PositionalEncoding(d_model)

        self.tok_emb   = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=False,
            norm_first=True,
        )
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt, src_lengths, tgt_key_padding_mask=None):
        # src: (B, Ts, 80) | tgt: (B, Lt)
        src, src_lengths2 = self.subsample(src, src_lengths)     # (B, Ts', d)
        src = src * math.sqrt(self.d_model)
        src = self.pos_enc(src)
        src = src.transpose(0, 1)                                # (Ts', B, d)

        Ts2 = src.size(0)
        src_pad_mask = make_key_padding_mask(src_lengths2, Ts2)  # (B, Ts')

        tgt = self.tok_emb(tgt) * math.sqrt(self.d_model)
        tgt = self.pos_enc(tgt)
        tgt = tgt.transpose(0, 1)                                # (Lt, B, d)

        tgt_mask = subsequent_mask(tgt.size(0), tgt.device)

        out = self.transformer(
            src, tgt,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_pad_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=src_pad_mask,
        )
        out = out.transpose(0, 1)  # (B, Lt, d)
        return self.fc_out(out)


In [12]:
model = ASRTransformer(vocab_size=sp.get_piece_size(), pad_id=PAD_ID).to(device)

if USE_LABEL_SMOOTHING:
    criterion = nn.CrossEntropyLoss(ignore_index=PAD_ID, label_smoothing=LABEL_SMOOTHING)
else:
    criterion = nn.CrossEntropyLoss(ignore_index=PAD_ID)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, betas=(0.9, 0.98), weight_decay=1e-2)

use_cuda = torch.cuda.is_available()
scaler = torch.amp.GradScaler("cuda", enabled=use_cuda)

LAST_CKPT = os.path.join(CHECKPOINT_DIR, "asr_en_last.pt")
BEST_CKPT = os.path.join(CHECKPOINT_DIR, "asr_en_best.pt")

@dataclass
class TrainState:
    epoch: int = 0
    best_val: float = 1e9

state = TrainState()

if os.path.exists(LAST_CKPT):
    ckpt = torch.load(LAST_CKPT, map_location=device)
    model.load_state_dict(ckpt["model"])
    optimizer.load_state_dict(ckpt["optim"])
    state.epoch = ckpt.get("epoch", 0)
    state.best_val = ckpt.get("best_val", 1e9)
    print(f"Resumed from epoch {state.epoch}, best_val={state.best_val:.4f}")
else:
    print("No checkpoint found. Starting fresh.")

print("Params (M):", sum(p.numel() for p in model.parameters())/1e6)




No checkpoint found. Starting fresh.
Params (M): 12.186016


In [13]:
@torch.inference_mode()
def run_eval_loss(loader):
    model.eval()
    total_loss, steps = 0.0, 0

    for feats, toks, feat_lens, tok_lens in loader:
        feats, toks = feats.to(device), toks.to(device)
        feat_lens, tok_lens = feat_lens.to(device), tok_lens.to(device)

        tgt_in  = toks[:, :-1]
        tgt_out = toks[:, 1:]
        tgt_pad_mask = make_key_padding_mask(tok_lens - 1, tgt_in.size(1))

        logits = model(feats, tgt_in, src_lengths=feat_lens, tgt_key_padding_mask=tgt_pad_mask)
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))

        total_loss += float(loss.item())
        steps += 1

    return total_loss / max(1, steps)


In [15]:
import torch

def subsequent_mask(size: int, device: torch.device):
    """
    Generate a causal (look-ahead) mask for Transformer decoder.
    Shape: (size, size)
    """
    return torch.triu(
        torch.ones(size, size, device=device, dtype=torch.bool),
        diagonal=1
    )


In [16]:
for epoch in range(state.epoch, EPOCHS):
    model.train()
    t0 = time.time()
    total, steps = 0.0, 0

    for b, (feats, toks, feat_lens, tok_lens) in enumerate(train_loader):
        feats, toks = feats.to(device), toks.to(device)
        feat_lens, tok_lens = feat_lens.to(device), tok_lens.to(device)

        if USE_SPEC_AUG:
            specaug.train()
            feats = specaug(feats)

        tgt_in  = toks[:, :-1]
        tgt_out = toks[:, 1:]
        tgt_pad_mask = make_key_padding_mask(tok_lens - 1, tgt_in.size(1))

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast(device_type="cuda", enabled=use_cuda):
            logits = model(feats, tgt_in, src_lengths=feat_lens, tgt_key_padding_mask=tgt_pad_mask)
            loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        scaler.step(optimizer)
        scaler.update()

        total += float(loss.item())
        steps += 1

        if (b == 0) or ((b + 1) % LOG_EVERY == 0):
            print(f"E{epoch+1}/{EPOCHS} | batch {b+1:05d} | loss {loss.item():.4f}")

    train_loss = total / max(1, steps)
    val_loss = run_eval_loss(val_loader)

    # Save last
    torch.save({
        "epoch": epoch+1,
        "best_val": state.best_val,
        "model": model.state_dict(),
        "optim": optimizer.state_dict(),
        "sp_model": SP_MODEL,
    }, LAST_CKPT)

    # Save best
    best_tag = ""
    if val_loss < state.best_val:
        state.best_val = val_loss
        torch.save({
            "epoch": epoch+1,
            "best_val": state.best_val,
            "model": model.state_dict(),
            "optim": optimizer.state_dict(),
            "sp_model": SP_MODEL,
        }, BEST_CKPT)
        best_tag = "✅ best"

    print(f"Epoch {epoch+1}/{EPOCHS} done in {time.time()-t0:.1f}s | train={train_loss:.4f} | val={val_loss:.4f} {best_tag}")


E1/25 | batch 00001 | loss 9.1936
E1/25 | batch 00050 | loss 6.8160
E1/25 | batch 00100 | loss 6.7713
E1/25 | batch 00150 | loss 6.6492
E1/25 | batch 00200 | loss 6.6592
E1/25 | batch 00250 | loss 6.7142
E1/25 | batch 00300 | loss 6.6077
E1/25 | batch 00350 | loss 6.3930
E1/25 | batch 00400 | loss 6.4021
E1/25 | batch 00450 | loss 6.3427
E1/25 | batch 00500 | loss 6.4428
E1/25 | batch 00550 | loss 6.3677
E1/25 | batch 00600 | loss 6.3221
E1/25 | batch 00650 | loss 6.0563
E1/25 | batch 00700 | loss 6.1656
E1/25 | batch 00750 | loss 6.0675
E1/25 | batch 00800 | loss 6.0909
E1/25 | batch 00850 | loss 6.1539
E1/25 | batch 00900 | loss 6.3288
E1/25 | batch 00950 | loss 6.1328
E1/25 | batch 01000 | loss 6.0092
E1/25 | batch 01050 | loss 6.1644
E1/25 | batch 01100 | loss 6.0557
E1/25 | batch 01150 | loss 6.0179
E1/25 | batch 01200 | loss 6.0183
E1/25 | batch 01250 | loss 6.2742
E1/25 | batch 01300 | loss 5.9616
E1/25 | batch 01350 | loss 6.2090
E1/25 | batch 01400 | loss 6.0582
E1/25 | batch 

In [17]:
def norm_text(s: str) -> str:
    s = s.lower()
    s = re.sub(r"[^a-z0-9\s']", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s

@torch.inference_mode()
def beam_search_decode(model, feats, feat_len, beam_size=5, max_len=200, len_penalty=0.6):
    model.eval()
    device_ = feats.device

    beams = [(torch.tensor([[BOS_ID]], device=device_, dtype=torch.long), 0.0, False)]

    for _ in range(max_len):
        new_beams = []
        for ys, score, finished in beams:
            if finished:
                new_beams.append((ys, score, True))
                continue

            logits = model(feats, ys, src_lengths=feat_len, tgt_key_padding_mask=None)
            log_probs = torch.log_softmax(logits[:, -1, :], dim=-1)

            topk = torch.topk(log_probs, k=beam_size, dim=-1)
            vals = topk.values.squeeze(0)
            idxs = topk.indices.squeeze(0)

            for lp, tid in zip(vals.tolist(), idxs.tolist()):
                ys2 = torch.cat([ys, torch.tensor([[tid]], device=device_, dtype=torch.long)], dim=1)
                new_beams.append((ys2, score + lp, tid == EOS_ID))

        def rank_key(item):
            ys, sc, fin = item
            L = ys.size(1)
            lp = ((5 + L) / 6) ** len_penalty
            return sc / lp

        new_beams.sort(key=rank_key, reverse=True)
        beams = new_beams[:beam_size]

        if all(fin for _, _, fin in beams):
            break

    best = max(beams, key=lambda x: x[1])[0].squeeze(0).tolist()
    return best

@torch.inference_mode()
def eval_wer_cer_fast(loader, max_utterances=30, beam_size=5, max_len=200):
    model.eval()
    refs, hyps = [], []
    n = 0
    t0 = time.time()

    for feats, toks, feat_lens, tok_lens in loader:
        feats, toks = feats.to(device), toks.to(device)
        feat_lens = feat_lens.to(device)

        B = feats.size(0)
        for i in range(B):
            f = feats[i:i+1]
            fl = feat_lens[i:i+1]

            pred_ids = beam_search_decode(model, f, fl, beam_size=beam_size, max_len=max_len)
            if pred_ids and pred_ids[0] == BOS_ID:
                pred_ids = pred_ids[1:]
            if EOS_ID in pred_ids:
                pred_ids = pred_ids[:pred_ids.index(EOS_ID)]
            pred_ids = [x for x in pred_ids if x != PAD_ID]
            pred_txt = sp.decode(pred_ids)

            ref_ids = [x for x in toks[i].tolist() if x not in (PAD_ID, BOS_ID, EOS_ID)]
            ref_txt = sp.decode(ref_ids)

            refs.append(norm_text(ref_txt))
            hyps.append(norm_text(pred_txt))
            n += 1

            if n % 10 == 0:
                print(f"Decoded {n}/{max_utterances} | WER~{wer(refs, hyps):.3f} CER~{cer(refs, hyps):.3f} | {time.time()-t0:.1f}s")

            if n >= max_utterances:
                print(f"Eval | WER={wer(refs, hyps):.4f} | CER={cer(refs, hyps):.4f} | N={n}")
                return

    print(f"Eval | WER={wer(refs, hyps):.4f} | CER={cer(refs, hyps):.4f} | N={n}")


In [19]:
# ----------------------------
# Load BEST checkpoint (safe)
# ----------------------------
if os.path.exists(BEST_CKPT):
    ckpt = torch.load(BEST_CKPT, map_location=device, weights_only=True)
    # If your checkpoint is a dict like {"model": state_dict, ...}
    if isinstance(ckpt, dict) and "model" in ckpt:
        model.load_state_dict(ckpt["model"])
    else:
        # fallback if you saved the raw state_dict
        model.load_state_dict(ckpt)
    print("Loaded best checkpoint:", BEST_CKPT)
else:
    print("BEST checkpoint not found:", BEST_CKPT)

# ----------------------------
# Evaluate WER/CER (more stable)
# ----------------------------
EVAL_UTT = 200        # 30 is noisy; 200 is much more reliable
BEAM = 8
MAX_LEN = 220
LEN_PEN = 0.8

# If your eval_wer_cer_fast doesn't take len_penalty, keep it as-is.
# Otherwise, update it similarly (optional).
eval_wer_cer_fast(test_loader, max_utterances=EVAL_UTT, beam_size=BEAM, max_len=MAX_LEN)

# ----------------------------
# Show one random sample
# ----------------------------
model.eval()

sample_path, sample_text = random.choice(test_pairs)
wav, _ = load_audio_16k_mono(sample_path)

# Use same feature extraction as training
mel_fn = torchaudio.transforms.MelSpectrogram(
    sample_rate=16000, n_fft=1024, hop_length=256, n_mels=80
)
mel = mel_fn(wav.unsqueeze(0))
feat = torch.log(mel + 1e-9).squeeze(0).T  # (T, 80)

feats = feat.unsqueeze(0).to(device)
feat_len = torch.tensor([feat.size(0)], device=device, dtype=torch.long)

# Decode with beam search (stronger settings)
pred_ids = beam_search_decode(
    model, feats, feat_len,
    beam_size=BEAM, max_len=250, len_penalty=LEN_PEN
)

# Strip BOS/EOS/PAD
if pred_ids and pred_ids[0] == BOS_ID:
    pred_ids = pred_ids[1:]
if EOS_ID in pred_ids:
    pred_ids = pred_ids[:pred_ids.index(EOS_ID)]
pred_ids = [x for x in pred_ids if x != PAD_ID]

pred_text = sp.decode(pred_ids)

print("\nAudio:", sample_path)
print("REF :", sample_text)
print("HYP :", pred_text)

# Optional: also show normalized versions (useful for debugging WER differences)
print("\nREF(norm):", norm_text(sample_text))
print("HYP(norm):", norm_text(pred_text))

Loaded best checkpoint: ./checkpoints_en\asr_en_best.pt
Decoded 10/200 | WER~0.583 CER~0.371 | 21.3s
Decoded 20/200 | WER~0.500 CER~0.334 | 47.7s
Decoded 30/200 | WER~0.512 CER~0.335 | 77.3s
Decoded 40/200 | WER~0.497 CER~0.329 | 91.4s
Decoded 50/200 | WER~0.493 CER~0.330 | 112.0s
Decoded 60/200 | WER~0.483 CER~0.324 | 135.4s
Decoded 70/200 | WER~0.504 CER~0.339 | 161.7s
Decoded 80/200 | WER~0.504 CER~0.345 | 185.2s
Decoded 90/200 | WER~0.494 CER~0.338 | 218.4s
Decoded 100/200 | WER~0.502 CER~0.344 | 236.1s
Decoded 110/200 | WER~0.517 CER~0.352 | 258.2s
Decoded 120/200 | WER~0.524 CER~0.355 | 271.1s
Decoded 130/200 | WER~0.545 CER~0.367 | 307.8s
Decoded 140/200 | WER~0.547 CER~0.368 | 327.6s
Decoded 150/200 | WER~0.546 CER~0.368 | 340.8s
Decoded 160/200 | WER~0.546 CER~0.368 | 364.6s
Decoded 170/200 | WER~0.540 CER~0.364 | 386.6s
Decoded 180/200 | WER~0.535 CER~0.361 | 418.1s
Decoded 190/200 | WER~0.536 CER~0.365 | 453.5s
Decoded 200/200 | WER~0.534 CER~0.365 | 491.1s
Eval | WER=0.5338