# Stage 1 – CTC pre-training of the Conformer encoder

In [None]:
import math
import os
import time
import random
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn.functional as F
import torchaudio
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
from datasets import load_dataset, Audio

from conformer import Conformer

try:
    import sentencepiece as spm
except ImportError:
    spm = None
    print("sentencepiece is not installed; falling back to the built-in character tokenizer when needed.")

torch.backends.cudnn.benchmark = False
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(message)s")

In [None]:
def set_seed(seed: int) -> None:
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

config = {
    "experiment_name": "stage1_ctc_conformer",
    "seed": 1337,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "data": {
        "sample_rate": 16_000,
        "text_lowercase": True,
        "max_audio_seconds": 20.0,
        "pad_to_max_seconds": True,
        "train_manifest": None,
        "valid_manifest": None,
        "hf_dataset": {
            "name": "parler-tts/mls_eng",
            "config": None,
            "train_split": "train",
            "valid_split": "validation",
            "text_column": "text",
            "audio_column": "audio",
            "cache_dir": "data/hf-cache",
            "streaming": False,
        },
    },
    "feature_extractor": {
        "n_mels": 80,
        "n_fft": 1024,
        "win_length": 400,
        "hop_length": 160,
        "f_min": 0.0,
        "f_max": None,
        "mel_power": 2.0,
        "log_offset": 1e-6,
    },
    "tokenizer": {
        "use_sentencepiece": False,
        "sentencepiece_model": "tokenizer/stage1_sp.model",
        "blank_id": 0,
        "chars": list("abcdefghijklmnopqrstuvwxyz' "),
    },
    "augmentation": {
        "freq_mask_param": 15,
        "time_mask_param": 30,
        "num_freq_masks": 2,
        "num_time_masks": 2,
        "prob": 0.5,
    },
    "model": {
        "input_dim": 80,
        "encoder_dim": 512,
        "num_layers": 3,
        "num_attention_heads": 8,
        "feed_forward_expansion_factor": 4,
        "conv_expansion_factor": 2,
        "conv_kernel_size": 31,
        "dropout": 0.1,
        "subsampling_factor": 8,
        "min_subsample_len_multiplier": 2,
    },
    "dataloader": {
        "batch_size": 4,
        "num_workers": 4,
        "pin_memory": True,
        "prefetch_factor": 2,
        "persistent_workers": False,
        "shuffle": True,
    },
    "optim": {
        "peak_lr": 1e-3,
        "weight_decay": 1e-4,
        "eps": 1e-9,
        "betas": (0.9, 0.98),
        "grad_accum_steps": 4,
    },
    "scheduler": {
        "warmup_steps": 20_000,
        "total_steps": 250_000,
        "final_lr_scale": 0.01,
    },
    "trainer": {
        "num_epochs": 30,
        "log_interval": 50,
        "val_interval": 1,
        "grad_clip": 5.0,
        "use_amp": True,
        "checkpoint_dir": "checkpoints/stage1",
        "max_to_keep": 5,
        "resume_from": None,
    },
}

set_seed(config["seed"])
Path(config["trainer"]["checkpoint_dir"]).mkdir(parents=True, exist_ok=True)
print(f"Running Stage 1 on device: {config['device']}")

## Data & tokenizer

In [None]:
class TextTokenizer:
    def __init__(self, cfg: Dict):
        self.blank_id = int(cfg.get("blank_id", 0))
        self.lowercase = bool(cfg.get("lowercase", True))
        self.use_sentencepiece = bool(cfg.get("use_sentencepiece", False))
        sp_model = cfg.get("sentencepiece_model")
        if self.use_sentencepiece:
            if spm is None:
                raise ImportError("sentencepiece is required for the requested tokenizer but is not installed.")
            if not sp_model or not Path(sp_model).is_file():
                raise FileNotFoundError(f"SentencePiece model not found: {sp_model}")
            self.processor = spm.SentencePieceProcessor(model_file=str(sp_model))
            self.offset = 1 if self.blank_id == 0 else 0
            self.vocab_size = self.processor.get_piece_size() + self.offset
        else:
            base_chars = cfg.get("chars") or list("abcdefghijklmnopqrstuvwxyz' ")
            symbols: List[str] = []
            if self.blank_id == 0:
                symbols.append("<blank>")
            symbols.extend(base_chars)
            self.stoi = {ch: idx for idx, ch in enumerate(symbols)}
            self.itos = {idx: ch for ch, idx in self.stoi.items()}
            self.vocab_size = len(symbols)

    def encode(self, text: str) -> List[int]:
        if text is None:
            return []
        text = text.strip()
        if self.lowercase:
            text = text.lower()
        if hasattr(self, "processor"):
            ids = self.processor.encode(text, out_type=int)
            return [idx + self.offset for idx in ids]
        tokens: List[int] = []
        for ch in text:
            idx = self.stoi.get(ch)
            if idx is not None and idx != self.blank_id:
                tokens.append(idx)
        return tokens

    def decode(self, ids: List[int]) -> str:
        if hasattr(self, "processor"):
            shifted = [idx - self.offset for idx in ids if idx >= self.offset]
            return self.processor.decode(shifted)
        return "".join(self.itos.get(idx, "") for idx in ids if idx != self.blank_id)

In [None]:
class LogMelFeatureExtractor(nn.Module):
    def __init__(self, sample_rate: int, **kwargs):
        super().__init__()
        self.sample_rate = sample_rate
        self.n_mels = kwargs.get("n_mels", 80)
        self.melspec = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=kwargs.get("n_fft", 1024),
            win_length=kwargs.get("win_length", 400),
            hop_length=kwargs.get("hop_length", 160),
            f_min=kwargs.get("f_min", 0.0),
            f_max=kwargs.get("f_max", None),
            power=kwargs.get("mel_power", 2.0),
            n_mels=self.n_mels,
            norm=None,
            mel_scale="htk",
        )
        self.log_offset = kwargs.get("log_offset", 1e-6)

    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        if waveform.dim() == 1:
            waveform = waveform.unsqueeze(0)
        if waveform.size(0) > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        mel = self.melspec(waveform)
        mel = torch.log(torch.clamp(mel, min=self.log_offset))
        mel = mel - mel.mean(dim=-1, keepdim=True)
        mel = mel / (mel.std(dim=-1, keepdim=True) + 1e-5)
        mel = mel.transpose(1, 2).squeeze(0).contiguous()
        return mel


class SpecAugment:
    def __init__(self, freq_mask_param: int, time_mask_param: int, num_freq_masks: int, num_time_masks: int, prob: float = 0.0):
        self.freq_mask = torchaudio.transforms.FrequencyMasking(freq_mask_param) if freq_mask_param > 0 else None
        self.time_mask = torchaudio.transforms.TimeMasking(time_mask_param) if time_mask_param > 0 else None
        self.num_freq_masks = num_freq_masks
        self.num_time_masks = num_time_masks
        self.prob = prob

    def __call__(self, features: torch.Tensor) -> torch.Tensor:
        if self.prob <= 0.0 or random.random() > self.prob:
            return features
        spec = features.transpose(0, 1).unsqueeze(0)
        if self.freq_mask is not None:
            for _ in range(self.num_freq_masks):
                spec = self.freq_mask(spec)
        if self.time_mask is not None:
            for _ in range(self.num_time_masks):
                spec = self.time_mask(spec)
        return spec.squeeze(0).transpose(0, 1)


class ManifestSpeechDataset(Dataset):
    def __init__(self, manifest_path: str, tokenizer: TextTokenizer, feature_extractor: LogMelFeatureExtractor, sample_rate: int, apply_augment: bool = False, augment: Optional[SpecAugment] = None):
        self.manifest_path = Path(manifest_path)
        if not self.manifest_path.is_file():
            raise FileNotFoundError(f"Manifest not found: {manifest_path}")
        self.tokenizer = tokenizer
        self.feature_extractor = feature_extractor
        self.sample_rate = sample_rate
        self.apply_augment = apply_augment
        self.augment = augment
        self.entries = self._load_manifest()
        self._ensure_durations()

    def _load_manifest(self) -> List[Dict]:
        entries: List[Dict] = []
        with self.manifest_path.open() as handle:
            for raw_line in handle:
                line = raw_line.strip()
                if not line or line.startswith("#"):
                    continue
                parts = line.split("	")
                if len(parts) == 4:
                    utt_id, audio_path, duration, transcript = parts
                elif len(parts) == 3:
                    utt_id, audio_path, transcript = parts
                    duration = None
                elif len(parts) == 2:
                    audio_path, transcript = parts
                    utt_id = Path(audio_path).stem
                    duration = None
                else:
                    raise ValueError(f"Invalid manifest line: {line}")
                path = Path(audio_path)
                if not path.is_absolute():
                    path = (self.manifest_path.parent / path).resolve()
                entries.append(
                    {
                        "utt_id": utt_id,
                        "audio_path": str(path),
                        "transcript": transcript.strip(),
                        "duration": float(duration) if duration else None,
                    }
                )
        if not entries:
            raise RuntimeError(f"Manifest {self.manifest_path} is empty.")
        return entries

    def _ensure_durations(self) -> None:
        for entry in self.entries:
            if entry["duration"] is None:
                try:
                    info = torchaudio.info(entry["audio_path"])
                    entry["duration"] = info.num_frames / info.sample_rate
                except Exception:
                    entry["duration"] = 0.0

    def __len__(self) -> int:
        return len(self.entries)

    @property
    def total_hours(self) -> float:
        total_seconds = sum(e.get("duration", 0.0) or 0.0 for e in self.entries)
        return total_seconds / 3600.0

    def __getitem__(self, idx: int) -> Dict:
        entry = self.entries[idx]
        waveform, sr = torchaudio.load(entry["audio_path"])
        waveform = waveform.to(torch.float32)
        if sr != self.sample_rate:
            waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)
        features = self.feature_extractor(waveform)
        if self.apply_augment and self.augment is not None:
            features = self.augment(features)
        tokens = torch.tensor(self.tokenizer.encode(entry["transcript"]), dtype=torch.long)
        duration = entry.get("duration")
        if duration is None:
            duration = waveform.size(-1) / self.sample_rate
        return {
            "features": features,
            "feature_length": features.size(0),
            "tokens": tokens,
            "token_length": int(tokens.size(0)),
            "utt_id": entry["utt_id"],
            "seconds": duration,
        }


class HuggingFaceSpeechDataset(Dataset):
    def __init__(self, dataset_cfg: Dict, split: str, tokenizer: TextTokenizer, feature_extractor: LogMelFeatureExtractor, sample_rate: int, apply_augment: bool = False, augment: Optional[SpecAugment] = None, target_seconds: Optional[float] = None, pad_to_target: bool = False):
        if dataset_cfg.get("streaming", False):
            raise ValueError("Streaming datasets are not supported in this notebook; please disable streaming.")
        self.dataset_name = dataset_cfg["name"]
        self.dataset_cfg = dataset_cfg
        self.split = split
        self.tokenizer = tokenizer
        self.feature_extractor = feature_extractor
        self.sample_rate = sample_rate
        self.apply_augment = apply_augment
        self.augment = augment
        self.audio_column = dataset_cfg.get("audio_column", "audio")
        self.text_column = dataset_cfg.get("text_column", "text")
        self.target_seconds = target_seconds
        self.pad_to_target = pad_to_target
        self.target_num_frames = int(round(target_seconds * sample_rate)) if target_seconds else None
        load_kwargs = {
            "split": split,
            "cache_dir": dataset_cfg.get("cache_dir"),
        }
        config_name = dataset_cfg.get("config")
        if config_name:
            ds = load_dataset(self.dataset_name, config_name, **load_kwargs)
        else:
            ds = load_dataset(self.dataset_name, **load_kwargs)
        ds = ds.cast_column(self.audio_column, Audio(sampling_rate=self.sample_rate))
        if self.text_column not in ds.column_names:
            raise ValueError(f"Column '{self.text_column}' not found in dataset columns {ds.column_names}.")
        if self.audio_column not in ds.column_names:
            raise ValueError(f"Column '{self.audio_column}' not found in dataset columns {ds.column_names}.")
        self.dataset = ds

    def __len__(self) -> int:
        return self.dataset.num_rows

    @property
    def total_hours(self) -> float:
        if self.target_num_frames is not None:
            fixed_seconds = self.target_num_frames / self.sample_rate
            return len(self) * fixed_seconds / 3600.0
        return 0.0

    def _fix_duration(self, waveform: torch.Tensor) -> torch.Tensor:
        if self.target_num_frames is None:
            return waveform
        num_frames = waveform.size(-1)
        if num_frames > self.target_num_frames:
            waveform = waveform[..., : self.target_num_frames]
        elif num_frames < self.target_num_frames and self.pad_to_target:
            pad = self.target_num_frames - num_frames
            waveform = F.pad(waveform, (0, pad))
        return waveform

    def __getitem__(self, idx: int) -> Dict:
        example = self.dataset[idx]
        audio_dict = example[self.audio_column]
        waveform = torch.tensor(audio_dict["array"], dtype=torch.float32)
        if waveform.dim() == 1:
            waveform = waveform.unsqueeze(0)
        waveform = self._fix_duration(waveform)
        features = self.feature_extractor(waveform)
        if self.apply_augment and self.augment is not None:
            features = self.augment(features)
        tokens = torch.tensor(self.tokenizer.encode(example[self.text_column]), dtype=torch.long)
        duration = self.target_seconds if self.target_seconds else waveform.size(-1) / self.sample_rate
        return {
            "features": features,
            "feature_length": features.size(0),
            "tokens": tokens,
            "token_length": int(tokens.size(0)),
            "utt_id": str(example.get("id", idx)),
            "seconds": duration,
        }


class SpeechDataCollator:
    def __init__(self, pad_to_multiple_of: Optional[int] = None, subsampling_factor: int = 1, min_subsample_len_multiplier: int = 1):
        self.pad_to_multiple_of = pad_to_multiple_of
        self.subsampling_factor = max(1, subsampling_factor)
        self.min_subsample_frames = max(1, self.subsampling_factor * max(1, min_subsample_len_multiplier))

    def _is_usable(self, sample: Dict) -> bool:
        if sample["token_length"] == 0:
            return False
        if sample["feature_length"] < self.min_subsample_frames:
            return False
        approx_logits = max(1, (sample["feature_length"] // self.subsampling_factor) - 1)
        if approx_logits < sample["token_length"]:
            return False
        return True

    def __call__(self, batch: List[Dict]):
        filtered = [sample for sample in batch if self._is_usable(sample)]
        if not filtered:
            return None
        feat_dim = filtered[0]["features"].size(1)
        max_len = max(sample["feature_length"] for sample in filtered)
        if self.pad_to_multiple_of and max_len % self.pad_to_multiple_of != 0:
            max_len = ((max_len // self.pad_to_multiple_of) + 1) * self.pad_to_multiple_of
        features = torch.zeros(len(filtered), max_len, feat_dim, dtype=torch.float32)
        input_lengths = torch.zeros(len(filtered), dtype=torch.long)
        target_lengths = torch.zeros(len(filtered), dtype=torch.long)
        targets: List[torch.Tensor] = []
        utt_ids: List[str] = []
        for idx, sample in enumerate(filtered):
            length = sample["feature_length"]
            features[idx, :length] = sample["features"]
            input_lengths[idx] = length
            target_lengths[idx] = sample["token_length"]
            targets.append(sample["tokens"])
            utt_ids.append(sample["utt_id"])
        targets = torch.cat(targets)
        return features, input_lengths, targets, target_lengths, utt_ids

In [None]:
class WarmupExponentialDecayScheduler(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_steps: int, total_steps: int, final_lr_scale: float, last_epoch: int = -1):
        self.warmup_steps = max(1, warmup_steps)
        self.total_steps = max(total_steps, warmup_steps + 1)
        self.final_lr_scale = final_lr_scale
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        step = max(1, self.last_epoch + 1)
        if step <= self.warmup_steps:
            scale = step / self.warmup_steps
        else:
            progress = min(1.0, (step - self.warmup_steps) / (self.total_steps - self.warmup_steps))
            scale = math.exp(math.log(self.final_lr_scale) * progress)
        return [base_lr * scale for base_lr in self.base_lrs]


class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.total = 0.0
        self.count = 0

    def update(self, value: float, n: int = 1):
        self.total += value * n
        self.count += n

    @property
    def avg(self) -> float:
        return self.total / max(1, self.count)


def save_checkpoint(state: Dict, checkpoint_dir: str, max_to_keep: int) -> Path:
    checkpoint_dir = Path(checkpoint_dir)
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    ckpt_path = checkpoint_dir / f"epoch{state['epoch']:02d}_val{state['val_loss']:.4f}.pt"
    torch.save(state, ckpt_path)
    checkpoints = sorted(checkpoint_dir.glob("epoch*.pt"))
    if len(checkpoints) > max_to_keep:
        for stale in checkpoints[:-max_to_keep]:
            stale.unlink(missing_ok=True)
    return ckpt_path

In [None]:
def train_one_epoch(model: nn.Module, dataloader: DataLoader, criterion, optimizer, scheduler, scaler: GradScaler, device: torch.device, epoch: int, global_step: int, grad_accum_steps: int, grad_clip: float, log_interval: int, amp_enabled: bool) -> Tuple[int, float]:
    model.train()
    loss_meter = AverageMeter()
    optimizer.zero_grad(set_to_none=True)
    steps_in_accum = 0
    start_time = time.time()
    skipped_batches = 0

    def _optimizer_step():
        nonlocal global_step, steps_in_accum
        if steps_in_accum == 0:
            return
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)
        scheduler.step()
        global_step += 1
        steps_in_accum = 0

    for batch_idx, batch in enumerate(dataloader, start=1):
        if batch is None:
            continue
        features, input_lengths, targets, target_lengths, _ = batch
        features = features.to(device)
        input_lengths = input_lengths.to(device)
        targets = targets.to(device)
        target_lengths = target_lengths.to(device)
        with autocast(enabled=amp_enabled):
            logits, logit_lengths = model(features, input_lengths)
        if torch.any(target_lengths > logit_lengths):
            skipped_batches += 1
            continue
        with autocast(enabled=amp_enabled):
            loss = criterion(logits.transpose(0, 1), targets, logit_lengths, target_lengths)
            loss = loss / grad_accum_steps
        scaler.scale(loss).backward()
        steps_in_accum += 1
        loss_meter.update(loss.item() * grad_accum_steps, n=features.size(0))
        if steps_in_accum == grad_accum_steps:
            _optimizer_step()
            if global_step > 0 and global_step % log_interval == 0:
                elapsed = time.time() - start_time
                current_lr = optimizer.param_groups[0]["lr"]
                logging.info(
                    f"Epoch {epoch:02d} | step {global_step} | loss {loss_meter.avg:.4f} | lr {current_lr:.2e} | {elapsed:.1f}s"
                )
                start_time = time.time()

    if steps_in_accum > 0:
        _optimizer_step()
    if skipped_batches:
        logging.info(f"Epoch {epoch:02d} skipped {skipped_batches} batches due to insufficient subsampled frames.")
    return global_step, loss_meter.avg


def evaluate(model: nn.Module, dataloader: DataLoader, criterion, device: torch.device, amp_enabled: bool) -> float:
    model.eval()
    loss_meter = AverageMeter()
    with torch.no_grad():
        for batch in dataloader:
            if batch is None:
                continue
            features, input_lengths, targets, target_lengths, _ = batch
            features = features.to(device)
            input_lengths = input_lengths.to(device)
            targets = targets.to(device)
            target_lengths = target_lengths.to(device)
            with autocast(enabled=amp_enabled):
                logits, logit_lengths = model(features, input_lengths)
                if torch.any(target_lengths > logit_lengths):
                    continue
                loss = criterion(logits.transpose(0, 1), targets, logit_lengths, target_lengths)
            loss_meter.update(loss.item(), n=features.size(0))
    return loss_meter.avg

In [None]:
def build_dataloader(dataset: Dataset, collate_fn, loader_cfg: Dict, shuffle: bool) -> DataLoader:
    kwargs = {
        "batch_size": loader_cfg["batch_size"],
        "num_workers": loader_cfg["num_workers"],
        "pin_memory": loader_cfg["pin_memory"],
        "persistent_workers": loader_cfg["persistent_workers"] and loader_cfg["num_workers"] > 0,
        "collate_fn": collate_fn,
        "drop_last": False,
    }
    if loader_cfg["num_workers"] > 0 and loader_cfg.get("prefetch_factor"):
        kwargs["prefetch_factor"] = loader_cfg["prefetch_factor"]
    return DataLoader(dataset, shuffle=shuffle, **kwargs)


def build_dataset(cfg: Dict, tokenizer: TextTokenizer, feature_extractor: LogMelFeatureExtractor, split: str, apply_augment: bool, augment: Optional[SpecAugment]):
    data_cfg = cfg["data"]
    hf_cfg = data_cfg.get("hf_dataset")
    if hf_cfg:
        target_seconds = data_cfg.get("max_audio_seconds")
        pad_to_target = data_cfg.get("pad_to_max_seconds", False)
        return HuggingFaceSpeechDataset(
            hf_cfg,
            split=split,
            tokenizer=tokenizer,
            feature_extractor=feature_extractor,
            sample_rate=data_cfg["sample_rate"],
            apply_augment=apply_augment,
            augment=augment,
            target_seconds=target_seconds,
            pad_to_target=pad_to_target,
        )
    manifest_key = "train_manifest" if split == "train" else "valid_manifest"
    manifest_path = data_cfg.get(manifest_key)
    if not manifest_path:
        raise ValueError("No dataset source configured. Provide Hugging Face settings or manifest paths.")
    return ManifestSpeechDataset(
        manifest_path,
        tokenizer,
        feature_extractor,
        data_cfg["sample_rate"],
        apply_augment=apply_augment,
        augment=augment,
    )


def format_hours(hours: float) -> str:
    if hours and hours > 0:
        return f"~{hours:.2f}h"
    return "n/a"


def run_training(cfg: Dict) -> Dict:
    device = torch.device(cfg["device"])
    tokenizer_cfg = dict(cfg["tokenizer"])
    tokenizer_cfg.setdefault("lowercase", cfg["data"].get("text_lowercase", True))
    tokenizer = TextTokenizer(tokenizer_cfg)
    feature_kwargs = dict(cfg["feature_extractor"])
    train_extractor = LogMelFeatureExtractor(sample_rate=cfg["data"]["sample_rate"], **feature_kwargs)
    valid_extractor = LogMelFeatureExtractor(sample_rate=cfg["data"]["sample_rate"], **feature_kwargs)
    augment = SpecAugment(**cfg["augmentation"]) if cfg["augmentation"].get("prob", 0.0) > 0 else None
    train_dataset = build_dataset(cfg, tokenizer, train_extractor, split=cfg["data"].get("hf_dataset", {}).get("train_split", "train"), apply_augment=True, augment=augment)
    valid_split = cfg["data"].get("hf_dataset", {}).get("valid_split", "validation")
    valid_dataset = build_dataset(cfg, tokenizer, valid_extractor, split=valid_split, apply_augment=False, augment=None)
    subsampling_factor = max(1, cfg["model"].get("subsampling_factor", 1))
    min_subsample_len_multiplier = cfg["model"].get("min_subsample_len_multiplier", 1)
    collate_fn = SpeechDataCollator(
        pad_to_multiple_of=subsampling_factor,
        subsampling_factor=subsampling_factor,
        min_subsample_len_multiplier=min_subsample_len_multiplier,
    )
    train_loader = build_dataloader(train_dataset, collate_fn, cfg["dataloader"], shuffle=cfg["dataloader"].get("shuffle", True))
    valid_loader = build_dataloader(valid_dataset, collate_fn, cfg["dataloader"], shuffle=False)

    hours_train = format_hours(getattr(train_dataset, "total_hours", 0.0))
    hours_valid = format_hours(getattr(valid_dataset, "total_hours", 0.0))
    frame_ms = cfg["feature_extractor"].get("hop_length", 160) / cfg["data"]["sample_rate"] * 1000
    effective_stride = frame_ms * subsampling_factor
    logging.info(
        f"Train set: {len(train_dataset)} utterances ({hours_train}), "
        f"Valid set: {len(valid_dataset)} utterances ({hours_valid})"
    )
    logging.info(
        f"Subsampling factor {subsampling_factor} ⇒ encoder frame rate ≈ {effective_stride:.1f} ms"
    )

    num_classes = tokenizer.vocab_size
    model = Conformer(
        num_classes=num_classes,
        input_dim=cfg["model"]["input_dim"],
        encoder_dim=cfg["model"]["encoder_dim"],
        num_encoder_layers=cfg["model"]["num_layers"],
        num_attention_heads=cfg["model"]["num_attention_heads"],
        feed_forward_expansion_factor=cfg["model"]["feed_forward_expansion_factor"],
        conv_expansion_factor=cfg["model"]["conv_expansion_factor"],
        conv_kernel_size=cfg["model"]["conv_kernel_size"],
        input_dropout_p=cfg["model"]["dropout"],
        feed_forward_dropout_p=cfg["model"]["dropout"],
        attention_dropout_p=cfg["model"]["dropout"],
        conv_dropout_p=cfg["model"]["dropout"],
    ).to(device)
    logging.info(f"Conformer parameters: {model.count_parameters():,}")

    criterion = nn.CTCLoss(blank=tokenizer.blank_id, zero_infinity=True)
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=cfg["optim"]["peak_lr"],
        betas=cfg["optim"]["betas"],
        eps=cfg["optim"]["eps"],
        weight_decay=cfg["optim"]["weight_decay"],
    )
    scheduler = WarmupExponentialDecayScheduler(
        optimizer,
        warmup_steps=cfg["scheduler"]["warmup_steps"],
        total_steps=cfg["scheduler"]["total_steps"],
        final_lr_scale=cfg["scheduler"]["final_lr_scale"],
    )
    amp_enabled = bool(cfg["trainer"]["use_amp"] and torch.cuda.is_available())
    scaler = GradScaler(enabled=amp_enabled)

    start_epoch = 1
    global_step = 0
    best_val = float("inf")
    best_path: Optional[Path] = None
    resume_path = cfg["trainer"].get("resume_from")
    if resume_path:
        ckpt = torch.load(resume_path, map_location=device)
        model.load_state_dict(ckpt["model_state"])
        optimizer.load_state_dict(ckpt["optim_state"])
        scheduler.load_state_dict(ckpt["scheduler_state"])
        if "scaler_state" in ckpt and amp_enabled and ckpt["scaler_state"] is not None:
            scaler.load_state_dict(ckpt["scaler_state"])
        start_epoch = ckpt["epoch"] + 1
        global_step = ckpt.get("global_step", 0)
        best_val = ckpt.get("best_val", best_val)
        best_path = Path(resume_path)
        logging.info(f"Resumed from {resume_path} (epoch {ckpt['epoch']})")

    for epoch in range(start_epoch, cfg["trainer"]["num_epochs"] + 1):
        global_step, train_loss = train_one_epoch(
            model,
            train_loader,
            criterion,
            optimizer,
            scheduler,
            scaler,
            device,
            epoch,
            global_step,
            cfg["optim"]["grad_accum_steps"],
            cfg["trainer"]["grad_clip"],
            cfg["trainer"]["log_interval"],
            amp_enabled,
        )
        if epoch % cfg["trainer"]["val_interval"] == 0:
            val_loss = evaluate(model, valid_loader, criterion, device, amp_enabled)
            improved = val_loss < best_val
            if improved:
                best_val = val_loss
            ckpt_state = {
                "epoch": epoch,
                "global_step": global_step,
                "val_loss": float(val_loss),
                "best_val": float(best_val),
                "model_state": model.state_dict(),
                "optim_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict(),
                "scaler_state": scaler.state_dict() if amp_enabled else None,
                "config": cfg,
            }
            ckpt_path = save_checkpoint(ckpt_state, cfg["trainer"]["checkpoint_dir"], cfg["trainer"]["max_to_keep"])
            if improved:
                best_path = ckpt_path
            logging.info(
                f"Epoch {epoch:02d} | train loss {train_loss:.4f} | val loss {val_loss:.4f} | best {best_val:.4f}"
            )
        else:
            logging.info(f"Epoch {epoch:02d} | train loss {train_loss:.4f}")
    return {"best_val_loss": best_val, "best_checkpoint": str(best_path) if best_path else None, "global_step": global_step}

In [None]:
try:
    trainer_state = run_training(config)
    print(trainer_state)
except Exception as exc:
    print(f"Training loop aborted: {exc}")