# WHISPER TRANSFORMERS FOR SPEECH RECOGNITION

Whisper Speech Recognition for Low-Resource Languages. Fine-tune OpenAI's Whisper on AfriSpeech-200 (Twi language).

## SETUP

In [None]:
# Install dependencies
!pip install -q --upgrade --no-cache-dir librosa==0.9.2 --no-deps

!pip install soundfile audioread numpy scipy resampy datasets==3.0.0 jiwer evaluate

# %pip install -q "transformers[audio]" datasets==3.0.0 torchaudio librosa soundfile jiwer accelerate evaluate einops sentencepiece

## IMPORT REQUIRED LIBRARIES

In [2]:
import os
import torch
import shutil
import librosa
import warnings
import torchaudio
import numpy as np
import seaborn as sns
import torch.nn as nn
import types
from functools import wraps
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Dict, List, Union, Optional, Tuple


from datasets import load_dataset, Dataset
from transformers import (
    WhisperForConditionalGeneration,
    WhisperProcessor,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
from transformers.utils import logging as transformers_logging
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
from jiwer import wer, cer
from tqdm.auto import tqdm
import evaluate

## SUPPRESS ALL WARNINGS

In [3]:
# Suppress Python warnings
warnings.filterwarnings('ignore')

# Suppress transformers logging
transformers_logging.set_verbosity_error()

# Suppress specific warnings that still appear
os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

# Suppress specific warning patterns
for pattern in [
    '.*fast tokenizer.*',
    '.*attention mask.*',
    '.*forced_decoder_ids.*',
    '.*Moving the following attributes.*',
    '.*missing keys.*',
    '.*Transcription using.*',
    '.*deprecated.*'
]:
    warnings.filterwarnings('ignore', message=pattern)

## CONFIGURATION

In [4]:
@dataclass
class Config:
    """Central configuration for the training pipeline"""
    # System settings
    use_lora: bool = True

    # Model settings
    model_name: str = "openai/whisper-tiny"
    language_code: str = "tw"  # ISO code for Twi
    task: str = "transcribe"

    # LoRA settings - defined as fields, values will be set in __post_init__
    lora_r: Optional[int] = None
    lora_alpha: Optional[int] = None
    lora_dropout: Optional[float] = None
    lora_target_modules: Optional[List[str]] = None

    # Dataset settings
    dataset_name: str = "intronhealth/afrispeech-200"
    language: str = "twi"
    max_train_samples: Optional[int] = None
    max_val_samples: Optional[int] = None

    # Training settings
    output_dir: str = "/content/whisper-afrispeech"
    num_epochs: int = 3
    batch_size: int = 2
    gradient_accumulation_steps: int = 4
    learning_rate: float = 1e-3  # Higher LR for LoRA
    warmup_steps: int = 200

    # Audio settings
    target_sample_rate: int = 16000
    max_audio_len: int = 30  # seconds

    # Evaluation settings
    generation_max_length: int = 225

    use_fp16: bool = True
    num_workers: int = 2
    seed: int = 42

    def __post_init__(self):
        """Set random seeds for reproducibility and LoRA specific settings"""
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)

        if self.use_lora:
            if self.lora_r is None:
                self.lora_r = 32  # LoRA rank
            if self.lora_alpha is None:
                self.lora_alpha = 64  # LoRA alpha (scaling factor)
            if self.lora_dropout is None:
                self.lora_dropout = 0.05
            # Set default LoRA target modules for Whisper if not provided
            if self.lora_target_modules is None:
                self.lora_target_modules = [
                    "q_proj", "v_proj",  # Attention layers
                    "k_proj", "out_proj",  # More attention layers
                    "fc1", "fc2"  # Feed-forward layers
                ]
        else:
            # Ensure LoRA-related fields are None if LoRA is not used
            self.lora_r = None
            self.lora_alpha = None
            self.lora_dropout = None
            self.lora_target_modules = None

        # Disable Weights & Biases (wandb) logging to prevent API prompts
        os.environ["WANDB_DISABLED"] = "true"
        os.environ["WANDB_SILENT"] = "true"
        os.environ["WANDB_MODE"] = "offline"


root_dir = "/content/whisper-afrispeech"
if os.path.exists(root_dir):
    shutil.rmtree(root_dir)

# Initialize configuration
config = Config(output_dir=root_dir, use_lora=False) # Disabled LoRA (use_lora=False)

## DATASET LOADER

In [5]:
class DatasetLoader:
    """Load and prepare AfriSpeech-200 dataset"""

    def __init__(self, config: Config):
        self.config = config

    def load_data(self) -> Tuple[Dataset, Dataset, Dataset]:
        """
        Load train/validation/test splits from HuggingFace

        Returns:
            Tuple of (train_dataset, val_dataset, test_dataset)
        """
        print(f"Loading {self.config.dataset_name} ({self.config.language})...")

        dataset = load_dataset(
            self.config.dataset_name,
            self.config.language,
            trust_remote_code=True
        )

        # Verify splits exist
        required_splits = ['train', 'validation', 'test']
        missing = [s for s in required_splits if s not in dataset]
        if missing:
            raise ValueError(f"Missing splits: {missing}")

        train_data = dataset["train"]
        val_data = dataset["validation"]
        test_data = dataset["test"]

        # Limit samples if specified
        if self.config.max_train_samples:
            train_data = train_data.select(range(min(self.config.max_train_samples, len(train_data))))
        if self.config.max_val_samples:
            val_data = val_data.select(range(min(self.config.max_val_samples, len(val_data))))

        print(f"✓ Train: {len(train_data):,} samples")
        print(f"✓ Val: {len(val_data):,} samples")
        print(f"✓ Test: {len(test_data):,} samples")

        return train_data, val_data, test_data

In [6]:
# Load datasets
loader = DatasetLoader(config)
train_data, val_data, test_data = loader.load_data()

Loading intronhealth/afrispeech-200 (twi)...


afrispeech-200.py: 0.00B [00:00, ?B/s]

README.md: 0.00B [00:00, ?B/s]

accent_stats.py: 0.00B [00:00, ?B/s]

audio/twi/train/train_twi_0.tar.gz:   0%|          | 0.00/1.10G [00:00<?, ?B/s]

audio/twi/dev/dev_twi_0.tar.gz:   0%|          | 0.00/152M [00:00<?, ?B/s]

audio/twi/test/test_twi_0.tar.gz:   0%|          | 0.00/45.9M [00:00<?, ?B/s]

transcripts/twi/train.csv:   0%|          | 0.00/427k [00:00<?, ?B/s]

transcripts/twi/dev.csv:   0%|          | 0.00/61.0k [00:00<?, ?B/s]

transcripts/twi/test.csv:   0%|          | 0.00/19.2k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]


Reading metadata...: 1315it [00:00, 41279.43it/s]


Generating validation split: 0 examples [00:00, ? examples/s]


Reading metadata...: 186it [00:00, 28110.13it/s]


Generating test split: 0 examples [00:00, ? examples/s]


Reading metadata...: 58it [00:00, 22623.42it/s]


✓ Train: 1,315 samples
✓ Val: 186 samples
✓ Test: 58 samples


## DATA COLLATOR

In [7]:
class DataCollatorSpeechSeq2SeqWithPadding:
    """
    Custom data collator for Whisper that handles:
    - Variable-length audio sequences
    - Padding for both audio features and text labels
    """

    def __init__(self, processor, padding: bool = True, return_tensors: str = "pt"):
        self.processor = processor
        self.padding = padding
        self.return_tensors = return_tensors

    def __call__(self, features: List[Dict[str, Union[torch.Tensor, np.ndarray]]]) -> Dict[str, torch.Tensor]:
        """
        Process a batch of features and return what Whisper needs.

        Returns:
            Dictionary with:
            - 'input_features': Audio features [batch, n_mels, time]
            - 'labels': Token IDs with -100 for padding
        """
        # Process input features (audio)
        input_features_list = self._process_audio_features(features)

        # Pad audio features
        padded_features = self._pad_audio_features(input_features_list)

        # Process and pad labels (text)
        padded_labels = self._process_labels(features)

        # Return batch
        batch = {
            "input_features": padded_features,
            "labels": padded_labels
        }

        return batch

    def _process_audio_features(self, features: List[Dict]) -> List[torch.Tensor]:
        """Convert audio features to tensors"""
        processed = []
        for feature in features:
            feat = feature["input_features"]
            if isinstance(feat, np.ndarray):
                feat = torch.tensor(feat, dtype=torch.float32)
            elif not isinstance(feat, torch.Tensor):
                feat = torch.tensor(feat, dtype=torch.float32)
            processed.append(feat)
        return processed

    def _pad_audio_features(self, features: List[torch.Tensor]) -> torch.Tensor:
        """Pad audio features to same length"""
        max_length = max(feat.shape[-1] for feat in features)
        padded = []

        for feat in features:
            if feat.shape[-1] < max_length:
                padding = torch.zeros(feat.shape[0], max_length - feat.shape[-1], dtype=feat.dtype)
                feat = torch.cat([feat, padding], dim=-1)
            padded.append(feat)

        return torch.stack(padded)

    def _process_labels(self, features) -> torch.Tensor:
        """
        Process and pad labels (text tokens).

        Returns tensor with -100 for padding positions (ignored in loss)
        """
        # Extract label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        # Pad the sequences
        padded_tokens = self.processor.tokenizer.pad(
            label_features,
            padding=self.padding,
            return_tensors=self.return_tensors
        )

        # Get padded input_ids and attention_mask
        input_ids = padded_tokens["input_ids"]
        attention_mask = padded_tokens["attention_mask"]

        # Replace padding token id with -100 so it's ignored in loss (for loss masking)
        labels = input_ids.masked_fill(attention_mask.ne(1), -100)

        return labels

## DATA PREPROCESSING

In [8]:
class AudioPreprocessor:
    """Handle audio preprocessing for Whisper"""

    def __init__(self, processor, config: Config):
        self.processor = processor
        self.config = config

    def preprocess_audio(self, audio_array: np.ndarray, sample_rate: int) -> np.ndarray:
        """
        Preprocess audio: normalize, resample, and ensure correct format
        """
        # Handle empty audio
        if len(audio_array) == 0:
            return np.zeros(self.config.target_sample_rate, dtype=np.float32)

        # Ensure float32 and 1D
        audio_array = np.array(audio_array, dtype=np.float32)
        if len(audio_array.shape) > 1:
            audio_array = audio_array.flatten()

        # Normalize to prevent clipping
        if audio_array.max() > 1.0 or audio_array.min() < -1.0:
            audio_array = audio_array / np.max(np.abs(audio_array))

        # Resample if needed
        if sample_rate != self.config.target_sample_rate:
            audio_array = librosa.resample(
                audio_array,
                orig_sr=sample_rate,
                target_sr=self.config.target_sample_rate
            )

        return audio_array

    def prepare_example(self, batch: Dict) -> Dict:
        """Prepare a single example for training"""
        # Extract and preprocess audio
        audio = batch["audio"]
        audio_array = self.preprocess_audio(audio["array"], audio["sampling_rate"])

        # Extract features
        inputs = self.processor.feature_extractor(
            audio_array,
            sampling_rate=self.config.target_sample_rate,
            return_tensors="pt"
        )

        # Process transcript
        transcript = self._get_transcript(batch)
        labels = self.processor.tokenizer(transcript, return_tensors="pt")

        return {
            "input_features": inputs.input_features[0],
            "labels": labels.input_ids[0],
            "transcript_clean": transcript
        }

    def _get_transcript(self, batch: Dict) -> str:
        """Extract and clean transcript"""
        transcript = batch.get("transcript", batch.get("transcription", "")).lower()
        return transcript if transcript else " "

## METRICS

In [9]:
class MetricsComputer:
    """Compute evaluation metrics (WER, CER, BLEU) safely for Whisper"""

    def __init__(self, processor):
        self.processor = processor
        self.tokenizer = processor.tokenizer
        self.wer_metric = evaluate.load("wer")
        self.cer_metric = evaluate.load("cer")
        self.bleu_metric = evaluate.load("bleu")

    def compute_metrics(self, pred) -> Dict[str, float]:

        # 1. Get predictions
        pred_ids = pred.predictions
        if isinstance(pred_ids, tuple):  # sometimes model returns (logits,)
            pred_ids = pred_ids[0]

        # decode predicted text
        pred_str = self.tokenizer.batch_decode(
            pred_ids,
            skip_special_tokens=True
        )

        # 2. Decode labels
        label_ids = pred.label_ids

        # Replace masked values with pad_token_id (Whisper uses eos as pad)
        label_ids[label_ids == -100] = self.tokenizer.pad_token_id

        label_str = self.tokenizer.batch_decode(
            label_ids,
            skip_special_tokens=True
        )

        # 3. Compute WER and CER
        wer = self.wer_metric.compute(predictions=pred_str, references=label_str)
        cer = self.cer_metric.compute(predictions=pred_str, references=label_str)

        # 4. Compute BLEU score
        # BLEU expects references as lists of lists and predictions as lists
        bleu = self.bleu_metric.compute(
            predictions=pred_str,
            references=[[ref] for ref in label_str]
        )

        return {
            "wer": wer,
            "cer": cer,
            "bleu": bleu["bleu"]
        }

## TRAINER

In [10]:
class WhisperTrainer:
    """Main training orchestrator"""

    def __init__(self, config: Config):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Initialize components
        self.processor = self._load_processor()
        self.model = self._load_model()
        self.preprocessor = AudioPreprocessor(self.processor, config)
        self.metrics_computer = MetricsComputer(self.processor)
        self.data_collator = DataCollatorSpeechSeq2SeqWithPadding(self.processor)

        print(f"✓ Using device: {self.device}")
        print(f"✓ Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")

    def _load_processor(self) -> WhisperProcessor:
        """Load Whisper processor"""
        return WhisperProcessor.from_pretrained(
            self.config.model_name,
            language=self.config.language_code,
            task=self.config.task
        )

    def _load_model(self) -> WhisperForConditionalGeneration:
        """Load and configure Whisper model"""
        model = WhisperForConditionalGeneration.from_pretrained(self.config.model_name)
        model.config.forced_decoder_ids = None
        model.config.suppress_tokens = []

        # Apply LoRA if enabled
        if self.config.use_lora:
            model = self._apply_lora(model)

        model.to(self.device)
        return model

    def _apply_lora(self, model):
        """Apply LoRA adapters to the model with proper forward wrapping"""
        print(f"\n{'='*70}")
        print("APPLYING LoRA")
        print(f"{'='*70}")

        # Configure LoRA
        lora_config = LoraConfig(
            r=self.config.lora_r,
            lora_alpha=self.config.lora_alpha,
            target_modules=self.config.lora_target_modules,
            lora_dropout=self.config.lora_dropout,
            bias="none",
            task_type="SEQ_2_SEQ_LM"
        )

        # Prepare model for training (freezes base model)
        model = prepare_model_for_kbit_training(model)

        # Add LoRA adapters
        model = get_peft_model(model, lora_config)

        # CRITICAL FIX: Wrap the forward method to handle PEFT's argument passing
        self._patch_peft_forward(model)

        # Print trainable parameters
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in model.parameters())

        print(f"✓ LoRA Configuration:")
        print(f"  Rank (r):              {self.config.lora_r}")
        print(f"  Alpha:                 {self.config.lora_alpha}")
        print(f"  Dropout:               {self.config.lora_dropout}")
        print(f"  Target modules:        {', '.join(self.config.lora_target_modules)}")
        print(f"\n✓ Parameter Statistics:")
        print(f"  Trainable params:      {trainable_params:,} ({100 * trainable_params / total_params:.2f}%)")
        print(f"  Total params:          {total_params:,}")
        print(f"  Memory reduction:      ~{100 * (1 - trainable_params / total_params):.1f}%")
        print(f"{'='*70}\n")

        return model

    def _patch_peft_forward(self, model):
        """
        Patch PEFT model's forward to handle Whisper's argument requirements.

        The issue: PEFT's forward method explicitly passes input_ids as a named
        argument to the base model, but Whisper uses input_features instead.

        We need to patch the PEFT forward to map input_ids -> input_features.
        """
        # Store original PEFT forward method
        original_peft_forward = model.forward

        def whisper_compatible_forward(
            self,
            input_ids=None,
            attention_mask=None,
            inputs_embeds=None,
            decoder_input_ids=None,
            decoder_attention_mask=None,
            decoder_inputs_embeds=None,
            labels=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
            task_ids=None,
            **kwargs
        ):
            """
            Modified forward that handles Whisper's input_features instead of input_ids.

            PEFT expects input_ids, but Whisper expects input_features.
            If input_features is in kwargs, use it. Otherwise use input_ids.
            """
            # Check if input_features is provided (Whisper's audio input)
            input_features = kwargs.pop('input_features', None)

            # If input_features exists, use it; otherwise fall back to input_ids
            # For Whisper, input_features IS the primary input (not input_ids)
            if input_features is not None:
                # Call base_model directly with input_features
                # Bypass PEFT's input_ids handling
                with self._enable_peft_forward_hooks(**kwargs):
                    kwargs_clean = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
                    return self.base_model(
                        input_features=input_features,  # Use input_features, not input_ids
                        attention_mask=attention_mask,
                        decoder_input_ids=decoder_input_ids,
                        decoder_attention_mask=decoder_attention_mask,
                        decoder_inputs_embeds=decoder_inputs_embeds,
                        labels=labels,
                        output_attentions=output_attentions,
                        output_hidden_states=output_hidden_states,
                        return_dict=return_dict,
                        **kwargs_clean
                    )
            else:
                # Fall back to original PEFT forward (for non-Whisper use cases)
                return original_peft_forward(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    inputs_embeds=inputs_embeds,
                    decoder_input_ids=decoder_input_ids,
                    decoder_attention_mask=decoder_attention_mask,
                    decoder_inputs_embeds=decoder_inputs_embeds,
                    labels=labels,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                    task_ids=task_ids,
                    **kwargs
                )

        # Bind the new forward method to the PEFT model
        model.forward = types.MethodType(whisper_compatible_forward, model)

        print("✓ Applied PEFT forward wrapper for Whisper compatibility")

    def prepare_datasets(self, train_data: Dataset, val_data: Dataset, test_data: Dataset) -> Tuple:
        """Preprocess all datasets"""
        print("\nPreprocessing datasets...")

        processed_train = train_data.map(
            self.preprocessor.prepare_example,
            remove_columns=train_data.column_names,
            desc="Processing train",
            num_proc=1
        )

        processed_val = val_data.map(
            self.preprocessor.prepare_example,
            remove_columns=val_data.column_names,
            desc="Processing validation",
            num_proc=1
        )

        processed_test = test_data.map(
            self.preprocessor.prepare_example,
            remove_columns=test_data.column_names,
            desc="Processing test",
            num_proc=1
        )

        print(f"✓ Preprocessing complete!")
        return processed_train, processed_val, processed_test

    def train(self, train_dataset: Dataset, val_dataset: Dataset):
        """Run training"""
        training_args = Seq2SeqTrainingArguments(
            output_dir=self.config.output_dir,
            per_device_train_batch_size=self.config.batch_size,
            per_device_eval_batch_size=self.config.batch_size,
            gradient_accumulation_steps=self.config.gradient_accumulation_steps,
            learning_rate=self.config.learning_rate,
            warmup_steps=self.config.warmup_steps,
            num_train_epochs=self.config.num_epochs,
            predict_with_generate=True,
            generation_max_length=self.config.generation_max_length,

            # Minimal logging - only training steps
            logging_steps=25,
            logging_strategy="steps",
            logging_first_step=True,

            # Disable evaluation during training for cleaner output
            eval_strategy="no",

            # Save settings
            save_strategy="epoch",
            save_total_limit=2,

            # Performance settings
            fp16=self.config.use_fp16 and torch.cuda.is_available(),
            dataloader_num_workers=self.config.num_workers,

            # Disable reporting
            report_to=[],
            disable_tqdm=False,

            # For PEFT models, we keep remove_unused_columns=False
            # because PEFT's forward signature differs from base model
            remove_unused_columns=False if self.config.use_lora else True,
            label_names=["labels"],

            # Suppress warnings
            label_smoothing_factor=0.0,
        )

        trainer = Seq2SeqTrainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            data_collator=self.data_collator,
            processing_class=self.processor.feature_extractor,
            compute_metrics=self.metrics_computer.compute_metrics
        )

        print("\n" + "="*70)
        print("STARTING TRAINING")
        print("="*70)
        print(f"Total steps: {len(train_dataset) // (self.config.batch_size * self.config.gradient_accumulation_steps) * self.config.num_epochs}")
        print(f"Logging every {training_args.logging_steps} steps")
        print("="*70 + "\n")

        # Train
        train_result = trainer.train()

        # Save final model
        trainer.save_model()
        trainer.save_state()

        # Final evaluation on validation set
        print("\n" + "="*70)
        print("EVALUATING ON VALIDATION SET")
        print("="*70)
        eval_results = trainer.evaluate(val_dataset)

        print("\n" + "="*70)
        print("TRAINING COMPLETE")
        print("="*70)
        print(f"\nFinal Training Metrics:")
        print(f"  Train Loss:           {train_result.metrics['train_loss']:.4f}")
        print(f"  Train Runtime:        {train_result.metrics['train_runtime']:.2f}s")
        print(f"  Samples/Second:       {train_result.metrics['train_samples_per_second']:.2f}")
        print(f"\nValidation Metrics:")
        print(f"  Validation Loss:      {eval_results['eval_loss']:.4f}")
        print(f"  Validation WER:       {eval_results['eval_wer']:.4f} ({eval_results['eval_wer']*100:.2f}%)")
        print(f"  Validation CER:       {eval_results['eval_cer']:.4f} ({eval_results['eval_cer']*100:.2f}%)")
        print(f"  Validation BLEU:      {eval_results['eval_bleu']:.4f}")
        print("="*70)

        return trainer

    def evaluate(self, trainer: Seq2SeqTrainer, test_dataset: Dataset) -> Dict[str, float]:
        """Evaluate on test set"""
        print("\n" + "="*70)
        print("EVALUATING ON TEST SET")
        print("="*70)

        eval_results = trainer.evaluate(test_dataset)

        print(f"\nTest Results:")
        print(f"  WER:  {eval_results['eval_wer']:.4f} ({eval_results['eval_wer']*100:.2f}%)")
        print(f"  CER:  {eval_results['eval_cer']:.4f} ({eval_results['eval_cer']*100:.2f}%)")
        print(f"  Loss: {eval_results['eval_loss']:.4f}")
        print(f"  BLEU: {eval_results['eval_bleu']:.4f}")
        print("="*70)

        return eval_results

    def detailed_evaluation(self, test_dataset: Dataset, num_samples: int = None) -> Dict[str, any]:
        """
        Comprehensive evaluation with detailed metrics

        Args:
            test_dataset: Test dataset to evaluate on
            num_samples: Number of samples to evaluate (None = all)

        Returns:
            Dictionary containing detailed metrics and predictions
        """
        print("\n" + "="*70)
        print("DETAILED EVALUATION")
        print("="*70)

        self.model.eval()

        all_predictions = []
        all_references = []
        sample_wers = []
        sample_cers = []

        num_samples = num_samples or len(test_dataset)

        print(f"Evaluating {num_samples} samples...")

        for idx in tqdm(range(num_samples), desc="Evaluating"):
            sample = test_dataset[idx]

            # Get input features
            input_features = torch.tensor(sample["input_features"]).unsqueeze(0).to(self.device)

            # Generate prediction
            with torch.no_grad():
                generated_tokens = self.model.generate(
                    input_features,
                    max_length=self.config.generation_max_length
                )

            prediction = self.processor.batch_decode(generated_tokens, skip_special_tokens=True)[0]
            reference = sample["transcript_clean"]

            all_predictions.append(prediction)
            all_references.append(reference)

            # Calculate per-sample metrics
            try:
                sample_wer = wer([reference], [prediction])
                sample_cer = cer([reference], [prediction])
                sample_wers.append(sample_wer)
                sample_cers.append(sample_cer)
            except:
                pass

        # Calculate overall metrics
        overall_wer = wer(all_references, all_predictions)
        overall_cer = cer(all_references, all_predictions)

        # Calculate statistics
        results = {
            'overall_wer': overall_wer,
            'overall_cer': overall_cer,
            'mean_wer': np.mean(sample_wers),
            'std_wer': np.std(sample_wers),
            'median_wer': np.median(sample_wers),
            'min_wer': np.min(sample_wers),
            'max_wer': np.max(sample_wers),
            'mean_cer': np.mean(sample_cers),
            'std_cer': np.std(sample_cers),
            'predictions': all_predictions,
            'references': all_references,
            'num_samples': num_samples
        }

        # Print results
        print("\n" + "="*70)
        print("EVALUATION RESULTS")
        print("="*70)
        print(f"\nOverall Metrics (Corpus-level):")
        print(f"   WER: {overall_wer:.4f} ({overall_wer*100:.2f}%)")
        print(f"   CER: {overall_cer:.4f} ({overall_cer*100:.2f}%)")

        print(f"\nPer-Sample Statistics:")
        print(f"   Mean WER:     {results['mean_wer']:.4f} ± {results['std_wer']:.4f}")
        print(f"   Median WER:   {results['median_wer']:.4f}")
        print(f"   Min WER:      {results['min_wer']:.4f}")
        print(f"   Max WER:      {results['max_wer']:.4f}")
        print(f"\n   Mean CER:     {results['mean_cer']:.4f} ± {results['std_cer']:.4f}")

        print(f"\nSamples evaluated: {num_samples}")
        print("="*70)

        return results


### Initialize Whisper Trainer

In [11]:
# Initialize trainer
trainer_wrapper = WhisperTrainer(config)

preprocessor_config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

normalizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/151M [00:00<?, ?B/s]

generation_config.json: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules: 0.00B [00:00, ?B/s]

✓ Using device: cuda
✓ Model parameters: 37,760,640


### Preprocess Datasets (Train, Test, and Validation Sets)

In [12]:
# Preprocess datasets
processed_train, processed_val, processed_test = trainer_wrapper.prepare_datasets(
    train_data, val_data, test_data
)


Preprocessing datasets...


Processing train:   0%|          | 0/1315 [00:00<?, ? examples/s]

Processing validation:   0%|          | 0/186 [00:00<?, ? examples/s]

Processing test:   0%|          | 0/58 [00:00<?, ? examples/s]

✓ Preprocessing complete!


### Baseline Model Performance

Establishes a baseline performance of the pre-trained model (whisper-tiny).

In [13]:
# Initialize baseline seq2seq trainer with (`use_lora=False`)
baseline_trainer = WhisperTrainer(config)

# Create Seq2SeqTrainingArguments for baseline evaluation
baseline_training_args = Seq2SeqTrainingArguments(
    output_dir=baseline_trainer.config.output_dir, # Use the same output_dir or a temporary one
    predict_with_generate=True,
    generation_max_length=baseline_trainer.config.generation_max_length,
    eval_strategy="no", # This is fine, as we explicitly call evaluate later
    per_device_eval_batch_size=baseline_trainer.config.batch_size,
    fp16=baseline_trainer.config.use_fp16 and torch.cuda.is_available(),
    dataloader_num_workers=baseline_trainer.config.num_workers,
    report_to=[],
    disable_tqdm=True, # No need for tqdm during this step
    remove_unused_columns=False,
    label_names=["labels"],
)

# Create a Seq2SeqTrainer instance for the baseline model
baseline_seq2seq_trainer = Seq2SeqTrainer(
    model=baseline_trainer.model,
    args=baseline_training_args,
    data_collator=baseline_trainer.data_collator,
    processing_class=baseline_trainer.processor.feature_extractor,
    compute_metrics=baseline_trainer.metrics_computer.compute_metrics
)

# Evaluate the baseline model using the correctly instantiated Seq2SeqTrainer
baseline_results = baseline_trainer.evaluate(baseline_seq2seq_trainer, processed_test)


✓ Using device: cuda
✓ Model parameters: 37,760,640

EVALUATING ON TEST SET
{'eval_loss': 4.962335109710693, 'eval_model_preparation_time': 0.007, 'eval_wer': 0.6930022573363431, 'eval_cer': 0.4180456366486675, 'eval_bleu': 0.29322770968972534, 'eval_runtime': 18.3534, 'eval_samples_per_second': 3.16, 'eval_steps_per_second': 1.58}

Test Results:
  WER:  0.6930 (69.30%)
  CER:  0.4180 (41.80%)
  Loss: 4.9623
  BLEU: 0.2932


### Model Training

In [14]:
# Train model
trainer = trainer_wrapper.train(processed_train, processed_val)


STARTING TRAINING
Total steps: 492
Logging every 25 steps



Step,Training Loss
1,5.8791
25,3.4683
50,1.7468
75,2.3416
100,3.2096
125,3.8866
150,4.4145
175,4.3452
200,4.2089
225,4.8751



EVALUATING ON VALIDATION SET



TRAINING COMPLETE

Final Training Metrics:
  Train Loss:           3.3057
  Train Runtime:        612.74s
  Samples/Second:       6.44

Validation Metrics:
  Validation Loss:      5.1814
  Validation WER:       0.9160 (91.60%)
  Validation CER:       0.7013 (70.13%)
  Validation BLEU:      0.0495


### Model Evaluation

In [15]:
# Evaluate on test set
results = trainer_wrapper.evaluate(trainer, processed_test)


EVALUATING ON TEST SET



Test Results:
  WER:  0.9707 (97.07%)
  CER:  0.6851 (68.51%)
  Loss: 4.9232
  BLEU: 0.0611


### LoRA Application

In [16]:
config_lora = Config(use_lora=True)

# Initialize trainer with LoRA
trainer_lora_wrapper = WhisperTrainer(config_lora)

trainer_lora = trainer_lora_wrapper.train(processed_train, processed_val)

results_lora = trainer_lora_wrapper.evaluate(trainer_lora, processed_test)


APPLYING LoRA
✓ Applied PEFT forward wrapper for Whisper compatibility
✓ LoRA Configuration:
  Rank (r):              32
  Alpha:                 64
  Dropout:               0.05
  Target modules:        q_proj, v_proj, k_proj, out_proj, fc1, fc2

✓ Parameter Statistics:
  Trainable params:      2,162,688 (5.42%)
  Total params:          39,923,328
  Memory reduction:      ~94.6%

✓ Using device: cuda
✓ Model parameters: 39,923,328

STARTING TRAINING
Total steps: 492
Logging every 25 steps



Step,Training Loss
1,5.8791
25,4.4042
50,1.8558
75,1.6625
100,1.6871
125,1.6879
150,1.5769
175,1.3521
200,1.2423
225,1.3884



EVALUATING ON VALIDATION SET



TRAINING COMPLETE

Final Training Metrics:
  Train Loss:           1.3474
  Train Runtime:        610.69s
  Samples/Second:       6.46

Validation Metrics:
  Validation Loss:      1.4095
  Validation WER:       0.4589 (45.89%)
  Validation CER:       0.2321 (23.21%)
  Validation BLEU:      0.4369

EVALUATING ON TEST SET

Test Results:
  WER:  0.3363 (33.63%)
  CER:  0.1482 (14.82%)
  Loss: 1.0295
  BLEU: 0.5515


### Inference

In [26]:
def show_sample_predictions(
    trainer_wrapper: WhisperTrainer,
    test_dataset: Dataset,
    num_samples: int = 5
):
    """
    Display sample predictions from test set

    Args:
        trainer_wrapper: Trained WhisperTrainer instance
        test_dataset: Test dataset
        num_samples: Number of samples to display
    """
    print("\n" + "="*120)
    print(f"SAMPLE PREDICTIONS (First {num_samples} test samples)")
    print("="*120)

    trainer_wrapper.model.eval()

    for idx in range(min(num_samples, len(test_dataset))):
        sample = test_dataset[idx]

        # Get input features
        input_features = torch.tensor(sample["input_features"]).unsqueeze(0).to(trainer_wrapper.device)

        # Generate prediction
        # CRITICAL: Use keyword argument, not positional
        # PEFT models require this for proper argument handling
        with torch.no_grad():
            generated_tokens = trainer_wrapper.model.generate(
                input_features=input_features,  # Use keyword argument
                max_length=trainer_wrapper.config.generation_max_length
            )

        prediction = trainer_wrapper.processor.batch_decode(generated_tokens, skip_special_tokens=True)[0]
        reference = sample["transcript_clean"]

        # Calculate WER for this sample
        sample_wer = wer([reference], [prediction])

        print(f"\n{'─'*120}")
        print(f"Sample {idx + 1}:")
        print(f"{'─'*120}")
        print(f"Reference:  {reference[:100]}{'...' if len(reference) > 100 else ''}")
        print(f"Prediction: {prediction[:100]}{'...' if len(prediction) > 100 else ''}")
        print(f"WER:        {sample_wer:.4f} ({sample_wer*100:.2f}%)")

In [27]:
show_sample_predictions(trainer_wrapper, processed_test, num_samples=5)


SAMPLE PREDICTIONS (First 5 test samples)

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Sample 1:
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Reference:  proteins break down to release amino acids which are used as fuel for hepatic gluconeogenesis to mai...
Prediction: prewed 16 was a nastywednesday, for hepaticlasixracheal 1990s ranging to maintain fecalcemia

WER:        0.9130 (91.30%)

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Sample 2:
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Reference:  aspiration is a potential risk in a patient who subsequently loses consciousness or fits and vomits....
Prediction: response are important determinant in location and location and tissues a

In [28]:
show_sample_predictions(trainer_lora_wrapper, processed_test, num_samples=5)


SAMPLE PREDICTIONS (First 5 test samples)

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Sample 1:
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Reference:  proteins break down to release amino acids which are used as fuel for hepatic gluconeogenesis to mai...
Prediction: rottings breakdown to release aminocourses which i used as foam for hepatic gluconeogen as a sore as...
WER:        0.6957 (69.57%)

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Sample 2:
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Reference:  aspiration is a potential risk in a patient who subsequently loses consciousness or fits and vomits....
Prediction: aspiration is a potential risk in efficient to also observe ent