Fine-Tune Whisper For Multilingual ASR with 🤗 Transformers
===========================================================

Introduction
------------

Whisper is a pre-trained model for automatic speech recognition (ASR) published in [September 2022](https://openai.com/blog/whisper/) by the authors Alec Radford et al. from OpenAI. Unlike many of its predecessors, such as [Wav2Vec 2.0](https://arxiv.org/abs/2006.11477), which are pre-trained on un-labelled audio data, Whisper is pre-trained on a vast quantity of __labelled__ audio-transcription data, 680,000 hours to be precise. This is an order of magnitude more data than the un-labelled audio data used to train Wav2Vec 2.0 (60,000 hours). What is more, 117,000 hours of this pre-training data is multilingual ASR data. This results in checkpoints that can be applied to over 96 languages, many of which are considered _low-resource_.

This quantity of labelled data enables Whisper to be pre-trained directly on the _supervised_ task of speech recognition, learning a speech-to-text mapping from the labelled audio-transcription pre-training data<sup>1</sup>. As a consequence, Whisper requires little additional fine-tuning to yield a performant ASR model. This is in contrast to Wav2Vec 2.0, which is pre-trained on the _unsupervised_ task of masked prediction. Here, the model is trained to learn an intermediate mapping from speech to hidden states from un-labelled audio only data. While unsupervised pre-training yields high-quality representations of speech, it does __not__ learn a speech-to-text mapping. This mapping is only learned during fine-tuning, thus requiring more fine-tuning to yield competitive performance.

------------------------------------------------------------------------

\\({}^1\\) The name Whisper follows from the acronym “WSPSR”, which stands for “Web-scale Supervised Pre-training for Speech Recognition”.

When scaled to 680,000 hours of labelled pre-training data, Whisper models demonstrate a strong ability to generalise to many datasets and domains. The pre-trained checkpoints achieve competitive results to state-of-the-art ASR systems, with near 3% word error rate (WER) on the test-clean subset of LibriSpeech ASR and a new state-of-the-art on TED-LIUM with 4.7% WER (_c.f._ Table 8 of the [Whisper paper](https://cdn.openai.com/papers/whisper.pdf)). The extensive multilingual ASR knowledge acquired by Whisper during pre-training can be leveraged for other low-resource languages; through fine-tuning, the pre-trained checkpoints can be adapted for specific datasets and languages to further improve upon these results.

Whisper is a Transformer based encoder-decoder model, also referred to as a _sequence-to-sequence_ model. It maps a _sequence_ of audio spectrogram features to a _sequence_ of text tokens. First, the raw audio inputs are converted to a log-Mel spectrogram by action of the feature extractor. The Transformer encoder then encodes the spectrogram to form a sequence of encoder hidden states. Finally, the decoder autoregressively predicts text tokens, conditional on both the previous tokens and the encoder hidden states. Figure 1 summarises the Whisper model.

<figure>
<img src="./img/whisper-architecture.svg" alt="Trulli" style="width:100%">
<figcaption align = "center"><b>Figure 1:</b> Whisper model. The architecture follows the standard Transformer-based encoder-decoder model. A log-Mel spectrogram is input to the encoder. The last encoder hidden states are input to the decoder via cross-attention mechanisms. The decoder autoregressively predicts text tokens, jointly conditional on the encoder hidden states and previously predicted tokens. Figure source: <a href="https://openai.com/blog/whisper/">OpenAI Whisper Blog</a>.</figcaption>
</figure>

In a sequence-to-sequence model, the encoder transforms the audio inputs into a set of hidden state representations, extracting important features from the spoken speech. The decoder plays the role of a language model, processing the hidden state representations and generating the corresponding text transcriptions. Incorporating a language model __internally__ in the system architecture is termed _deep fusion_. This is in contrast to _shallow fusion_, where a language model is combined __externally__ with an encoder, such as with CTC + $n$-gram (_c.f._ [Internal Language Model Estimation](https://arxiv.org/pdf/2011.01991.pdf)). With deep fusion, the entire system can be trained end-to-end with the same training data and loss function, giving greater flexibility and generally superior performance (_c.f._ [ESB Benchmark](https://arxiv.org/abs/2210.13352)).

Whisper is pre-trained and fine-tuned using the cross-entropy objective function, a standard objective function for training sequence-to-sequence systems on classification tasks. Here, the system is trained to correctly classify the target text token from a pre-defined vocabulary of text tokens.

The Whisper checkpoints come in five configurations of varying model sizes. The smallest four are trained on either English-only or multilingual data. The largest checkpoints are multilingual only. All 11 of the pre-trained checkpoints are available on the [Hugging Face Hub](https://huggingface.co/models?search=openai/whisper). The checkpoints are summarised in the following table with links to the models on the Hub:

|   Size   | Layers | Width | Heads | Parameters |                     English-only                     |                    Multilingual                     |
|:--------:|:------:|:-----:|:-----:|:----------:|:----------------------------------------------------:|:---------------------------------------------------:|
|   tiny   |   4    |  384  |   6   |    39 M    |  [✓](https://huggingface.co/openai/whisper-tiny.en)  |   [✓](https://huggingface.co/openai/whisper-tiny)   |
|   base   |   6    |  512  |   8   |    74 M    |  [✓](https://huggingface.co/openai/whisper-base.en)  |   [✓](https://huggingface.co/openai/whisper-base)   |
|  small   |   12   |  768  |  12   |   244 M    | [✓](https://huggingface.co/openai/whisper-small.en)  |  [✓](https://huggingface.co/openai/whisper-small)   |
|  medium  |   24   | 1024  |  16   |   769 M    | [✓](https://huggingface.co/openai/whisper-medium.en) |  [✓](https://huggingface.co/openai/whisper-medium)  |
|  large   |   32   | 1280  |  20   |   1550 M   |                          x                           |  [✓](https://huggingface.co/openai/whisper-large)   |
| large-v2 |   32   | 1280  |  20   |   1550 M   |                          x                           | [✓](https://huggingface.co/openai/whisper-large-v2) |
| large-v3 |   32   | 1280  |  20   |   1550 M   |                          x                           | [✓](https://huggingface.co/openai/whisper-large-v3) |

For demonstration purposes, we'll fine-tune the multilingual version of the [small](https://huggingface.co/openai/whisper-small) checkpoint with 244M params (~= 1GB). As for our data, we'll train and evaluate our system on a low-resource language taken from the [Common Voice dataset](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0). We'll show that with as little as 8 hours of fine-tuning data, we can achieve strong performance in this language.

Preparing Envrionment
---------------------

We'll employ several popular Python packages to fine-tune the Whisper model. We'll use `datasets[audio]` to download and prepare our training data, alongsidem `transformers` and `accelerate` to load and train our Whisper model. We'll also require the `soundfile` package to pre-process audio files, `evaluate` and `jiwer` to assess the performance of our model, and `tensorboard` to log our metrics. Finally, we'll use `gradio` to build a flashy demo of our fine-tuned model.

It is strongly recommended to upload model checkpoints directly the [Hugging Face Hub](https://huggingface.co/) whilst training. The Hub provides:

- Integrated version control: you can be sure that no model checkpoint is lost during training.
- Tensorboard logs: track important metrics over the course of training.
- Model cards: document what a model does and its intended use cases.
- Community: an easy way to share and collaborate with the community!

Linking the notebook to the Hub is straightforward - it simply requires entering Hub authentication token when prompted. Find Hub authentication token [here](https://huggingface.co/settings/tokens) (__This step is required in this notebook in order to access the Mozilla datasets later below__):

In [1]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Preparing Feature Extractor, Tokenizer and Data
-----------------------------------------------

The ASR pipeline can be de-composed into three stages:

1. A feature extractor which pre-processes the raw audio-inputs
2. The model which performs the sequence-to-sequence mapping
3. A tokenizer which post-processes the model outputs to text format

In 🤗 Transformers, the Whisper model has an associated feature extractor and tokenizer,
called [WhisperFeatureExtractor](https://huggingface.co/docs/transformers/main/model_doc/whisper#transformers.WhisperFeatureExtractor)
and [WhisperTokenizer](https://huggingface.co/docs/transformers/main/model_doc/whisper#transformers.WhisperTokenizer)
respectively.

### Loading WhisperFeatureExtractor

The Whisper feature extractor performs two operations:

1. Pads/truncates the audio inputs to 30s: any audio inputs shorter than 30s are padded to 30s with silence (zeros), and those longer that 30s are truncated to 30s
2. Converts the audio inputs to _log-Mel spectrogram_ input features, a visual representation of the audio and the form of the input expected by the Whisper model:

   <figure>
       <img src="./img/spectrogram.png" alt="Spectrogram" style="width:100%">
       <figcaption align = "center"><b>Figure 2:</b> Conversion of sampled audio array to log-Mel spectrogram. Left: sampled 1-dimensional audio signal. Right: corresponding log-Mel spectrogram. Figure source: <a href="https://ai.googleblog.com/2019/04/specaugment-new-data-augmentation.html">Google SpecAugment Blog</a>.
    </figcaption>

We'll load the feature extractor from the pre-trained checkpoint with the default values:

In [2]:
from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")

### Loading WhisperTokenizer

The Whisper model outputs a sequence of _token ids_. The tokenizer maps each of these token ids to their corresponding text string. For Italian, we can load the pre-trained tokenizer and use it for fine-tuning without any further modifications. We simply have to
specify the target language and the task. These arguments inform the tokenizer to prefix the language and task tokens to the start of encoded label sequences:

In [3]:
from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Italian", task="transcribe")

### (Optional) Extractor & Tokenizer Combined

To simplify using the feature extractor and tokenizer, we can wrap both into a single `WhisperProcessor` class. This processor object
inherits from the `WhisperFeatureExtractor` and `WhisperProcessor`, and can be used on the audio inputs and model predictions as required.
In doing so, we only need to keep track of two objects during training: the `processor` and the `model`:

In [4]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Italian", task="transcribe")

Loading Data
------------

Common Voice is a series of crowd-sourced datasets where speakers record text from Wikipedia in various languages. We'll use the latest edition of the Common Voice dataset at the time of writing (version 17). As for our language, we'll fine-tune our model on __Italian__. 

> We can find the latest version of the Common Voice dataset by checking the [Mozilla Foundation organisation page on the Hugging Face Hub](https://huggingface.co/mozilla-foundation). Later versions cover more languages and contain more data per-language

Let's head to the Hub and view the dataset page for Common Voice: [mozilla-foundation/common_voice_17_0](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0). __The first time we view this page, we'll be asked to accept the terms of use. After that, we'll be given full access to the dataset programmatically.__

Using 🤗 Datasets, downloading and preparing data is extremely simple. We can download and prepare the Common Voice splits in just one line of code.

In [6]:
from datasets import load_dataset, DatasetDict

common_voice = DatasetDict()

common_voice["train"] = load_dataset("mozilla-foundation/common_voice_17_0", "it", split="train")
common_voice["test"] = load_dataset("mozilla-foundation/common_voice_17_0", "it", split="test")

print(common_voice)

n_shards.json:   0%|          | 0.00/17.5k [00:00<?, ?B/s]

audio/it/train/it_train_0.tar:   0%|          | 0.00/1.69G [00:00<?, ?B/s]

audio/it/train/it_train_1.tar:   0%|          | 0.00/1.59G [00:00<?, ?B/s]

audio/it/train/it_train_2.tar:   0%|          | 0.00/1.38G [00:00<?, ?B/s]

audio/it/train/it_train_3.tar:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

audio/it/train/it_train_4.tar:   0%|          | 0.00/296M [00:00<?, ?B/s]

audio/it/dev/it_dev_0.tar:   0%|          | 0.00/703M [00:00<?, ?B/s]

audio/it/test/it_test_0.tar:   0%|          | 0.00/743M [00:00<?, ?B/s]

audio/it/other/it_other_0.tar:   0%|          | 0.00/265M [00:00<?, ?B/s]

audio/it/invalidated/it_invalidated_0.ta(…):   0%|          | 0.00/858M [00:00<?, ?B/s]

audio/it/validated/it_validated_0.tar:   0%|          | 0.00/1.88G [00:00<?, ?B/s]

audio/it/validated/it_validated_1.tar:   0%|          | 0.00/1.65G [00:00<?, ?B/s]

audio/it/validated/it_validated_2.tar:   0%|          | 0.00/1.58G [00:00<?, ?B/s]

audio/it/validated/it_validated_3.tar:   0%|          | 0.00/1.41G [00:00<?, ?B/s]

audio/it/validated/it_validated_4.tar:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

audio/it/validated/it_validated_5.tar:   0%|          | 0.00/1.21G [00:00<?, ?B/s]

transcript/it/train.tsv:   0%|          | 0.00/54.2M [00:00<?, ?B/s]

dev.tsv:   0%|          | 0.00/4.65M [00:00<?, ?B/s]

test.tsv:   0%|          | 0.00/4.61M [00:00<?, ?B/s]

other.tsv:   0%|          | 0.00/2.36M [00:00<?, ?B/s]

invalidated.tsv:   0%|          | 0.00/6.27M [00:00<?, ?B/s]

transcript/it/validated.tsv:   0%|          | 0.00/75.2M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]


Reading metadata...: 0it [00:00, ?it/s][A
Reading metadata...: 10212it [00:00, 102109.87it/s][A
Reading metadata...: 20423it [00:00, 100585.13it/s][A
Reading metadata...: 30484it [00:00, 98934.88it/s] [A
Reading metadata...: 40381it [00:00, 97739.22it/s][A
Reading metadata...: 50158it [00:00, 97007.00it/s][A
Reading metadata...: 59860it [00:00, 96623.02it/s][A
Reading metadata...: 69523it [00:00, 96441.91it/s][A
Reading metadata...: 79168it [00:00, 95983.20it/s][A
Reading metadata...: 88767it [00:00, 94711.98it/s][A
Reading metadata...: 98381it [00:01, 95143.75it/s][A
Reading metadata...: 108011it [00:01, 95484.25it/s][A
Reading metadata...: 117562it [00:01, 95348.21it/s][A
Reading metadata...: 127099it [00:01, 95045.58it/s][A
Reading metadata...: 136605it [00:01, 90443.87it/s][A
Reading metadata...: 145911it [00:01, 91200.54it/s][A
Reading metadata...: 155281it [00:01, 91930.40it/s][A
Reading metadata...: 169771it [00:01, 94586.16it/s][A


Generating validation split: 0 examples [00:00, ? examples/s]


Reading metadata...: 0it [00:00, ?it/s][A
Reading metadata...: 15149it [00:00, 102080.20it/s][A


Generating test split: 0 examples [00:00, ? examples/s]


Reading metadata...: 0it [00:00, ?it/s][A
Reading metadata...: 15155it [00:00, 117115.51it/s][A


Generating other split: 0 examples [00:00, ? examples/s]


Reading metadata...: 7385it [00:00, 116557.99it/s]


Generating invalidated split: 0 examples [00:00, ? examples/s]


Reading metadata...: 0it [00:00, ?it/s][A
Reading metadata...: 19705it [00:00, 121761.29it/s][A


Generating validated split: 0 examples [00:00, ? examples/s]


Reading metadata...: 0it [00:00, ?it/s][A
Reading metadata...: 12781it [00:00, 127802.20it/s][A
Reading metadata...: 25642it [00:00, 128273.66it/s][A
Reading metadata...: 38470it [00:00, 127122.02it/s][A
Reading metadata...: 51184it [00:00, 125678.44it/s][A
Reading metadata...: 63755it [00:00, 125304.40it/s][A
Reading metadata...: 76287it [00:00, 123089.54it/s][A
Reading metadata...: 88603it [00:00, 120102.04it/s][A
Reading metadata...: 100627it [00:00, 118814.25it/s][A
Reading metadata...: 112518it [00:00, 118539.99it/s][A
Reading metadata...: 124378it [00:01, 118082.82it/s][A
Reading metadata...: 136190it [00:01, 118019.07it/s][A
Reading metadata...: 147994it [00:01, 117418.06it/s][A
Reading metadata...: 159738it [00:01, 117154.00it/s][A
Reading metadata...: 171455it [00:01, 117043.28it/s][A
Reading metadata...: 183160it [00:01, 115578.44it/s][A
Reading metadata...: 194790it [00:01, 115788.05it/s][A
Reading metadata...: 206372it [00:01, 115675.60it/s][A
Reading met

DatasetDict({
    train: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'variant'],
        num_rows: 169771
    })
    test: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'variant'],
        num_rows: 15155
    })
})


Most ASR datasets only provide input audio samples (`audio`) and the corresponding transcribed text (`sentence`). Common Voice contains additional metadata information, such as `accent` and `locale`, which we can disregard for ASR. Keeping the notebook as general as possible, we only consider the input audio and transcribed text for fine-tuning, discarding the additional metadata information:

In [7]:
common_voice = common_voice.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"])