## Whisper Fine-tuning in Google Colab

Whisper is pre-trained and fine-tuned using the cross-entropy objective function. Here, the system is trained to correctly classify the target text token from a pre-defined vocabulary of text tokens.

Please refer to this detailed document - 


In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Fri Apr 14 10:58:00 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   29C    P0    41W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

Your runtime has 89.6 gigabytes of available RAM

You are using a high-RAM runtime!


### Setting up
* transformers -> load and train the Whisper model
* datasets -> training data
* librosa -> signal processing
* evaluate & jiwer -> performance of the model
* gradio -> demo app for the fine-tuned model

In [3]:
!pip install datasets>=2.6.1
!pip install git+https://github.com/huggingface/transformers
!pip install librosa
!pip install evaluate>=0.30
!pip install jiwer
!pip install gradio

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/huggingface/transformers
  Cloning https://github.com/huggingface/transformers to /tmp/pip-req-build-00h3_61c
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers /tmp/pip-req-build-00h3_61c
  Resolved https://github.com/huggingface/transformers to commit 53c710d17bd43bf7c7edfc4187bba196ee9438ae
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m81.5 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: transformers
  Building wheel for transformers (pyp

In [4]:
from huggingface_hub import notebook_login

In [5]:
notebook_login()

Token is valid.
Your token has been saved in your configured git credential helpers (store).
Your token has been saved to /root/.cache/huggingface/token
Login successful


### Load Dataset
Common Voice is a series of crowd-sourced datasets where speakers record text from Wikipedia in various languages. When I was working on this notebook the latest edition of the Common Voice datset was [mozilla-foundation/common_voice_13_0](https://huggingface.co/datasets/mozilla-foundation/common_voice_13_0). 

Here for the language we will use "Hindi" (hi) which is spoken in India.

In [6]:
from datasets import load_dataset

In [7]:
from datasets.load import DatasetDict
common_voice = DatasetDict()

In [8]:
common_voice["train"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split = "train+validation", use_auth_token = True)
common_voice["test"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split = "test", use_auth_token = True)

Downloading builder script:   0%|          | 0.00/8.30k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/14.4k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.44k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/60.9k [00:00<?, ?B/s]

Downloading and preparing dataset common_voice_11_0/hi to /root/.cache/huggingface/datasets/mozilla-foundation___common_voice_11_0/hi/11.0.0/2c65b95d99ca879b1b1074ea197b65e0497848fd697fdb0582e0f6b75b6f4da0...


Downloading data:   0%|          | 0.00/12.2k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/5 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/114M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/61.9M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/92.2M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/113M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/23.4M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/5 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/5 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/1.30M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/627k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/824k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/201k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/5 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]


Reading metadata...: 4361it [00:00, 106495.57it/s]


Generating validation split: 0 examples [00:00, ? examples/s]



Reading metadata...: 2179it [00:00, 125029.25it/s]


Generating test split: 0 examples [00:00, ? examples/s]




Reading metadata...: 2894it [00:00, 126736.51it/s]


Generating other split: 0 examples [00:00, ? examples/s]





Reading metadata...: 3328it [00:00, 87891.92it/s]


Generating invalidated split: 0 examples [00:00, ? examples/s]






Reading metadata...: 680it [00:00, 99623.69it/s]


Dataset common_voice_11_0 downloaded and prepared to /root/.cache/huggingface/datasets/mozilla-foundation___common_voice_11_0/hi/11.0.0/2c65b95d99ca879b1b1074ea197b65e0497848fd697fdb0582e0f6b75b6f4da0. Subsequent calls will reuse this data.




In [9]:
print(common_voice)

DatasetDict({
    train: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
        num_rows: 6540
    })
    test: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
        num_rows: 2894
    })
})


### ASR Pipeline

1. Feature Extractior to pre-process the raw audio-inputs
2. Model to perform the sequence-to-sequence mapping
3. A tokenizer which post-processes the model outputs to text format.

In [10]:
from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")

Downloading (…)rocessor_config.json:   0%|          | 0.00/185k [00:00<?, ?B/s]

In [11]:
from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")

Downloading (…)okenizer_config.json:   0%|          | 0.00/842 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.20M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

Downloading (…)main/normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/2.08k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.08k [00:00<?, ?B/s]

In [12]:
# Let's test it out -
input_str = common_voice["train"][0]["sentence"]
input_str

'हमने उसका जन्मदिन मनाया।'

In [13]:
labels = tokenizer(input_str).input_ids
labels

[50258,
 50276,
 50359,
 50363,
 44500,
 48521,
 35082,
 21981,
 8485,
 231,
 45938,
 41858,
 17937,
 8485,
 250,
 35082,
 27099,
 48521,
 3941,
 99,
 33279,
 35082,
 48449,
 35082,
 17937,
 48268,
 17937,
 8703,
 97,
 50257]

In [14]:
decoder_with_special = tokenizer.decode(labels, skip_special_tokens = False)
decoder_with_special

'<|startoftranscript|><|hi|><|transcribe|><|notimestamps|>हमने उसका जन्मदिन मनाया।<|endoftext|>'

In [15]:
decoded_str = tokenizer.decode(labels, skip_special_tokens=True)

In [16]:
decoded_str

'हमने उसका जन्मदिन मनाया।'

In [17]:
input_str == decoded_str

True

We will combine the feature extractor and tokenizer, and will wrap into a single class called WhisperProcessor class.

In [18]:
from transformers import WhisperProcessor

In [19]:
processor = WhisperProcessor.from_pretrained('openai/whisper-small', language = 'Hindi', task = 'transcribe')

### Prepare Data

In [20]:
print(common_voice['train'][0])

{'client_id': '0f018a99663f33afbb7d38aee281fb1afcfd07f9e7acd00383f604e1e17c38d6ed8adf1bd2ccbf927a52c5adefb8ac4b158ce27a7c2ed9581e71202eb302dfb3', 'path': '/root/.cache/huggingface/datasets/downloads/extracted/c090ee7e53af621f387a7ad442373c879defc9f799c91a68edc0f646c36f115f/common_voice_hi_26008353.mp3', 'audio': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/c090ee7e53af621f387a7ad442373c879defc9f799c91a68edc0f646c36f115f/common_voice_hi_26008353.mp3', 'array': array([ 5.81611368e-26, -1.48634016e-25, -9.37040538e-26, ...,
        1.06425901e-07,  4.46416450e-08,  2.61450239e-09]), 'sampling_rate': 48000}, 'sentence': 'हमने उसका जन्मदिन मनाया।', 'up_votes': 2, 'down_votes': 0, 'age': '', 'gender': '', 'accent': '', 'locale': 'hi', 'segment': ''}


Note: Here we have 1-dimensional input audio array and corresponding target transcription. Also observe that the sampling rate is 48KHz. But the Whisper Model needs it to be resampled at 16KHz.

In [21]:
from datasets import Audio

In [22]:
# cast_column() method does not change the audio in-place, but rather signals to datasets to resample audio samples on the fly the first time that they are loaded.

common_voice = common_voice.cast_column("audio", Audio(sampling_rate = 16000))

In [23]:
# Now, let's reload it -
print(common_voice["train"][0])

{'client_id': '0f018a99663f33afbb7d38aee281fb1afcfd07f9e7acd00383f604e1e17c38d6ed8adf1bd2ccbf927a52c5adefb8ac4b158ce27a7c2ed9581e71202eb302dfb3', 'path': '/root/.cache/huggingface/datasets/downloads/extracted/c090ee7e53af621f387a7ad442373c879defc9f799c91a68edc0f646c36f115f/common_voice_hi_26008353.mp3', 'audio': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/c090ee7e53af621f387a7ad442373c879defc9f799c91a68edc0f646c36f115f/common_voice_hi_26008353.mp3', 'array': array([ 3.81639165e-17,  2.42861287e-17, -1.73472348e-17, ...,
       -1.30981789e-07,  2.63096808e-07,  4.77157300e-08]), 'sampling_rate': 16000}, 'sentence': 'हमने उसका जन्मदिन मनाया।', 'up_votes': 2, 'down_votes': 0, 'age': '', 'gender': '', 'accent': '', 'locale': 'hi', 'segment': ''}


We can see that the sampling rate is now 16KHz which also means that in the array for every three values of amplitude now we have one such values.

In [24]:
def prepare_dataset(batch):
  # load and resample audio data from 48 to 16 KHz
  audio = batch["audio"]

  # compute the log-Mel input features from input audio array
  batch["input_features"] = feature_extractor(audio["array"], sampling_rate = audio["sampling_rate"]).input_features[0]

  # encode target text to label ids
  batch["labels"] = tokenizer(batch["sentence"]).input_ids
  return batch

In [25]:
common_voice = common_voice.map(prepare_dataset, remove_columns = common_voice.column_names["train"], num_proc = 4)

Map (num_proc=4):   0%|          | 0/6540 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/2894 [00:00<?, ? examples/s]

### Training and Evaluation Pipeline

1. Define Data Collator : Takes pre-processed data and prepares PyTorch tensors ready for the model.

2. Evaluation Metrics : Evaluate the model using WER (Word Error Rate) metric.

3. Load the pre-trained model checkpoint and configure it for training.

4. Training arguments for the trainer.


#### Data Collator

In [26]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

In [27]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
  processor: Any

  def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
    # from audio inputs return the torch tensors
    input_features = [{"input_features": feature["input_features"]} for feature in features]
    batch = self.processor.feature_extractor.pad(input_features, return_tensors = "pt")

    # get the tokenized label sequences
    label_features = [{"input_ids": feature["labels"]} for feature in features]
    # pad the labels to max length
    labels_batch = self.processor.tokenizer.pad(label_features, return_tensors = "pt")

    # replace the padding with -100 to ignore loss correctly
    labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

    if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
      labels = labels[:, 1:]

    batch["labels"] = labels

    return batch

In [28]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor = processor)

### Evaluation Metrics

In [29]:
import evaluate

metric = evaluate.load("wer")

Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

In [30]:
def compute_metrics(pred):
  pred_ids = pred.predictions
  label_ids = pred.label_ids

  # replace -100 with the pad_token_id
  label_ids[label_ids == -100] = tokenizer.pad_token_id

  pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens = True)
  label_str = tokenizer.batch_decode(label_ids, skip_special_tokens = True)

  wer = 100 * metric.compute(predictions = pred_str, references = label_str)

  return {"wer": wer}

### Load the Pre-trained Checkpoint

In [31]:
from transformers import WhisperForConditionalGeneration

In [32]:
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.97k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/967M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/3.51k [00:00<?, ?B/s]

In [33]:
model.config.forced_decoder_ids = None
model.config.supress_tokens = []

### Training Arguments

In [34]:
from transformers import Seq2SeqTrainingArguments

In [61]:
training_args = Seq2SeqTrainingArguments(
    output_dir = "./whisper-small-finetuned-hi-commonvoice",
    per_device_train_batch_size = 16,
    gradient_accumulation_steps = 1, # increase by 2x for every 2x decrease in batch size
    learning_rate = 1e-5,
    warmup_steps = 500,
    max_steps = 4000,
    gradient_checkpointing = True,
    fp16 = True,
    evaluation_strategy = "steps",
    per_device_eval_batch_size = 8,
    predict_with_generate = True,
    generation_max_length = 225,
    save_steps = 1000,
    eval_steps = 1000,
    logging_steps = 25,
    report_to = ["tensorboard"],
    load_best_model_at_end = True,
    metric_for_best_model = "wer",
    greater_is_better = False,
    push_to_hub = True
)

In [62]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args = training_args,
    model = model,
    train_dataset = common_voice["train"],
    eval_dataset = common_voice["test"],
    data_collator = data_collator,
    compute_metrics = compute_metrics,
    tokenizer = processor.feature_extractor,
)

Cloning https://huggingface.co/suvrobaner/whisper-small-finetuned-hi-commonvoice into local empty directory.


Download file pytorch_model.bin:   0%|          | 8.00k/922M [00:00<?, ?B/s]

Download file runs/Apr14_11-13-34_79a505b7f58e/events.out.tfevents.1681470842.79a505b7f58e.1976.0:  54%|#####3…

Download file training_args.bin: 100%|##########| 3.68k/3.68k [00:00<?, ?B/s]

Download file runs/Apr14_11-13-34_79a505b7f58e/1681470842.3701324/events.out.tfevents.1681470842.79a505b7f58e.…

Clean file training_args.bin:  27%|##7       | 1.00k/3.68k [00:00<?, ?B/s]

Clean file runs/Apr14_11-13-34_79a505b7f58e/1681470842.3701324/events.out.tfevents.1681470842.79a505b7f58e.197…

Clean file runs/Apr14_11-13-34_79a505b7f58e/events.out.tfevents.1681470842.79a505b7f58e.1976.0:   3%|3        …

Clean file pytorch_model.bin:   0%|          | 1.00k/922M [00:00<?, ?B/s]

In [37]:
trainer.train()

`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


Step,Training Loss,Validation Loss


Step,Training Loss,Validation Loss,Wer
1000,0.0895,0.29268,35.054601
2000,0.0221,0.345999,33.911792
3000,0.0022,0.409631,33.535088
4000,0.0005,0.434476,33.14992


TrainOutput(global_step=4000, training_loss=0.10362103581079282, metrics={'train_runtime': 13906.3588, 'train_samples_per_second': 4.602, 'train_steps_per_second': 0.288, 'total_flos': 1.845907654606848e+19, 'train_loss': 0.10362103581079282, 'epoch': 9.78})

### Pushing the checkpoint to the Hub

In [68]:
kwargs = {
    "dataset_tags": "mozilla-foundation/common_voice_11_0",
    "dataset": "Common Voice 11.0",
    "dataset_args": "config: hi, split: test",
    "language": "hi",
    "model_name": "Whisper Small Hi - Suvro Banerjee",
    "finetuned_from": "openai/whisper-small",
    "tasks": "automatic-speech-recognition",
    "tags": "hf-asr-leaderboard",
}

In [73]:
trainer.push_to_hub(**kwargs)
tokenizer.push_to_hub("suvrobaner/whisper-small-finetuned-hi-commonvoice")

CommitInfo(commit_url='https://huggingface.co/suvrobaner/whisper-small-finetuned-hi-commonvoice/commit/04a8d7ad59a28a636c04707d741b6b7e9d2c8393', commit_message='Upload tokenizer', commit_description='', oid='04a8d7ad59a28a636c04707d741b6b7e9d2c8393', pr_url=None, pr_revision=None, pr_num=None)

### Let's load the model from Huggingface Hub

In [74]:
from transformers import pipeline

In [75]:
import gradio as gr

In [76]:
pipe = pipeline(model = "suvrobaner/whisper-small-finetuned-hi-commonvoice")

Downloading (…)okenizer_config.json:   0%|          | 0.00/805 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

Downloading (…)main/normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/2.08k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.08k [00:00<?, ?B/s]

In [77]:
def transcribe(audio):
    text = pipe(audio)["text"]
    return text

iface = gr.Interface(
    fn=transcribe, 
    inputs=gr.Audio(source="microphone", type="filepath"), 
    outputs="text",
    title="Whisper Small Hindi",
    description="Realtime demo for Hindi speech recognition using a fine-tuned Whisper small model.",
)

iface.launch()

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Note: opening Chrome Inspector may crash demo inside Colab notebooks.

To create a public link, set `share=True` in `launch()`.


<IPython.core.display.Javascript object>



There are more ASR examples here : https://github.com/huggingface/transformers/tree/main/examples/pytorch/speech-recognition