# WHISPER TRANSFORMERS FOR SPEECH RECOGNITION

Whisper Speech Recognition for Low-Resource Languages. Fine-tune OpenAI's Whisper on AfriSpeech-200 (Twi language).

## SETUP

In [None]:
# Install dependencies
!pip install -q --upgrade --no-cache-dir librosa==0.9.2 --no-deps

!pip install soundfile audioread numpy scipy resampy datasets==3.0.0 jiwer evaluate

# %pip install -q "transformers[audio]" datasets==3.0.0 torchaudio librosa soundfile jiwer accelerate evaluate einops sentencepiece

## IMPORT REQUIRED LIBRARIES

In [2]:
import os
import torch
import librosa
import warnings
import torchaudio
import numpy as np
import seaborn as sns
from pathlib import Path
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Dict, List, Union, Optional, Tuple
import shutil

from datasets import load_dataset, Dataset
from transformers import (
    WhisperForConditionalGeneration,
    WhisperProcessor,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
from jiwer import wer, cer
import evaluate

warnings.filterwarnings("ignore")
if os.path.exists("/content/whisper-afrispeech"):
    shutil.rmtree("/content/whisper-afrispeech")

## CONFIGURATION

In [3]:
@dataclass
class Config:
    """Central configuration for the training pipeline"""
    # Model settings
    model_name: str = "openai/whisper-tiny"
    language_code: str = "tw"  # ISO code for Twi
    task: str = "transcribe"

    # Dataset settings
    dataset_name: str = "intronhealth/afrispeech-200"
    language: str = "twi"
    max_train_samples: Optional[int] = None
    max_val_samples: Optional[int] = None

    # Training settings
    output_dir: str = "/content/whisper-afrispeech"
    num_epochs: int = 3
    batch_size: int = 2
    gradient_accumulation_steps: int = 4
    learning_rate: float = 1e-5
    warmup_steps: int = 200

    # Audio settings
    target_sample_rate: int = 16000
    max_audio_len: int = 30  # seconds

    # Evaluation settings
    generation_max_length: int = 225

    # System settings
    use_fp16: bool = True
    num_workers: int = 2
    seed: int = 42

    def __post_init__(self):
        """Set random seeds for reproducibility"""
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)
        sns.set(style="whitegrid")

        # Disable Weights & Biases (wandb) logging to prevent API prompts
        os.environ["WANDB_DISABLED"] = "true"
        os.environ["WANDB_SILENT"] = "true"
        os.environ["WANDB_MODE"] = "offline"


# Initialize configuration
config = Config()

## DATA COLLATOR

In [4]:
class DataCollatorSpeechSeq2SeqWithPadding:
    """
    Custom data collator for Whisper that handles:
    - Variable-length audio sequences
    - Padding for both audio features and text labels
    """

    def __init__(self, processor, padding: bool = True, return_tensors: str = "pt"):
        self.processor = processor
        self.padding = padding
        self.return_tensors = return_tensors

    def __call__(self, features: List[Dict[str, Union[torch.Tensor, np.ndarray]]]) -> Dict[str, torch.Tensor]:
        # Process input features (audio)
        input_features_list = self._process_audio_features(features)

        # Pad audio features
        padded_features = self._pad_audio_features(input_features_list)

        # Process and pad labels (text)
        padded_labels, attention_mask = self._process_labels(features)

        return {
            "input_features": padded_features,
            "labels": padded_labels,
            "attention_mask": attention_mask
        }

    def _process_audio_features(self, features: List[Dict]) -> List[torch.Tensor]:
        """Convert audio features to tensors"""
        processed = []
        for feature in features:
            feat = feature["input_features"]
            if isinstance(feat, np.ndarray):
                feat = torch.tensor(feat, dtype=torch.float32)
            elif not isinstance(feat, torch.Tensor):
                feat = torch.tensor(feat, dtype=torch.float32)
            processed.append(feat)
        return processed

    def _pad_audio_features(self, features: List[torch.Tensor]) -> torch.Tensor:
        """Pad audio features to same length"""
        max_length = max(feat.shape[-1] for feat in features)
        padded = []

        for feat in features:
            if feat.shape[-1] < max_length:
                padding = torch.zeros(feat.shape[0], max_length - feat.shape[-1], dtype=feat.dtype)
                feat = torch.cat([feat, padding], dim=-1)
            padded.append(feat)

        return torch.stack(padded)

    def _process_labels(self, features):
        """Process and pad labels (text)"""
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        padded_tokens = self.processor.tokenizer.pad(
            label_features,
            padding=self.padding,
            return_tensors=self.return_tensors
        )

        input_ids = padded_tokens["input_ids"]
        attention_mask = padded_tokens["attention_mask"]

        # Convert padding tokens → -100 for loss masking
        labels = input_ids.masked_fill(attention_mask.ne(1), -100)

        return labels, attention_mask


## DATA PREPROCESSING

In [12]:
class AudioPreprocessor:
    """Handle audio preprocessing for Whisper"""

    def __init__(self, processor, config: Config):
        self.processor = processor
        self.config = config

    def preprocess_audio(self, audio_array: np.ndarray, sample_rate: int) -> np.ndarray:
        """
        Preprocess audio: normalize, resample, and ensure correct format
        """
        # Handle empty audio
        if len(audio_array) == 0:
            return np.zeros(self.config.target_sample_rate, dtype=np.float32)

        # Ensure float32 and 1D
        audio_array = np.array(audio_array, dtype=np.float32)
        if len(audio_array.shape) > 1:
            audio_array = audio_array.flatten()

        # Normalize to prevent clipping
        if audio_array.max() > 1.0 or audio_array.min() < -1.0:
            audio_array = audio_array / np.max(np.abs(audio_array))

        # Resample if needed
        if sample_rate != self.config.target_sample_rate:
            audio_array = librosa.resample(
                audio_array,
                orig_sr=sample_rate,
                target_sr=self.config.target_sample_rate
            )

        return audio_array

    def prepare_example(self, batch: Dict) -> Dict:
        """Prepare a single example for training"""
        # Extract and preprocess audio
        audio = batch["audio"]
        audio_array = self.preprocess_audio(audio["array"], audio["sampling_rate"])

        # Extract features
        inputs = self.processor.feature_extractor(
            audio_array,
            sampling_rate=self.config.target_sample_rate,
            return_tensors="pt"
        )

        # Process transcript
        transcript = self._get_transcript(batch)
        labels = self.processor.tokenizer(transcript, return_tensors="pt")

        return {
            "input_features": inputs.input_features[0],
            "labels": labels.input_ids[0],
            "transcript_clean": transcript
        }

    def _get_transcript(self, batch: Dict) -> str:
        """Extract and clean transcript"""
        transcript = batch.get("transcript", batch.get("transcription", "")).lower()
        return transcript if transcript else " "

## DATASET LOADER

In [6]:
class DatasetLoader:
    """Load and prepare AfriSpeech-200 dataset"""

    def __init__(self, config: Config):
        self.config = config

    def load_data(self) -> Tuple[Dataset, Dataset, Dataset]:
        """
        Load train/validation/test splits from HuggingFace

        Returns:
            Tuple of (train_dataset, val_dataset, test_dataset)
        """
        print(f"Loading {self.config.dataset_name} ({self.config.language})...")

        dataset = load_dataset(
            self.config.dataset_name,
            self.config.language,
            trust_remote_code=True
        )

        # Verify splits exist
        required_splits = ['train', 'validation', 'test']
        missing = [s for s in required_splits if s not in dataset]
        if missing:
            raise ValueError(f"Missing splits: {missing}")

        train_data = dataset["train"]
        val_data = dataset["validation"]
        test_data = dataset["test"]

        # Limit samples if specified
        if self.config.max_train_samples:
            train_data = train_data.select(range(min(self.config.max_train_samples, len(train_data))))
        if self.config.max_val_samples:
            val_data = val_data.select(range(min(self.config.max_val_samples, len(val_data))))

        print(f"✓ Train: {len(train_data):,} samples")
        print(f"✓ Val: {len(val_data):,} samples")
        print(f"✓ Test: {len(test_data):,} samples")

        return train_data, val_data, test_data

In [7]:
# Load datasets
loader = DatasetLoader(config)
train_data, val_data, test_data = loader.load_data()

Loading intronhealth/afrispeech-200 (twi)...


afrispeech-200.py: 0.00B [00:00, ?B/s]

README.md: 0.00B [00:00, ?B/s]

accent_stats.py: 0.00B [00:00, ?B/s]

audio/twi/train/train_twi_0.tar.gz:   0%|          | 0.00/1.10G [00:00<?, ?B/s]

audio/twi/dev/dev_twi_0.tar.gz:   0%|          | 0.00/152M [00:00<?, ?B/s]

audio/twi/test/test_twi_0.tar.gz:   0%|          | 0.00/45.9M [00:00<?, ?B/s]

transcripts/twi/train.csv:   0%|          | 0.00/427k [00:00<?, ?B/s]

transcripts/twi/dev.csv:   0%|          | 0.00/61.0k [00:00<?, ?B/s]

transcripts/twi/test.csv:   0%|          | 0.00/19.2k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]


Reading metadata...: 1315it [00:00, 46063.92it/s]


Generating validation split: 0 examples [00:00, ? examples/s]


Reading metadata...: 186it [00:00, 6196.90it/s]


Generating test split: 0 examples [00:00, ? examples/s]


Reading metadata...: 58it [00:00, 28221.54it/s]


✓ Train: 1,315 samples
✓ Val: 186 samples
✓ Test: 58 samples


## METRICS

In [8]:
class MetricsComputer:
    """Compute evaluation metrics (WER, CER) safely for Whisper"""

    def __init__(self, processor):
        self.processor = processor
        self.tokenizer = processor.tokenizer
        self.wer_metric = evaluate.load("wer")
        self.cer_metric = evaluate.load("cer")

    def compute_metrics(self, pred) -> Dict[str, float]:

        # 1. Get predictions
        pred_ids = pred.predictions
        if isinstance(pred_ids, tuple):  # sometimes model returns (logits,)
            pred_ids = pred_ids[0]

        # decode predicted text
        pred_str = self.tokenizer.batch_decode(
            pred_ids,
            skip_special_tokens=True
        )

        # 2. Decode labels
        label_ids = pred.label_ids

        # Replace masked values with pad_token_id (Whisper uses eos as pad)
        label_ids[label_ids == -100] = self.tokenizer.pad_token_id

        label_str = self.tokenizer.batch_decode(
            label_ids,
            skip_special_tokens=True
        )

        # 3. Compute metrics
        wer = self.wer_metric.compute(predictions=pred_str, references=label_str)
        cer = self.cer_metric.compute(predictions=pred_str, references=label_str)

        return {"wer": wer, "cer": cer}


## TRAINER

In [9]:
class WhisperTrainer:
    """Main training orchestrator"""

    def __init__(self, config: Config):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Initialize components
        self.processor = self._load_processor()
        self.model = self._load_model()
        self.preprocessor = AudioPreprocessor(self.processor, config)
        self.metrics_computer = MetricsComputer(self.processor)
        self.data_collator = DataCollatorSpeechSeq2SeqWithPadding(self.processor)

        print(f"✓ Using device: {self.device}")
        print(f"✓ Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")

    def _load_processor(self) -> WhisperProcessor:
        """Load Whisper processor"""
        return WhisperProcessor.from_pretrained(
            self.config.model_name,
            language=self.config.language_code,
            task=self.config.task
        )

    def _load_model(self) -> WhisperForConditionalGeneration:
        """Load and configure Whisper model"""
        model = WhisperForConditionalGeneration.from_pretrained(self.config.model_name)
        model.config.forced_decoder_ids = None
        model.config.suppress_tokens = []
        model.to(self.device)
        return model

    def prepare_datasets(self, train_data: Dataset, val_data: Dataset, test_data: Dataset) -> Tuple:
        """Preprocess all datasets"""
        print("\nPreprocessing datasets...")

        processed_train = train_data.map(
            self.preprocessor.prepare_example,
            remove_columns=train_data.column_names,
            desc="Processing train",
            num_proc=1
        )

        processed_val = val_data.map(
            self.preprocessor.prepare_example,
            remove_columns=val_data.column_names,
            desc="Processing validation",
            num_proc=1
        )

        processed_test = test_data.map(
            self.preprocessor.prepare_example,
            remove_columns=test_data.column_names,
            desc="Processing test",
            num_proc=1
        )

        print(f"✓ Preprocessing complete!")
        return processed_train, processed_val, processed_test

    def train(self, train_dataset: Dataset, val_dataset: Dataset):
        """Run training"""
        training_args = Seq2SeqTrainingArguments(
            output_dir=self.config.output_dir,
            per_device_train_batch_size=self.config.batch_size,
            per_device_eval_batch_size=self.config.batch_size,
            gradient_accumulation_steps=self.config.gradient_accumulation_steps,
            learning_rate=self.config.learning_rate,
            warmup_steps=self.config.warmup_steps,
            num_train_epochs=self.config.num_epochs,
            predict_with_generate=True,
            generation_max_length=self.config.generation_max_length,

            # Minimal logging - only training steps
            logging_steps=25,
            logging_strategy="steps",
            logging_first_step=True,

            # Disable evaluation during training for cleaner output
            eval_strategy="no",

            # Save settings
            save_strategy="epoch",
            save_total_limit=2,

            # Performance settings
            fp16=self.config.use_fp16 and torch.cuda.is_available(),
            dataloader_num_workers=self.config.num_workers,

            # Disable reporting
            report_to=[],
            disable_tqdm=False,  # Keep progress bar

            # Suppress warnings
            label_smoothing_factor=0.0,
        )

        trainer = Seq2SeqTrainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            data_collator=self.data_collator,
            processing_class=self.processor.feature_extractor,
            compute_metrics=self.metrics_computer.compute_metrics
        )

        print("\n" + "="*60)
        print("STARTING TRAINING")
        print("="*60)
        print(f"Total steps: {len(train_dataset) // (self.config.batch_size * self.config.gradient_accumulation_steps) * self.config.num_epochs}")
        print(f"Logging every {training_args.logging_steps} steps")


        train_result = trainer.train()

        # Save final model
        trainer.save_model()
        trainer.save_state()

        print("\n" + "="*60)
        print("TRAINING COMPLETE")
        print("="*60)

        print(f"\nFinal Training Metrics:")
        print(f"  Train Loss:           {train_result.metrics['train_loss']:.4f}")
        print(f"  Train Runtime:        {train_result.metrics['train_runtime']:.2f}s")
        print(f"  Samples/Second:       {train_result.metrics['train_samples_per_second']:.2f}")

        # Evaluation on validation set
        print("\n" + "="*60)
        print("EVALUATING ON VALIDATION SET")
        print("="*60)
        eval_results = trainer.evaluate(val_dataset)
        print(f"\nValidation Metrics:")
        print(f"  Validation Loss:      {eval_results['eval_loss']:.4f}")
        print(f"  Validation WER:       {eval_results['eval_wer']:.4f} ({eval_results['eval_wer']*100:.2f}%)")
        print(f"  Validation CER:       {eval_results['eval_cer']:.4f} ({eval_results['eval_cer']*100:.2f}%)")


        return trainer

    def evaluate(self, trainer: Seq2SeqTrainer, test_dataset: Dataset) -> Dict[str, float]:
        """Evaluate on test set"""
        print("\n" + "="*60)
        print("EVALUATING ON TEST SET")
        print("="*60)

        eval_results = trainer.evaluate(test_dataset)

        print(f"\nTest Results:")
        print(f"  WER:  {eval_results['eval_wer']:.4f} ({eval_results['eval_wer']*100:.2f}%)")
        print(f"  CER:  {eval_results['eval_cer']:.4f} ({eval_results['eval_cer']*100:.2f}%)")
        print(f"  Loss: {eval_results['eval_loss']:.4f}")

        return eval_results

In [13]:
# Initialize trainer
trainer_wrapper = WhisperTrainer(config)

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

✓ Using device: cuda
✓ Model parameters: 37,760,640


In [14]:
# Preprocess datasets
processed_train, processed_val, processed_test = trainer_wrapper.prepare_datasets(
    train_data, val_data, test_data
)


Preprocessing datasets...


Processing train:   0%|          | 0/1315 [00:00<?, ? examples/s]

Processing validation:   0%|          | 0/186 [00:00<?, ? examples/s]

Processing test:   0%|          | 0/58 [00:00<?, ? examples/s]

✓ Preprocessing complete!


In [15]:
# Train model
trainer = trainer_wrapper.train(processed_train, processed_val)


STARTING TRAINING
Total steps: 492
Logging every 25 steps


You're using a WhisperTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a WhisperTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss
1,5.8791
25,5.7886
50,4.397
75,2.8429
100,2.2794
125,2.0451
150,1.8133
175,1.555
200,1.3761
225,1.4019


You're using a WhisperTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a WhisperTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a WhisperTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a WhisperTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.



TRAINING COMPLETE

Final Training Metrics:
  Train Loss:           1.7062
  Train Runtime:        598.15s
  Samples/Second:       6.59

EVALUATING ON VALIDATION SET


You're using a WhisperTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a WhisperTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Using custom `forced_decoder_ids` from the (generation) config. This is deprecated in favor of the `task` and `language` flags/config options.
Transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English. This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`. See https://github.com/huggingface/transformers/pull/28687 for more details.



Validation Metrics:
  Validation Loss:      1.2588
  Validation WER:       0.4849 (48.49%)
  Validation CER:       0.2409 (24.09%)


In [16]:
# Evaluate on test set
test_results = trainer_wrapper.evaluate(trainer, processed_test)


EVALUATING ON TEST SET


You're using a WhisperTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a WhisperTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.



Test Results:
  WER:  0.3883 (38.83%)
  CER:  0.1765 (17.65%)
  Loss: 0.9018


In [17]:
def show_sample_predictions(
    trainer_wrapper: WhisperTrainer,
    test_dataset: Dataset,
    num_samples: int = 5
):
    """
    Display sample predictions from test set

    Args:
        trainer_wrapper: Trained WhisperTrainer instance
        test_dataset: Test dataset
        num_samples: Number of samples to display
    """
    print("\n" + "="*80)
    print(f"SAMPLE PREDICTIONS (First {num_samples} test samples)")
    print("="*80)

    trainer_wrapper.model.eval()

    for idx in range(min(num_samples, len(test_dataset))):
        sample = test_dataset[idx]

        # Get input features
        input_features = torch.tensor(sample["input_features"]).unsqueeze(0).to(trainer_wrapper.device)

        # Generate prediction
        with torch.no_grad():
            generated_tokens = trainer_wrapper.model.generate(input_features)

        prediction = trainer_wrapper.processor.batch_decode(generated_tokens, skip_special_tokens=True)[0]
        reference = sample["transcript_clean"]

        # Calculate WER for this sample
        sample_wer = wer([reference], [prediction])

        print(f"\n{'─'*80}")
        print(f"Sample {idx + 1}:")
        print(f"{'─'*80}")
        print(f"Reference:  {reference[:100]}{'...' if len(reference) > 100 else ''}")
        print(f"Prediction: {prediction[:100]}{'...' if len(prediction) > 100 else ''}")
        print(f"WER:        {sample_wer:.4f} ({sample_wer*100:.2f}%)")

In [18]:
# View sample predictions
show_sample_predictions(trainer_wrapper, processed_test, num_samples=5)

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



SAMPLE PREDICTIONS (First 5 test samples)

────────────────────────────────────────────────────────────────────────────────
Sample 1:
────────────────────────────────────────────────────────────────────────────────
Reference:  proteins break down to release amino acids which are used as fuel for hepatic gluconeogenesis to mai...
Prediction: protein is breakdown to release amino acids, which are used as full for hepatic gluconeogenesis, so ...
WER:        0.4348 (43.48%)

────────────────────────────────────────────────────────────────────────────────
Sample 2:
────────────────────────────────────────────────────────────────────────────────
Reference:  aspiration is a potential risk in a patient who subsequently loses consciousness or fits and vomits....
Prediction:  aspiration is a potential risk in a patient to also upset country loses consciousness or feeds and ...
WER:        0.3750 (37.50%)

────────────────────────────────────────────────────────────────────────────────
Sample 3: