In [1]:
# tutorial for finetuning whisper
# https://huggingface.co/blog/fine-tune-whisper

In [2]:
import torch
print("MPS:", torch.backends.mps.is_available())

import sys
print(sys.executable)

MPS: True
/Users/zuzamakowska/Documents/Africa/Project/Low-resource-languages/venv/bin/python


In [3]:
# huggingface-cli whoami <- to check if you're logged in to hugging face 

## Download Common Voice dataset (Swahili)

In [4]:
from datasets.utils.logging import set_verbosity_info
set_verbosity_info()

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# api key: echo $MDC_API_KEY

from datasets import load_dataset, Features, Value, Audio

features = Features({
    "client_id": Value("string"),
    "path": Value("string"),
    "sentence_id": Value("string"),
    "sentence": Value("string"),
    "sentence_domain": Value("string"),
    "up_votes": Value("string"),
    "down_votes": Value("string"),
    "age": Value("string"),
    "gender": Value("string"),
    "accents": Value("string"),
    "variant": Value("string"),
    "locale": Value("string"),
    "segment": Value("string"),
})

ds = load_dataset(
    "csv",
    data_files={
        "train": "../data/cv-corpus-23.0-2025-09-05/sw/train.tsv",
        "validation": "../data/cv-corpus-23.0-2025-09-05/sw/dev.tsv",
        "test": "../data/cv-corpus-23.0-2025-09-05/sw/test.tsv"
    },
    delimiter="\t",
    features=features,
)


Using custom data configuration default-f2ac76f0bf43f341
Found cached dataset csv (/Users/zuzamakowska/.cache/huggingface/datasets/csv/default-f2ac76f0bf43f341/0.0.0/a43390c7ecea6519ff2ce9d10005c8750601c9e456069be5efbd2747df45f420)


In [6]:
print(ds["train"][0])
print(ds["train"].features)

{'client_id': 'f0e3f1cbc3526ea2273da567cb59583344ad1606b7be2ad13a902e8b971af62a3da34c20d4ee0e9ee59812079b0400ef7d2091514e02e74fa44c27a2c9ca5862', 'path': 'common_voice_sw_30558307.mp3', 'sentence_id': '5071bd0e179a4bb6d66e567969bcedb020858ddde6ed0eda07174e909497afe0', 'sentence': 'deLima alifunga mabao mawili kwenye fainali ya kombe la dunia', 'sentence_domain': None, 'up_votes': '2', 'down_votes': '0', 'age': 'thirties', 'gender': 'female_feminine', 'accents': None, 'variant': None, 'locale': 'sw', 'segment': None}
{'client_id': Value('string'), 'path': Value('string'), 'sentence_id': Value('string'), 'sentence': Value('string'), 'sentence_domain': Value('string'), 'up_votes': Value('string'), 'down_votes': Value('string'), 'age': Value('string'), 'gender': Value('string'), 'accents': Value('string'), 'variant': Value('string'), 'locale': Value('string'), 'segment': Value('string')}


In [7]:
def fix_path(batch):
    batch["path"] = "/Users/zuzamakowska/Documents/Africa/Project/Low-resource-languages/data/cv-corpus-23.0-2025-09-05/sw/clips/" + batch["path"]
    return batch

ds = ds.map(fix_path)
print(ds["train"].features)

Loading cached processed dataset at /Users/zuzamakowska/.cache/huggingface/datasets/csv/default-f2ac76f0bf43f341/0.0.0/a43390c7ecea6519ff2ce9d10005c8750601c9e456069be5efbd2747df45f420/cache-c3ef7b19a29cc9c2_*_of_00001.arrow
Loading cached processed dataset at /Users/zuzamakowska/.cache/huggingface/datasets/csv/default-f2ac76f0bf43f341/0.0.0/a43390c7ecea6519ff2ce9d10005c8750601c9e456069be5efbd2747df45f420/cache-47e7b05a26992a9b_*_of_00001.arrow
Loading cached processed dataset at /Users/zuzamakowska/.cache/huggingface/datasets/csv/default-f2ac76f0bf43f341/0.0.0/a43390c7ecea6519ff2ce9d10005c8750601c9e456069be5efbd2747df45f420/cache-92b28f3e847a198a_*_of_00001.arrow


{'client_id': Value('string'), 'path': Value('string'), 'sentence_id': Value('string'), 'sentence': Value('string'), 'sentence_domain': Value('string'), 'up_votes': Value('string'), 'down_votes': Value('string'), 'age': Value('string'), 'gender': Value('string'), 'accents': Value('string'), 'variant': Value('string'), 'locale': Value('string'), 'segment': Value('string')}


In [8]:
from datasets import Audio
ds = ds.cast_column("path", Audio(sampling_rate=16000))

In [9]:

print(ds["train"].features)
print(ds["train"][0])


{'client_id': Value('string'), 'path': Audio(sampling_rate=16000, decode=True, num_channels=None, stream_index=None), 'sentence_id': Value('string'), 'sentence': Value('string'), 'sentence_domain': Value('string'), 'up_votes': Value('string'), 'down_votes': Value('string'), 'age': Value('string'), 'gender': Value('string'), 'accents': Value('string'), 'variant': Value('string'), 'locale': Value('string'), 'segment': Value('string')}
{'client_id': 'f0e3f1cbc3526ea2273da567cb59583344ad1606b7be2ad13a902e8b971af62a3da34c20d4ee0e9ee59812079b0400ef7d2091514e02e74fa44c27a2c9ca5862', 'path': <datasets.features._torchcodec.AudioDecoder object at 0x149d1f280>, 'sentence_id': '5071bd0e179a4bb6d66e567969bcedb020858ddde6ed0eda07174e909497afe0', 'sentence': 'deLima alifunga mabao mawili kwenye fainali ya kombe la dunia', 'sentence_domain': None, 'up_votes': '2', 'down_votes': '0', 'age': 'thirties', 'gender': 'female_feminine', 'accents': None, 'variant': None, 'locale': 'sw', 'segment': None}


In [10]:
ds = ds.remove_columns(['client_id', 'sentence_id', 'sentence_domain', 'up_votes', 'down_votes', 'age', 'gender', 'accents', 'locale', 'segment'])

In [11]:
ds

DatasetDict({
    train: Dataset({
        features: ['path', 'sentence', 'variant'],
        num_rows: 46611
    })
    validation: Dataset({
        features: ['path', 'sentence', 'variant'],
        num_rows: 11692
    })
    test: Dataset({
        features: ['path', 'sentence', 'variant'],
        num_rows: 11944
    })
})

In [12]:
print(ds['train'].column_names)
print(ds['train'].features)

['path', 'sentence', 'variant']
{'path': Audio(sampling_rate=16000, decode=True, num_channels=None, stream_index=None), 'sentence': Value('string'), 'variant': Value('string')}


In [13]:
print(ds["train"][0])

{'path': <datasets.features._torchcodec.AudioDecoder object at 0x13c9ed600>, 'sentence': 'deLima alifunga mabao mawili kwenye fainali ya kombe la dunia', 'variant': None}


In [14]:
ds = ds.with_format("numpy")
ds = ds.rename_column("path", "audio")
sample = ds["train"][0]
print(sample["audio"]["array"][:10])

[-7.3472851e-12 -4.9304753e-12  1.4526930e-12  1.2472109e-11
  1.9373041e-11  3.6777239e-12 -7.8702288e-12 -2.6043463e-11
 -9.5683704e-13  4.1811121e-11]


In [15]:
print(ds['train'][0])

{'audio': <datasets.features._torchcodec.AudioDecoder object at 0x17789fc10>, 'sentence': np.str_('deLima alifunga mabao mawili kwenye fainali ya kombe la dunia'), 'variant': None}


## Features Extraction

In [16]:
from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Swahili", task="transcribe", padding='longest')


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [17]:
from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained('openai/whisper-small')

In [18]:
input_str = ds['train'][0]['sentence']
# labels = tokenizer(input_str).input_ids
input_str

np.str_('deLima alifunga mabao mawili kwenye fainali ya kombe la dunia')

In [19]:
labels = tokenizer(input_str).input_ids
labels

[50258,
 50318,
 50359,
 50363,
 1479,
 43,
 4775,
 419,
 351,
 1063,
 64,
 275,
 5509,
 78,
 463,
 86,
 2312,
 350,
 15615,
 1200,
 283,
 491,
 5103,
 2478,
 5207,
 650,
 635,
 10234,
 654,
 50257]

In [20]:
decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
decoded_with_special

'<|startoftranscript|><|sw|><|transcribe|><|notimestamps|>deLima alifunga mabao mawili kwenye fainali ya kombe la dunia<|endoftext|>'

In [21]:
decoded_str = tokenizer.decode(labels, skip_special_tokens=True)
decoded_str

'deLima alifunga mabao mawili kwenye fainali ya kombe la dunia'

In [22]:
raw_tokens = tokenizer(input_str)
raw_tokens

{'input_ids': [50258, 50318, 50359, 50363, 1479, 43, 4775, 419, 351, 1063, 64, 275, 5509, 78, 463, 86, 2312, 350, 15615, 1200, 283, 491, 5103, 2478, 5207, 650, 635, 10234, 654, 50257], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [23]:
decoded_tokens = tokenizer.convert_ids_to_tokens(labels)
print(decoded_tokens)

['<|startoftranscript|>', '<|sw|>', '<|transcribe|>', '<|notimestamps|>', 'de', 'L', 'ima', 'Ä al', 'if', 'ung', 'a', 'Ä m', 'aba', 'o', 'Ä ma', 'w', 'ili', 'Ä k', 'wen', 'ye', 'Ä f', 'ain', 'ali', 'Ä ya', 'Ä kom', 'be', 'Ä la', 'Ä dun', 'ia', '<|endoftext|>']


### WhisperProcessor

In [24]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained('openai/whisper-small', language='Swahili', task='transcribe')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [25]:
print(ds["train"][0])

{'audio': <datasets.features._torchcodec.AudioDecoder object at 0x17aec7880>, 'sentence': np.str_('deLima alifunga mabao mawili kwenye fainali ya kombe la dunia'), 'variant': None}


In [26]:
from datasets import Audio

ds = ds.cast_column('audio', Audio(sampling_rate=16000))

In [27]:
print(ds["train"][0])

{'audio': <datasets.features._torchcodec.AudioDecoder object at 0x178f504f0>, 'sentence': np.str_('deLima alifunga mabao mawili kwenye fainali ya kombe la dunia'), 'variant': None}


In [28]:
def prepare_dataset(batch):
    audio = batch["audio"]
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate = audio["sampling_rate"]).input_features[0]
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

In [29]:
preprocessed_ds = ds.map(prepare_dataset, num_proc=4)

Loading cached processed dataset at /Users/zuzamakowska/.cache/huggingface/datasets/csv/default-f2ac76f0bf43f341/0.0.0/a43390c7ecea6519ff2ce9d10005c8750601c9e456069be5efbd2747df45f420/cache-fca951e89734764d_*_of_00004.arrow
Concatenating 4 shards
Loading cached processed dataset at /Users/zuzamakowska/.cache/huggingface/datasets/csv/default-f2ac76f0bf43f341/0.0.0/a43390c7ecea6519ff2ce9d10005c8750601c9e456069be5efbd2747df45f420/cache-0172993911fa61d3_*_of_00004.arrow
Concatenating 4 shards
Loading cached processed dataset at /Users/zuzamakowska/.cache/huggingface/datasets/csv/default-f2ac76f0bf43f341/0.0.0/a43390c7ecea6519ff2ce9d10005c8750601c9e456069be5efbd2747df45f420/cache-4df469bd0b16d126_*_of_00004.arrow
Concatenating 4 shards


In [30]:
preprocessed_ds

DatasetDict({
    train: Dataset({
        features: ['audio', 'sentence', 'variant', 'input_features', 'labels'],
        num_rows: 46611
    })
    validation: Dataset({
        features: ['audio', 'sentence', 'variant', 'input_features', 'labels'],
        num_rows: 11692
    })
    test: Dataset({
        features: ['audio', 'sentence', 'variant', 'input_features', 'labels'],
        num_rows: 11944
    })
})

In [31]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained('openai/whisper-small')

In [32]:
model.generation_config.language = "swahili"
model.generation_config.task = "transcribe"

model.generation_config.forced_decoder_ids = None

In [33]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch


In [34]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

In [35]:
import evaluate

metric = evaluate.load("wer")

In [38]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [37]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="../models/whisper-small-sw",
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=5000,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
)




In [44]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=preprocessed_ds["train"],
    eval_dataset=preprocessed_ds["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)


max_steps is given, it will override any value given in num_train_epochs


In [None]:
trainer.train(log_level="info")

  return fn(*args, **kwargs)
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


In [None]:
ds = ds.remove_columns(['client_id', 'sentence_id', 'sentence_domain', 'up_votes', 'down_votes', 'age', 'gender', 'accents', 'locale', 'segment'])

In [42]:
print(ds)
print(ds["train"].num_rows)
print(ds["validation"].num_rows)
print(ds["test"].num_rows)

DatasetDict({
    train: Dataset({
        features: ['audio', 'sentence', 'variant'],
        num_rows: 46611
    })
    validation: Dataset({
        features: ['audio', 'sentence', 'variant'],
        num_rows: 11692
    })
    test: Dataset({
        features: ['audio', 'sentence', 'variant'],
        num_rows: 11944
    })
})
46611
11692
11944


In [43]:
print(ds["train"][0])

{'audio': <datasets.features._torchcodec.AudioDecoder object at 0xfb7a58d90>, 'sentence': np.str_('deLima alifunga mabao mawili kwenye fainali ya kombe la dunia'), 'variant': None}
