# Fine-tune Whisper Small on TORGO with LoRA

This notebook fine-tunes **Whisper Small** on the preprocessed TORGO dataset using **LoRA** (Low-Rank Adaptation).

It loads the processed `.wav` files from `audio/torgo/processed/` along with the companion `metadata.json` for transcriptions and speaker status.

**Prerequisites:**
1. Run `audio/data_loader.ipynb` to download the TORGO dataset
2. Run `audio/torgo_preprocessing.ipynb` to generate processed audio + metadata
3. Run this notebook from the **project root**

**Hardware note:** Configured for Apple Silicon (M4 MacBook Pro, MPS backend). batch_size=1 with gradient accumulation keeps peak memory under ~10 GB.

In [1]:
import os
import yaml
from pathlib import Path

# Run from project root (or parent if config not in cwd)
PROJECT_ROOT = Path(os.getcwd()).resolve()
while PROJECT_ROOT != PROJECT_ROOT.parent and not (PROJECT_ROOT / "asr" / "config.yaml").exists():
    PROJECT_ROOT = PROJECT_ROOT.parent
os.chdir(PROJECT_ROOT)

with open("asr/config.yaml") as f:
    config = yaml.safe_load(f)

model_cfg = config["model"]
lora_cfg = config["lora"]
data_cfg = config["data"]
output_cfg = config["output"]
training_cfg = config["training"]

print("Model:", model_cfg["name"])
print("Processed data:", data_cfg["processed_dir"])
print("Output dir:", output_cfg["model_dir"])

Model: openai/whisper-small
Processed data: audio/torgo/processed
Output dir: asr/checkpoints/whisper-lora


In [2]:
import json
import re
from datasets import Dataset, DatasetDict, Audio

processed_dir = PROJECT_ROOT / data_cfg["processed_dir"]
metadata_path = PROJECT_ROOT / data_cfg["metadata_path"]
assert metadata_path.exists(), f"Metadata not found at {metadata_path}. Run audio/torgo_preprocessing.ipynb first."

with open(metadata_path) as f:
    metadata = json.load(f)

include_augmented = data_cfg.get("include_augmented", False)
dysarthric_only = data_cfg.get("dysarthric_only", False)

AUGMENT_PATTERN = re.compile(r"sample_\d{5}_.+\.wav")

def load_split(split_name):
    split_meta = metadata[split_name]
    audio_paths, transcriptions = [], []
    for filename, meta in split_meta.items():
        if dysarthric_only and meta["speech_status"] != "dysarthria":
            continue
        if not include_augmented and AUGMENT_PATTERN.match(filename):
            continue
        wav_path = processed_dir / split_name / filename
        if wav_path.exists() and meta["transcription"]:
            audio_paths.append(str(wav_path))
            transcriptions.append(meta["transcription"])
    ds = Dataset.from_dict({"audio": audio_paths, "transcription": transcriptions})
    ds = ds.cast_column("audio", Audio(sampling_rate=data_cfg["sampling_rate"]))
    return ds

dataset = DatasetDict({
    "train": load_split("train"),
    "validation": load_split("validation"),
    "test": load_split("test"),
})

print(f"Augmented data: {'included' if include_augmented else 'excluded'}")
print(f"Speech filter:  {'dysarthric only' if dysarthric_only else 'all speakers'}")
print(dataset)
print(f"\nTrain: {len(dataset['train'])}  Val: {len(dataset['validation'])}  Test: {len(dataset['test'])}")

  from .autonotebook import tqdm as notebook_tqdm


Augmented data: included
Speech filter:  all speakers
DatasetDict({
    train: Dataset({
        features: ['audio', 'transcription'],
        num_rows: 92617
    })
    validation: Dataset({
        features: ['audio', 'transcription'],
        num_rows: 1653
    })
    test: Dataset({
        features: ['audio', 'transcription'],
        num_rows: 1655
    })
})

Train: 92617  Val: 1653  Test: 1655


In [3]:
from transformers import AutoProcessor

model_name = model_cfg["name"]
language = model_cfg["language"]
task = model_cfg["task"]
sampling_rate = data_cfg["sampling_rate"]
max_audio_length = data_cfg["max_audio_length"]

processor = AutoProcessor.from_pretrained(model_name, language=language, task=task)
feature_extractor = processor.feature_extractor
tokenizer = processor.tokenizer
print("Processor ready.")



Processor ready.


In [None]:
def prepare_dataset(example):
    audio = example["audio"]
    example["input_features"] = feature_extractor(
        audio["array"], sampling_rate=audio["sampling_rate"]
    ).input_features[0]
    example["labels"] = tokenizer(example["transcription"]).input_ids
    return example

dataset = dataset.map(
    prepare_dataset,
    remove_columns=dataset["train"].column_names,
    num_proc=2,
    desc="Extracting features",
)
print(dataset)

Extracting features (num_proc=2):  78%|███████▊  | 71962/92617 [03:08<01:39, 208.34 examples/s]

In [18]:
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]
        batch["labels"] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [19]:
import torch
from transformers import WhisperForConditionalGeneration
from peft import LoraConfig, get_peft_model

model = WhisperForConditionalGeneration.from_pretrained(model_name)
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

lora_config = LoraConfig(
    r=lora_cfg["r"],
    lora_alpha=lora_cfg["alpha"],
    target_modules=lora_cfg["target_modules"],
    lora_dropout=lora_cfg["dropout"],
    bias="none",
)
model = get_peft_model(model, lora_config)
model.gradient_checkpointing_enable()
model.print_trainable_parameters()

device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

trainable params: 1,769,472 || all params: 243,504,384 || trainable%: 0.7266694631666262


In [20]:
import numpy as np
import evaluate
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, TrainerCallback, TrainerState, TrainerControl, TrainingArguments

metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    wer = metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}


In [None]:
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers import EarlyStoppingCallback

class SavePeftModelCallback(TrainerCallback):
    def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
        peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
        kwargs["model"].save_pretrained(peft_model_path)
        pytorch_bin = os.path.join(checkpoint_folder, "pytorch_model.bin")
        if os.path.exists(pytorch_bin):
            os.remove(pytorch_bin)
        return control

batch_size = training_cfg["batch_size"]
grad_accum = training_cfg["gradient_accumulation_steps"]
effective_batch = batch_size * grad_accum
steps_per_epoch = len(dataset["train"]) // effective_batch

print(f"Batch size: {batch_size}  |  Grad accum: {grad_accum}  |  Effective batch: {effective_batch}")
print(f"Steps per epoch: ~{steps_per_epoch}  |  Total steps: ~{steps_per_epoch * training_cfg['epochs']}")

training_args = Seq2SeqTrainingArguments(
    output_dir=output_cfg["model_dir"],
    logging_dir=output_cfg["log_dir"],
    num_train_epochs=training_cfg["epochs"],
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=training_cfg["learning_rate"],
    warmup_steps=training_cfg["warmup_steps"],
    fp16=training_cfg["fp16"],
    eval_strategy="steps",
    eval_steps=training_cfg["eval_steps"],
    save_steps=training_cfg["save_steps"],
    gradient_accumulation_steps=grad_accum,
    load_best_model_at_end=True,
    metric_for_best_model=training_cfg["metric_for_best_model"],
    greater_is_better=training_cfg["greater_is_better"],
    generation_max_length=225,
    logging_steps=training_cfg["logging_steps"],
    remove_unused_columns=False,
    label_names=["labels"],
    save_total_limit=3,
    dataloader_pin_memory=False,
)

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    data_collator=data_collator,
    processing_class=processor.feature_extractor,
    compute_metrics=compute_metrics,
    callbacks=[
        SavePeftModelCallback(),
        EarlyStoppingCallback(early_stopping_patience=training_cfg["early_stopping_patience"]),
    ],
)
model.config.use_cache = False
print("Trainer ready.")

MPS detected: using batch_size=2, gradient_accumulation_steps=8 to avoid OOM.


RecursionError: maximum recursion depth exceeded

: 

In [10]:
trainer.train()



RuntimeError: MPS backend out of memory (MPS allocated: 18.52 GiB, other allocations: 870.72 MiB, max allowed: 20.13 GiB). Tried to allocate 823.97 MiB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [None]:
final_dir = Path(output_cfg["model_dir"]) / "final_adapter"
final_dir.mkdir(parents=True, exist_ok=True)
trainer.save_model(str(final_dir))
print(f"Saved final LoRA adapter to {final_dir}")