<a href="https://colab.research.google.com/github/NolanChai/lign167-whisper/blob/main/fine_tune_1500.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Whisper Fine-Tuning on OCSC Child Speech

## Overview
Fine-tune OpenAI's Whisper-Medium model on the Ohio Child Speech Corpus to improve automatic speech recognition for children aged 4-9. This notebook implements the training pipeline with W&B logging and checkpoint management.

### Install Dependencies
Install specific versions of Transformers, Datasets, and evaluation libraries. Remove peft to avoid version conflicts since we're doing full fine-tuning (not LoRA). Confirm CUDA availability and GPU type. Training requires GPU acceleration.


In [None]:
!nvidia-smi

!pip -q install \
  transformers==4.45.2 \
  datasets==2.20.0 \
  evaluate==0.4.2 \
  huggingface_hub==0.26.2 \
  soundfile==0.12.1 \
  wandb

# ffmpeg is usually present, but just in case (no sudo needed in Colab)
!command -v ffmpeg >/dev/null || apt-get -y install -qq ffmpeg

# Remove peft to avoid version conflicts with transformers
!pip -q uninstall -y peft

### Environment Setup

Configure:
- Google Drive mounting for persistent storage
- Base directories for data and outputs
- Random seeds for reproducibility (seed=42)
- Device selection (CUDA if available)

Create a timestamped run directory on Google Drive to persist checkpoints across Colab disconnections. Each run gets a unique identifier.

In [None]:
import os, re, io, subprocess, shutil, random
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import soundfile as sf
import torch
from tqdm import tqdm

from google.colab import drive
from huggingface_hub import snapshot_download

# Mount Drive
drive.mount("/content/drive")

# Base dirs
BASE_COLAB = Path("/content")
DATA = BASE_COLAB / "data"; DATA.mkdir(parents=True, exist_ok=True)
RAW  = DATA / "ocsc_raw"
MANI = DATA / "manifests"; MANI.mkdir(parents=True, exist_ok=True)

print("DATA:", DATA)
print("RAW:", RAW)
print("MANI:", MANI)

# Seeds
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# Drive run directory
BASE_DRIVE = Path("/content/drive/MyDrive")
RUNS_DIR   = BASE_DRIVE / "ocsc_whisper_runs"
RUNS_DIR.mkdir(parents=True, exist_ok=True)

timestamp   = datetime.now().strftime("%Y%m%d-%H%M%S")
RUN_NAME    = f"whisper-medium-ocsc-ft-{timestamp}-phase"
RUN_OUTPUT_DIR = RUNS_DIR / RUN_NAME
RUN_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print("Run name:      ", RUN_NAME)
print("RUN_OUTPUT_DIR:", RUN_OUTPUT_DIR)

Mounted at /content/drive
DATA: /content/data
RAW: /content/data/ocsc_raw
MANI: /content/data/manifests
Device: cuda
Run name:       whisper-medium-ocsc-ft-20251203-154530-phase
RUN_OUTPUT_DIR: /content/drive/MyDrive/ocsc_whisper_runs/whisper-medium-ocsc-ft-20251203-154530-phase


### Weights & Biases Integration
Initialize W&B for experiment tracking. Logs training metrics, checkpoints, and enables run comparison. Clear any stale run IDs to ensure fresh tracking.

In [None]:
import wandb

# Make sure we don't accidentally resume an old W&B run
os.environ.pop("WANDB_RUN_ID", None)
os.environ.pop("WANDB_RUN_PATH", None)

os.environ["WANDB_PROJECT"]   = "ocsc-whisper"      # your project
os.environ["WANDB_WATCH"]     = "false"
os.environ["WANDB_LOG_MODEL"] = "checkpoint"        # log checkpoints as model artifacts

wandb.login()  # paste API key if prompted

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mnochai[0m ([33mnoulan[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

### Load Preprocessed Manifests
Load train/dev/test CSV manifests created by the preprocessing notebook. These contain audio paths, timestamps, and normalized transcriptions for each utterance.

In [None]:
# Converted manifests from Notebook A, stored on Drive
CONV_ROOT = BASE_DRIVE / "ocsc_converted"
stamps = [p for p in CONV_ROOT.iterdir() if p.is_dir()]
assert stamps, "No converted manifests found under /content/drive/MyDrive/ocsc_converted"

STAMP_DIR = sorted(stamps, key=lambda p: p.name)[-1]  # latest timestamp folder
print("Using manifests from:", STAMP_DIR)

for name in ["ocsc_manifest_utterances.csv", "ocsc_train.csv", "ocsc_dev.csv", "ocsc_test.csv"]:
    src = STAMP_DIR / name
    assert src.exists(), f"Missing {src}"
    shutil.copy2(src, MANI / name)

df_train = pd.read_csv(MANI / "ocsc_train.csv")
df_dev   = pd.read_csv(MANI / "ocsc_dev.csv")
df_test  = pd.read_csv(MANI / "ocsc_test.csv")

print("Train rows:", len(df_train))
print("Dev rows:",   len(df_dev))
print("Test rows:",  len(df_test))
df_train.head()

Using manifests from: /content/drive/MyDrive/ocsc_converted/20251201-213109
Train rows: 96908
Dev rows: 11230
Test rows: 25408


Unnamed: 0,session_id,age_folder,audio_path,cha_path,speaker_id,age_years,age_bucket,task,start_s,end_s,text,norm_text,dur_s
0,4022,4,/content/data/ocsc_raw/Eng-NA/OCSC/4/4022.wav,/content/4022.cha,CHI_4022,4.0,4-5,IntroRobot,27.352,28.891,hello . 27352_28891,hello,1.539
1,4022,4,/content/data/ocsc_raw/Eng-NA/OCSC/4/4022.wav,/content/4022.cha,CHI_4022,4.0,4-5,IntroRobot,32.158,33.708,my name is Teigan . 32158_33708,my name is teigan,1.55
2,4022,4,/content/data/ocsc_raw/Eng-NA/OCSC/4/4022.wav,/content/4022.cha,CHI_4022,4.0,4-5,IntroRobot,48.155,49.945,mine's rainbow . 48155_49945,mine's rainbow,1.79
3,4022,4,/content/data/ocsc_raw/Eng-NA/OCSC/4/4022.wav,/content/4022.cha,CHI_4022,4.0,4-5,IntroRobot,64.699,65.845,xxx . 64699_65845,xxx,1.146
4,4022,4,/content/data/ocsc_raw/Eng-NA/OCSC/4/4022.wav,/content/4022.cha,CHI_4022,4.0,4-5,IntroRobot,65.845,66.642,&-um . 65845_66642,um,0.797


### Audio Clip Extraction
Pull the audio files from HuggingFace if not already cached. Audio is stored separately from manifests to enable streaming during training.

`load_clip_ffmpeg`: Extract audio segments using ffmpeg subprocess calls.
- Resolves audio paths across different storage locations
- Slices audio by start/end timestamps from manifest
- Resamples to 16kHz mono (Whisper's expected format)
- Returns PyTorch tensor for model input

Using ffmpeg instead of librosa/torchaudio for stability in Colab environment.


In [None]:
# Download NolanChai/childes-ocsc if not already present
if not (RAW / "Eng-NA" / "OCSC").exists():
    print("Downloading OCSC audio tree from HF dataset NolanChai/childes-ocsc...")
    snapshot_download(
        repo_id="NolanChai/childes-ocsc",
        repo_type="dataset",
        local_dir=str(RAW),
        local_dir_use_symlinks=False,
    )

AUDIO_ROOT = RAW / "Eng-NA" / "OCSC"
print("Audio root:", AUDIO_ROOT)
assert AUDIO_ROOT.exists()

AUDIO_EXTS = [".wav", ".mp3"]

def resolve_audio_path(p_str: str) -> str:
    """
    Robustly resolve audio_path from manifest to actual file:
    - If path exists as is, use it.
    - Else try suffix starting from Eng-NA/OCSC.
    - Else try swapping .wav/.mp3.
    - Else last-resort search by stem under AUDIO_ROOT.
    """
    p = Path(p_str)
    if p.exists():
        return str(p)

    # Try suffix starting from Eng-NA/OCSC
    m = re.search(r"(Eng-NA/OCSC/.+)$", str(p))
    if m:
        rel = Path(m.group(1))
        cand = AUDIO_ROOT / rel.relative_to("Eng-NA/OCSC")
        if cand.exists():
            return str(cand)
        for ext in AUDIO_EXTS:
            if cand.with_suffix(ext).exists():
                return str(cand.with_suffix(ext))

    # Swap extension in place
    for ext in AUDIO_EXTS:
        alt = p.with_suffix(ext)
        if alt.exists():
            return str(alt)

    # Last resort: search by stem
    stem = p.stem
    hits = list(AUDIO_ROOT.rglob(f"{stem}.*"))
    for h in hits:
        if h.suffix.lower() in AUDIO_EXTS:
            return str(h)

    raise FileNotFoundError(f"Audio not found for: {p_str}")

def load_clip_ffmpeg(path: str, start_s: float, end_s: float, target_sr: int = 16000):
    """
    Slice [start_s, end_s) using ffmpeg into memory and return (waveform_tensor, sr).
    Avoids torchaudio/librosa, stable in Colab.
    """
    path = resolve_audio_path(path)
    dur  = max(0.01, float(end_s) - float(start_s))

    cmd = [
        "ffmpeg", "-hide_banner", "-loglevel", "error",
        "-ss", f"{float(start_s):.3f}",
        "-i", path,
        "-t", f"{dur:.3f}",
        "-ac", "1", "-ar", str(target_sr),
        "-f", "wav", "pipe:1",
    ]
    out = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
    data, sr = sf.read(io.BytesIO(out.stdout), dtype="float32", always_2d=False)
    if data.ndim > 1:
        data = data.mean(axis=1)
    return torch.from_numpy(data), sr

### Model Initialization

Load Whisper-Medium English (`openai/whisper-medium.en`):
- **Processor**: Tokenizer + feature extractor (always from base model)
- **Model**: Check for existing W&B checkpoint to resume training, otherwise start from pretrained weights

Disable Whisper's default forced decoder IDs and token suppression to allow the model to learn from child speech patterns without constraints.

In [None]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration

# Base model id you used originally
BASE_MODEL_ID = "openai/whisper-medium.en"

# Processor never changes during training; always load from base
processor = WhisperProcessor.from_pretrained(
    BASE_MODEL_ID,
    language="en",
    task="transcribe",
)

# Try to load the latest W&B model artifact as initialization
api = wandb.Api()
ENTITY        = "noulan"
PROJECT       = "ocsc-whisper"
ARTIFACT_NAME = "model-whisper-medium-ocsc-ft"  # base name from your artifacts

WANDB_CKPT_ROOT = Path("/content/wandb_checkpoints")
WANDB_CKPT_ROOT.mkdir(exist_ok=True)

model = None
try:
    artifact_id = f"{ENTITY}/{PROJECT}/{ARTIFACT_NAME}:latest"
    artifact = api.artifact(artifact_id, type="model")
    artifact_dir = Path(artifact.download(root=str(WANDB_CKPT_ROOT)))
    print("Loaded weights from W&B artifact:", artifact_id)
    print("Artifact directory:", artifact_dir)

    model = WhisperForConditionalGeneration.from_pretrained(str(artifact_dir))
except Exception as e:
    print("Could NOT load W&B model artifact; falling back to base model.")
    print("Reason:", repr(e))
    model = WhisperForConditionalGeneration.from_pretrained(BASE_MODEL_ID)

model.to(device)

# For fine-tuning, drop Whisper's forced prompts/suppression
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

print("Model dtype:", next(model.parameters()).dtype)

### Evaluation Metrics

Define WER (Word Error Rate) and CER (Character Error Rate) computation:
- WER = (Substitutions + Deletions + Insertions) / Reference Words
- Primary metric for ASR evaluation
- Computed on decoded text (not token IDs)

In [None]:
!pip -q install evaluate jiwer

import evaluate
wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")

def compute_metrics(pred):
    """
    pred: transformers.EvalPrediction
      - pred.predictions: generated token IDs (because predict_with_generate=True)
      - pred.label_ids: label token IDs (with -100 where we ignore)
    """
    pred_ids = pred.predictions
    if isinstance(pred_ids, tuple):
        pred_ids = pred_ids[0]
    pred_ids = np.asarray(pred_ids)

    label_ids = pred.label_ids.copy()
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str  = processor.tokenizer.batch_decode(pred_ids,  skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    cer = cer_metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer, "cer": cer}

### PyTorch Dataset for Streaming Audio

`AudioUttDataset`: Custom dataset that:
1. Reads utterance metadata from manifest DataFrame
2. Loads audio clips on-the-fly using ffmpeg (no pre-extraction needed)
3. Extracts Mel spectrogram features via Whisper processor
4. Tokenizes transcription text as decoder labels

This streaming approach avoids storing extracted audio, saving disk space.

`DataCollatorSpeechSeq2SeqWithPadding`: Prepare batches for training:
- **Encoder inputs**: Pad Mel spectrograms to uniform length within batch
- **Decoder labels**: Pad token sequences, mask padding with -100 (ignored in loss)

Handles variable-length audio and text within each batch.

Load a small batch to verify shapes and dtypes are correct before starting training. Catches data pipeline bugs early.

In [None]:
import torch
from torch.utils.data import DataLoader
from dataclasses import dataclass
from typing import Dict, List, Any

class AudioUttDataset(torch.utils.data.Dataset):
    def __init__(self, df: pd.DataFrame, processor: WhisperProcessor):
        self.df = df.reset_index(drop=True)
        self.proc = processor

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        r = self.df.iloc[idx]
        y, sr = load_clip_ffmpeg(
            r["audio_path"],
            float(r["start_s"]),
            float(r["end_s"]),
        )
        feats = self.proc.feature_extractor(
            y.numpy(),
            sampling_rate=sr,
            return_attention_mask=False,
        )
        labels = self.proc.tokenizer(r["norm_text"]).input_ids

        return {
            "input_features": feats["input_features"][0],  # (80, T) float32
            "labels": labels,                              # list[int]
        }

train_ds = AudioUttDataset(df_train, processor)
dev_ds   = AudioUttDataset(df_dev,   processor)

print("Train examples:", len(train_ds))
print("Dev examples:",   len(dev_ds))

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        # ----- Encoder features -----
        feats = [f["input_features"] for f in features]
        max_T = max(f.shape[-1] for f in feats)

        feat_tensors = []
        for f in feats:
            if f.shape[-1] < max_T:
                pad_T = max_T - f.shape[-1]
                f = np.pad(
                    f,
                    pad_width=((0, 0), (0, pad_T)),
                    mode="constant",
                    constant_values=0.0,
                )
            feat_tensors.append(torch.tensor(f, dtype=torch.float32))
        input_features = torch.stack(feat_tensors)  # (B, 80, max_T)

        # ----- Decoder labels -----
        pad_id = self.processor.tokenizer.pad_token_id
        label_tensors = [torch.tensor(f["labels"], dtype=torch.long) for f in features]
        labels = torch.nn.utils.rnn.pad_sequence(
            label_tensors,
            batch_first=True,
            padding_value=pad_id,
        )
        labels[labels == pad_id] = -100

        return {
            "input_features": input_features,
            "labels": labels,
        }

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

# quick shape sanity check
probe_loader = DataLoader(train_ds, batch_size=2, shuffle=False, collate_fn=data_collator)
probe_batch = next(iter(probe_loader))
for k, v in probe_batch.items():
    print(k, v.shape, v.dtype)

Train examples: 96908
Dev examples: 11230
input_features torch.Size([2, 80, 3000]) torch.float32
labels torch.Size([2, 10]) torch.int64


## Training Arguments

### Hyperparameters
- **Epochs**: 2 (target, may not complete due to compute constraints)
- **Batch size**: 16 per device
- **Learning rate**: 1e-5 with 10% warmup
- **Weight decay**: 0.01

### Checkpointing Strategy
- Evaluate and save every 500 steps
- Keep only 1 checkpoint (save_total_limit=1) to manage storage
- Load best model at end based on eval loss

### Memory Optimization
- Gradient checkpointing enabled (trades compute for memory)
- FP16 disabled for stability

### Seq2Seq Trainer Setup
Initialize HuggingFace Seq2SeqTrainer with model, datasets, and training arguments. The trainer handles the training loop, evaluation, checkpointing, and W&B logging.

In [None]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

training_args = Seq2SeqTrainingArguments(
    output_dir=str(RUN_OUTPUT_DIR),

    # ===== Core =====
    num_train_epochs=2,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=1,

    learning_rate=1e-5,
    warmup_ratio=0.1,
    weight_decay=0.01,
    label_smoothing_factor=0.0,

    fp16=False,
    gradient_checkpointing=True,

    # ===== Eval & logging =====
    eval_strategy="steps",
    eval_steps=500,
    logging_strategy="steps",
    logging_steps=50,

    # ===== Checkpoints =====
    save_strategy="steps",
    save_steps=500,
    save_total_limit=1,

    load_best_model_at_end=True,
    metric_for_best_model=None,
    greater_is_better=False,

    # ===== Generation during eval =====
    predict_with_generate=True,
    generation_max_length=225,
    generation_num_beams=1,

    # ===== Misc / perf =====
    dataloader_num_workers=4,
    remove_unused_columns=False,
    report_to=["wandb"],
    run_name=RUN_NAME,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=dev_ds,
    tokenizer=processor.tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

### Execute Training
Start the fine-tuning process. Training progress logged to W&B. Checkpoints saved to Google Drive for persistence.

**Note**: Full training (2 epochs) may take multiple days on Colab GPU. Monitor W&B for loss curves and eval metrics.

In [None]:
train_result = trainer.train()
print(train_result)

# Save final model + processor for later eval / inference
OUT_DIR = BASE_DRIVE / "experiments" / RUN_NAME / "final"
OUT_DIR.mkdir(parents=True, exist_ok=True)
model.save_pretrained(str(OUT_DIR))
processor.save_pretrained(str(OUT_DIR))
print("Saved fine-tuned model to:", OUT_DIR)

### Export Final Model
Save the fine-tuned model and processor to Google Drive for later inference and evaluation. This creates a self-contained model directory loadable with `from_pretrained()`.

In [None]:
from pathlib import Path

from google.colab import drive
drive.mount("/content/drive")

BASE_DRIVE = Path("/content/drive/MyDrive")
RUNS_DIR   = BASE_DRIVE / "ocsc_whisper_runs"

!ls -R "$RUNS_DIR"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/ocsc_whisper_runs:
ocsc-whisper-medium-ocsc-20250303-121530
ocsc-whisper-medium-ocsc-20251202-144503
whisper-medium-ocsc-ft-20251202-153226
whisper-medium-ocsc-ft-20251202-153450
whisper-medium-ocsc-ft-20251203-154530-phase
whisper-medium-ocsc-ft-phase2

/content/drive/MyDrive/ocsc_whisper_runs/ocsc-whisper-medium-ocsc-20250303-121530:
checkpoint-1000  checkpoint-500

/content/drive/MyDrive/ocsc_whisper_runs/ocsc-whisper-medium-ocsc-20250303-121530/checkpoint-1000:
added_tokens.json	normalizer.json		 tokenizer_config.json
config.json		optimizer.pt		 trainer_state.json
generation_config.json	rng_state.pth		 training_args.bin
merges.txt		scheduler.pt		 vocab.json
model.safetensors	special_tokens_map.json

/content/drive/MyDrive/ocsc_whisper_runs/ocsc-whisper-medium-ocsc-20250303-121530/checkpoint-500:
added_tokens.json	normalizer.json		 t

### Verify Saved Checkpoints
List all saved checkpoints and identify the latest one. Useful for resuming training or running evaluation on intermediate checkpoints.

In [None]:
RUN_DIR = RUNS_DIR / "whisper-medium-ocsc-ft-20251203-154530-phase"
!ls "$RUN_DIR"

checkpoint-1000


In [None]:
from transformers.trainer_utils import get_last_checkpoint
last_ckpt = get_last_checkpoint(str(RUN_DIR))
print("Last local checkpoint:", last_ckpt)

Last local checkpoint: /content/drive/MyDrive/ocsc_whisper_runs/whisper-medium-ocsc-ft-20251203-154530-phase/checkpoint-1000
