<a href="https://colab.research.google.com/github/KelvinM9187/Supervised-Speech-Recognition-with-Transformers/blob/main/speech_recognition_twi_lora.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Install required dependencies
!pip install -q --upgrade pip
!pip install -q datasets transformers accelerate bitsandbytes peft evaluate jiwer soundfile torchaudio librosa

from huggingface_hub import notebook_login
notebook_login()

In [None]:
import math
from pathlib import Path
import random
import time
from tqdm import tqdm
from collections import defaultdict
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torchaudio
import librosa
from datasets import load_dataset, Dataset, Audio, DatasetDict
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)

from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
PREFIX_CHECKPOINT_DIR = "checkpoint"

from peft import LoraConfig, get_peft_model, PeftModel, prepare_model_for_kbit_training

In [None]:
MODEL_NAME = "openai/whisper-base"
OUTPUT_DIR = "/content/drive/MyDrive/whisper_twi_checkpoints"
TARGET_HOURS = 4.0
MAX_SAMPLES = 8000
SAMPLE_RATE = 16000

# Data split ratios
TRAIN_RATIO = 0.8
VAL_RATIO = 0.1
TEST_RATIO = 0.1

# Training hyperparams
PER_DEVICE_BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 8
LEARNING_RATE = 3e-4
NUM_TRAIN_EPOCHS = 3
SAVE_STEPS = 200
EVAL_STEPS = 200
LOGGING_STEPS = 50
FP16 = True

# LoRA config
LORA_R = 8
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
TARGET_MODULES = ["q_proj", "v_proj"]

SEED = 42

os.makedirs(OUTPUT_DIR, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# reproducibility
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if device == "cuda":
    torch.cuda.manual_seed_all(SEED)

In [None]:
# Load Twi dataset

print("Loading Twi dataset from Common Voice...")

try:
    # Try loading Common Voice with Twi language
    dataset = load_dataset("mozilla-foundation/common_voice_16_1", "tw")
    print("Successfully loaded Common Voice Twi dataset!")

    # Get the splits
    train_data = dataset['train']
    val_data = dataset['validation']
    test_data = dataset['test']

    print(f"Dataset sizes - Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")

    # Create data dictionary
    data_dict = DatasetDict({
        "train": train_data,
        "validation": val_data,
        "test": test_data
    })


except Exception as e:
    print(f"Could not load Common Voice Twi: {e}")
    print("Trying alternative approach with FLEURS dataset...")

    try:
        # Try FLEURS dataset which contains Twi
        dataset = load_dataset("google/fleurs", "tw_gh")
        print("Successfully loaded FLEURS Twi dataset!")

        train_data = dataset['train']
        val_data = dataset['validation']
        test_data = dataset['test']


        data_dict = DatasetDict({
            "train": train_data,
            "validation": val_data,
            "test": test_data
        })

        print(f"Loaded FLEURS Twi: {len(train_data)} train, {len(val_data)} val, {len(test_data)} test")


    except Exception as e2:
        print(f"Could not load FLEURS Twi: {e2}")
        print("Using LibriSpeech as fallback for testing...")

        # Fallback to LibriSpeech
        dataset = load_dataset("librispeech_asr", "clean", split="train+validation+test")

        # Split manually
        n = len(dataset)
        n_train = int(n * TRAIN_RATIO)
        n_val = int(n * VAL_RATIO)

        train_data = dataset.select(range(0, n_train))
        val_data = dataset.select(range(n_train, n_train + n_val))
        test_data = dataset.select(range(n_train + n_val, n))

        data_dict = DatasetDict({
            "train": train_data,
            "validation": val_data,
            "test": test_data
        })

        print(f"Using LibriSpeech fallback: {len(train_data)} train, {len(val_data)} val, {len(test_data)} test")

# Cast audio to fixed sampling rate for Whisper
print("Converting audio to 16kHz...")
data_dict = data_dict.cast_column("audio", Audio(sampling_rate=SAMPLE_RATE))

In [None]:
# Load model & processor
print("Loading processor and model...")
processor = WhisperProcessor.from_pretrained(MODEL_NAME, task="transcribe")
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME)
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

In [None]:
# Preprocess function
MAX_DURATION_IN_SECONDS = 30.0
MAX_INPUT_LENGTH = int(MAX_DURATION_IN_SECONDS * SAMPLE_RATE)
MAX_LABEL_LENGTH = 448

def prepare_example(batch):
    try:
        # audio is dataset.Audio with 'array' and 'sampling_rate'
        audio = batch["audio"]
        array = audio["array"]
        sr = audio["sampling_rate"]

        # Handle cases where audio might be None
        if array is None:
            # Create silent audio as fallback
            array = np.zeros(int(SAMPLE_RATE * 1.0))  # 1 second of silence
            sr = SAMPLE_RATE

        # Ensure audio is mono and correct length
        if len(array.shape) > 1:
            array = array.mean(axis=0)  # Convert to mono

        # feature extractor
        inputs = processor.feature_extractor(array, sampling_rate=sr)
        # extract first (single) example features
        batch["input_features"] = inputs.input_features[0]

        # Get transcription text
        transcription = batch.get("sentence") or batch.get("transcription") or batch.get("text") or ""

        # labels
        batch["labels"] = processor.tokenizer(transcription).input_ids

        # lengths for filtering
        batch["input_length"] = len(array)
        batch["labels_length"] = len(processor.tokenizer(transcription, add_special_tokens=False).input_ids)
        return batch
    except Exception as e:
        print(f"Error processing example: {e}")
        # Return None to be filtered out
        return None

# Apply mapping with error handling
print("Preparing dataset features...")

def safe_prepare_example(example):
    try:
        return prepare_example(example)
    except Exception as e:
        print(f"Skipping example due to error: {e}")
        return None

# Filter out None results
for split in ["train", "validation", "test"]:
    print(f"Processing {split} split...")
    current_columns = data_dict[split].column_names
    data_dict[split] = data_dict[split].map(
        safe_prepare_example,
        remove_columns=current_columns,
        num_proc=2
    )
    # Remove None results
    original_len = len(data_dict[split])
    data_dict[split] = data_dict[split].filter(lambda x: x is not None)
    new_len = len(data_dict[split])
    print(f"{split}: {new_len}/{original_len} examples processed successfully")

# Filter by length
def keep_example(example):
    il = example.get("input_length", 0)
    ll = example.get("labels_length", 0)
    return (0 < il < MAX_INPUT_LENGTH) and (ll < MAX_LABEL_LENGTH)

for split in ["train", "validation", "test"]:
    original_len = len(data_dict[split])
    data_dict[split] = data_dict[split].filter(keep_example)
    filtered_len = len(data_dict[split])
    print(f"{split}: {filtered_len}/{original_len} after length filtering")

# Remove temporary columns
for split in ["train", "validation", "test"]:
    data_dict[split] = data_dict[split].remove_columns(["input_length", "labels_length"])

In [None]:
# Data collator for Whisper seq2seq
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import torch

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # pad input features (feature_extractor)
        input_features = [{"input_features": f["input_features"]} for f in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # pad labels
        label_features = [{"input_ids": f["labels"]} for f in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding token id's of the labels by -100 so it's ignored by the loss
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # drop BOS
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [None]:
# Setup LoRA (PEFT) and prepare model for training
print("Preparing model for k-bit training and applying LoRA (PEFT)...")
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=TARGET_MODULES,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="SEQ_2_SEQ_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
model.to(device)

In [None]:
# Training arguments & Trainer
output_dir = OUTPUT_DIR
training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    evaluation_strategy="steps",
    eval_steps=EVAL_STEPS,
    logging_steps=LOGGING_STEPS,
    save_steps=SAVE_STEPS,
    save_total_limit=5,
    learning_rate=LEARNING_RATE,
    fp16=FP16,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    predict_with_generate=True,
    remove_unused_columns=False,
    push_to_hub=False,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
)

# Callback to save only adapter weights
class SavePeftModelCallback(TrainerCallback):
    def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
        peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
        kwargs["model"].save_pretrained(peft_model_path)
        pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
        if os.path.exists(pytorch_model_path):
            os.remove(pytorch_model_path)
        return control

# Metrics: WER
import evaluate
metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    wer = metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=data_dict["train"],
    eval_dataset=data_dict["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[SavePeftModelCallback()],
    tokenizer=processor.feature_extractor
)

In [None]:
# Checkpoint for Resuming Training Sessions
def find_latest_checkpoint(output_dir):
    ckpts = [d for d in os.listdir(output_dir) if d.startswith(PREFIX_CHECKPOINT_DIR)]
    if not ckpts:
        return None
    ckpts_sorted = sorted(ckpts, key=lambda x: int(x.split("-")[-1]))
    return os.path.join(output_dir, ckpts_sorted[-1])

latest_ckpt = None
if os.path.exists(output_dir):
    latest_ckpt = find_latest_checkpoint(output_dir)
if latest_ckpt:
    print("Found checkpoint:", latest_ckpt)
else:
    print("No checkpoint found. Starting fresh training.")

In [None]:
# Training Session
print("Starting training...")
train_result = trainer.train(resume_from_checkpoint=latest_ckpt if latest_ckpt else None)

# Save final PEFT adapter & processor
print("Saving adapter and processor to:", output_dir)
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)

# Save training state & metrics
trainer.save_state()
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)

# Evaluate on test set
print("Evaluating on test set...")
eval_res = trainer.evaluate(eval_dataset=data_dict["test"], max_length=256, num_beams=1)
print("Test evaluation:", eval_res)
trainer.log_metrics("eval_test", eval_res)
trainer.save_metrics("eval_test", eval_res)

# Plotting loss & WER history
log_hist = trainer.state.log_history
steps = []
losses = []
eval_steps = []
eval_wers = []
for entry in log_hist:
    if "loss" in entry:
        steps.append(entry.get("step", None))
        losses.append(entry["loss"])
    if "eval_wer" in entry:
        eval_steps.append(entry.get("step", None))
        eval_wers.append(entry["eval_wer"])

# Plot training loss
plt.figure(figsize=(8,4))
if len(steps) > 0:
    plt.plot(steps, losses, marker="o")
    plt.title("Training loss")
    plt.xlabel("Step")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.show()
else:
    print("No training loss logs to plot.")

# Plot eval WER curve
plt.figure(figsize=(8,4))
if len(eval_steps) > 0:
    plt.plot(eval_steps, eval_wers, marker="o")
    plt.title("Validation WER (lower is better)")
    plt.xlabel("Step")
    plt.ylabel("WER")
    plt.grid(True)
    plt.show()
else:
    print("No eval WER logs to plot.")

# Minimal inference example
from transformers import pipeline

print("Loading final PEFT model for inference...")
peft_model = PeftModel.from_pretrained(model, OUTPUT_DIR)
peft_model.to(device)
inference_pipe = pipeline(
    "automatic-speech-recognition",
    model=peft_model,
    feature_extractor=processor.feature_extractor,
    tokenizer=processor.tokenizer,
    device=0 if device == "cuda" else -1
)

# Test inference
if len(data_dict["test"]) > 0:
    example_audio = data_dict["test"][0]["audio"]["array"]
    # Get transcription from available fields
    transcription = data_dict["test"][0].get("sentence") or data_dict["test"][0].get("transcription") or data_dict["test"][0].get("text") or "No transcription available"
    print("Example ground truth:", transcription)
    print("Model transcription:", inference_pipe(example_audio)["text"])
else:
    print("No test samples available for inference demo")

print("All done. Checkpoints and final adapter saved to:", OUTPUT_DIR)