In [1]:
import math
import torch
import torchaudio
import soundfile as sf

from train_streaming_ctc import BiLSTMCTC, CharTokenizerCTC

# Prefer soundfile backend to avoid torchcodec requirement when loading audio.
try:
    torchaudio.set_audio_backend("soundfile")
except Exception:
    pass

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_PATH = "bilstm_ctc_checkpoint.pt"


def build_logmel_extractor(
    target_sample_rate: int = 16_000,
    n_fft: int = 400,
    hop_length: int = 160,
    n_mels: int = 80,
    fmin: float = 0.0,
    fmax: float | None = None,
):
    mel = torchaudio.transforms.MelSpectrogram(
        sample_rate=target_sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
        f_min=fmin,
        f_max=fmax or target_sample_rate / 2,
        power=2.0,
    )
    to_db = torchaudio.transforms.AmplitudeToDB(stype="power")

    def extract(waveform: torch.Tensor, sample_rate: int) -> torch.Tensor:
        # Collapse stereo to mono and resample if needed, then return [frames, n_mels].
        if waveform.ndim == 2:
            waveform = waveform.mean(dim=0)
        if waveform.ndim == 1:
            waveform = waveform.unsqueeze(0)
        if sample_rate != target_sample_rate:
            waveform = torchaudio.functional.resample(waveform, sample_rate, target_sample_rate)
        logmel = to_db(mel(waveform)).squeeze(0)
        return logmel.transpose(0, 1)

    return extract


def load_audio(path: str):
    try:
        waveform, sample_rate = torchaudio.load(path)
        return waveform, sample_rate
    except ImportError as e:
        if "torchcodec" not in str(e).lower():
            raise
    audio_np, sample_rate = sf.read(path, dtype="float32", always_2d=False)
    waveform = torch.from_numpy(audio_np)
    if waveform.ndim == 1:
        waveform = waveform.unsqueeze(0)
    elif waveform.ndim == 2:
        waveform = waveform.transpose(0, 1)
    return waveform, sample_rate


def load_bilstm_ctc(checkpoint_path: str = CHECKPOINT_PATH, device: torch.device = DEVICE):
    ckpt = torch.load(checkpoint_path, map_location=device)

    tokenizer = CharTokenizerCTC()
    tok_state = ckpt.get("tokenizer_state") or {}
    if tok_state:
        tokenizer.id2ch = tok_state.get("id2ch", tokenizer.id2ch)
        tokenizer.blank_id = tok_state.get("blank_id", tokenizer.blank_id)
        tokenizer.ch2id = {ch: i for i, ch in enumerate(tokenizer.id2ch)}
        tokenizer.vocab_size = len(tokenizer.id2ch)

    cfg = ckpt.get("config") or {}
    logmel_extractor = build_logmel_extractor(
        target_sample_rate=cfg.get("sample_rate", 16_000),
        n_fft=cfg.get("n_fft", 400),
        hop_length=cfg.get("hop_length", 160),
        n_mels=cfg.get("n_mels", 80),
        fmin=cfg.get("fmin", 0.0),
        fmax=cfg.get("fmax"),
    )

    model = BiLSTMCTC(
        n_mels=cfg.get("n_mels", 80),
        vocab_size=tokenizer.vocab_size,
        hidden=cfg.get("hidden", 256),
        num_layers=cfg.get("num_layers", 3),
        dropout=cfg.get("dropout", 0.1),
    ).to(device)
    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()
    return model, tokenizer, logmel_extractor


def log_add_exp(a: float, b: float) -> float:
    if a == -math.inf:
        return b
    if b == -math.inf:
        return a
    if a < b:
        a, b = b, a
    return a + math.log1p(math.exp(b - a))


def greedy_decode(log_probs: torch.Tensor, tokenizer: CharTokenizerCTC) -> str:
    best_path = log_probs.argmax(dim=-1)
    if best_path.ndim == 2:
        best_path = best_path[:, 0]
    decoded = []
    prev = None
    for idx in best_path.tolist():
        if idx == tokenizer.blank_id:
            prev = None
            continue
        if idx != prev:
            decoded.append(tokenizer.id2ch[idx])
        prev = idx
    return "".join(decoded)


def ctc_prefix_beam_search(log_probs: torch.Tensor, tokenizer: CharTokenizerCTC, beam_width: int = 10) -> str:
    if log_probs.ndim == 3:
        log_probs = log_probs[:, 0, :]
    log_probs = log_probs.cpu()
    T, V = log_probs.shape
    beams: dict[tuple[int, ...], tuple[float, float]] = {(): (0.0, -math.inf)}
    for t in range(T):
        next_beams: dict[tuple[int, ...], tuple[float, float]] = {}
        step_log_probs = log_probs[t]
        topk_vals, topk_ids = torch.topk(step_log_probs, k=min(beam_width * 2, V))
        topk = list(zip(topk_ids.tolist(), topk_vals.tolist()))
        for prefix, (p_b, p_nb) in beams.items():
            # Extend with blank
            p_blank = step_log_probs[tokenizer.blank_id].item()
            nb_p_b, nb_p_nb = next_beams.get(prefix, (-math.inf, -math.inf))
            nb_p_b = log_add_exp(nb_p_b, p_b + p_blank)
            nb_p_b = log_add_exp(nb_p_b, p_nb + p_blank)
            next_beams[prefix] = (nb_p_b, nb_p_nb)

            # Extend with non-blank tokens
            last = prefix[-1] if prefix else None
            for idx, p in topk:
                if idx == tokenizer.blank_id:
                    continue
                new_prefix = prefix + (idx,) if idx != last else prefix
                nb_p_b, nb_p_nb = next_beams.get(new_prefix, (-math.inf, -math.inf))
                if idx == last:
                    nb_p_nb = log_add_exp(nb_p_nb, p_b + p)
                else:
                    nb_p_nb = log_add_exp(nb_p_nb, p_b + p)
                nb_p_nb = log_add_exp(nb_p_nb, p_nb + p)
                next_beams[new_prefix] = (nb_p_b, nb_p_nb)

        def beam_score(item):
            p_b, p_nb = item[1]
            return log_add_exp(p_b, p_nb)

        beams = dict(sorted(next_beams.items(), key=beam_score, reverse=True)[:beam_width])

    best_prefix, (p_b, p_nb) = max(beams.items(), key=lambda kv: log_add_exp(kv[1][0], kv[1][1]))
    tokens = []
    prev = None
    for idx in best_prefix:
        if idx == tokenizer.blank_id:
            prev = None
            continue
        if idx != prev:
            tokens.append(tokenizer.id2ch[idx])
        prev = idx
    return "".join(tokens)


def transcribe_waveform(
    waveform: torch.Tensor,
    sample_rate: int,
    model: BiLSTMCTC,
    tokenizer: CharTokenizerCTC,
    logmel_extractor,
    device: torch.device = DEVICE,
    decode: str = "beam",
    beam_width: int = 10,
) -> str:
    model.eval()
    with torch.inference_mode():
        features = logmel_extractor(waveform, sample_rate).unsqueeze(0).to(device)
        log_probs = model(features)
        if decode == "beam":
            transcript = ctc_prefix_beam_search(log_probs, tokenizer, beam_width=beam_width)
        else:
            transcript = greedy_decode(log_probs.cpu(), tokenizer)
    return transcript


def transcribe_file(
    path: str,
    model: BiLSTMCTC,
    tokenizer: CharTokenizerCTC,
    logmel_extractor,
    device: torch.device = DEVICE,
    decode: str = "beam",
    beam_width: int = 10,
) -> str:
    waveform, sample_rate = load_audio(path)
    return transcribe_waveform(waveform, sample_rate, model, tokenizer, logmel_extractor, device, decode, beam_width)


model, tokenizer, logmel_extractor = load_bilstm_ctc()


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from streaming_librispeech_dataset import StreamingLibriSpeechDataset


def _canonical_split(split: str) -> str:
    aliases = {
        "validation.clean": "validation",
        "dev-clean": "validation",
        "dev": "validation",
        "test.clean": "test",
        "test-clean": "test",
    }
    return aliases.get(split, split)


def _levenshtein(seq1, seq2):
    m, n = len(seq1), len(seq2)
    dp = list(range(n + 1))
    for i in range(1, m + 1):
        prev = dp[0]
        dp[0] = i
        s1 = seq1[i - 1]
        for j in range(1, n + 1):
            temp = dp[j]
            cost = 0 if s1 == seq2[j - 1] else 1
            dp[j] = min(dp[j] + 1, dp[j - 1] + 1, prev + cost)
            prev = temp
    return dp[-1]


def _cer(ref: str, hyp: str) -> float:
    if not ref:
        return 0.0 if not hyp else 1.0
    return _levenshtein(ref, hyp) / len(ref)


def _wer(ref: str, hyp: str) -> float:
    ref_toks, hyp_toks = ref.split(), hyp.split()
    if not ref_toks:
        return 0.0 if not hyp_toks else 1.0
    return _levenshtein(ref_toks, hyp_toks) / len(ref_toks)


def evaluate_librispeech(
    split: str = "validation",
    subset: str = "clean",
    max_samples: int = 20,
    decode: str = "beam",
    beam_width: int = 10,
):
    """
    Stream a few samples from LibriSpeech and report CER/WER.
    """

    split = _canonical_split(split)

    dataset = StreamingLibriSpeechDataset(
        subset=subset,
        split=split,
        sampling_rate=16_000,
        streaming=True,
        max_samples=max_samples,
    )

    char_total = 0
    char_errors = 0
    word_total = 0
    word_errors = 0

    for idx, (waveform, _logmel, sample_rate, transcript) in enumerate(dataset, start=1):
        ref = tokenizer.normalize(transcript)
        hyp = tokenizer.normalize(
            transcribe_waveform(
                waveform,
                sample_rate,
                model,
                tokenizer,
                logmel_extractor,
                decode=decode,
                beam_width=beam_width,
            )
        )

        char_errors += _levenshtein(ref, hyp)
        char_total += len(ref)
        word_errors += _levenshtein(ref.split(), hyp.split())
        word_total += max(len(ref.split()), 1)

    cer = char_errors / max(char_total, 1)
    wer = word_errors / max(word_total, 1)
    print(f"Split={split}, samples={max_samples}, decode={decode}, beam_width={beam_width}")
    print(f"CER={cer:.4f}, WER={wer:.4f}")
    return {"cer": cer, "wer": wer, "samples": max_samples}


In [3]:
# Example: run after the first two cells have executed
metrics_dev = evaluate_librispeech(split="validation", max_samples=20, decode="beam", beam_width=10)
metrics_test = evaluate_librispeech(split="test", max_samples=20, decode="beam", beam_width=10)


Split=validation, samples=20, decode=beam, beam_width=10
CER=0.2716, WER=0.6386
Split=test, samples=20, decode=beam, beam_width=10
CER=0.2811, WER=0.6871


In [6]:
# Transcribe a local WAV file using the loaded model
wav_path = "audio_tests/epi.wav"  # change to your file
transcript = transcribe_file(
    wav_path,
    model,
    tokenizer,
    logmel_extractor,
    decode="beam",
    beam_width=10,
)
print(transcript)


her lile it eat gelar im love deter od oldite demolegeot enede e lorg a an de quint e lad loreded to da he won wet in toge mole and lined in uorig e be doal i e te ganis blel o tedit it anger wy beleded gher klolis te li no


In [7]:
# Transcribe a random sample from the LibriSpeech test split
import random
from streaming_librispeech_dataset import StreamingLibriSpeechDataset

max_samples = 50  # how many items to scan from the stream
target_idx = random.randint(0, max_samples - 1)

dataset = StreamingLibriSpeechDataset(
    subset="clean",
    split="test",
    sampling_rate=16_000,
    streaming=True,
    max_samples=max_samples,
)

chosen = None
for i, (waveform, _logmel, sample_rate, transcript) in enumerate(dataset):
    if i == target_idx:
        chosen = (waveform, sample_rate, transcript)
        break

if chosen is None:
    raise RuntimeError(f"Failed to grab sample {target_idx} from first {max_samples} items.")

waveform, sample_rate, reference = chosen
decoded = transcribe_waveform(
    waveform,
    sample_rate,
    model,
    tokenizer,
    logmel_extractor,
    decode="beam",
    beam_width=10,
)
print(f"Reference: {tokenizer.normalize(reference)}")
print(f"Decoded:   {decoded}")




Reference: it is you who are mistaken raoul i have read his distress in his eyes in his every gesture and action the whole day
Decoded:   it is e wiv was daken erl i arve lid is distres in his es in his every gestur at paction the whole day
