# Single Config Experiment Runner

Use this notebook to run exactly one experiment definition declared directly in the cell below.
It builds the dataloaders/model based on that config and prints the train/validation loss after each epoch.


In [1]:
import json
from dataclasses import replace
from pathlib import Path

import torch
from torch.utils.data import DataLoader
import copy

from config import Config
from src.dataset import JavaneseASRDataset, collate_fn
from src.features import LogMelFeatureExtractor
from src.vocab import Vocabulary
from src.decoder import GreedyDecoder
from src.utils import set_seed, read_transcript, count_parameters
from src.data_split import create_speaker_disjoint_split, load_split_info
from src.model import Seq2SeqASR
from scripts.train import train_one_epoch, validate_with_metrics


In [None]:
# Declare the single experiment directly here
EXPERIMENT_7 = {
    "name": "Inline: Char + CTC Joint",
    "description": "Character vocab with joint CTC-attention for alignment help.",
    "config": {
        "token_type": "char",
        "encoder_type": "pyramidal",
        "decoder_type": "lstm",
        "learning_rate": 1e-3,
        "num_epochs": 25,
        "use_ctc": True,
        "ctc_weight": 0.7,
        "encoder_hidden_size": 320,
        "decoder_dim": 640,
        "batch_size": 4
    },
}

# Override options for quick tweaks
MAX_EPOCHS = None       # set an int to cap runtime (e.g., 2 for a smoke test)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

base_config = Config()
current_config = replace(base_config, **EXPERIMENT_7["config"])
if MAX_EPOCHS:
    current_config = replace(current_config, num_epochs=int(MAX_EPOCHS))
current_config = replace(current_config, device=DEVICE)

print(f"Selected experiment: {EXPERIMENT_7['name']}")
print(f"Device: {current_config.device}")
print(f"Epochs: {current_config.num_epochs}, batch_size: {current_config.batch_size}, lr: {current_config.learning_rate}")


Selected experiment: Inline: Char + CTC Joint
Device: cuda
Epochs: 25, batch_size: 4, lr: 0.001


In [3]:
def build_dataloaders(cfg: Config):
    transcripts = read_transcript(cfg.transcript_file)
    vocab = Vocabulary(token_type=cfg.token_type)
    vocab.build_from_transcripts(transcripts, min_freq=1)

    feature_extractor = LogMelFeatureExtractor(
        sample_rate=cfg.sample_rate,
        n_mels=cfg.n_mels
    )

    split_info_path = Path(cfg.split_info_path)
    if split_info_path.exists():
        split_info = load_split_info(str(split_info_path))
        split_dict = split_info["split"]
    else:
        split_dict = create_speaker_disjoint_split(
            transcript_file=cfg.transcript_file,
            seed=cfg.seed,
            save_split_info=True,
            split_info_path=str(split_info_path)
        )

    train_dataset = JavaneseASRDataset(
        audio_dir=cfg.audio_dir,
        transcript_file=cfg.transcript_file,
        vocab=vocab,
        feature_extractor=feature_extractor,
        apply_spec_augment=cfg.apply_spec_augment,
        utt_id_filter=split_dict["train"]
    )

    val_dataset = JavaneseASRDataset(
        audio_dir=cfg.audio_dir,
        transcript_file=cfg.transcript_file,
        vocab=vocab,
        feature_extractor=feature_extractor,
        apply_spec_augment=False,
        utt_id_filter=split_dict["val"]
    )

    use_pin_memory = torch.cuda.is_available()
    train_loader = DataLoader(
        train_dataset,
        batch_size=cfg.batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=4,
        pin_memory=use_pin_memory,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=cfg.batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=4,
        pin_memory=use_pin_memory,
    )
    return vocab, train_loader, val_loader


In [4]:
set_seed(current_config.seed)

vocab, train_loader, val_loader = build_dataloaders(current_config)
print(f"Vocabulary size: {len(vocab)}")
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

model = Seq2SeqASR(
    vocab_size=len(vocab),
    input_dim=current_config.input_dim,
    encoder_hidden_size=current_config.encoder_hidden_size,
    encoder_num_layers=current_config.encoder_num_layers,
    decoder_dim=current_config.decoder_dim,
    attention_dim=current_config.attention_dim,
    embedding_dim=current_config.embedding_dim,
    dropout=current_config.dropout,
    use_ctc=current_config.use_ctc,
    ctc_weight=current_config.ctc_weight,
    encoder_type=current_config.encoder_type,
    decoder_type=current_config.decoder_type,
).to(current_config.device)
print(f"Model parameters: {count_parameters(model):,}")

optimizer = torch.optim.Adam(model.parameters(), lr=current_config.learning_rate)
decoder = GreedyDecoder(model, vocab, max_len=current_config.max_decode_len, device=current_config.device)

train_losses, val_losses, val_cers, val_wers = [], [], [], []

patience = 5  # early stopping rounds
bad_epochs = 0
best_val_loss = float("inf")
best_state = None
best_epoch = 0

for epoch in range(1, current_config.num_epochs + 1):
    train_loss = train_one_epoch(
        model,
        train_loader,
        optimizer,
        vocab,
        current_config.device,
        epoch,
        current_config.grad_clip_norm,
        encoder_type=current_config.encoder_type,
    )
    val_loss, val_cer, val_wer, _, _ = validate_with_metrics(
        model,
        val_loader,
        decoder,
        vocab,
        current_config.device,
        encoder_type=current_config.encoder_type,
    )

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_cers.append(val_cer)
    val_wers.append(val_wer)

    print(
        f"Epoch {epoch}/{current_config.num_epochs} - "
        f"train_loss: {train_loss:.4f} | val_loss: {val_loss:.4f} | "
        f"val_cer: {val_cer:.4f} | val_wer: {val_wer:.4f}"
    )

    if val_loss < best_val_loss - 1e-4:
        best_val_loss = val_loss
        best_state = copy.deepcopy(model.state_dict())
        best_epoch = epoch
        bad_epochs = 0
    else:
        bad_epochs += 1
        if bad_epochs >= patience:
            print(f"Early stopping at epoch {epoch} (no val_loss improvement for {patience} epochs)")
            break

if best_state is not None:
    model.load_state_dict(best_state)
    print(f"Restored best model from epoch {best_epoch} (val_loss={best_val_loss:.4f})")
else:
    print("No improvement tracked; using last epoch model")

Random seed set to 42
Built char-level vocabulary with 34 tokens
Audio file not found for utterance speaker46_f_nn_utt20
Audio file not found for utterance speaker46_f_nn_utt21
Audio file not found for utterance speaker46_f_nn_utt22
Audio file not found for utterance speaker46_f_nn_utt23
Audio file not found for utterance speaker46_f_nn_utt24
Audio file not found for utterance speaker46_f_nn_utt25
Audio file not found for utterance speaker46_f_nn_utt26
Audio file not found for utterance speaker46_f_nn_utt27
Audio file not found for utterance speaker46_f_nn_utt28
Audio file not found for utterance speaker46_f_nn_utt29
Filtered dataset: 2090 -> 1470 utterances
Validating audio files...
Loaded 1470 valid utterances from data/transcripts.csv
Audio file not found for utterance speaker46_f_nn_utt20
Audio file not found for utterance speaker46_f_nn_utt21
Audio file not found for utterance speaker46_f_nn_utt22
Audio file not found for utterance speaker46_f_nn_utt23
Audio file not found for utt

Epoch 1 [Train]: 100%|██████████| 368/368 [01:26<00:00,  4.25it/s, loss=1.8094]


Epoch 1/25 - train_loss: 2.4681 | val_loss: 2.7336 | val_cer: 4.9333 | val_wer: 4.2683


Epoch 2 [Train]: 100%|██████████| 368/368 [01:25<00:00,  4.29it/s, loss=1.7274]


Epoch 2/25 - train_loss: 1.8659 | val_loss: 2.3956 | val_cer: 3.5341 | val_wer: 3.2949


Epoch 3 [Train]: 100%|██████████| 368/368 [01:25<00:00,  4.30it/s, loss=1.2165]


Epoch 3/25 - train_loss: 1.6220 | val_loss: 2.2574 | val_cer: 4.1821 | val_wer: 4.1608


Epoch 4 [Train]: 100%|██████████| 368/368 [01:23<00:00,  4.41it/s, loss=1.2465]


Epoch 4/25 - train_loss: 1.4651 | val_loss: 2.0777 | val_cer: 1.2663 | val_wer: 1.7711


Epoch 5 [Train]: 100%|██████████| 368/368 [01:26<00:00,  4.27it/s, loss=1.4364]


Epoch 5/25 - train_loss: 1.3364 | val_loss: 1.9703 | val_cer: 0.9361 | val_wer: 1.3638


Epoch 6 [Train]: 100%|██████████| 368/368 [01:25<00:00,  4.33it/s, loss=1.4745]


Epoch 6/25 - train_loss: 1.2357 | val_loss: 1.9170 | val_cer: 0.8874 | val_wer: 1.3385


Epoch 7 [Train]: 100%|██████████| 368/368 [01:26<00:00,  4.25it/s, loss=1.0741]


Epoch 7/25 - train_loss: 1.1411 | val_loss: 1.9397 | val_cer: 0.7509 | val_wer: 1.2647


Epoch 8 [Train]: 100%|██████████| 368/368 [01:26<00:00,  4.25it/s, loss=0.9412]


Epoch 8/25 - train_loss: 1.0952 | val_loss: 1.9386 | val_cer: 0.6466 | val_wer: 0.9930


Epoch 9 [Train]: 100%|██████████| 368/368 [01:25<00:00,  4.30it/s, loss=0.4880]


Epoch 9/25 - train_loss: 1.0199 | val_loss: 1.8038 | val_cer: 0.5257 | val_wer: 0.9298


Epoch 10 [Train]: 100%|██████████| 368/368 [01:26<00:00,  4.24it/s, loss=0.9531]


Epoch 10/25 - train_loss: 0.9830 | val_loss: 1.8081 | val_cer: 0.4550 | val_wer: 0.8708


Epoch 11 [Train]: 100%|██████████| 368/368 [01:26<00:00,  4.25it/s, loss=0.9540]


Epoch 11/25 - train_loss: 0.9209 | val_loss: 1.7912 | val_cer: 0.6759 | val_wer: 1.1313


Epoch 12 [Train]: 100%|██████████| 368/368 [01:26<00:00,  4.27it/s, loss=1.1613]


Epoch 12/25 - train_loss: 0.8957 | val_loss: 1.8060 | val_cer: 0.4909 | val_wer: 0.8904


Epoch 13 [Train]: 100%|██████████| 368/368 [01:27<00:00,  4.22it/s, loss=0.7237]


Epoch 13/25 - train_loss: 0.8497 | val_loss: 1.8332 | val_cer: 0.5959 | val_wer: 1.0534


Epoch 14 [Train]: 100%|██████████| 368/368 [01:26<00:00,  4.25it/s, loss=1.2706]


Epoch 14/25 - train_loss: 0.8244 | val_loss: 1.7570 | val_cer: 0.4111 | val_wer: 0.7781


Epoch 15 [Train]: 100%|██████████| 368/368 [01:27<00:00,  4.23it/s, loss=1.5721]


Epoch 15/25 - train_loss: 0.7999 | val_loss: 1.7672 | val_cer: 0.3914 | val_wer: 0.7430


Epoch 16 [Train]: 100%|██████████| 368/368 [01:25<00:00,  4.31it/s, loss=0.6623]


Epoch 16/25 - train_loss: 0.7886 | val_loss: 1.7442 | val_cer: 0.4410 | val_wer: 0.8553


Epoch 17 [Train]: 100%|██████████| 368/368 [01:27<00:00,  4.21it/s, loss=0.6016]


Epoch 17/25 - train_loss: 0.7448 | val_loss: 1.7339 | val_cer: 0.3663 | val_wer: 0.7669


Epoch 18 [Train]: 100%|██████████| 368/368 [01:26<00:00,  4.27it/s, loss=0.9033]


Epoch 18/25 - train_loss: 0.7295 | val_loss: 1.7637 | val_cer: 0.4633 | val_wer: 0.8483


Epoch 19 [Train]: 100%|██████████| 368/368 [01:25<00:00,  4.32it/s, loss=0.6882]


Epoch 19/25 - train_loss: 0.7332 | val_loss: 1.7460 | val_cer: 0.4239 | val_wer: 0.8167


Epoch 20 [Train]: 100%|██████████| 368/368 [01:27<00:00,  4.23it/s, loss=0.3103]


Epoch 20/25 - train_loss: 0.7051 | val_loss: 1.7338 | val_cer: 0.4203 | val_wer: 0.7690


Epoch 21 [Train]: 100%|██████████| 368/368 [01:25<00:00,  4.29it/s, loss=0.2579]


Epoch 21/25 - train_loss: 0.6855 | val_loss: 1.7046 | val_cer: 0.3712 | val_wer: 0.7577


Epoch 22 [Train]: 100%|██████████| 368/368 [01:27<00:00,  4.22it/s, loss=0.2504]


Epoch 22/25 - train_loss: 0.6825 | val_loss: 1.7055 | val_cer: 0.3070 | val_wer: 0.6721


Epoch 23 [Train]: 100%|██████████| 368/368 [01:26<00:00,  4.24it/s, loss=0.6784]


Epoch 23/25 - train_loss: 0.6658 | val_loss: 1.7032 | val_cer: 0.4245 | val_wer: 0.8497


Epoch 24 [Train]: 100%|██████████| 368/368 [01:26<00:00,  4.25it/s, loss=1.4464]


Epoch 24/25 - train_loss: 0.6735 | val_loss: 1.6569 | val_cer: 0.2725 | val_wer: 0.6320


Epoch 25 [Train]: 100%|██████████| 368/368 [01:26<00:00,  4.25it/s, loss=0.8303]


Epoch 25/25 - train_loss: 0.6647 | val_loss: 1.7126 | val_cer: 0.2957 | val_wer: 0.6735
Restored best model from epoch 24 (val_loss=1.6569)


In [5]:
print("Finished.")
if train_losses:
    print(f"Final train_loss: {train_losses[-1]:.4f}")
if val_losses:
    print(f"Final val_loss: {val_losses[-1]:.4f}")
if val_wers:
    best_wer = min(val_wers)
    print(f"Best val WER: {best_wer:.4f}")


Finished.
Final train_loss: 0.6647
Final val_loss: 1.7126
Best val WER: 0.6320


In [6]:
import random
import jiwer
from src.metrics import compute_batch_cer

def build_test_loader(cfg: Config, vocab: Vocabulary):
    split_info = load_split_info(cfg.split_info_path)
    test_ids = split_info.get("split", {}).get("test", [])
    if not test_ids:
        raise ValueError("No test IDs found in split info; regenerate splits first.")

    feature_extractor = LogMelFeatureExtractor(
        sample_rate=cfg.sample_rate,
        n_mels=cfg.n_mels
    )

    test_dataset = JavaneseASRDataset(
        audio_dir=cfg.audio_dir,
        transcript_file=cfg.transcript_file,
        vocab=vocab,
        feature_extractor=feature_extractor,
        apply_spec_augment=False,
        utt_id_filter=test_ids,
    )

    use_pin_memory = torch.cuda.is_available()
    test_loader = DataLoader(
        test_dataset,
        batch_size=cfg.batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=4,
        pin_memory=use_pin_memory,
    )
    return test_dataset, test_loader

test_dataset, test_loader = build_test_loader(current_config, vocab)
decoder_eval = GreedyDecoder(model, vocab, max_len=current_config.max_decode_len, device=current_config.device)

print(f"Test utterances: {len(test_dataset)}; batches: {len(test_loader)}")

model.eval()
all_refs, all_hyps = [], []
total_loss = 0.0
total_cer = 0.0
total_samples = 0
num_batches = 0

with torch.no_grad():
    for batch in test_loader:
        features = batch["features"].to(current_config.device)
        feature_lengths = batch["feature_lengths"].to(current_config.device)
        targets = batch["targets"].to(current_config.device)
        target_lengths = batch["target_lengths"].to(current_config.device)
        transcripts = batch["transcripts"]

        attention_logits, ctc_logits = model(features, feature_lengths, targets, teacher_forcing_ratio=0.0)
        encoder_lengths = feature_lengths // 4 if current_config.encoder_type == "pyramidal" else feature_lengths
        loss = model.compute_loss(
            attention_logits=attention_logits,
            targets=targets,
            target_lengths=target_lengths,
            ctc_logits=ctc_logits,
            encoder_lengths=encoder_lengths,
            pad_idx=vocab.pad_idx,
            blank_idx=vocab.blank_idx,
        )

        total_loss += loss.item()
        num_batches += 1

        hyps = decoder_eval.decode(features, feature_lengths)
        cer = compute_batch_cer(transcripts, hyps)
        total_cer += cer * len(transcripts)
        total_samples += len(transcripts)
        all_refs.extend(transcripts)
        all_hyps.extend(hyps)

avg_loss = total_loss / num_batches if num_batches else 0.0
avg_cer = total_cer / total_samples if total_samples else 0.0
avg_wer = jiwer.wer(all_refs, all_hyps) if all_refs else 0.0

print(f"Test avg_loss: {avg_loss:.4f} | avg_cer: {avg_cer:.4f} | avg_wer: {avg_wer:.4f}")

# Randomly sample 5 test utterances for inspection
n_show = min(5, len(test_dataset))
sample_indices = random.sample(range(len(test_dataset)), n_show) if n_show else []
print("Random sample of test predictions:")
for idx in sample_indices:
    feats, tgt, transcript, utt_id = test_dataset[idx]
    feat_len = torch.tensor([feats.size(0)], dtype=torch.long)
    with torch.no_grad():
        hyp = decoder_eval.decode(
            feats.unsqueeze(0).to(current_config.device),
            feat_len.to(current_config.device)
        )[0]
    print(f"[{utt_id}]REF: {transcript} HYP: {hyp}")

Audio file not found for utterance speaker46_f_nn_utt20
Audio file not found for utterance speaker46_f_nn_utt21
Audio file not found for utterance speaker46_f_nn_utt22
Audio file not found for utterance speaker46_f_nn_utt23
Audio file not found for utterance speaker46_f_nn_utt24
Audio file not found for utterance speaker46_f_nn_utt25
Audio file not found for utterance speaker46_f_nn_utt26
Audio file not found for utterance speaker46_f_nn_utt27
Audio file not found for utterance speaker46_f_nn_utt28
Audio file not found for utterance speaker46_f_nn_utt29
Filtered dataset: 2090 -> 410 utterances
Validating audio files...
Loaded 410 valid utterances from data/transcripts.csv
Test utterances: 410; batches: 103
Test avg_loss: 1.9075 | avg_cer: 0.3091 | avg_wer: 0.6773
Random sample of test predictions:
[speaker55_m_nn_utt08]REF: aku pengin turu rumiyin HYP: aku pengin turu pumiyin
[speaker07_m_n_utt28]REF: preinan sesuk aku pengin dolan ning grojogan sewu HYP: rakine aku pengin tolan neng k