# Fine-tuning Template Notebook

In [None]:
# !pip uninstall -y torch torchaudio torchvision triton torchtext torchaudio bitsandbytes
# !pip install -r finetuning_requirements.txt.txt --force-reinstall --no-cache-dir
# !pip install torchvision==0.19.1 --index-url https://download.pytorch.org/whl/cu121
# !pip install -U boto3==1.34.69 botocore==1.34.69 aiobotocore==2.12.3 s3fs==2024.3.1

In [None]:
import torch
import bitsandbytes as bnb
import triton
print("Torch:", torch.__version__)
print("bitsandbytes:", bnb.__version__)
print("Triton:", triton.__version__)
print("CUDA available:", torch.cuda.is_available())

In [None]:
import torchvision
print("Torch:", torch.__version__)
print("Torchvision:", torchvision.__version__)

In [None]:
import datetime as dt
import json
import numpy as np
import os
import pandas as pd
from pathlib import Path
import platform
import re
import subprocess
import time

import boto3
from concurrent.futures import ThreadPoolExecutor
from datasets import Audio, Dataset, IterableDataset, load_dataset, load_from_disk
from jiwer import wer
from peft import (
    LoraConfig,
    AdaLoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
)
from tqdm import tqdm
from transformers import (
    WhisperProcessor,
    WhisperTokenizer,
    WhisperFeatureExtractor,
    WhisperForConditionalGeneration,
    BitsAndBytesConfig,
    DataCollatorWithPadding,
    TrainingArguments,
    Trainer,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)

### Project Configuration
- Be sure to set the `PROJECT_NAME` var with your `name/experiment_description` so we can keep track

In [None]:
# Project Config
PROJECT_NAME = "dave/EXPERIMENT_NAME_HERE"
BUCKET = "asrelder-data"
CLIPS_PREFIX = "common_voice/23/cv-corpus-23.0-2025-09-05/en/clips/"
OUTPUT_PREFIX = f"experiments/{PROJECT_NAME}"
BASE_S3_PREFIX = f"s3://{BUCKET}/{CLIPS_PREFIX}"

# Keys
os.environ["AWS_ACCESS_KEY_ID"] = "FILL_ME_IN"
os.environ["AWS_SECRET_ACCESS_KEY"] = "FILL_ME_IN"
os.environ["AWS_DEFAULT_REGION"] = "FILL_ME_IN"

# Confirm
!aws sts get-caller-identity

# Test
sts = boto3.client("sts")
print(sts.get_caller_identity())

### Download the train, validation, and test CSVs from GDrive

In [None]:
# Get csvs from GDrive
DRIVE_FILE_IDS = {
    "train": "1AdCeMxDcE4rxqWSyPsfEh7TaS5dWjXD5",  # common_voices_23_balanced_on_60.csv
    "val": "1GzrujHvGwA7MA9awtQI4IFRQIdYcLiBO",  # common_voices_23_val_full.csv
    "test": "1bSjhB8WTDZWBTuppB-vU56AgOEzNNAeN",  # common_voices_23_test_full.csv
}

os.makedirs("data", exist_ok=True)

def download_from_drive(name, file_id):
    out_path = f"data/{name}.csv"
    url = f"https://drive.google.com/uc?id={file_id}"
    print(f"Downloading {name} split from Google Drive ‚Üí {out_path}")
    subprocess.run(["gdown", "--fuzzy", url, "-O", out_path], check=True)
    return out_path

TRAIN_PATH = download_from_drive("train", DRIVE_FILE_IDS["train"])
VAL_PATH = download_from_drive("val", DRIVE_FILE_IDS["val"])
TEST_PATH = download_from_drive("test", DRIVE_FILE_IDS["test"])

DATA_FILES = {
    "train": TRAIN_PATH,
    "val": VAL_PATH,
    "test": TEST_PATH,
}

In [None]:
# Look at the columns for train, val, and test
for split in ["train", "val", "test"]:
    path = f"data/{split}.csv"
    df = pd.read_csv(path, nrows=1)
    print(f"{path} columns: {list(df.columns)}")

In [None]:
# Add age_group for val and test; column present in train
def add_age_group_column(file_path: str):
    """
    Adds 'age_group' column to the CSV file if missing
    Derives it from the 'age' column (e.g., '23' -> '20')
    NOTE: Modifies the file in place
    """
    df = pd.read_csv(file_path)
    if "age_group" in df.columns:
        print(f"Skipping, 'age_group' already exists in {file_path}")
        return

    def infer_age_group(age_value):
        if pd.isna(age_value):
            return ""
        # Normalize to string
        s = str(age_value).strip().lower()
        # Handle numeric (e.g. 23)
        if re.match(r"^\d{2}$", s):
            decade = int(s) // 10 * 10
            return f"{decade}s"
        # Handle ranges like "25-34"
        match = re.match(r"(\d{2})\s*-\s*(\d{2})", s)
        if match:
            decade = int(match.group(1)) // 10 * 10
            return f"{decade}s"
        # Handle words like 'twenties', 'forty', etc.
        words_to_decade = {
            "teen": "10s", "teens": "10s",
            "twenty": "20s", "twenties": "20s",
            "thirty": "30s", "thirties": "30s",
            "forty": "40s", "forties": "40s",
            "fifty": "50s", "fifties": "50s",
            "sixty": "60s", "sixties": "60s",
            "seventy": "70s", "seventies": "70s",
            "eighty": "80s", "eighties": "80s",
        }
        for k, v in words_to_decade.items():
            if k in s:
                return v
        return ""  # unknown or other format

    df["age_group"] = df["age"].apply(infer_age_group)
    df.to_csv(file_path, index=False)
    print(f"Added 'age_group' to {file_path} ({len(df)} rows)")

add_age_group_column("data/val.csv")
add_age_group_column("data/test.csv")

In [None]:
def add_s3_paths(file_path: str):
    """
    Prepends full S3 URI prefix to the 'path' column
    Modifies the file in place
    """
    df = pd.read_csv(file_path)
    print(f"Processing {file_path} ({len(df)} rows)")
    if "path" in df.columns:
        def prepend_prefix(p):
            if isinstance(p, str) and not p.startswith("s3://"):
                return f"{BASE_S3_PREFIX}/{p.lstrip('/')}"
            return p
        df["path"] = df["path"].apply(prepend_prefix)
        print("Updated 'path' column with S3 prefix")

    df.to_csv(file_path, index=False)
    print(f"Saved updated CSV: {file_path}\n")
    print(f"Updated 'path' for {file_path} ({len(df)} rows)")

# Apply to val/test (train likely already correct)
add_s3_paths("data/train.csv")
add_s3_paths("data/val.csv")
add_s3_paths("data/test.csv")

### Load Whisper model

In [None]:
# Whisper model + language/task settings
WHISPER_MODEL = "openai/whisper-base"
LANGUAGE = "en"
TASK = "transcribe"

# Compute and training
MIXED_PRECISION = "fp16"
GRADIENT_ACCUMULATION = 2
BATCH_SIZE_PER_DEVICE = 4
NUM_EPOCHS = 3
LEARNING_RATE = 0.0001
MAX_AUDIO_SECONDS = 30

# PEFT method
# NOTE: qlora, dora, adalora, none
PEFT_METHOD = "qlora"

# ReFT toggle (stubbed ‚Äì see the ReFT cell for instructions)
ENABLE_REFT = False

print({
    "python": platform.python_version(),
    "cuda_available": torch.cuda.is_available(),
    "torch_version": torch.__version__,
    "device_count": torch.cuda.device_count(),
    "device_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu",
    "time": dt.datetime.now().isoformat(timespec="seconds"),
})

In [None]:
# SageMaker Utils
# Detect SageMaker environment, map S3 paths

def in_sagemaker() -> bool:
    return any(k.startswith("SM_") for k in os.environ.keys()) or os.environ.get("SAGEMAKER_JOB_NAME") is not None

def s3_join(*parts: str) -> str:
    return "/".join([p.strip("/").replace("s3://","") for p in parts])

print("In SageMaker:", in_sagemaker())

In [None]:
# Data cleaning
def clean_numeric_columns(file_path: str, numeric_cols=("variant", "segment")):
    df = pd.read_csv(file_path)
    for col in numeric_cols:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors="coerce")
    df.to_csv(file_path, index=False)
    print(f"Cleaned {file_path}: non-numeric entries coerced to NaN.")

for split in ["train", "val", "test"]:
    clean_numeric_columns(f"data/{split}.csv")

In [None]:
# # Data Loading
# NOTE: Went with a streaming approach (streaming=True) rather than copying over 150,000 audio clips
# raw_datasets = load_dataset("csv", data_files=DATA_FILES)

# Streaming Load from S3
print("Loading CSVs in streaming mode...")
audio_col = "path"
text_col = "sentence"
raw_datasets = load_dataset("csv", data_files=DATA_FILES, streaming=True)
print(raw_datasets)

In [None]:
# Use half of available CPUs
print(f"CPU count: {os.cpu_count()}")
NUM_PROC = max(1, os.cpu_count() // 2)

# Cache the audio files so that we don't have to stream them from S3
CACHE_DIR = "data/processed_whisper"

### Whisper Processor & Preprocessing

Prepare tensors for each audio clip to send to model's encoder
- Load audio from S3 or disk
- Resample to 16 kHz
- Compute short-time Fourier transform (STFT)
- Convert to Mel scale
- Apply log compression
- Normalize to match Whisper's training statistics

In [None]:
# Whisper Processor & Preprocessing
processor = WhisperProcessor.from_pretrained(
    WHISPER_MODEL,
    language=LANGUAGE,
    task=TASK
)
feature_extractor: WhisperFeatureExtractor = processor.feature_extractor
tokenizer: WhisperTokenizer = processor.tokenizer
MAX_INPUT_LENGTH = int(MAX_AUDIO_SECONDS * feature_extractor.sampling_rate)

# Narrow to only the audio path and text columns
# NOTE: Peek at one example to infer columns
first_example = next(iter(raw_datasets["train"]))
column_names = list(first_example.keys())
print("Detected columns:", column_names)
remove_columns = [c for c in column_names if c not in (audio_col, text_col)]
print("Removed columns:", remove_columns)

# Preprocessing Function
def prepare_example(batch):
    audio = batch[audio_col]
    # HF Audio feature will lazily decode from S3 on the fly
    if isinstance(audio, dict) and "array" in audio:
        arr = audio["array"]
        sr = audio.get("sampling_rate", 16000)
    else:
        arr = audio["array"]
        sr = audio["sampling_rate"]

    if arr.shape[0] > MAX_INPUT_LENGTH:
        arr = arr[:MAX_INPUT_LENGTH]

    inputs = feature_extractor(arr, sampling_rate=sr)
    labels = tokenizer(batch[text_col]).input_ids
    return {
        "input_features": inputs["input_features"][0],
        "labels": labels
    }

# Testing Mode (subset for quick iteration)
TESTING_MODE = True

if TESTING_MODE:
    print("Running in TESTING_MODE (100 samples)...")
    small_train = raw_datasets["train"].take(100)  # streaming-safe slice
    small_train = Dataset.from_generator(lambda: small_train)
    small_train = small_train.cast_column(audio_col, Audio(sampling_rate=16000))

    processed = small_train.map(
        prepare_example,
        remove_columns=remove_columns,
        desc="Processing Whisper sample",
    )
    print(processed)
    processed.save_to_disk(CACHE_DIR)
    print(f"üíæ Cached small sample to {CACHE_DIR}")

else:
    print("\nüöÄ Running full preprocessing stream (this may take a while)...\n")
    start = time.time()
    streamed = raw_datasets["train"].cast_column(audio_col, Audio(sampling_rate=16000))
    processed = streamed.map(
        prepare_example,
        remove_columns=remove_columns,
        desc="Processing Whisper full stream"
    )
    print(f"‚è± Completed in {(time.time() - start)/60:.2f} min")
    processed.save_to_disk(CACHE_DIR)
    print(f"Cached full stream to {CACHE_DIR}")

### Load the processed tensors from disk (inside data/ directory)

In [None]:
# Load the processed tensors from disk (inside data/ directory)
proc_datasets = load_from_disk("data/processed_whisper")
print(proc_datasets)

### Training Configuration

In [None]:
# Model & PEFT Setup
# NOTE: Do not use `task_type`, it causes 'input_ids' issues

bnb_config = None
load_in_4bit = False

if PEFT_METHOD == "qlora":
    load_in_4bit = True
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
    )

quant_kwargs = dict(device_map="auto", torch_dtype=torch.float16)

if bnb_config is not None:
    quant_kwargs["quantization_config"] = bnb_config

model = WhisperForConditionalGeneration.from_pretrained(
    WHISPER_MODEL,
    **quant_kwargs,
)

# Gradient checkpointing is very helpful for Whisper fine-tuning
model.gradient_checkpointing_enable()

# Prepare model for k-bit training if using QLoRA
if load_in_4bit:
    model = prepare_model_for_kbit_training(model)

# Define which modules to apply LoRA on for Whisper
# NOTE: I don't know that these are the "right" layers, they're from: https://github.com/openai/whisper/discussions/830
whisper_layers = [
    "q_proj",
    "k_proj",
    "v_proj",
    "out_proj",
    "fc1",
    "fc2",
]

# Choose PEFT config
peft_config = None
if PEFT_METHOD == "qlora":
    peft_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        # task_type="SEQ_2_SEQ_LM",
        target_modules=whisper_layers,
    )
elif PEFT_METHOD == "dora":
    peft_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        # task_type="SEQ_2_SEQ_LM",
        target_modules=whisper_layers,
        use_dora=True,  # DoRA flag in PEFT
    )
elif PEFT_METHOD == "adalora":
    peft_config = AdaLoraConfig(
        init_r=12,   # starting rank
        target_r=8,  # target rank after adaptation
        beta1=0.85,
        beta2=0.85,
        tinit=200,
        tfinal=1000,
        deltaT=10,
        lora_alpha=32,
        lora_dropout=0.05,
        orth_reg_weight=0.5,
        target_modules=whisper_layers,
        # task_type="SEQ_2_SEQ_LM",
    )
else:
    print("PEFT_METHOD == 'none' ‚Üí full‚Äëparameter fine‚Äëtuning (not recommended on small GPUs).")
    raise Exception("We are not supporting full parameter fine tuning right now. Provide a supported PEFT_METHOD.")

if peft_config is not None:
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

### Training

In [None]:
# Hugging Face Training

# Metric
def compute_metrics(pred):
    pred_ids = pred.predictions
    if isinstance(pred_ids, tuple):
        pred_ids = pred_ids[0]
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids = pred.label_ids
    label_ids[label_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    return {"wer": wer(label_str, pred_str)}

# Force language/task tokens for Whisper decoding
forced_decoder_ids = processor.get_decoder_prompt_ids(language=LANGUAGE, task=TASK)

args = Seq2SeqTrainingArguments(
    output_dir=f"./outputs/{PROJECT_NAME}",
    per_device_train_batch_size=BATCH_SIZE_PER_DEVICE,
    per_device_eval_batch_size=BATCH_SIZE_PER_DEVICE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION,
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_EPOCHS,
    evaluation_strategy="steps",
    eval_steps=500,
    save_steps=500,
    logging_steps=50,
    predict_with_generate=True,
    fp16=(MIXED_PRECISION == "fp16"),
    bf16=(MIXED_PRECISION == "bf16"),
    report_to=["none"],  # or ['tensorboard']
    gradient_checkpointing=True,
)

# Data collator
def data_collator(features):
    input_features = torch.stack([torch.tensor(f["input_features"]) for f in features])
    label_batch = [f["labels"] for f in features]
    labels = tokenizer.pad(
        {"input_ids": label_batch},
        padding=True,
        return_tensors="pt"
    ).input_ids
    labels[labels == tokenizer.pad_token_id] = -100
    return {"input_features": input_features, "labels": labels}

# Detect whether it's a DatasetDict or a single Dataset
if isinstance(proc_datasets, dict) or hasattr(proc_datasets, "keys"):
    train_data = proc_datasets.get("train") or proc_datasets["train"]
    eval_data = proc_datasets.get("validation") or proc_datasets.get("val") or None
else:
    # Single dataset ‚Äî just use it for both train/eval if in testing mode
    train_data = proc_datasets
    eval_data = proc_datasets

trainer = Seq2SeqTrainer(
# trainer = WhisperSeq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=train_data,
    eval_dataset=eval_data,
    tokenizer=processor.feature_extractor,  # logs shapes properly
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

print("Starting training...")
train_result = trainer.train()
print(train_result)
trainer.save_model()

### Evaluation and Save Results

In [None]:
# Evaluate on Test Split
test_metrics = {}
if "test" in proc_datasets:
    test_metrics = trainer.evaluate(proc_datasets["test"], metric_key_prefix="test")
    print(test_metrics)
else:
    print("No test split available in proc_datasets. Skipping.")

# Save metrics
os.makedirs("metrics", exist_ok=True)
now = dt.datetime.utcnow().strftime("%Y-%m-%d%H:%M:%S")
with open(f"metrics/results_{now}.json", "w") as f:
    json.dump(
        {"eval": trainer.state.log_history, "test": test_metrics},
        f,
        indent=2
    )
print("Saved metrics to metrics/results.json")

In [None]:
# # Upload Artifacts to S3

# def aws_cp(local_path: str, s3_uri: str):
#     cmd = ["aws","s3","cp","--recursive", local_path, s3_uri]
#     print(" ".join(cmd))
#     try:
#         subprocess.check_call(cmd)
#     except Exception as e:
#         print("Upload failed:", e)

# S3_OUTPUT_URI = f"{S3_BUCKET}/{S3_OUTPUT_PREFIX}".rstrip("/")
# aws_cp("./outputs", f"s3://{S3_OUTPUT_URI}/outputs")
# aws_cp("./metrics", f"s3://{S3_OUTPUT_URI}/metrics")