<a href="https://colab.research.google.com/github/Michael-Sylvester/Ashesi-Deep-Learning/blob/main/Transformers_DL_Prosit_3_(whisper_base).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ASR with pretrained Transformers

This notebook builds a Transformer Neural Network to transcribe doctor's spoken notes into text using the AfriSpeech-200 dataset.

Author: [Michael Kwabena Sylvester] Purpose: Medical dictation system for African hospitals Dataset: AfriSpeech-200 (https://huggingface.co/datasets/intronhealth/afrispeech-200)

In [None]:
# Check if packages are still installed in case of session restarts
import importlib
def is_package_installed(package_name):
  """Checks if a Python package is installed."""
  return importlib.util.find_spec(package_name) is not None

!pip install jiwer>=3.0.1
!pip install --upgrade transformers

if not is_package_installed('huggingface_hub'):
  !pip install huggingface_hub

# Downgrade datasets to library 2.18 or 2.19 so use remote script execution
if not is_package_installed('datasets') or importlib.metadata.version('datasets') not in ['2.18.0', '2.19.1']:
  !pip install datasets==2.19.1

# Importing these is done seperatly because they have to be installed first and the cache for dataset must be changed from the default

from datasets import load_dataset, Dataset, Audio, load_from_disk



In [None]:
import os
import re
import json
import math
from dataclasses import dataclass
from typing import List, Dict, Any, Union

import numpy as np
import torch
from datasets import load_dataset, load_from_disk, concatenate_datasets
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)
import soundfile as sf
import librosa
from jiwer import wer, cer
from tqdm.auto import tqdm

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB\n")


Using device: cuda
GPU: Tesla T4
Memory: 15.83 GB



In [None]:

# Paths and caching
PERMANENT_CACHE = "/content/afrispeech_cache"
SAVE_PATH = "/content/Hugging_Face/afrispeech_saved"
os.makedirs(PERMANENT_CACHE, exist_ok=True)
os.makedirs(SAVE_PATH, exist_ok=True)
os.environ["HF_DATASETS_CACHE"] = PERMANENT_CACHE

In [None]:
# Set random seeds for reproducibility of weights
torch.manual_seed(42)
np.random.seed(42)
print("✓ Random seeds set for reproducibility")

✓ Random seeds set for reproducibility


In [None]:
# Replace MODEL_BASE
MODEL_BASE = "openai/whisper-base"  # Options: tiny, base, small, medium, large
# small is good balance for medical transcription
TARGET_SR = 16000  # Whisper uses 16kHz
SAMPLE_SPLIT = "train[:20%]"
LANG_CONFIG = "en"  # Whisper supports language codes

Load and explore dataset

In [None]:
# AfriSpeech-200 contains recordings from multiple African languages
print("\nLoading 20% of Twi dataset...")

def Check_and_load_dataset(load_disk=True):
  if load_disk:
    try:
        print(f"✓ Attempting to load dataset from disk at {SAVE_PATH}...")
        dataset = load_from_disk(SAVE_PATH)
        print(f"✓ Dataset loaded successfully from disk.")
    except FileNotFoundError:
        print(f"Dataset not found on disk or incomplete. Downloading from Hugging Face...")
        dataset = download_dataset()
  else:
    dataset = download_dataset()
  return dataset

def download_dataset():
    dataset = load_dataset(
            "intronhealth/afrispeech-200",
            "twi",
            streaming=False,
            cache_dir=PERMANENT_CACHE,
            split="train[:30%]",
            verification_mode='no_checks',  # Skip verification that can hang
        )

    print(f"✓ Dataset downloaded successfully")
    dataset.save_to_disk(SAVE_PATH)
    print(f"✓ Successfully saved dataset to {SAVE_PATH}")

    return dataset

dataset = Check_and_load_dataset(False)
print(f"Dataset loaded with {len(dataset)} samples.")


Loading 20% of Twi dataset...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Downloading builder script: 0.00B [00:00, ?B/s]

Downloading readme: 0.00B [00:00, ?B/s]

Downloading extra modules: 0.00B [00:00, ?B/s]

Downloading data:   0%|          | 0.00/1.10G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/152M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/45.9M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/427k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/61.0k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/19.2k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]


Reading metadata...: 1315it [00:00, 50854.81it/s]


Generating validation split: 0 examples [00:00, ? examples/s]


Reading metadata...: 186it [00:00, 7526.03it/s]


Generating test split: 0 examples [00:00, ? examples/s]


Reading metadata...: 58it [00:00, 38898.25it/s]


✓ Dataset downloaded successfully


Saving the dataset (0/2 shards):   0%|          | 0/394 [00:00<?, ? examples/s]

✓ Successfully saved dataset to /content/Hugging_Face/afrispeech_saved
Dataset loaded with 394 samples.


In [None]:
print("Example keys:", list(dataset.features.keys()))
print("Number of samples:", len(dataset))

Example keys: ['speaker_id', 'path', 'audio_id', 'audio', 'transcript', 'age_group', 'gender', 'accent', 'domain', 'country', 'duration']
Number of samples: 394


In [None]:
# Now, split the combined dataset into train, validation, and test sets
# 80% train, 10% validation, 10% test (from the combined data)
train_test_split = dataset.train_test_split(test_size=0.2, seed=42) # This gives 80% train, 20% temp
train_ds = train_test_split['train']
temp_dataset = train_test_split['test'] # This is 20% of the combined data

val_test_split = temp_dataset.train_test_split(test_size=0.5, seed=42) # This splits the 20% temp into 10% val, 10% test
val_ds = val_test_split['train'] # 10% of combined
test_ds = val_test_split['test'] # 10% of combined

# Empty temp datasets
del temp_dataset, val_test_split
print(f"\nDataset split:")
print(f"  Training samples: {len(train_ds)}")
print(f"  Validation samples: {len(val_ds)}")
print(f"  Test samples: {len(test_ds)}")


Dataset split:
  Training samples: 315
  Validation samples: 39
  Test samples: 40


In [None]:
print("\n" + "=" * 70)
print("Exploring the Twi dataset...")
print("=" * 70)

# Display sample data
print("\nSample entry:")
sample = train_ds[0]
for key, value in sample.items():
    if key == 'audio':
        print(f"  {key}: {{array shape: {value['array'].shape}, sampling_rate: {value['sampling_rate']}}}")
    else:
        print(f"  {key}: {value}")


Exploring the Twi dataset...

Sample entry:
  speaker_id: 3d82dd8788a22fe050e4d31d3a3f0f01
  path: /content/afrispeech_cache/downloads/extracted/03239d12530f1c0ebd8c4bb43ad20f7a0dab6f4cfa82a040874e11cdc674add2/539b7903-f4b7-402f-b105-f18945627850/834d518d3dfe5afdb2a6a16ff2ef4ead.wav
  audio_id: 539b7903-f4b7-402f-b105-f18945627850/834d518d3dfe5afdb2a6a16ff2ef4ead
  audio: {array shape: (186895,), sampling_rate: 44100}
  transcript: One thing we see on college campuses is that no one builds single-purpose spaces anymore.

  age_group: 19-25
  gender: Male
  accent: twi
  domain: general
  country: GH
  duration: 4.237981796264648


Load and process audio

In [None]:
def resample_to_16k(batch):
    """Return a dict with 'speech' numpy array resampled to 16k and 'transcript' text."""
    audio = batch["audio"]
    # Some dataset rows store audio as dict{'array', 'sampling_rate'}; adapt accordingly
    if isinstance(audio, dict) and "array" in audio:
        array = np.asarray(audio["array"], dtype=np.float32)
        sr = int(audio["sampling_rate"])
    else:
        # If audio is a path accepted by load_dataset, let librosa read it
        array, sr = sf.read(audio)
        array = array.astype(np.float32)
    if sr != TARGET_SR:
        array = librosa.resample(array, orig_sr=sr, target_sr=TARGET_SR)
    # Ensure 1D float32
    if array.ndim > 1:
        array = np.mean(array, axis=1)
    batch["speech"] = array
    # unify transcript field names: AfriSpeech uses 'transcript' (check)
    if "transcript" in batch:
        batch["transcript_text"] = batch["transcript"]
    elif "text" in batch:
        batch["transcript_text"] = batch["text"]
    else:
        batch["transcript_text"] = ""
    return batch

# Map the resampling across datasets (this will store 'speech' and 'transcript_text')
train_ds = train_ds.map(resample_to_16k)
val_ds = val_ds.map(resample_to_16k)
test_ds = test_ds.map(resample_to_16k)


# Prepare the data to be loaded into the data loader

In [None]:
# Load Whisper processor (handles both feature extraction and tokenization)
processor = WhisperProcessor.from_pretrained(
    MODEL_BASE,
    language="eng",
    task="transcribe"
)

In [None]:
def prepare_dataset(batch):
    """Prepare audio and text for Whisper training."""
    # Compute log-Mel spectrogram features
    # Fix: 'audio_array' does not exist, it should be 'speech' from the previous mapping step.
    audio = batch["speech"]
    batch["input_features"] = processor(
        audio,
        sampling_rate=TARGET_SR,
        return_tensors="pt"
    ).input_features[0]

    # Encode target text
    batch["labels"] = processor.tokenizer(
        batch["transcript_text"],
        return_tensors="pt"
    ).input_ids[0]

    return batch

train_ds = train_ds.map(prepare_dataset, remove_columns=train_ds.column_names)
val_ds = val_ds.map(prepare_dataset, remove_columns=val_ds.column_names)
test_ds = test_ds.map(prepare_dataset )

Map:   0%|          | 0/315 [00:00<?, ? examples/s]

Map:   0%|          | 0/39 [00:00<?, ? examples/s]

Map:   0%|          | 0/40 [00:00<?, ? examples/s]

DataCollator

In [None]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features):
        # Separate input and label features
        input_features = [{"input_features": f["input_features"]} for f in features]
        label_features = [{"input_ids": f["labels"]} for f in features]

        # Pad audio features
        batch = self.processor.feature_extractor.pad(
            input_features,
            return_tensors="pt"
        )

        # Manually create attention mask for the encoder
        # Mask = 1 for real frames, 0 for padded frames
        attention_mask = torch.ones(batch["input_features"].shape[:-1], dtype=torch.long)
        batch["attention_mask"] = attention_mask

        # Pad text labels
        labels_batch = self.processor.tokenizer.pad(
            label_features,
            return_tensors="pt"
        )

        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        # Remove BOS token if present
        # if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all():
        #    labels = labels[:, 1:]

        batch["labels"] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

Load model for fine-tuning (CTC)

In [None]:
# Load Whisper model
model = WhisperForConditionalGeneration.from_pretrained(MODEL_BASE)

# Optionally freeze encoder for faster training
# model.freeze_encoder()

# Important: Set language and task tokens
model.generation_config.language = "en"
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")


model.to(device)
print(f"Model loaded: {MODEL_BASE}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/290M [00:00<?, ?B/s]

generation_config.json: 0.00B [00:00, ?B/s]

Model loaded: openai/whisper-base
Model parameters: 72.59M


TrainingArguments and Trainer

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-medical-twi",
    per_device_train_batch_size=8,  # Adjust based on your GPU/CPU
    gradient_accumulation_steps=2,
    learning_rate=1e-5,
    warmup_steps=300,
    num_train_epochs=10,
    eval_strategy="steps",
    save_strategy="steps",
    logging_steps=25,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    fp16=torch.cuda.is_available(),  # Use fp16 only if GPU available
    report_to=["none"],
    predict_with_generate=True,  # Important for Whisper
    generation_max_length=225
)

def compute_metrics(pred):
    """Compute WER and CER metrics."""
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # Replace -100 with pad token
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # Decode predictions and labels
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    # Compute metrics
    wer_score = wer(label_str, pred_str)
    cer_score = cer(label_str, pred_str)

    return {"wer": wer_score, "cer": cer_score}

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer = processor.tokenizer
)

  trainer = Seq2SeqTrainer(


In [None]:
trainer.train()

You're using a WhisperTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss,Wer,Cer
25,6.4531,5.743569,0.600349,0.329598
50,4.9355,4.139003,0.586387,0.324138
75,3.293,2.654808,0.558464,0.293391
100,2.047,1.958582,0.500873,0.260632
125,1.4627,1.720352,0.448517,0.234195
150,1.125,1.596068,0.455497,0.229598
175,0.8417,1.52501,0.450262,0.224713
200,0.6329,1.488349,0.43281,0.218678


There were missing keys in the checkpoint model loaded: ['proj_out.weight'].


TrainOutput(global_step=200, training_loss=2.598861131668091, metrics={'train_runtime': 521.8033, 'train_samples_per_second': 6.037, 'train_steps_per_second': 0.383, 'total_flos': 2.04308987904e+17, 'train_loss': 2.598861131668091, 'epoch': 10.0})

In [None]:
test_ds

Dataset({
    features: ['speaker_id', 'path', 'audio_id', 'audio', 'transcript', 'age_group', 'gender', 'accent', 'domain', 'country', 'duration', 'speech', 'transcript_text', 'input_features', 'labels'],
    num_rows: 40
})

In [None]:
def transcribe_audio(audio_array: np.ndarray):
    """Transcribe audio using trained Whisper model."""

    # Convert raw audio to input features for Whisper
    inputs = processor(
        audio_array,
        sampling_rate=TARGET_SR,
        return_tensors="pt"
    ).to(device) # This dictionary contains 'input_features'. An attention_mask is not generated for single, unpadded inputs.

    # Whisper needs forced decoder ids (language + task tokens)
    forced_decoder_ids = processor.get_decoder_prompt_ids(
        language="en",
        task="transcribe"
    )

    with torch.no_grad():
        # Pass the input_features to generate
        predicted_ids = model.generate(
            input_features=inputs.input_features,
            forced_decoder_ids=forced_decoder_ids,
            max_new_tokens=80,
            num_beams=5,
            repetition_penalty=1.2,
            length_penalty=1.0,
            no_repeat_ngram_size=3
        )

    # Decode to text
    transcription = processor.batch_decode(
        predicted_ids,
        skip_special_tokens=True
    )[0]

    return transcription

# Test on a sample
idx = 0

# Use the 'speech' field which contains the preprocessed (resampled to 16kHz) audio
sample_audio = test_ds[idx]["speech"]
transcription = transcribe_audio(sample_audio)
print(f"Transcription: {transcription}")
print(f"Original: {test_ds[idx]["transcript"]}")

Transcription:  As of Thursday evening, there were 258 confirmed cases of COVID-19 in Ontario.
Original: As of Thursday evening, there were 258 confirmed cases of COVID-19 in Ontario.



Inference example