# Fine-tuning Template Notebook

In [None]:
# Clean out stale wheels/caches
!pip cache purge -q
!rm -rf ~/.cache/torch_extensions

# Ensure specific versions of bnb, torch for compatibility
!pip uninstall -y torch torchvision torchaudio bitsandbytes
!pip install --index-url https://download.pytorch.org/whl/cu121 \
  torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1

!pip install bitsandbytes==0.45.0 triton==3.0.0 peft==0.12.0 accelerate==0.34.2 transformers==4.45.2

In [None]:
import torch, bitsandbytes as bnb, peft, transformers, triton
print("torch", torch.__version__)
print(torch.cuda.get_device_name(0))
print("transformers", transformers.__version__)
print("bnb", bnb.__version__)
print("peft", peft.__version__)
print("triton", triton.__version__)
print("CUDA?", torch.cuda.is_available())

In [None]:
from bitsandbytes.optim import AdamW8bit
import datetime as dt
import inspect
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from pathlib import Path
import platform
import re
import subprocess
import time
import boto3
from concurrent.futures import ThreadPoolExecutor
from datasets import Audio, Dataset, DatasetDict, IterableDataset, load_dataset, load_from_disk, concatenate_datasets
from jiwer import wer
from peft import (
    LoraConfig,
    AdaLoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
)
from tqdm import tqdm
from transformers import (
    WhisperProcessor,
    WhisperTokenizer,
    WhisperFeatureExtractor,
    WhisperForConditionalGeneration,
    BitsAndBytesConfig,
    DataCollatorWithPadding,
    TrainingArguments,
    Trainer,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    EarlyStoppingCallback,
    get_scheduler,
)

In [None]:
# Set cache for Sagemaker to fetch processed clips
EFS_CACHE = "/mnt/custom-file-systems/efs/fs-05e9c5047118066bb_fsap-0b12f9aa4c66d5d46/hf_cache"
os.environ["HF_HOME"] = EFS_CACHE
os.environ["HF_DATASETS_CACHE"] = os.path.join(EFS_CACHE, "datasets")
os.environ["HF_HUB_CACHE"] = os.path.join(EFS_CACHE, "hub")
os.environ["TORCH_HOME"] = os.path.join(EFS_CACHE, "torch")
os.environ["HF_HUB_DOWNLOAD_TMP"] = os.path.join(EFS_CACHE, "tmp")

os.makedirs(os.environ["HF_HUB_DOWNLOAD_TMP"], exist_ok=True)
print("Hugging Face cache now points to EFS.")

### Project Configuration

In [None]:
# Project Config
BUCKET = "asrelder-data"
CLIPS_PREFIX = "common_voice/23/cv-corpus-23.0-2025-09-05/en/clips/"
BASE_S3_PREFIX = f"s3://{BUCKET}/{CLIPS_PREFIX}"

# Keys
os.environ["AWS_ACCESS_KEY_ID"] = "FILL_ME_IN"
os.environ["AWS_SECRET_ACCESS_KEY"] = "FILL_ME_IN"
os.environ["AWS_DEFAULT_REGION"] = "FILL_ME_IN"

# Confirm
!aws sts get-caller-identity

# Test
sts = boto3.client("sts")
print(sts.get_caller_identity())

### Download the train, validation, and test CSVs from GDrive

In [None]:
# Get csvs from GDrive
DRIVE_FILE_IDS = {
    "train": "1AdCeMxDcE4rxqWSyPsfEh7TaS5dWjXD5",  # common_voices_23_balanced_on_60.csv
    "val": "1GzrujHvGwA7MA9awtQI4IFRQIdYcLiBO",  # common_voices_23_val_full.csv
    "test": "1bSjhB8WTDZWBTuppB-vU56AgOEzNNAeN",  # common_voices_23_test_full.csv
}

os.makedirs("data", exist_ok=True)

def download_from_drive(name, file_id):
    out_path = f"data/{name}.csv"
    url = f"https://drive.google.com/uc?id={file_id}"
    print(f"Downloading {name} split from Google Drive → {out_path}")
    subprocess.run(["gdown", "--fuzzy", url, "-O", out_path], check=True)
    return out_path

TRAIN_PATH = download_from_drive("train", DRIVE_FILE_IDS["train"])
VAL_PATH = download_from_drive("val", DRIVE_FILE_IDS["val"])
TEST_PATH = download_from_drive("test", DRIVE_FILE_IDS["test"])

DATA_FILES = {
    "train": TRAIN_PATH,
    "val": VAL_PATH,
    "test": TEST_PATH,
}

In [None]:
# Look at the columns for train, val, and test
for split in ["train", "val", "test"]:
    path = f"data/{split}.csv"
    df = pd.read_csv(path, nrows=1)
    print(f"{path} columns: {list(df.columns)}")

In [None]:
# Add age_group for val and test (column present in train)
def add_age_group_column(file_path: str):
    """
    Adds 'age_group' column to the CSV file if missing
    Derives it from the 'age' column (e.g., '23' -> '20')
    NOTE: Modifies the file in place
    """
    df = pd.read_csv(file_path)
    if "age_group" in df.columns:
        print(f"Skipping, 'age_group' already exists in {file_path}")
        return

    def infer_age_group(age_value):
        if pd.isna(age_value):
            return ""
        # Normalize to string
        s = str(age_value).strip().lower()
        # Handle numeric (e.g. 23)
        if re.match(r"^\d{2}$", s):
            decade = int(s) // 10 * 10
            return f"{decade}s"
        # Handle ranges like "25-34"
        match = re.match(r"(\d{2})\s*-\s*(\d{2})", s)
        if match:
            decade = int(match.group(1)) // 10 * 10
            return f"{decade}s"
        # Handle words like 'twenties', 'forty', etc.
        words_to_decade = {
            "teen": "10s",
            "teens": "10s",
            "twenty": "20s",
            "twenties": "20s",
            "thirty": "30s",
            "thirties": "30s",
            "forty": "40s",
            "forties": "40s",
            "fifty": "50s",
            "fifties": "50s",
            "sixty": "60s",
            "sixties": "60s",
            "seventy": "70s",
            "seventies": "70s",
            "eighty": "80s",
            "eighties": "80s",
        }
        for k, v in words_to_decade.items():
            if k in s:
                return v
        return ""  # unknown or other format

    df["age_group"] = df["age"].apply(infer_age_group)
    df.to_csv(file_path, index=False)
    print(f"Added 'age_group' to {file_path} ({len(df)} rows)")

add_age_group_column("data/val.csv")
add_age_group_column("data/test.csv")

In [None]:
def add_s3_paths(file_path: str):
    """
    Prepends full S3 URI prefix to the 'path' column
    Modifies the file in place
    """
    df = pd.read_csv(file_path)
    print(f"Processing {file_path} ({len(df)} rows)")
    if "path" in df.columns:
        def prepend_prefix(p):
            if isinstance(p, str) and not p.startswith("s3://"):
                return f"{BASE_S3_PREFIX}/{p.lstrip('/')}"
            return p
        df["path"] = df["path"].apply(prepend_prefix)
        print("Updated 'path' column with S3 prefix")

    df.to_csv(file_path, index=False)
    print(f"Saved updated CSV: {file_path}\n")
    print(f"Updated 'path' for {file_path} ({len(df)} rows)")

# Apply to val/test (train likely already correct)
add_s3_paths("data/train.csv")
add_s3_paths("data/val.csv")
add_s3_paths("data/test.csv")

### Load Whisper model

In [None]:
# Whisper model + language/task settings
WHISPER_MODEL = "openai/whisper-base"
LANGUAGE = "en"
TASK = "transcribe"

# Compute and training
GRADIENT_ACCUMULATION = 2
BATCH_SIZE_PER_DEVICE = 4
NUM_EPOCHS = 3
LEARNING_RATE = 0.0001
MAX_AUDIO_SECONDS = 30

In [None]:
# SageMaker Utils
# Detect SageMaker environment, map S3 paths

def in_sagemaker() -> bool:
    return any(k.startswith("SM_") for k in os.environ.keys()) or os.environ.get("SAGEMAKER_JOB_NAME") is not None

def s3_join(*parts: str) -> str:
    return "/".join([p.strip("/").replace("s3://","") for p in parts])


print("In SageMaker:", in_sagemaker())

In [None]:
# Data cleaning
def clean_numeric_columns(file_path: str, numeric_cols=("variant", "segment")):
    df = pd.read_csv(file_path)
    for col in numeric_cols:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors="coerce")
    df.to_csv(file_path, index=False)
    print(f"Cleaned {file_path}: non-numeric entries coerced to NaN.")


for split in ["train", "val", "test"]:
    clean_numeric_columns(f"data/{split}.csv")

In [None]:
# Data Loading
STREAMING = False
audio_col = "path"
text_col = "sentence"
if STREAMING:
    print("Loading CSVs in streaming mode...")
else:
    print("Loading CSVs into memory...")
raw_datasets = load_dataset("csv", data_files=DATA_FILES, streaming=STREAMING)
print(raw_datasets)

In [None]:
# Use half of available CPUs
print(f"CPU count: {os.cpu_count()}")
NUM_PROC = max(1, os.cpu_count() // 2)
print(f"NUM_PROC: {NUM_PROC}")

# Cache the audio files so that we don't have to stream them from S3
CACHE_DIR = "data/processed_whisper"

### Whisper Processor & Preprocessing

Prepare tensors for each audio clip to send to model's encoder
- Load audio from S3 or disk
- Resample to 16 kHz
- Compute short-time Fourier transform (STFT)
- Convert to Mel scale
- Apply log compression
- Normalize to match Whisper's training statistics

In [None]:
# Whisper Processor & Preprocessing
processor = WhisperProcessor.from_pretrained(
    WHISPER_MODEL,
    language=LANGUAGE,
    task=TASK
)
feature_extractor: WhisperFeatureExtractor = processor.feature_extractor
tokenizer: WhisperTokenizer = processor.tokenizer
MAX_INPUT_LENGTH = int(MAX_AUDIO_SECONDS * feature_extractor.sampling_rate)

### Load Preprocessed Dataset from EFS

In [None]:
ds = load_from_disk("/mnt/custom-file-systems/efs/fs-05e9c5047118066bb_fsap-0b12f9aa4c66d5d46/dave_sandbox/full_dataset")
# ds = load_from_disk("/mnt/sagemaker-nvme/datasets/filtered_5k_448tok")
# ds = load_from_disk("/mnt/custom-file-systems/efs/fs-05e9c5047118066bb_fsap-0b12f9aa4c66d5d46/dave_sandbox/filtered_15k_448tok")
print(ds.column_names)
print(ds.features)
print(len(ds))

### Ablation Study

In [None]:
csv_path = "training/ablation_study.csv"
ablation_study_df = pd.read_csv(csv_path)
ablation_study_df.columns = ablation_study_df.columns.str.strip().str.lower().str.replace(" ", "_")
ablation_study_df = ablation_study_df.loc[:, ~ablation_study_df.columns.str.contains("^unnamed")]
print(f"Total runs: {ablation_study_df.shape[0]}")
ablation_study_df

In [None]:
# Filter to "To Do"
todo_df = ablation_study_df[ablation_study_df["status"].str.lower().isin(["to do"])]
todo_df = todo_df.reset_index(drop=True)
print(f"{todo_df.shape[0]} runs ready to launch.")
todo_df[["run", "peft", "num_epochs", "learning_rate", "warmup_steps", "r", "alpha", "dropout", "target_layers", "lr_scheduler_type", "optimizer"]]

In [None]:
def row_to_config(row):
    """
    Convert one CSV row into a structured training config dict
    """
    return {
        "run_id": int(row.run),
        "peft_method": row.peft.lower(),
        "target_modules": [m.strip() for m in str(row.target_layers).split(",")],
        "learning_rate": row.learning_rate,
        "num_epochs": int(row.num_epochs),
        "batch_size": int(row.batch_size),
        "load_in_4bit": bool(row.load_in_4bit),
        "rank_r": None if pd.isna(row.r) else int(row.r) if str(row.r).isdigit() else row.r,
        "lora_alpha": row.alpha,
        "dropout": row.dropout,
        "weight_decay": row.weight_decay,
        "warmup_steps": int(row.warmup_steps),
        "optimizer": row.optimizer.lower(),
        "lr_scheduler_type": row.lr_scheduler_type.lower(),
        "notes": row.notes if pd.notna(row.notes) else "",
    }

In [None]:
def build_model(config):
    bnb_config = None
    quant_kwargs = dict(device_map="auto", torch_dtype=torch.float16)

    # Quantization setup
    if config["peft_method"] == "qlora":
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
        )
        quant_kwargs["quantization_config"] = bnb_config

    model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base", **quant_kwargs)
    # model.gradient_checkpointing_enable()

    if config["peft_method"] == "qlora":
        model = prepare_model_for_kbit_training(model)

    model.gradient_checkpointing_enable()

    # Sanitize numeric values before building PEFT config
    safe_r = int(config.get("rank_r", 16))
    safe_alpha = make_json_safe(config.get("lora_alpha", 32))
    safe_dropout = make_json_safe(config.get("dropout", 0.05))
    safe_target_modules = make_json_safe(config.get("target_modules", ["q_proj", "k_proj", "v_proj"]))

    # Define PEFT config
    if config["peft_method"] in ["lora", "qlora", "dora"]:
        peft_cfg = LoraConfig(
            r=safe_r,
            lora_alpha=safe_alpha,
            lora_dropout=safe_dropout,
            bias="none",
            target_modules=safe_target_modules,
            use_dora=config["peft_method"] == "dora",
        )

    elif config["peft_method"] == "adalora":
        peft_cfg = AdaLoraConfig(
            init_r=int(12),
            target_r=int(8),
            beta1=0.85,
            beta2=0.85,
            tinit=int(200),
            tfinal=int(1000),
            deltaT=int(10),
            lora_alpha=safe_alpha,
            lora_dropout=safe_dropout,
            target_modules=safe_target_modules,
        )

    else:
        raise ValueError(f"Unsupported PEFT method: {config['peft_method']}")

    model = get_peft_model(model, peft_cfg)
    model = model.to("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Loaded {config['peft_method']} model for run #{config['run_id']}")
    return model

In [None]:
def make_json_safe(x):
    """
    Convert numpy or pandas scalar types to native Python types
    """
    if isinstance(x, (np.integer,)):
        return int(x)
    elif isinstance(x, (np.floating,)):
        return float(x)
    elif isinstance(x, (np.bool_,)):
        return bool(x)
    elif isinstance(x, (np.ndarray, list, tuple)):
        return [make_json_safe(i) for i in x]
    return x

### Run Experiment(s)

In [None]:
N_TRAIN = 10000  # 5000, 1000
N_VAL = 2000  # 1000, 300
N_TEST = 2000  # 1000, 300
N_TO_PULL = int((N_TRAIN + N_VAL + N_TEST) * 1.1)

In [None]:
# Apply fast, local filter BEFORE selecting slices
MAX_LABEL_TOKENS = 448
def is_short(example):
    return len(example["labels"]) <= MAX_LABEL_TOKENS

In [None]:
# Shuffle once
oversample_ds = ds.shuffle(seed=5678).select(range(N_TO_PULL))
print(f"Fetched {len(oversample_ds)} clips. Need to check if they are under token limit...")

sample_ds = oversample_ds.filter(lambda x: len(x["labels"]) <= MAX_LABEL_TOKENS, num_proc=1)
print(f"Found {len(sample_ds)} valid clips out of {len(oversample_ds)} total")

In [None]:
# CACHE_DIR_9K = "/mnt/sagemaker-nvme/datasets/filtered_5k_448tok"
# os.makedirs(CACHE_DIR, exist_ok=True)

# CACHE_DIR_9K = "/mnt/custom-file-systems/efs/fs-05e9c5047118066bb_fsap-0b12f9aa4c66d5d46/dave_sandbox/filtered_9k_448tok"
# os.makedirs(CACHE_DIR, exist_ok=True)

# sample_ds.save_to_disk(CACHE_DIR_9K)
# print(f"Saved 9K filtered dataset to {CACHE_DIR_9K}")

# CACHE_DIR_15K = "/mnt/custom-file-systems/efs/fs-05e9c5047118066bb_fsap-0b12f9aa4c66d5d46/dave_sandbox/filtered_15k_448tok"
# os.makedirs(CACHE_DIR, exist_ok=True)

# sample_ds.save_to_disk(CACHE_DIR_15K)
# print(f"Saved 15K filtered dataset to {CACHE_DIR_15K}")

In [None]:
# Select a clean subset and split into train/val/test
sample_ds_train = sample_ds.select(range(min(N_TRAIN, len(sample_ds))))
sample_ds_val = sample_ds.select(range(N_TRAIN, min(N_TRAIN + N_VAL, len(sample_ds))))
sample_ds_test = sample_ds.select(range(N_TRAIN + N_VAL, min(N_TRAIN + N_VAL + N_TEST, len(sample_ds))))
proc_datasets = DatasetDict({
    "train": sample_ds_train,
    "val": sample_ds_val,
    "test": sample_ds_test,
})
print(f"Train: {len(sample_ds_train)} | Val: {len(sample_ds_val)}")
print(proc_datasets["train"])
print(proc_datasets["val"])
print(proc_datasets["test"])

#### Run experiment

In [None]:
### Run one experiment
# row_index = 0
# PROJECT_NAME = "dave/20251028_qlora_run1"

# row_index = 1
# PROJECT_NAME = "dave/20251029_dora_run2"

# row_index = 2
# PROJECT_NAME = "dave/20251029_adalora_run3"

# row_index = 4
# PROJECT_NAME = "dave/20251029_qlora_run5"

# row_index = 5
# PROJECT_NAME = "dave/20251029_qlora_run5"

# row_index = 7
# PROJECT_NAME = "dave/20251029_qlora_run7"

# row_index = 7
# PROJECT_NAME = "dave/20251029_qlora_run7_real7_previous_was6"

# row_index = 8
# PROJECT_NAME = "dave/20251029_dora_run9"

# row_index = 9
# PROJECT_NAME = "dave/20251029_dora_run10"

# row_index = 10
# PROJECT_NAME = "dave/20251031_qlora_run11"

# row_index = 11
# PROJECT_NAME = "dave/20251031_dora_run12"

# row_index = 12
# PROJECT_NAME = "dave/20251031_qlora_run13"

row_index = 13
PROJECT_NAME = "dave/20251031_qlora_run14"

row_config = row_to_config(todo_df.loc[row_index, :])
print("row_config:", row_config)
model = build_model(row_config)

print("param dtype:", next(model.parameters()).dtype)  # expect torch.float16
print("has 4bit modules:", any(isinstance(m, bnb.nn.Linear4bit) for m in model.modules()))

In [None]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    if isinstance(pred_ids, tuple):
        pred_ids = pred_ids[0]
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids = pred.label_ids
    label_ids[label_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    return {"wer": wer(label_str, pred_str)}

In [None]:
def data_collator(features):
    # Inspect model dtype dynamically from the encoder’s first layer
    try:
        model_dtype = next(model.parameters()).dtype
    except Exception:
        model_dtype = torch.float32  # fallback

    # Stack audio features
    input_features = torch.stack([
        torch.tensor(f["input_features"], dtype=model_dtype)
        for f in features
    ])

    # Pad and process labels
    label_batch = [f["labels"] for f in features]
    labels = tokenizer.pad(
        {"input_ids": label_batch},
        padding=True,
        return_tensors="pt"
    ).input_ids
    labels[labels == tokenizer.pad_token_id] = -100

    return {"input_features": input_features, "labels": labels}

In [None]:
def prepare_args(cfg, PROJECT_NAME, MIXED_PRECISION="fp16"):
    """
    Create Seq2SeqTrainingArguments from cfg dict safely,
    restoring all essential Whisper fine-tuning defaults.
    """
    valid_keys = inspect.signature(Seq2SeqTrainingArguments).parameters.keys()
    filtered = {k: v for k, v in cfg.items() if k in valid_keys}

    peft_method = cfg.get("peft_method", "").lower()
    load_in_4bit = cfg.get("load_in_4bit", False)

    # Core run settings
    filtered["output_dir"] = f"./outputs/{PROJECT_NAME}"
    filtered["overwrite_output_dir"] = True

    # Batching
    filtered["per_device_train_batch_size"] = cfg.get("batch_size", 4)
    filtered["per_device_eval_batch_size"] = cfg.get("batch_size", 4)
    filtered["gradient_accumulation_steps"] = 4
    filtered["dataloader_num_workers"] = 2

    # Training loop
    filtered["num_train_epochs"] = cfg.get("num_epochs", 5)
    filtered["learning_rate"] = cfg.get("learning_rate", 5e-5)
    filtered["weight_decay"] = cfg.get("weight_decay", 0.01)
    filtered["lr_scheduler_type"] = cfg.get("lr_scheduler_type", "linear")
    filtered["warmup_steps"] = cfg.get("warmup_steps", 0)

    # Evaluation / saving
    filtered["evaluation_strategy"] = "epoch"
    filtered["save_strategy"] = "epoch"
    filtered["eval_accumulation_steps"] = None
    filtered["eval_delay"] = 0
    filtered["load_best_model_at_end"] = True
    filtered["metric_for_best_model"] = "wer"
    filtered["greater_is_better"] = False
    filtered["save_total_limit"] = 3

    # Logging
    # NOTE: Need to set logging_strategy == epoch in order to get training loss
    filtered["logging_strategy"] = "epoch"
    filtered["report_to"] = ["tensorboard"]
    filtered["logging_dir"] = f"./logs/{PROJECT_NAME}"
    filtered["disable_tqdm"] = False

    # Precision / optimization
    # Setting fp16 to True for qlora since on Tesla T4
    # Setting both fp16, bf16 to False forces plain float32
    filtered["fp16"] = True
    filtered["bf16"] = False
    filtered["gradient_checkpointing"] = False  # Checkpointing re-computes every forward pass on the backward step
    filtered["predict_with_generate"] = True  # Need True in order to calc validation WER during training

    # Determinism / reproducibility
    filtered["seed"] = 5678
    filtered["dataloader_pin_memory"] = True

    # Handle for 4-bit quantized QLoRA training
    if peft_method == "qlora" and load_in_4bit:
        filtered["optim"] = "adamw_bnb_8bit"  
    else:
        filtered["optim"] = "adamw_torch"

    # Adding on 11/1
    if cfg.get("peft_method", "").lower() == "dora":
        print("[DoRA detected] Disabling gradient clipping and AMP to avoid FP16 unscale errors.")
        filtered["max_grad_norm"] = 0.0
        filtered["fp16"] = False 
        filtered["bf16"] = False
    
    # Sanitize numeric types for JSON serialization safety
    cfg = {k: make_json_safe(v) for k, v in filtered.items()}

    return Seq2SeqTrainingArguments(**cfg)

In [None]:
train_data = proc_datasets.get("train") or proc_datasets["train"]
eval_data = proc_datasets.get("val") or proc_datasets.get("validation") or None
print(f"Train samples: {len(train_data)}")
print(f"Eval samples: {len(eval_data) if eval_data else 0}")
print(f"Device: {model.device}")

In [None]:
def make_bnb_optimizer_and_scheduler(model, args):
    """Explicitly create the bitsandbytes optimizer"""
    optimizer = AdamW8bit(
        model.parameters(),
        lr=args.learning_rate,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=args.weight_decay,
    )
    scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=args.max_steps,
    )
    return optimizer, scheduler

In [None]:
# Prepare trainer
args = prepare_args(row_config, PROJECT_NAME)
optimizer, lr_scheduler = make_bnb_optimizer_and_scheduler(model, args)
trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=train_data,
    eval_dataset=eval_data,
    tokenizer=processor.feature_extractor,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    optimizers=(optimizer, lr_scheduler),
)
trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=1e-4))

print("Using optimizer:", trainer.args.optim)
print(type(trainer.optimizer))

In [None]:
# Fix for bitsandbytes optimizer missing .train() and .eval()
if hasattr(trainer, "optimizer"):
    def _noop_train(*args, **kwargs):
        return None
    def _noop_eval(*args, **kwargs):
        return None
    opt = trainer.optimizer
    if not hasattr(opt, "train"):
        opt.train = _noop_train
    if not hasattr(opt, "eval"):
        opt.eval = _noop_eval
    if hasattr(opt, "optimizer"):
        inner_opt = opt.optimizer
        if not hasattr(inner_opt, "train"):
            inner_opt.train = _noop_train
        if not hasattr(inner_opt, "eval"):
            inner_opt.eval = _noop_eval

In [None]:
print("Starting training...")
train_result = trainer.train()
print(train_result)

# Save LoRA adapter (PEFT) weights (usually to ./results or ./output_dir)
trainer.save_model()

# Save tokenizer and processor
SAVE_DIR = f"{PROJECT_NAME}/for_hf"
processor.feature_extractor.save_pretrained(SAVE_DIR)
processor.save_pretrained(SAVE_DIR)

In [None]:
# Save a fully merged model, loadable with pipeline() without PeFT
MERGED_SAVE_DIR = f"{PROJECT_NAME}/merged"

# Merge LoRA → base weights
merged_model = model.merge_and_unload()

# Save merged model in the same structure
merged_model.save_pretrained(MERGED_SAVE_DIR)
processor.feature_extractor.save_pretrained(MERGED_SAVE_DIR)
processor.save_pretrained(MERGED_SAVE_DIR)

In [None]:
### Chart loss + WER
logs_df = pd.DataFrame(trainer.state.log_history)
train_logs = logs_df.dropna(subset=["loss"])
eval_logs = logs_df.dropna(subset=["eval_loss"])

# Create figure and axes
fig, ax1 = plt.subplots(figsize=(10, 6))

# Left axis: Losses
ax1.plot(train_logs["epoch"], train_logs["loss"], label="Training Loss", color="tab:blue", linewidth=2)
ax1.plot(eval_logs["epoch"], eval_logs["eval_loss"], label="Validation Loss", color="tab:orange", linewidth=2, marker="o")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.tick_params(axis="y")
# ax1.set_ylim(bottom=-0.10)

# Right axis: WER
ax2 = ax1.twinx()
if "eval_wer" in eval_logs.columns:
    ax2.plot(eval_logs["epoch"], eval_logs["eval_wer"], label="Validation WER", color="tab:red", linewidth=2, marker="x", linestyle="--")
    ax2.set_ylabel("Word Error Rate (WER)", rotation=270, labelpad=15)
    ax2.tick_params(axis="y")

# Shared legend & grid
plt.title("Training and Validation Metrics over Epochs", fontsize=14, fontweight="bold", pad=40)
ax1.grid(True, which="both", linestyle="--", alpha=0.6)

lines, labels = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
fig.legend(
    lines + lines2,
    labels + labels2,
    loc="upper center",
    ncol=3,
    bbox_to_anchor=(0.5, 0.88),
    fontsize=11,
    frameon=False,
)

# Combine legends from both axes
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

### Evaluation and Save Results

In [None]:
# Evaluate on Test Split
test_metrics = {}
if "test" in proc_datasets:
    test_metrics = trainer.evaluate(proc_datasets["test"], metric_key_prefix="test")
    print(test_metrics)
else:
    print("No test split available in proc_datasets. Skipping.")

# Save metrics
os.makedirs("metrics", exist_ok=True)
now = dt.datetime.utcnow().strftime("%Y-%m-%d%H:%M:%S")
fp = f"metrics/results_{now}.json"
with open(fp, "w") as f:
    json.dump({"eval": trainer.state.log_history, "test": test_metrics}, f, indent=2)
print(f"Saved metrics to {fp}")

### Push finetuned to Huggingface

In [None]:
from huggingface_hub import HfApi, login, whoami

login(token="FILL_ME_IN")

In [None]:
api = HfApi()

In [None]:
# TO_PUSH_DIR = "dave/20251031_qlora_run13/merged"
TO_PUSH_DIR = "dave/20251031_qlora_run14/merged"
api.upload_folder(
    folder_path=TO_PUSH_DIR,
    repo_id="MIDS-Choate-Kuruppu-Russell/aging-in-place-whisper",
    repo_type="model",
)