# Fine-Tuning Whisper for Speech Recognition with AfriSpeech-200
**Course:** ICS553 Deep Learning - Prosit 3

## 1. Introduction
This notebook fine-tunes the **OpenAI Whisper** model (small) for Automatic Speech Recognition (ASR) on the **AfriSpeech-200** dataset. The goal is to adapt the pretrained model to better understand Pan-African accented English.

We will:
1.  Load the AfriSpeech-200 dataset (streaming mode).
2.  Preprocess the audio (resampling, feature extraction).
3.  Fine-tune the Whisper Large V2 model.
4.  Evaluate the performance using Word Error Rate (WER).
5.  Build a demo using Gradio.

## 2. Setup Environment



In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Sun Nov 30 22:20:36 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   38C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
!pip install --upgrade accelerate
!pip install transformers==4.52.0 datasets==2.19.0
!pip install librosa evaluate jiwer gradio

Collecting transformers==4.52.0
  Downloading transformers-4.52.0-py3-none-any.whl.metadata (38 kB)
Collecting datasets==2.19.0
  Downloading datasets-2.19.0-py3-none-any.whl.metadata (19 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers==4.52.0)
  Downloading tokenizers-0.21.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Collecting pyarrow-hotfix (from datasets==2.19.0)
  Downloading pyarrow_hotfix-0.7-py3-none-any.whl.metadata (3.6 kB)
Collecting fsspec<=2024.3.1,>=2023.1.0 (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets==2.19.0)
  Downloading fsspec-2024.3.1-py3-none-any.whl.metadata (6.8 kB)
Reason for being yanked: <none given>[0m[33m
[0mDownloading transformers-4.52.0-py3-none-any.whl (10.5 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m10.5/10.5 MB[0m [31m45.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading datasets-2.19.0-py3-none-any.wh

In [15]:
#Set to False if not pushing to hub or just checking locally
PUSH_TO_HUB = False

from huggingface_hub import notebook_login
if PUSH_TO_HUB:
  notebook_login()

## 3. Imports and Configuration


In [3]:
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from datasets import load_dataset, Audio
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor
from transformers import WhisperForConditionalGeneration
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import pipeline
from tqdm import tqdm

import gradio as gr
import evaluate
import os

# Configuration
MODEL_NAME = "openai/whisper-small"
LANGUAGE = "english"
LANGUAGE_ABR = "en"
TASK = "transcribe"
DATASET_NAME = "tobiolatunji/afrispeech-200"
OUTPUT_DIR = "./afrispeech_ayarma_small"
ACCENT = 'all'


## 4. Load and Prepare Dataset
The dataset is loaded in streaming mode to efficiently handle large audio files. Unused columns are removed, and the transcript column is renamed for consistency.

Audio must be downsampled to 16 kHz before being provided to the Whisper feature extractor. The dataset‚Äôs built-in method is used to assign the correct sampling rate to the audio field. This operation does not modify the audio in place; instead, it instructs the datasets library to resample each audio sample lazily‚Äîthat is, the resampling occurs automatically the first time a given sample is accessed.



In [4]:
dataset = load_dataset(DATASET_NAME, ACCENT, streaming=True, trust_remote_code=True)

dataset = dataset.remove_columns(["speaker_id","path","audio_id","age_group","gender","accent","domain","country","duration"])
dataset = dataset.rename_column("transcript","sentence")

# Cast audio to 16kHz
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))


Downloading builder script: 0.00B [00:00, ?B/s]

Downloading readme: 0.00B [00:00, ?B/s]

Downloading extra modules: 0.00B [00:00, ?B/s]

## 5. Feature Engineering

Speech is represented by a 1-dimensional array that varies with time. The value of the array at any given time step is the signal's amplitude at that point. From the amplitude information alone, we can reconstruct the frequency spectrum of the audio and recover all acoustic features.

#### WhisperFeatureExtractor
The `WhisperFeatureExtractor` prepares the raw audio signal for the Transformer Encoder by performing two critical preprocessing operations to ensure input consistency:
*  Temporal Normalization (Padding/Truncation):Since Transformers generally require fixed-size inputs or standard batch shapes, the extractor standardizes the length of all audio samples to a fixed window of 30 seconds. Given Whisper's sampling rate of 16 kHz, this results in a fixed vector length of $N = 480,000$ samples.$$x_{norm} = \begin{cases}
\text{pad}(x, N) & \text{if } \text{length}(x) < N \\
\text{crop}(x, N) & \text{if } \text{length}(x) > N
\end{cases}$$
* Spectral Feature Extraction:The normalized audio vector is then transformed into the frequency domain. The extractor computes the log-Mel spectrogram, mapping the 1D audio waveform into a 2D spectro-temporal representation.$$\mathcal{F}(x_{norm}) \rightarrow \mathbf{X} \in \mathbb{R}^{n_{mels} \times T}$$For Whisper, this typically results in a feature map with $n_{mels} = 80$ frequency bins.

In [5]:
feature_extractor = WhisperFeatureExtractor.from_pretrained(MODEL_NAME)

preprocessor_config.json: 0.00B [00:00, ?B/s]

#### WhisperTokenizer
The output of the Whisper model is not raw text, but rather a sequence of token IDs. Each ID corresponds to the index of a specific sub-word unit in the model's vocabulary dictionary (which contains approximately 50k - 52k items for Whisper).
The WhisperTokenizer bridges the gap between these machine-interpretable integers and human-readable strings. It handles:
* Index Lookup: Identifying the sub-word unit associated with each predicted index.String
* Reconstruction: Concatenating these units to form the final transcript.E

Example Mapping:$$\text{Model Output: } [1169, 3797, 3332] \implies \text{Vocabulary Mapping: } [\text{"The"}, \text{" cat"}, \text{" sat"}] \implies \text{Result: "The cat sat"}$$

In [6]:
tokenizer = WhisperTokenizer.from_pretrained(MODEL_NAME, language=LANGUAGE, task=TASK)

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

normalizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

#### WhisperProcessor
Initialization process is simplified by wrapping the `WhisperFeatureExtractor` and `WhisperTokenizer` into a single `WhisperProcessor` object. This creates a cohesive interface for transforming both the input signals and the target labels.Mathematically, the processor $P$ serves as a composite function handling the multimodal inputs:$$P(x_{audio}, x_{text}) \rightarrow \left( \mathbf{S}_{log-mel}, \mathbf{t}_{tokens} \right)$$Where:$\mathbf{S}_{log-mel}$ is the spectro-temporal representation fed into the Encoder.$\mathbf{t}_{tokens}$ is the sequence of token IDs used by the Decoder for autoregressive modeling.

In [7]:
processor = WhisperProcessor.from_pretrained(MODEL_NAME, language=LANGUAGE, task=TASK)

A function is defined to preprocess the data before it is passed to the model.

In [8]:
def prepare_dataset(batch):
    # load audio data
    audio = batch["audio"]

    # compute input length
    batch["input_length"] = len(batch["audio"])

    # compute log-Mel input features from input audio array
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids
    batch["labels"] = tokenizer(batch["sentence"]).input_ids

    # compute labels length
    batch["labels_length"] = len(batch["labels"])
    return batch

# pre-process
dataset = dataset.map(prepare_dataset)


## 6. Model Configuration
The pretrained `Whisper small` model is loaded and configured for fine-tuning.

Generation parameters are overridden so that no tokens are forced or suppressed during decoding. The `use_cache` option is disabled because it is incompatible with gradient checkpointing.


In [9]:
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME)

# Override generation arguments
model.config.forced_decoder_ids = None
model.generation_config.forced_decoder_ids = None
model.generation_config.suppress_tokens = []
model.config.suppress_tokens = []
model.config.use_cache = False # Disable cache for gradient checkpointing
model.generation_config.language = LANGUAGE
model.generation_config.task = TASK


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/967M [00:00<?, ?B/s]

generation_config.json: 0.00B [00:00, ?B/s]

## 7. Data Filtering
Audio files longer than 30 seconds or containing invalid entries are excluded. Label sequences that surpass the maximum permitted length are also filtered out.



In [10]:
MAX_DURATION_IN_SECONDS = 30.0
max_input_length = MAX_DURATION_IN_SECONDS * 16000

def filter_inputs(input_length):
    """Filter inputs with zero input length or longer than 30s"""
    return 0 < input_length < max_input_length

max_label_length = model.config.max_length

def filter_labels(labels_length):
    """Filter label sequences longer than max length (448)"""
    return labels_length < max_label_length

# filter by audio length
dataset = dataset.filter(filter_inputs, input_columns=["input_length"])
# filter by label length
dataset = dataset.filter(filter_labels, input_columns=["labels_length"])


## 8. Data Collator & Metrics
The Data Collator handles two streams: fixed-size audio features for the Encoder and variable-length text for the Decoder.

**Encoder (`input_features`)**: The feature extractor outputs 30s log-Mel spectrograms. The collator stacks them into:

$$\mathbf{X}_{\text{batch}} \in \mathbb{R}^{B \times n_{\text{mels}} \times T}$$

**Decoder (`labels`)**: Text labels vary in length, so we pad dynamically and replace padding with `-100` to ignore it during loss:

$$
y_i =
\begin{cases}
\text{token_id}, & \text{valid text} \\
-100, & \text{padding}
\end{cases}
$$




In [11]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

#### Evaluation metrics

We'll use the word error rate (WER) metric, the 'de-facto' metric for assessing
ASR systems.

In [12]:
metric = evaluate.load("wer")

Downloading builder script: 0.00B [00:00, ?B/s]

We define a function that takes our model
predictions and returns the WER metric. T

In [13]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    label_ids[label_ids == -100] = tokenizer.pad_token_id

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}

## 9. Training
The dataset is split as follows: the first 500 samples are allocated to the validation set, the next 500 to the test set, and the remaining samples constitute the training set.



In [16]:
MAX_STEPS = 3000
WARM_UP_STEPS = 100
SAVE_STEP = 500

training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=2,
    learning_rate=1e-5,
    warmup_steps=WARM_UP_STEPS,
    max_steps=MAX_STEPS,
    gradient_checkpointing=True,
    fp16=True,
    eval_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=SAVE_STEP,
    eval_steps=SAVE_STEP,
    logging_steps=20,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=PUSH_TO_HUB,
)

In [17]:
# Split dataset
test_dataset = dataset['train'].take(500)              # First 500
val_dataset = dataset['train'].skip(500).take(500)     # Next 500
train_dataset = dataset['train'].skip(1000)            # The rest


In [None]:
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

processor.save_pretrained(training_args.output_dir)
print("Starting training...")
trainer.train()

  trainer = Seq2SeqTrainer(


Starting training...


Reading metadata...: 57819it [00:01, 42283.37it/s]


Step,Training Loss,Validation Loss,Wer
500,0.5725,0.455687,16.711926
1000,0.5284,0.44267,17.221091
1500,0.5196,0.406835,15.840688
2000,0.3895,0.389782,15.20706
2500,0.3543,0.384068,15.139172
3000,0.4652,0.385472,14.765784


Reading metadata...: 57819it [00:01, 49644.62it/s]
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Reading metadata...: 57819it [00:01, 51343.22it/s]
Reading metadata...: 57819it [00:01, 49761.49it/s]
Reading metadata...: 57819it [00:01, 46658.83it/s]
Reading metadata...: 57819it [00:01, 43397.50it/s]
Reading metadata...: 57819it [00:01, 48587.15it/s]
Reading metadata...: 57819it [00:01, 47379.47it/s]
There were missing keys in the checkpoint model loaded: ['proj_out.weight'].


TrainOutput(global_step=3000, training_loss=0.5413057045936585, metrics={'train_runtime': 7725.1612, 'train_samples_per_second': 12.427, 'train_steps_per_second': 0.388, 'total_flos': 2.77056413577216e+19, 'train_loss': 0.5413057045936585, 'epoch': 1.4083333333333332})

## 10. Inference & Demo



In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
MAX_DURATION = 30

pipe = pipeline(
    "automatic-speech-recognition",
    model=OUTPUT_DIR,
    device=device,
)

# Disable forced decoder ids if you did that for training
pipe.model.config.forced_decoder_ids = None

preds, refs = [], []

for i, sample in enumerate(tqdm(test_dataset, total=500)):
    audio = sample["audio"]
    duration = len(audio["array"]) / audio["sampling_rate"]

    # Skip very long samples to avoid long-form mode
    if duration > MAX_DURATION:
        continue

    out = pipe(audio)
    preds.append(out["text"])
    refs.append(sample["sentence"])

wer_ft = 100 * metric.compute(predictions=preds, references=refs)

print(f"\n{'='*60}")
print(f"üìä EVALUATION RESULTS")
print(f"{'='*60}")
print(f"Fine-tuned model Test WER: {wer_ft:.2f}")


Device set to use cuda:0
  0%|          | 0/500 [00:00<?, ?it/s]
Reading metadata...: 0it [00:00, ?it/s][A
Reading metadata...: 1it [00:00,  2.69it/s][A
Reading metadata...: 12305it [00:00, 34135.45it/s][A
Reading metadata...: 19543it [00:00, 36522.00it/s][A
Reading metadata...: 30559it [00:00, 41459.76it/s][A
Reading metadata...: 42791it [00:00, 58542.28it/s][A
Reading metadata...: 57819it [00:01, 45709.92it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 500/500 [03:57<00:00,  2.10it/s]



üìä EVALUATION RESULTS
Fine-tuned model Test WER: 16.14


In [None]:
# Load a sample from the test set
sample = next(iter(test_dataset))

# Reset forced_decoder_ids
pipe.model.config.forced_decoder_ids = None
pipe.model.generation_config.forced_decoder_ids = None

# Perform inference
result = pipe(sample["audio"])
print(f"Prediction: {result['text']}")
print(f"Reference: {sample['sentence']}")

Reading metadata...: 57819it [00:01, 55755.35it/s]


Prediction: Similarly, an ATEI or a session as she said over are the anti-attentive drugs in a diabetic patient where they slow their progression of necrosis.
Reference: Similarly, an ACEI or a sartan is preferred over other antihypertensive drugs in diabetic patients where they slow the progression of nephropathy.


### Gradio Demo


In [None]:
# 1. Setup

#Use local trained model
# model_id = OUTPUT_DIR

#Deployed model with weights
model_id = 'Ajegetina/afrispeech_ayarma_small'

print(f"Loading {model_id} on {device}...")

# 2. Load the pipeline locally
pipe = pipeline(
    "automatic-speech-recognition",
    model=model_id,
    device=device,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)

#load deployed pipeline
pipe = pipeline(
    "automatic-speech-recognition",
    model=model_id,
    device=device,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)

pipe.model.config.forced_decoder_ids = None
pipe.model.generation_config.forced_decoder_ids = None

# 3. Define the function
def transcribe(audio):
    if audio is None:
        return "Please record audio first."

    result = pipe(audio)
    return result["text"]

# 4. Launch Gradio
iface = gr.Interface(
    fn=transcribe,
    inputs=gr.Audio(sources=["microphone"], type="filepath"),
    outputs="text",
    title="Whisper Large Amaritech",
    description="Realtime demo for speech recognition.",
)

iface.launch(debug=True)


In [None]:
kwargs = {
    "dataset_tags": "tobiolatunji/afrispeech-200",
    "dataset": "Afrispeech 200",
    "dataset_args": "accent: 'all'",
    "language": "en",
    "model_name": "Whisper small - Ayarma",
    "finetuned_from": "openai/whisper-small",
    "tasks": "automatic-speech-recognition",
}

trainer.push_to_hub(**kwargs)

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...a_small/training_args.bin: 100%|##########| 5.84kB / 5.84kB            

  ...08841.1a5f9795ba4e.2568.0: 100%|##########| 39.6kB / 39.6kB            

  ...a_small/model.safetensors:   4%|4         | 41.9MB /  967MB            

CommitInfo(commit_url='https://huggingface.co/Ajegetina/afrispeech_ayarma_small/commit/39774c99cf6c67be4eb479d35fd953ddad4742fd', commit_message='End of training', commit_description='', oid='39774c99cf6c67be4eb479d35fd953ddad4742fd', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Ajegetina/afrispeech_ayarma_small', endpoint='https://huggingface.co', repo_type='model', repo_id='Ajegetina/afrispeech_ayarma_small'), pr_revision=None, pr_num=None)