## Fine tune OpenAI Whisper model

This notebook demonstrates how to fine-tune OpenAI's Whisper model for Swahili speech recognition. We'll take a pre-trained Whisper model and adapt it to better understand Swahili audio, improving its accuracy for this specific language.

In this notebook the following steps will be followed
1. **Environment Setup** - Install required packages and authenticate
2. **Data Loading** - Load and inspect the Swahili audio dataset
3. **Data Preprocessing** - Convert audio and text to model-compatible format
4. **Model Configuration** - Set up the pre-trained Whisper model
5. **Training Setup** - Configure training parameters and metrics
6. **Fine-tuning** - Train the model on Swahili data
7. **Evaluation** - Assess model performance using Word Error Rate (WER)
8. **Model Deployment** - Save and upload the trained model


## Environment Setup

In [None]:
# Disable experiment tracking for cleaner output
import os
os.environ["WANDB_DISABLED"] = "true"

# Install required packages for speech recognition training
!pip install datasets transformers torch evaluate jiwer accelerate

In [None]:

# Import necessary libraries
from datasets import load_dataset, Audio
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
import torch
import evaluate
from dataclasses import dataclass
from typing import Dict, List, Union



## Data Loading

In [None]:

# Load the Swahili speech dataset from Hugging Face Hub
# This dataset contains audio files paired with their transcriptions
dataset = load_dataset("Denhotech/data_hf1")

print(f"Dataset loaded successfully!")
print(f"Training samples: {len(dataset['train'])}")
print(f"Test samples: {len(dataset['test'])}")


## Data Preprocessing

In [None]:

# Ensure all audio is sampled at 16kHz (Whisper's required sampling rate)
sample_audio = dataset["train"][0]["audio"]
if sample_audio['sampling_rate'] != 16000:
    print("Resampling audio to 16kHz...")
    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
else:
    print("Audio already at correct sampling rate (16kHz)")

# Initialize the Whisper processor for Swahili
# This handles both audio feature extraction and text tokenization
processor = WhisperProcessor.from_pretrained(
    "openai/whisper-small",
    language="swahili",
    task="transcribe"
)

def prepare_dataset(batch):
    """
    Convert raw audio and text into model inputs
    - Audio becomes input_features (mel spectrograms)
    - Text becomes labels (token IDs)
    """
    # Convert audio waveform to mel-spectrogram features
    audio = batch["audio"]
    batch["input_features"] = processor(
        audio["array"],
        sampling_rate=audio["sampling_rate"]
    ).input_features[0]

    # Convert text to token IDs (handle different column names)
    text_col = "text" if "text" in batch else "sentence"
    batch["labels"] = processor.tokenizer(batch[text_col]).input_ids

    return batch

# Apply preprocessing to the entire dataset
print("Processing dataset...")
dataset = dataset.map(prepare_dataset, remove_columns=dataset["train"].column_names)
print("Dataset preprocessing complete!")


## Data Collation Setup

In [None]:

@dataclass
class DataCollator:
    """
    Handles batching of variable-length audio and text sequences
    Ensures all sequences in a batch have the same length through padding
    """
    processor: any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Pad audio features to same length within batch
        input_features = [{"input_features": f["input_features"]} for f in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # Pad text labels to same length within batch
        label_features = [{"input_ids": f["labels"]} for f in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # Replace padding tokens with -100 (ignored during loss calculation)
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # Remove beginning-of-sequence token if present (not needed for Whisper training)
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        return batch

# Create data collator instance
data_collator = DataCollator(processor=processor)


## Model Setup

In [None]:

# Load the pre-trained Whisper-small model
# This model already knows how to transcribe speech but we'll adapt it for Swahili
print("Loading pre-trained Whisper model...")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
model.config.use_cache = False  # Disable caching for training efficiency


## Evaluation Metrics

In [None]:

# Load Word Error Rate (WER) metric - standard for speech recognition evaluation
# Lower WER means better performance (0% = perfect transcription)
from jiwer import wer, RemovePunctuation, RemoveDiacritics, ToLowerCase, RemoveMultipleSpaces

# Define the transformation for WER computation
transformation = (
    RemovePunctuation() 
    | RemoveDiacritics()
    | ToLowerCase()
    | RemoveMultipleSpaces()
)


def compute_metrics(eval_pred):
    pred_ids, label_ids = eval_pred
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer_score = 100 * wer(
        ground_truth=label_str,
        hypothesis=pred_str,
        truth_transform=transformation,
        hypothesis_transform=transformation,
    )
    return {"wer": wer_score}



## Training Configuration

In [None]:
# Configure training parameters
training_args = Seq2SeqTrainingArguments(
    output_dir="whisper-swahili",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    learning_rate=1e-5,
    warmup_steps=100,
    num_train_epochs=3,
    eval_strategy="steps",
    eval_steps=500,
    save_steps=500,
    logging_steps=50,
    fp16=True,
    predict_with_generate=True,
    remove_unused_columns=False,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=True,
    hub_model_id="Denhotech/asr_model",
)



# Initialize the trainer with model, data, and configuration
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    processing_class=processor,
)


## Model Fine-tuning

In [None]:
print("Starting fine-tuning process...")

# Begin training - this adapts the pre-trained model to Swahili speech
trainer.train()

print("Fine-tuning completed!")

## Model Saving and Deployment

In [None]:

print("Saving trained model...")

# Save the fine-tuned model and processor locally
trainer.save_model()
processor.save_pretrained("whisper-swahili")

# Upload the trained model to Hugging Face Hub for sharing
print("Uploading to Hugging Face Hub...")
trainer.push_to_hub()
processor.push_to_hub("Denhotech/asr_model")

print("Model successfully uploaded to: https://huggingface.co/Denhotech/asr_model")


## Final Inference and Submission File Generation

In [None]:
import pandas as pd
import torchaudio
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch
from tqdm import tqdm
import os

# Load the fine-tuned model 
processor = WhisperProcessor.from_pretrained("whisper-swahili")
model = WhisperForConditionalGeneration.from_pretrained("whisper-swahili")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()

# Path to held-out validation dataset in CSV
# You need a CSV with at least two columns: `id` and `audio_filepath`
# e.g. validation.csv:
# id,audio_filepath
# 001,validation_audio/001.wav
# 002,validation_audio/002.wav

validation_csv_path = "validation.csv"  # 🔁 Replace with real path
validation_df = pd.read_csv(validation_csv_path)

# Resample helper if needed
resampler = torchaudio.transforms.Resample(orig_freq=48000, new_freq=16000)  # adjust if input is 48kHz

# Inference loop
predictions = []

print("Running inference on validation set...")

for idx, row in tqdm(validation_df.iterrows(), total=len(validation_df)):
    audio_path = row["audio_filepath"]
    waveform, sample_rate = torchaudio.load(audio_path)

    # Optional resampling to 16kHz
    if sample_rate != 16000:
        waveform = resampler(waveform)

    # Whisper expects mono input
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)

    input_features = processor(
        waveform.squeeze().numpy(),
        sampling_rate=16000,
        return_tensors="pt"
    ).input_features.to(device)

    with torch.no_grad():
        predicted_ids = model.generate(input_features)

    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]

    predictions.append({
        "id": row["id"],
        "predicted_text": transcription
    })

#Save to CSV in submission format
submission_df = pd.DataFrame(predictions)
submission_df.to_csv("submission_predictions.csv", index=False)

print("✅ Inference complete! Saved to submission_predictions.csv")
