# TORGO `baseline_evaluation.py` — Colab Notebook (Explained)

## What this `.py` file is for
The script evaluates **pretrained Whisper** models on the TORGO dataset and reports:
- **WER** (Word Error Rate) and **CER** (Character Error Rate)
- Metrics broken down by **speech status** (e.g., dysarthric vs healthy)
- Error breakdown counts: **substitutions**, **deletions**, **insertions**
- A summary table comparing multiple model sizes (tiny/base/small)

### Key definitions
- **Reference (ref)**: the ground-truth transcription (what was actually said)
- **Hypothesis (hyp)**: the model’s predicted transcription
- **WER (Word Error Rate)**:  
  \[ \text{WER} = \frac{S + D + I}{N} \]  
  where `S`=substitutions, `D`=deletions, `I`=insertions, `N`=number of reference words.
- **CER (Character Error Rate)**: same idea as WER but measured at the character level.

> In plain language: WER tells you “how many word-level edits it takes” to turn the model output into the correct text.


## 0) Setup: Imports and model registry

### Why these imports exist
- `torch`: runs Whisper inference and chooses device (cuda/mps/cpu)
- `numpy`: audio arrays are NumPy arrays
- `jiwer`: computes WER/CER + word-level error breakdown
- `transformers`: loads Whisper model + processor
- `datasets`: loads saved TORGO dataset from disk and decodes audio
- `argparse/json/pathlib/defaultdict`: CLI + saving results + grouping by status

### `WHISPER_MODELS`
This dictionary maps a short name (e.g. `"tiny"`) to the Hugging Face model id.


In [6]:
import argparse
import json
from collections import defaultdict
from pathlib import Path

import torch
import numpy as np
from jiwer import wer, cer, process_words
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_from_disk, Audio


WHISPER_MODELS = {
    "tiny": "openai/whisper-tiny",
    "base": "openai/whisper-base",
    "small": "openai/whisper-small",
}


## 1) `transcribe_audio(model, processor, audio_array, sr, device)`

### Purpose
Transcribe **one audio sample** using a Whisper model.

### Step-by-step
1. Use `processor(...)` to convert the raw audio waveform (`audio_array`) into Whisper input features.
2. Move features onto the selected device (`cpu`, `cuda`, or `mps`).
3. Disable gradients with `torch.no_grad()` (faster and uses less memory for inference).
4. Use `model.generate(...)` to produce token ids for the transcription.
5. Decode tokens back to text with `processor.batch_decode(...)`.
6. Normalize output by stripping whitespace and lowercasing.

### Why lowercasing?
WER is sensitive to casing differences unless you normalize. Lowercasing makes metrics more consistent.


In [7]:
def transcribe_audio(model, processor, audio_array: np.ndarray, sr: int, device: str) -> str:
    """Transcribe a single audio sample using Whisper."""
    input_features = processor(
        audio_array, sampling_rate=sr, return_tensors="pt"
    ).input_features.to(device)

    with torch.no_grad():
        predicted_ids = model.generate(input_features)

    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
    return transcription.strip().lower()


## 2) `evaluate_model(model_name, model_id, dataset, device)`

### Purpose
Evaluate **one Whisper model** on the **test split** and compute metrics:
- WER/CER overall
- WER/CER per `speech_status`
- Error breakdown counts per group

### Key steps
1. Load Whisper processor + model from `model_id`
2. Iterate over the **test split**
3. For each sample:
   - Get `reference` transcription (ground truth)
   - Get `speech_status` group label
   - Get audio waveform (`audio["array"]`) and sample rate (`audio["sampling_rate"]`)
   - Produce `hypothesis` transcription via `transcribe_audio(...)`
   - Append reference/hypothesis into a group bucket (`group_results[status]`)
4. After the loop, compute metrics for each group using `jiwer`:
   - `wer(refs, hyps)`
   - `cer(refs, hyps)`
   - `process_words(refs, hyps)` gives substitutions/deletions/insertions counts
5. Also compute **overall** WER/CER by aggregating all refs/hyps

### Why group by `speech_status`?
Because dysarthric speech is harder for ASR. You want to see how performance differs on dysarthric vs healthy speech.


In [8]:
def evaluate_model(model_name: str, model_id: str, dataset, device: str) -> dict:
    """Evaluate a single Whisper model on the test split."""
    print(f"\nEvaluating {model_name} ({model_id})...")

    processor = WhisperProcessor.from_pretrained(model_id)
    model = WhisperForConditionalGeneration.from_pretrained(model_id).to(device)
    model.eval()

    # Collect results grouped by speech_status
    group_results = defaultdict(lambda: {"refs": [], "hyps": []})

    test_data = dataset["test"] if "test" in dataset else dataset[list(dataset.keys())[0]]
    total = len(test_data)

    for i, sample in enumerate(test_data):
        reference = sample.get("transcription", "")
        if not reference:
            continue
        reference = reference.strip().lower()

        status = sample.get("speech_status", "unknown")
        # Handle ClassLabel encoding (int -> string)
        if isinstance(status, int):
            status_map = {0: "dysarthric", 1: "healthy"}
            status = status_map.get(status, f"unknown_{status}")
        audio = sample["audio"]

        hypothesis = transcribe_audio(
            model, processor, audio["array"], audio["sampling_rate"], device
        )

        group_results[status]["refs"].append(reference)
        group_results[status]["hyps"].append(hypothesis)

        if (i + 1) % 100 == 0:
            print(f"  Processed {i + 1}/{total} samples...")

    # Compute metrics
    report = {"model": model_name, "model_id": model_id, "groups": {}}

    all_refs, all_hyps = [], []

    for status, data in group_results.items():
        refs, hyps = data["refs"], data["hyps"]
        all_refs.extend(refs)
        all_hyps.extend(hyps)

        group_wer = wer(refs, hyps)
        group_cer = cer(refs, hyps)

        # Error type breakdown
        output = process_words(refs, hyps)
        report["groups"][status] = {
            "wer": group_wer,
            "cer": group_cer,
            "substitutions": output.substitutions,
            "deletions": output.deletions,
            "insertions": output.insertions,
            "num_samples": len(refs),
        }

    # Overall metrics
    if all_refs:
        report["overall_wer"] = wer(all_refs, all_hyps)
        report["overall_cer"] = cer(all_refs, all_hyps)
        report["total_samples"] = len(all_refs)

    print(f"  Done. Overall WER: {report.get('overall_wer', 0)*100:.1f}%")
    return report


## 3) `print_report(reports)`

### Purpose
Pretty-print a comparison table for multiple evaluated models.

### What it prints
1. A header for results
2. A table with columns:
   - Model name
   - Overall WER
   - Overall CER
   - Dysarthric WER
   - Healthy WER
3. The **best model** (lowest overall WER)
4. Per-group details for the best model:
   - WER, CER
   - number of samples
   - substitutions, deletions, insertions
5. The WER gap between dysarthric and healthy speech

### Why choose “best” by overall WER?
WER is the most common headline metric for ASR quality.


In [9]:
def print_report(reports: list[dict]):
    """Print formatted comparison of all models."""
    print("\n" + "=" * 70)
    print("BASELINE EVALUATION RESULTS")
    print("=" * 70)

    # Model comparison
    print(f"\n{'Model':>10} {'Overall WER':>12} {'Overall CER':>12} {'Dysarthric':>12} {'Healthy':>12}")
    print("-" * 60)
    for r in reports:
        dys_wer = r.get("groups", {}).get("dysarthric", {}).get("wer", 0)
        healthy_wer = r.get("groups", {}).get("healthy", {}).get("wer", 0)
        print(
            f"{r['model']:>10} "
            f"{r.get('overall_wer', 0)*100:>11.1f}% "
            f"{r.get('overall_cer', 0)*100:>11.1f}% "
            f"{dys_wer*100:>11.1f}% "
            f"{healthy_wer*100:>11.1f}%"
        )

    # Best model
    best = min(reports, key=lambda r: r.get("overall_wer", float("inf")))
    print(f"\nBest model: {best['model']} (WER: {best.get('overall_wer', 0)*100:.1f}%)")

    # Per-group details for best model
    print(f"\n{'Per-Group Details (best model: ' + best['model'] + ')':=^70}")
    print(f"  {'Group':<14} {'WER':>8} {'CER':>8} {'Samples':>8} {'Sub':>6} {'Del':>6} {'Ins':>6}")
    print("  " + "-" * 56)
    for group, data in sorted(best["groups"].items()):
        print(
            f"  {group:<14} "
            f"{data['wer']*100:>7.1f}% "
            f"{data['cer']*100:>7.1f}% "
            f"{data['num_samples']:>8} "
            f"{data['substitutions']:>6} "
            f"{data['deletions']:>6} "
            f"{data['insertions']:>6}"
        )

    # WER gap
    dys = best["groups"].get("dysarthric", {}).get("wer", 0)
    healthy = best["groups"].get("healthy", {}).get("wer", 0)
    if dys and healthy:
        print(f"\n  WER gap (dysarthric - healthy): {(dys - healthy)*100:.1f}%")


## 4) `main()` — orchestrates the full evaluation

### Purpose
Turn the evaluation functions into a command-line tool that:
1. Loads a local saved TORGO dataset (`load_from_disk`)
2. Evaluates multiple Whisper model sizes
3. Prints a summary report
4. Saves results to JSON

### Key steps
1. Parse CLI args:
   - `--models`: which Whisper sizes to evaluate
   - `--input`: directory containing `torgo_dataset/`
   - `--output`: where to save JSON results
2. Pick `device`:
   - `cuda` if GPU is available
   - otherwise `mps` (Apple Silicon) if available
   - otherwise `cpu`
3. Load dataset and cast audio to 16kHz
4. Evaluate each requested model
5. Print comparison report and save results

### In Colab
Instead of CLI args, you’ll usually set Python variables.  
So we also provide a notebook runner at the end.


In [10]:
args = argparse.Namespace(
    models=["tiny", "base", "small"],
    input="../audio/torgo",
    output="./baseline_results.json",
)


In [11]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

Using device: mps


In [12]:
dataset_path = Path(args.input) / "torgo_dataset"
print(f"Loading dataset from {dataset_path}...")
dataset = load_from_disk(str(dataset_path))
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

Loading dataset from data/audio/torgo/torgo_dataset...


FileNotFoundError: Directory data/audio/torgo/torgo_dataset not found

In [None]:
reports = []
for model_name in args.models:
    report = evaluate_model(model_name, WHISPER_MODELS[model_name], dataset, device)
    reports.append(report)

print_report(reports)

In [None]:
# Save results
output_path = Path(args.output)
with open(output_path, "w") as f:
    json.dump(reports, f, indent=2)
print(f"\nResults saved to {output_path}")