In [None]:
# scripts/prepare_manifest.py
import os, argparse, json

def gather_librispeech(root, out_manifest):
    items = []
    for subdir, _, files in os.walk(root):
        for f in files:
            if f.endswith(".trans.txt"):
                trans_path = os.path.join(subdir, f)
                with open(trans_path, 'r') as fh:
                    for line in fh:
                        parts = line.strip().split(' ', 1)
                        if len(parts) < 2: continue
                        key, text = parts
                        wav = os.path.join(subdir, key + ".flac")
                        if os.path.exists(wav):
                            items.append({"audio": wav, "text": text.lower()})
    with open(out_manifest, 'w') as out:
        for it in items:
            out.write(json.dumps(it) + "\n")
    print(f"Manifest saved to {out_manifest}, {len(items)} entries.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--librispeech_root", required=True)
    parser.add_argument("--out", default="data/manifest.json")
    args = parser.parse_args()
    os.makedirs(os.path.dirname(args.out), exist_ok=True)
    gather_librispeech(args.librispeech_root, args.out)


In [None]:
# src/features.py
import torch
import torchaudio
import numpy as np

class FeatureExtractor:
    def __init__(self, sample_rate=16000, n_mels=80, n_fft=512, win_length=400, hop_length=160):
        self.mel_spec = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate, n_fft=n_fft, win_length=win_length,
            hop_length=hop_length, n_mels=n_mels, power=1.0
        )
        self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB(stype='power', top_db=80.0)

    def extract(self, waveform):
        # waveform: (1, n_samples) torch.Tensor float32
        mel = self.mel_spec(waveform)  # (1, n_mels, time)
        log_mel = self.amplitude_to_db(mel)
        # mean-variance per utterance (CMVN)
        mean = log_mel.mean(dim=-1, keepdim=True)
        std = log_mel.std(dim=-1, keepdim=True).clamp(min=1e-5)
        norm = (log_mel - mean) / std
        return norm  # shape: (1, n_mels, time)


In [None]:
# src/model.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBranch(nn.Module):
    def __init__(self, d_model, expansion=2, kernel_size=31, dropout=0.1):
        super().__init__()
        self.pointwise = nn.Conv1d(d_model, d_model*expansion, 1)
        self.depthwise = nn.Conv1d(d_model*expansion, d_model*expansion, kernel_size, padding=kernel_size//2, groups=d_model*expansion)
        self.project = nn.Conv1d(d_model*expansion, d_model, 1)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x: (batch, time, d_model) -> conv expects (batch, d_model, time)
        x = x.transpose(1,2)
        x = self.pointwise(x)
        x = self.depthwise(x)
        x = self.gelu(x)
        x = self.project(x)
        x = x.transpose(1,2)
        return self.dropout(x)

class MultiBranchLayer(nn.Module):
    def __init__(self, d_model, nhead, conv_expansion=2, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.conv_branch = ConvBranch(d_model, expansion=conv_expansion, dropout=dropout)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_model*4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model*4, d_model),
            nn.Dropout(dropout)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, x):
        # attention path
        attn_out,_ = self.self_attn(x, x, x)
        x = self.norm1(x + attn_out)
        # conv branch
        conv_out = self.conv_branch(x)
        x = self.norm2(x + conv_out)
        # ff
        ff_out = self.ff(x)
        x = self.norm3(x + ff_out)
        return x

class EBranchformerEncoder(nn.Module):
    def __init__(self, input_dim=80, d_model=256, num_layers=12, nhead=4, conv_expansion=2, dropout=0.1):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, d_model)
        self.layers = nn.ModuleList([MultiBranchLayer(d_model, nhead, conv_expansion, dropout) for _ in range(num_layers)])
        self.final_ln = nn.LayerNorm(d_model)

    def forward(self, x, lengths=None):
        # x: (batch, n_mels, time) -> transpose to (batch, time, n_mels)
        x = x.squeeze(1).transpose(1,2) if x.dim()==3 else x.transpose(1,2)
        x = self.input_proj(x)
        for layer in self.layers:
            x = layer(x)
        x = self.final_ln(x)
        return x  # (batch, time, d_model)


In [None]:
# src/train.py
import os, json, yaml, argparse, math
from pathlib import Path
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchaudio
from src.features import FeatureExtractor
from src.model import EBranchformerEncoder

class LibriDataset(Dataset):
    def __init__(self, manifest, tokenizer, feat_extractor):
        self.items = []
        with open(manifest) as fh:
            for line in fh:
                self.items.append(json.loads(line.strip()))
        self.tokenizer = tokenizer
        self.feat_extractor = feat_extractor

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        it = self.items[idx]
        wav, sr = torchaudio.load(it['audio'])
        if wav.shape[0] > 1:
            wav = wav.mean(dim=0, keepdim=True)
        feats = self.feat_extractor.extract(wav)  # (1, n_mels, time)
        text = it['text']
        target = torch.tensor(self.tokenizer.encode(text), dtype=torch.long)
        return feats, target

def collate_fn(batch):
    feats = [b[0] for b in batch]
    targets = [b[1] for b in batch]
    # pad feats on time dim
    max_t = max(f.shape[-1] for f in feats)
    feats_p = []
    input_lengths = []
    for f in feats:
        pad = max_t - f.shape[-1]
        if pad>0:
            f = torch.nn.functional.pad(f, (0,pad))
        feats_p.append(f)
        input_lengths.append(f.shape[-1])
    feats_p = torch.stack(feats_p)  # (B, 1, n_mels, T)
    targets_concat = torch.cat(targets)
    target_lengths = torch.tensor([t.shape[0] for t in targets], dtype=torch.long)
    return feats_p, targets_concat, torch.tensor(input_lengths, dtype=torch.long), target_lengths

# Simple tokenizer (char-level)
class CharTokenizer:
    def __init__(self):
        chars = ["'", " "] + [chr(i) for i in range(97, 123)]  # a-z plus apostrophe and space
        self.vocab = ['<blank>'] + chars
        self.ctoi = {c:i for i,c in enumerate(self.vocab)}
    def encode(self, text):
        text = text.lower()
        return [self.ctoi.get(ch, self.ctoi[" "]) for ch in text]
    def decode(self, ids):
        return ''.join([self.vocab[i] for i in ids if i!=0])

def train(manifest, config_path, out_dir):
    config = yaml.safe_load(open(config_path))
    os.makedirs(out_dir, exist_ok=True)
    device = torch.device(config['train'].get('device','cuda') if torch.cuda.is_available() else 'cpu')

    feat = FeatureExtractor(**config['dataset'])
    tokenizer = CharTokenizer()
    ds = LibriDataset(manifest, tokenizer, feat)
    dl = DataLoader(ds, batch_size=config['train']['batch_size'], shuffle=True, collate_fn=collate_fn, num_workers=4)

    model_enc = EBranchformerEncoder(input_dim=config['dataset']['n_mels'],
                                     d_model=config['model']['d_model'],
                                     num_layers=config['model']['num_layers'],
                                     nhead=config['model']['nhead'],
                                     conv_expansion=config['model']['conv_expansion'],
                                     dropout=config['model']['dropout']).to(device)
    ctc_head = nn.Linear(config['model']['d_model'], len(tokenizer.vocab)).to(device)

    optimizer = torch.optim.Adam(list(model_enc.parameters()) + list(ctc_head.parameters()), lr=config['train']['lr'])
    ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)

    for epoch in range(config['train']['epochs']):
        model_enc.train()
        epoch_loss = 0.0
        pbar = tqdm(dl, desc=f"Epoch {epoch+1}/{config['train']['epochs']}")
        for feats_p, targets_concat, input_lengths, target_lengths in pbar:
            feats_p = feats_p.to(device)  # (B,1,n_mels,T)
            B, _, _, T = feats_p.size()
            feats_p = feats_p.squeeze(1)  # (B, n_mels, T)
            enc_out = model_enc(feats_p.to(device))  # (B, T, d_model)
            logits = ctc_head(enc_out)  # (B, T, C)
            log_probs = logits.log_softmax(-1).permute(1,0,2)  # (T, B, C)
            input_lengths = torch.tensor(input_lengths, dtype=torch.long)
            # convert input_lengths from number of frames already
            loss = ctc_loss(log_probs, targets_concat.to(device), input_lengths, target_lengths.to(device))
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(list(model_enc.parameters())+list(ctc_head.parameters()), 5.0)
            optimizer.step()
            epoch_loss += loss.item()
            pbar.set_postfix(loss=loss.item())
        avg_loss = epoch_loss / len(dl)
        print(f"Epoch {epoch+1} finished. Avg Loss: {avg_loss:.4f}")
        torch.save({
            'epoch': epoch+1,
            'model_state': model_enc.state_dict(),
            'ctc_head': ctc_head.state_dict(),
            'optimizer': optimizer.state_dict()
        }, os.path.join(out_dir, f'checkpoint_epoch{epoch+1}.pt'))

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--manifest", required=True)
    parser.add_argument("--config", default="configs/config.yaml")
    parser.add_argument("--out", default="checkpoints")
    args = parser.parse_args()
    train(args.manifest, args.config, args.out)
