# LoRA Fine-tuning — Whisper Small on TORGO

Fine-tunes Whisper Small on TORGO using LoRA. Config from `config.yaml`. Notebook lives in `asr/`.

In [7]:
import importlib, accelerate
importlib.reload(accelerate)

import numpy as np
import yaml, torch, evaluate
from pathlib import Path
from datasets import load_from_disk
from transformers import WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer, EarlyStoppingCallback
from peft import LoraConfig, get_peft_model, TaskType
from torch.nn.utils.rnn import pad_sequence

In [2]:
PROJECT_ROOT = Path.cwd().parent
with open(Path.cwd() / "config.yaml") as f:
    cfg = yaml.safe_load(f)
print(cfg["model"]["name"])

openai/whisper-small


In [3]:
dataset = load_from_disk(str(PROJECT_ROOT / cfg["data"]["dataset_path"]))
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['audio', 'transcription', 'speech_status', 'gender', 'duration'],
        num_rows: 14896
    })
    test: Dataset({
        features: ['audio', 'transcription', 'speech_status', 'gender', 'duration'],
        num_rows: 1656
    })
})


In [4]:
processor = WhisperProcessor.from_pretrained(cfg["model"]["name"], language=cfg["model"]["language"], task=cfg["model"]["task"])
model = WhisperForConditionalGeneration.from_pretrained(cfg["model"]["name"])
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

Loading weights: 100%|██████████| 479/479 [00:00<00:00, 801.95it/s, Materializing param=model.encoder.layers.11.self_attn_layer_norm.weight]   


In [5]:
model.config.use_cache = False
model.gradient_checkpointing_enable()

lora = cfg["lora"]
model = get_peft_model(model, LoraConfig(r=lora["r"], lora_alpha=lora["alpha"], lora_dropout=lora["dropout"], target_modules=lora["target_modules"], bias="none", task_type=TaskType.SEQ_2_SEQ_LM))
model.print_trainable_parameters()

trainable params: 1,769,472 || all params: 243,504,384 || trainable%: 0.7267


In [6]:
sr = cfg["data"]["sampling_rate"]
max_samples = int(cfg["data"]["max_audio_length"] * sr)

def prepare(batch):
    audio = batch["audio"]
    array = np.asarray(audio["array"], dtype=np.float32) if isinstance(audio, dict) else np.asarray(audio, dtype=np.float32)
    array = array[:max_samples]
    padded = np.zeros(max_samples, dtype=np.float32)
    padded[:len(array)] = array
    batch["input_features"] = processor.feature_extractor(padded, sampling_rate=sr, return_tensors="np").input_features[0]
    batch["labels"] = processor.tokenizer(batch["transcription"]).input_ids
    return batch

dataset = dataset.map(prepare, remove_columns=dataset["train"].column_names, desc="Preparing")
print(dataset)

Preparing: 100%|██████████| 14896/14896 [01:26<00:00, 171.72 examples/s]
Preparing: 100%|██████████| 1656/1656 [00:08<00:00, 204.96 examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 14896
    })
    test: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 1656
    })
})





In [8]:
pad_id = processor.tokenizer.pad_token_id

def collate(batch):
    feats = torch.stack([torch.tensor(x["input_features"]) for x in batch])
    labels = [torch.tensor(x["labels"]) for x in batch]
    labels = pad_sequence(labels, batch_first=True, padding_value=-100)
    return {"input_features": feats, "labels": labels}

In [9]:
wer_metric = evaluate.load("wer")

def compute_metrics(pred):
    labels = np.where(pred.label_ids != -100, pred.label_ids, pad_id)
    pred_str = processor.tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)
    return {"wer": wer_metric.compute(predictions=pred_str, references=label_str)}

In [10]:
tc = cfg["training"]
model_dir = str(PROJECT_ROOT / cfg["output"]["model_dir"])
log_dir = str(PROJECT_ROOT / cfg["output"]["log_dir"])
Path(model_dir).mkdir(parents=True, exist_ok=True)
Path(log_dir).mkdir(parents=True, exist_ok=True)

args = Seq2SeqTrainingArguments(
    output_dir=model_dir,
    remove_unused_columns=False,
    per_device_train_batch_size=tc["batch_size"],
    per_device_eval_batch_size=tc["batch_size"],
    num_train_epochs=tc["epochs"],
    learning_rate=tc["learning_rate"],
    warmup_steps=tc["warmup_steps"],
    fp16=tc["fp16"],
    eval_strategy="steps",
    eval_steps=tc["eval_steps"],
    save_steps=tc["save_steps"],
    save_total_limit=2,
    gradient_accumulation_steps=tc["gradient_accumulation_steps"],
    logging_dir=log_dir,
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model=tc["metric_for_best_model"],
    greater_is_better=tc["greater_is_better"],
)

`logging_dir` is deprecated and will be removed in v5.2. Please set `TENSORBOARD_LOGGING_DIR` instead.


In [11]:
trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    data_collator=collate,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=tc["early_stopping_patience"])] if tc.get("early_stopping_patience") else [],
)

In [12]:
trainer.train()

  super().__init__(loader)


RuntimeError: MPS backend out of memory (MPS allocated: 18.48 GiB, other allocations: 870.78 MiB, max allowed: 20.13 GiB). Tried to allocate 823.97 MiB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [None]:
trainer.save_model(model_dir)
processor.save_pretrained(model_dir)
print(f"Saved to {model_dir}")