In [1]:
# tutorial for finetuning whisper
# https://huggingface.co/blog/fine-tune-whisper

In [2]:
import torch
print("MPS:", torch.backends.mps.is_available())

import sys
print(sys.executable)

MPS: True
/Users/zuzamakowska/Documents/Africa/Project/Low-resource-languages/code/venv/bin/python


In [3]:
# huggingface-cli whoami <- to check if you're logged in to hugging face 

## Download Common Voice dataset (Swahili)

In [4]:
from datasets.utils.logging import set_verbosity_info
set_verbosity_info()

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
from datasets import load_dataset, Audio, DownloadConfig
from huggingface_hub import get_token

tok = get_token()
dcfg = DownloadConfig(resume_download=True)

ds = load_dataset(
    "mozilla-foundation/common_voice_17_0",
    data_dir="sw",
    revision="refs/convert/parquet",
    token=tok,
    download_config=dcfg,
)
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
print(ds)


No config specified, defaulting to the single config: common_voice_17_0/default
Using custom data configuration default-d20c9c50bd16512f
Found cached dataset common_voice_17_0 (/Users/zuzamakowska/.cache/huggingface/datasets/mozilla-foundation___common_voice_17_0/default-d20c9c50bd16512f/0.0.0/91fadf4081526b3cf5edcf0157d15a949a2012a0)


DatasetDict({
    train: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'variant'],
        num_rows: 46494
    })
    validation: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'variant'],
        num_rows: 12251
    })
    test: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'variant'],
        num_rows: 12253
    })
})


In [6]:
ds = ds.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"])

In [7]:
ds

DatasetDict({
    train: Dataset({
        features: ['audio', 'sentence', 'variant'],
        num_rows: 46494
    })
    validation: Dataset({
        features: ['audio', 'sentence', 'variant'],
        num_rows: 12251
    })
    test: Dataset({
        features: ['audio', 'sentence', 'variant'],
        num_rows: 12253
    })
})

In [8]:
print(ds['train'].column_names)
print(ds['train'].features)

['audio', 'sentence', 'variant']
{'audio': Audio(sampling_rate=16000, decode=True, stream_index=None), 'sentence': Value('string'), 'variant': Value('string')}


## Features Extraction

In [9]:
from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Swahili", task="transcribe", padding='longest')


In [12]:
from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained('openai/whisper-small')

In [15]:
input_str = ds['train'][0]['sentence']
# labels = tokenizer(input_str).input_ids
input_str

'Macho yangu yamefungwa'

In [16]:
labels = tokenizer(input_str).input_ids
labels

[50258,
 50318,
 50359,
 50363,
 44,
 46574,
 5581,
 84,
 288,
 529,
 69,
 1063,
 4151,
 50257]

In [19]:
decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
decoded_with_special

'<|startoftranscript|><|sw|><|transcribe|><|notimestamps|>Macho yangu yamefungwa<|endoftext|>'

In [21]:
decoded_str = tokenizer.decode(labels, skip_special_tokens=True)
decoded_str

'Macho yangu yamefungwa'

In [22]:
raw_tokens = tokenizer(input_str)
raw_tokens

{'input_ids': [50258, 50318, 50359, 50363, 44, 46574, 5581, 84, 288, 529, 69, 1063, 4151, 50257], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [25]:
decoded_tokens = tokenizer.convert_ids_to_tokens(labels)
print(decoded_tokens)

['<|startoftranscript|>', '<|sw|>', '<|transcribe|>', '<|notimestamps|>', 'M', 'acho', 'Ġyang', 'u', 'Ġy', 'ame', 'f', 'ung', 'wa', '<|endoftext|>']


### WhisperProcessor

In [26]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained('openai/whisper-small', language='Swahili', task='transcribe')

In [28]:
print(ds["train"][0])



{'audio': <datasets.features._torchcodec.AudioDecoder object at 0x13050d690>, 'sentence': 'Macho yangu yamefungwa', 'variant': ''}


In [29]:
from datasets import Audio

ds = ds.cast_column('audio', Audio(sampling_rate=16000))

In [30]:
print(ds["train"][0])

{'audio': <datasets.features._torchcodec.AudioDecoder object at 0x130493350>, 'sentence': 'Macho yangu yamefungwa', 'variant': ''}


In [31]:
def prepare_dataset(batch):
    audio = batch["audio"]
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate = audio["sampling_rate"]).input_features[0]
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

In [32]:
preprocessed_ds = ds.map(prepare_dataset, num_proc=4)

Process #1 will write at /Users/zuzamakowska/.cache/huggingface/datasets/mozilla-foundation___common_voice_17_0/default-d20c9c50bd16512f/0.0.0/91fadf4081526b3cf5edcf0157d15a949a2012a0/cache-d9e2fc0a921ecb28_00000_of_00004.arrow
Process #2 will write at /Users/zuzamakowska/.cache/huggingface/datasets/mozilla-foundation___common_voice_17_0/default-d20c9c50bd16512f/0.0.0/91fadf4081526b3cf5edcf0157d15a949a2012a0/cache-d9e2fc0a921ecb28_00001_of_00004.arrow
Process #3 will write at /Users/zuzamakowska/.cache/huggingface/datasets/mozilla-foundation___common_voice_17_0/default-d20c9c50bd16512f/0.0.0/91fadf4081526b3cf5edcf0157d15a949a2012a0/cache-d9e2fc0a921ecb28_00002_of_00004.arrow
Process #4 will write at /Users/zuzamakowska/.cache/huggingface/datasets/mozilla-foundation___common_voice_17_0/default-d20c9c50bd16512f/0.0.0/91fadf4081526b3cf5edcf0157d15a949a2012a0/cache-d9e2fc0a921ecb28_00003_of_00004.arrow
Map (num_proc=4):   0%|          | 0/46494 [00:00<?, ? examples/s]Spawning 4 processes
C

In [34]:
preprocessed_ds

DatasetDict({
    train: Dataset({
        features: ['audio', 'sentence', 'variant', 'input_features', 'labels'],
        num_rows: 46494
    })
    validation: Dataset({
        features: ['audio', 'sentence', 'variant', 'input_features', 'labels'],
        num_rows: 12251
    })
    test: Dataset({
        features: ['audio', 'sentence', 'variant', 'input_features', 'labels'],
        num_rows: 12253
    })
})

In [35]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained('openai/whisper-small')

In [36]:
model.generation_config.language = "swahili"
model.generation_config.task = "transcribe"

model.generation_config.forced_decoder_ids = None