In [1]:
from pathlib import Path

data_root = Path("LibriSpeech/train-clean-100")
audio_with_text = []

for trans_file in sorted(data_root.rglob("*.trans.txt")):
    with trans_file.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                clip_id, transcript = line.split(" ", 1)
            except ValueError:
                continue
            audio_path = trans_file.with_name(f"{clip_id}.flac")
            if not audio_path.exists():
                continue
            audio_with_text.append([str(audio_path), transcript])

print(f"Collected {len(audio_with_text)} audio/transcription pairs")


Collected 28539 audio/transcription pairs


In [2]:
import soundfile as sf

def load_flac(path):
    waveform, sample_rate = sf.read(path, dtype="float32")
    return waveform, sample_rate

def pairs_to_audio_dataset(pairs, limit=None):
    dataset = []
    for idx, (audio_path, transcript) in enumerate(pairs):
        if limit is not None and idx >= limit:
            break
        waveform, _ = load_flac(audio_path)
        dataset.append([waveform, transcript])
    return dataset

# Limit can be raised/removed to collect more entries, but beware of RAM.
audio_dataset = pairs_to_audio_dataset(audio_with_text, limit=5)
print(len(audio_dataset))
audio_dataset[0][0].shape, audio_dataset[0][1]


5


((225360,),
 'CHAPTER ONE MISSUS RACHEL LYNDE IS SURPRISED MISSUS RACHEL LYNDE LIVED JUST WHERE THE AVONLEA MAIN ROAD DIPPED DOWN INTO A LITTLE HOLLOW FRINGED WITH ALDERS AND LADIES EARDROPS AND TRAVERSED BY A BROOK')

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
import librosa
import numpy as np
from torch.nn.utils.rnn import pad_sequence

class OnDemandLibriSpeechDataset(Dataset):
    """Load FLAC files on access so we never keep the entire corpus in RAM."""
    def __init__(self, audio_pairs):
        self.audio_pairs = audio_pairs

    def __len__(self):
        return len(self.audio_pairs)

    def __getitem__(self, idx):
        audio_path, transcript = self.audio_pairs[idx]
        waveform, sample_rate = load_flac(audio_path)
        waveform_tensor = torch.from_numpy(waveform)
        logmel = self.audio_to_melspec(waveform=waveform, sr=sample_rate)
        logmel_tensor = torch.from_numpy(logmel).float() 
        return waveform_tensor, logmel_tensor, sample_rate, transcript
    
    def audio_to_melspec(self, waveform, sr = 16_000, n_fft=400, hop_length=160, n_mels=80, fmin=0.0, fmax=None, eps=1e-9):
        mel = librosa.feature.melspectrogram(
            y=waveform, 
            sr=sr,
            n_fft=n_fft,
            hop_length=hop_length,
            n_mels=n_mels,
            fmin=fmin,
            fmax=sr/2,
            power=2.0
        )
        logmel = np.log(mel + eps)
        return logmel



def collate_waveforms_ctc(batch):
    """
    batch: list of (waveform_tensor, logmel_tensor, sample_rate, transcript)
      waveform_tensor: (N,)
      logmel_tensor: (n_mels, T)

    returns:
      waveforms: list[Tensor] (still ragged; optional)
      logmels_padded: Tensor (B, T_max, n_mels)
      input_lengths: Tensor (B,)  # number of frames T for each sample
      sample_rates: list[int]
      transcripts: list[str]
    """
    waveforms, logmels, sample_rates, transcripts = zip(*batch)

    # Convert each logmel from (n_mels, T) -> (T, n_mels)
    logmels_TF = [lm.transpose(0, 1) for lm in logmels]  # list of (T_i, n_mels)

    input_lengths = torch.tensor([lm.shape[0] for lm in logmels_TF], dtype=torch.long)

    # Pad in time dimension to (B, T_max, n_mels)
    logmels_padded = pad_sequence(logmels_TF, batch_first=True, padding_value=0.0)

    return waveforms, logmels_padded, input_lengths, list(sample_rates), list(transcripts)

train_dataset = OnDemandLibriSpeechDataset(audio_with_text)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_waveforms_ctc)
batch_waveforms, logmels, lengths, batch_rates, batch_text = next(iter(train_loader))
print(logmels.shape) # batch size x T_max x n_mels

torch.Size([4, 1579, 80])


In [4]:
import IPython.display as ipd
ipd.display(ipd.Audio(batch_waveforms[1], rate=batch_rates[0]))

# Tokenizer

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CharTokenizerCTC:
    def __init__(self):
        alphabet = list("abcdefghijklmnopqrstuvwxyz' ")  # includes space
        self.blank_id = 0
        self.id2ch = ["<blank>"] + alphabet
        self.ch2id = {ch: i for i, ch in enumerate(self.id2ch)}
        self.vocab_size = len(self.id2ch)

    def normalize(self, text: str) -> str:
        text=text.lower()
        out=[]
        for c in text:
            out.append(c if c in self.ch2id else " ")
        s = "".join(out)
        while "  " in s:
            s = s.rplace("  ", " ")
        return s.strip()
    
    def encode(self, text: str) -> torch.Tensor:
        text=self.normalize(text)
        ids = [self.ch2id[c] for c in text]
        return torch.tensor(ids, dtype=torch.long)
    
    def decode_ctc_greedy(self, log_probs: torch.Tensor) -> str:
        """
        log_probs: (T, V) or (T, B, V) - use per-sample before calling.
        """
        pred = log_probs.argmax(dim=-1) # (T,)
        # CTC collapse repeats, remove blanks
        out = []
        prev = None
        for p in pred.tolist():
            if p != prev and p!= self.blank_id:
                out.append(self.id2ch[p])
            prev = p
        return "".join(out).strip()

# Network

In [6]:
class BiLSTMCTC(nn.Module):
    def __init__(self, n_mels, vocab_size, hidden = 256, num_layers = 3, dropout = 0.1):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=n_mels,
            hidden_size = hidden,
            num_layers = num_layers,
            dropout = dropout if num_layers > 1 else 0.0,
            bidirectional = True,
            batch_first = True, # (B, T, F)
        )
        self.proj = nn.Linear(hidden * 2, vocab_size)

    def forward(self, x):
        """
        x: (B, T, n_mels)
        returns:
          log_probs: (T, B, V)  (what nn.CTCLoss expects if batch_first=False)
        """
        h, _ = self.lstm(x)          # (B, T, 2H)
        logits = self.proj(h)        # (B, T, V)
        log_probs = F.log_softmax(logits, dim=-1)  # (B, T, V)
        return log_probs.transpose(0, 1)           # (T, B, V)

    

# Target builder

In [7]:
def make_ctc_targets(tokenizer: CharTokenizerCTC, batch_text):
    """
    batch_text: list[str] length B
    returns:
      targets_1d: (sumU,)
      target_lengths: (B,)
    """
    encoded = [tokenizer.encode(t) for t in batch_text]
    target_lengths = torch.tensor([e.numel() for e in encoded], dtype=torch.long)
    targets_1d = torch.cat(encoded, dim=0) if len(encoded) > 0 else torch.empty(0, dtype=torch.long)
    return targets_1d, target_lengths

In [8]:
tok = CharTokenizerCTC()
example = make_ctc_targets(tok, ["I am kevin"])
print(example[0])

tensor([ 9, 28,  1, 13, 28, 11,  5, 22,  9, 14])


# Training

In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = CharTokenizerCTC()
sample_n_mels = train_dataset[0][1].shape[0]
model = BiLSTMCTC(
    n_mels=sample_n_mels,
    vocab_size=tokenizer.vocab_size,
    hidden=256,
    num_layers=3,
    dropout=0.1,
).to(device)

criterion = nn.CTCLoss(blank=tokenizer.blank_id, zero_infinity=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
num_epochs = 5
grad_clip = 5.0
log_interval = 10
checkpoint_path = "bilstm_ctc_checkpoint.pt"

for epoch in range(1, num_epochs + 1):
    model.train()
    running_loss = 0.0
    frames_seen = 0
    last_loss = None
    for step, (_, batch_logmels, input_lengths, _, batch_text) in enumerate(train_loader, start=1):
        batch_logmels = batch_logmels.to(device)
        input_lengths = input_lengths.to(device)

        targets_1d, target_lengths = make_ctc_targets(tokenizer, batch_text)
        if targets_1d.numel() == 0:
            continue  # skip samples that normalize to empty strings
        targets_1d = targets_1d.to(device)
        target_lengths = target_lengths.to(device)

        log_probs = model(batch_logmels)
        loss = criterion(log_probs, targets_1d, input_lengths, target_lengths)

        optimizer.zero_grad()
        loss.backward()
        if grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()

        running_loss += loss.item() * batch_logmels.size(0)
        frames_seen += input_lengths.sum().item()
        last_loss = loss.item()

        if step % log_interval == 0 or step == len(train_loader):
            avg_loss = running_loss / max(frames_seen, 1)
            print(
                f"Epoch {epoch} Step {step}/{len(train_loader)} - avg loss/frame: {avg_loss:.4f} - batch loss: {loss.item():.4f}"
            )

    if last_loss is not None:
        print(f"Epoch {epoch} finished. last batch loss={last_loss:.4f}")
    else:
        print(f"Epoch {epoch} finished but no valid transcripts were processed.")
    print()

torch.save(
    {
        "model_state_dict": model.state_dict(),
        "tokenizer_state": {
            "id2ch": tokenizer.id2ch,
            "blank_id": tokenizer.blank_id,
        },
        "config": {
            "n_mels": sample_n_mels,
            "hidden": 256,
            "num_layers": 3,
            "dropout": 0.1,
        },
    },
    checkpoint_path,
)
print(f"Saved model checkpoint to {checkpoint_path}")


Epoch 1 Step 10/7135 - avg loss/frame: 0.0053 - batch loss: 3.1721
Epoch 1 Step 20/7135 - avg loss/frame: 0.0039 - batch loss: 2.9192
Epoch 1 Step 30/7135 - avg loss/frame: 0.0033 - batch loss: 2.8880
Epoch 1 Step 40/7135 - avg loss/frame: 0.0031 - batch loss: 2.8721
Epoch 1 Step 50/7135 - avg loss/frame: 0.0029 - batch loss: 2.8637
Epoch 1 Step 60/7135 - avg loss/frame: 0.0029 - batch loss: 2.9068
Epoch 1 Step 70/7135 - avg loss/frame: 0.0027 - batch loss: 2.8399
Epoch 1 Step 80/7135 - avg loss/frame: 0.0027 - batch loss: 2.9635
Epoch 1 Step 90/7135 - avg loss/frame: 0.0026 - batch loss: 2.9838
Epoch 1 Step 100/7135 - avg loss/frame: 0.0026 - batch loss: 2.8402
Epoch 1 Step 110/7135 - avg loss/frame: 0.0026 - batch loss: 2.8362
Epoch 1 Step 120/7135 - avg loss/frame: 0.0025 - batch loss: 2.8593
Epoch 1 Step 130/7135 - avg loss/frame: 0.0025 - batch loss: 2.9102
Epoch 1 Step 140/7135 - avg loss/frame: 0.0025 - batch loss: 2.8886
Epoch 1 Step 150/7135 - avg loss/frame: 0.0025 - batch lo