<a href="https://colab.research.google.com/github/sanchit-gandhi/codesnippets/blob/main/fine_tune_whisper_streaming_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-Tune Whisper With 🤗 Transformers and Streaming Mode

## Introduction

## Prepare Environment

First of all, let's try to secure a decent GPU for our Colab! Unfortunately, it's becoming much harder to get access to a good GPU with the free version of Google Colab. However, with Google Colab Pro / Pro+ one should have no issues in being allocated a V100 or P100 GPU.

To get a GPU, click _Runtime_ -> _Change runtime type_, then change _Hardware accelerator_ from _None_ to _GPU_.

We can verify that we've been assigned a GPU and view its specifications:

In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Fri Sep 27 21:10:52 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.29.06              Driver Version: 545.29.06    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        On  | 00000000:01:00.0 Off |                  Off |
| 31%   28C    P8              15W / 450W |      3MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 4090        On  | 00000000:24:00.0 Off |  

Next, we need to update the Unix package `ffmpeg` to version 4:

In [3]:
!apt update
!apt install -y ffmpeg

Hit:1 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:2 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]      [0m
Get:3 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]        [0m
Hit:4 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Hit:5 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease   
Hit:6 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Fetched 257 kB in 1s (333 kB/s)33m[33m[33m[33m
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
96 packages can be upgraded. Run 'apt list --upgradable' to see them.
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
ffmpeg is already the newest version (7:4.4.2-0ubuntu0.22.04.1).
0 upgraded, 0 newly installed, 0 to remove and 96 not upgraded.


We'll employ several popular Python packages to fine-tune the Whisper model.
We'll use `datasets` to download and prepare our training data and 
`transformers` to load and train our Whisper model. We'll also require
the `soundfile` package to pre-process audio files, `evaluate` and `jiwer` to
assess the performance of our model. Finally, we'll
use `gradio` to build a flashy demo of our fine-tuned model.

In [4]:
!pip install git+https://github.com/huggingface/datasets
!pip install git+https://github.com/huggingface/transformers
!pip install transformers[torch]
!pip install librosa
!pip install evaluate>=0.3.0
!pip install jiwer
!pip install gradio
!pip install more-itertools

Collecting git+https://github.com/huggingface/datasets
  Cloning https://github.com/huggingface/datasets to /tmp/pip-req-build-9mrai7xa
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/datasets /tmp/pip-req-build-9mrai7xa
  Resolved https://github.com/huggingface/datasets to commit 0d4c4dfaf0190669137e612bec93889bf0e7c1ff
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Collecting git+https://github.com/huggingface/transformers
  Cloning https://github.com/huggingface/transformers to /tmp/pip-req-build-_b4g4fn1
  Running command git clone --filter=blob:none --quiet https://gith

Linking the notebook to the Hugging Face Hub is straightforward - it simply requires entering your 
Hub authentication token when prompted. Find your Hub authentication token [here](https://huggingface.co/settings/tokens):

In [5]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## Load Dataset with Streaming

This is where the magic happens! We'll first write a wrapper function around 🤗 Datasets `load_dataset` method. This function downloads the required splits using streaming mode by forcing `streaming=True` in the `load_dataset` method. Multiple splits can be combined (interleaved) by concatenating them with the "+" symbol when specifying the split name, e.g. `split=train+validation` will return a single split with the training and validation splits interleaved together. The function has the same arguments and key-word arguments as 🤗 Datasets `load_dataset` method, so we can use it in exactly the same way!

In [2]:
from datasets import interleave_datasets, load_dataset

def load_streaming_dataset(dataset_name, split, **kwargs):
    if "+" in split:
        # load multiple splits separated by the `+` symbol *with* streaming mode
        dataset_splits = [load_dataset(dataset_name, split=split_name, streaming=True, **kwargs) for split_name in split.split("+")]
        # interleave multiple splits to form one dataset
        interleaved_dataset = interleave_datasets(dataset_splits)
        return interleaved_dataset
    else:
        # load a single split *with* streaming mode
        dataset = load_dataset(dataset_name, split=split, streaming=True, **kwargs).cast_column("audio", Audio(decode=False))
        return dataset

We'll train our system on the Spanish split of [Common Voice 11](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0). We can see how much training data we have by viewing the [language page](https://commonvoice.mozilla.org/en/datasets) on the Common Voice website. The Spanish split has over 400 hours of labelled training data - that's enourmous! More than we could ever fit on a Google Colab or a standard workstation. But with streaming mode, we'll only download data as and when we need it, making training on this dataset possible!

Since Spanish is relatively high-resource, we'll only use the `train` split for training and the `test` split for evaluation. If you're training on a low-resource language, such as the Hindi split of Common Voice 11, it's worth combining the `train` and `validation` splits to give a larger training set. You can achieve this by setting: `split="train+validation"` for the training split.

If you're using a gated dataset, like Common Voice 11, ensure you have accepted the terms of use on the Hugging Face Hub: [mozilla-foundation/common_voice_11_0](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0). Once you have accepted the terms, you will have full access to the dataset and be able to load the data locally.

In [6]:
from datasets import IterableDatasetDict, Audio

raw_datasets = IterableDatasetDict()

raw_datasets["train"] = load_streaming_dataset("voa-engines/curated_dataset_v1", split="train")  # set split="train+validation" for low-resource
# raw_datasets["test"] = load_streaming_dataset("mozilla-foundation/common_voice_17_0", "pt", split="test")

NameError: name 'Audio' is not defined

## Prepare Processor and Pre-Process Data

The ASR pipeline can be de-composed into three stages: 
1) A feature extractor which pre-processes the raw audio-inputs
2) The model which performs the sequence-to-sequence mapping 
3) A tokenizer which post-processes the model outputs to text format

In 🤗 Transformers, the Whisper model has an associated feature extractor and tokenizer, 
called [WhisperFeatureExtractor](https://huggingface.co/docs/transformers/main/model_doc/whisper#transformers.WhisperFeatureExtractor)
and [WhisperTokenizer](https://huggingface.co/docs/transformers/main/model_doc/whisper#transformers.WhisperTokenizer) 
respectively. To make our lives simple, these two objects are wrapped under a single class, called the [WhisperProcessor](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperProcessor). We can call the WhisperProcessor to perform 
both the audio pre-processing and the text token post-processing. In doing so, we only need to keep track of two objects during training: 
the `processor` and the `model`.

If using a multilingual checkpoint, you should set the `"language"` to your target text language. You should also set the task to `"transcribe"` for speech recogntition and `"translate"` for speech translation. These arguments modify the behaviour of the tokenizer - they should be set correctly to ensure the target labels are encoded properly. These arguments should be omitted for English-only fine-tuning.

In [10]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("pierreguillou/whisper-medium-portuguese", language="Portuguese", task="transcribe")

### Pre-Process Data

We'll define our pre-processing strategy. We advise that you **do not** lower-case the transcriptions or remove punctuation unless mixing different datasets. This will enable you to fine-tune Whisper models that can predict punctuation and casing. Later, you will see how we can evaluate the predictions without punctuation or casing, so that the models benefit from the WER improvement obtained by normalising the transcriptions while still predicting fully formatted transcriptions.

In [11]:
from transformers.models.whisper.english_normalizer import BasicTextNormalizer

do_lower_case = False
do_remove_punctuation = False

normalizer = BasicTextNormalizer()

Now we can write a function to prepare our data ready for the model:
1. We load and resample the audio data by calling `batch["audio"]`. As explained above, 🤗 Datasets performs any necessary resampling operations on the fly.
2. We use the feature extractor to compute the log-Mel spectrogram input features from our 1-dimensional audio array.
3. We perform any optional pre-processing (lower-case or remove punctuation).
4. We encode the transcriptions to label ids through the use of the tokenizer.

In [12]:
import torchaudio
from io import BytesIO

def is_valid_example(example):
    # Check if 'audio' and 'audio']['bytes'] exist
    if 'audio' not in example or 'bytes' not in example['audio']:
        return False
    audio_bytes = example['audio']['bytes']
    if not audio_bytes:
        return False

    # Try to load the audio bytes to check if the audio is valid
    try:
        audio_file = BytesIO(audio_bytes)
        torchaudio.set_audio_backend("ffmpeg")  # Ensure FFmpeg is installed
        # Attempt to load the audio
        waveform, sample_rate = torchaudio.load(audio_file, format='mp3')
    except Exception as e:
        print(f"Skipping corrupted audio data: {e}")
        return False

    # Check if 'transcription' exists and is a valid non-empty string
    transcription = example.get("transcription")
    if not transcription or not isinstance(transcription, str):
        return False
    transcription = transcription.strip()
    if not transcription:
        return False

    # If all checks pass, the example is valid
    return True


In [13]:
import base64
from io import BytesIO
import torchaudio
import torch
import re

def prepare_dataset(batch):
    try:
        # Load MP3 audio bytes
        audio_bytes = batch['audio']['bytes']
        
        # Check if audio_bytes is empty or None
        if not audio_bytes:
            raise ValueError("Audio bytes are missing or null")

        # Wrap the audio bytes in a BytesIO object
        audio_file = BytesIO(audio_bytes)

        # Use torchaudio to load the MP3 audio
        try:
            # Ensure that the appropriate backend is used
            # You can use "sox_io" or "ffmpeg" depending on your installation
            torchaudio.set_audio_backend("ffmpeg")  # or "ffmpeg" if FFmpeg is installed
            incoming_waveform, sample_rate = torchaudio.load(audio_file, format='mp3')
        except Exception as e:
            raise ValueError(f"Failed to load MP3 audio with torchaudio: {e}")

        # Optional resampling to 16kHz if required
        target_sample_rate = 16000
        if sample_rate != target_sample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
            incoming_waveform = resampler(incoming_waveform)

        # Compute input features using your processor's feature extractor
        try:
            batch["input_features"] = processor.feature_extractor(
                incoming_waveform.squeeze().numpy(), sampling_rate=target_sample_rate
            ).input_features[0]
        except Exception as e:
            raise ValueError(f"Error extracting features: {e}")

        # Compute the input length in seconds
        batch["input_length"] = incoming_waveform.size(1) / target_sample_rate
        
        # Process transcription and labels
        # Try both 'transcription' and 'sentence' keys if applicable
        transcription = batch.get("transcription")
        # If 'transcription' is not available, you can try 'sentence' or another key
        # transcription = transcription or batch.get("sentence")

        # Check if transcription is missing or null
        if not transcription or not isinstance(transcription, str):
            raise ValueError("Transcription is missing or null")

        # Strip leading and trailing whitespace
        transcription = transcription.strip()
        if not transcription:
            raise ValueError("Transcription is empty after stripping")

        if do_lower_case:
            transcription = transcription.lower()

        if do_remove_punctuation:
            transcription = re.sub(punctuation_to_remove_regex, " ", transcription).strip()

        # Encode target text to label ids
        try:
            batch["labels"] = processor.tokenizer(transcription).input_ids
        except Exception as e:
            raise ValueError(f"Error tokenizing transcription: {e}")

        return batch
    
    except ValueError as ve:
        print(f"Skipping corrupted data: {ve}")
        return None  # Returning None will exclude this batch from the final dataset


We can apply the data preparation function to all of our training examples using 🤗 Datasets' `.map` method. We'll remove all of the columns from the raw training data, leaving just the `input_features` and `labels` defined in the `prepare_dataset` function:

In [12]:
from datasets import Audio, load_dataset
dataset_2 = load_dataset("voa-engines/curated_dataset_v1", split="train").cast_column("audio", Audio(decode=False))

README.md:   0%|          | 0.00/369 [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/440M [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/446M [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/446M [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/424M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7311 [00:00<?, ? examples/s]

In [13]:
filtered_dataset = dataset_2.filter(is_valid_example)

Filter:   0%|          | 0/7311 [00:00<?, ? examples/s]

  torchaudio.set_audio_backend("ffmpeg")  # Ensure FFmpeg is installed


We can now define how we shuffle the data in the train split. The size of the subset we load is set by the variable `buffer_size`. You can increase or decrease this depending on your memory constraints. In this example, the `buffer_size` is set to 500, meaning 500 samples are loaded before shuffling across the subset. The larger we set this value, the closer to True offline shuffling. The `seed` is set for reproducibility:

In [14]:
vectorized_datasets = filtered_dataset.map(
    prepare_dataset,
    remove_columns=['audio', 'transcription'],  # Remove unnecessary columns after processing
    batched=False
).with_format("torch")

Map:   0%|          | 0/7088 [00:00<?, ? examples/s]

  torchaudio.set_audio_backend("ffmpeg")  # or "ffmpeg" if FFmpeg is installed


In [14]:
# Assuming vectorized_datasets is a Hugging Face Dataset object
# train_test_data = vectorized_datasets.train_test_split(test_size=0.1)
from datasets import load_dataset
train_dataset = load_dataset("voa-engines/features_dataset_v1", split="train")
eval_dataset = load_dataset("voa-engines/features_dataset_v1", split="validation")

Using the latest cached version of the dataset since voa-engines/features_dataset_v1 couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /root/.cache/huggingface/datasets/voa-engines___features_dataset_v1/default/0.0.0/d631ee82eb2afe985ca3aeb3cbcf90524326b458 (last modified on Fri Sep 27 20:51:45 2024).
Using the latest cached version of the dataset since voa-engines/features_dataset_v1 couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /root/.cache/huggingface/datasets/voa-engines___features_dataset_v1/default/0.0.0/d631ee82eb2afe985ca3aeb3cbcf90524326b458 (last modified on Fri Sep 27 20:51:45 2024).


In [15]:
train_dataset = train_dataset.with_format(None)
iterable_dataset = train_dataset.to_iterable_dataset(num_shards=128)
train_dataset = iterable_dataset.shuffle(seed=42, buffer_size=400)

## Training and Evaluation

Now that we've prepared our data, we're ready to dive into the training pipeline. 
The [🤗 Trainer](https://huggingface.co/transformers/master/main_classes/trainer.html?highlight=trainer)
will do much of the heavy lifting for us. All we have to do is:

- Define a data collator: the data collator takes our pre-processed data and prepares PyTorch tensors ready for the model.

- Evaluation metrics: during evaluation, we want to evaluate the model using the [word error rate (WER)](https://huggingface.co/metrics/wer) metric. We need to define a `compute_metrics` function that handles this computation.

- Load a pre-trained checkpoint: we need to load a pre-trained checkpoint and configure it correctly for training.

- Define the training configuration: this will be used by the 🤗 Trainer to define the training schedule.

### Define a Data Collator

The data collator for a sequence-to-sequence speech model is unique in the sense that it 
treats the `input_features` and `labels` independently: the  `input_features` must be 
handled by the feature extractor and the `labels` by the tokenizer.

The `input_features` are already padded to 30s and converted to a log-Mel spectrogram 
of fixed dimension by action of the feature extractor, so all we have to do is convert the `input_features`
to batched PyTorch tensors. We do this using the feature extractor's `.pad` method with `return_tensors=pt`.

The `labels` on the other hand are un-padded. We first pad the sequences
to the maximum length in the batch using the tokenizer's `.pad` method. The padding tokens 
are then replaced by `-100` so that these tokens are **not** taken into account when 
computing the loss. We then cut the BOS token from the start of the label sequence as we 
append it later during training.

We can leverage the `WhisperProcessor` we defined earlier to perform both the 
feature extractor and the tokenizer operations:

In [16]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

Let's initialise the data collator we've just defined:

In [17]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

### Evaluation Metrics

We'll use the word error rate (WER) metric, the 'de-facto' metric for assessing 
ASR systems. For more information, refer to the WER [docs](https://huggingface.co/metrics/wer). We'll load the WER metric from 🤗 Evaluate:

In [18]:
import evaluate

metric = evaluate.load("wer")

We then simply have to define a function that takes our model 
predictions and returns the WER metric. This function, called
`compute_metrics`, first replaces `-100` with the `pad_token_id`
in the `label_ids` (undoing the step we applied in the 
data collator to ignore padded tokens correctly in the loss).
It then decodes the predicted and label ids to strings. Finally,
it computes the WER between the predictions and reference labels. 
Here, we have the option of evaluating with the 'normalised' transcriptions 
and predictions. We recommend you set this to `True` to benefit from the WER 
improvement obtained by normalising the transcriptions.

In [19]:
# evaluate with the 'normalised' WER
do_normalize_eval = True

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    if do_normalize_eval:
        pred_str = [normalizer(pred) for pred in pred_str]
        label_str = [normalizer(label) for label in label_str]
        # filtering step to only evaluate the samples that correspond to non-zero references:
        pred_str = [pred_str[i] for i in range(len(pred_str)) if len(label_str[i]) > 0]
        label_str = [label_str[i] for i in range(len(label_str)) if len(label_str[i]) > 0]
    
    wer = 100 * metric.compute(predictions=pred_str, references=label_str)
    return {"wer":wer}

### Load a Pre-Trained Checkpoint

Now let's load the pre-trained Whisper `small` checkpoint. Again, this 
is trivial through use of 🤗 Transformers!

In [20]:
!pip install accelerate

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [21]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("pierreguillou/whisper-medium-portuguese")

Override generation arguments - no tokens are forced as decoder outputs (see [`forced_decoder_ids`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.forced_decoder_ids)), no tokens are suppressed during generation (see [`suppress_tokens`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.suppress_tokens)). Set `use_cache` to False since we're using gradient checkpointing, and the two are incompatible:

In [22]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.config.use_cache = False

In [33]:
import os

os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"

In [36]:
from transformers import (
    WhisperForConditionalGeneration,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    TrainerCallback,
    set_seed,
)

# Define the training arguments (as above)
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-medium-voa-v1",
    overwrite_output_dir=True,
    num_train_epochs=3,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=8,
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=5000,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    eval_steps=1000,
    save_steps=1000,
    logging_steps=25,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    seed=42,
    report_to=["tensorboard"],
    push_to_hub=True,
    dataloader_num_workers=4,
    disable_tqdm=False,
    remove_unused_columns=False,
)



In [40]:
from transformers import WhisperForConditionalGeneration, WhisperTokenizer
from accelerate import Accelerator
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F

# Initialize Accelerator
accelerator = Accelerator()

train_dataloader = DataLoader(
    train_dataset,
    batch_size=training_args.per_device_train_batch_size,
    shuffle=False,
    collate_fn=data_collator,
)
eval_dataloader = DataLoader(
    eval_dataset,
    batch_size=training_args.per_device_eval_batch_size,
    shuffle=False,
    collate_fn=data_collator,
)

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate)

# Prepare everything with accelerator
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

In [41]:
# Training loop
from tqdm.auto import tqdm

num_epochs = training_args.num_train_epochs
global_step = 0

for epoch in range(num_epochs):
    model.train()
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch}", disable=not accelerator.is_local_main_process)
    for batch in progress_bar:
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()
        global_step += 1

        if global_step % training_args.logging_steps == 0:
            progress_bar.set_postfix({"loss": loss.item()})

        if training_args.evaluation_strategy == "steps" and global_step % training_args.eval_steps == 0:
            # Evaluation code here
            model.eval()
            eval_loss = 0
            for eval_batch in eval_dataloader:
                with torch.no_grad():
                    outputs = model(**eval_batch)
                    eval_loss += outputs.loss.item()
            eval_loss = eval_loss / len(eval_dataloader)
            print(f"Validation Loss: {eval_loss}")
            model.train()

        if global_step >= training_args.max_steps:
            break

    if training_args.evaluation_strategy == "epoch":
        # Evaluation code here
        pass

    if global_step >= training_args.max_steps:
        break

Epoch 0: 0it [00:00, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 24.00 MiB. GPU 0 has a total capacity of 23.65 GiB of which 5.06 MiB is free. Process 1667366 has 23.63 GiB memory in use. Of the allocated memory 22.97 GiB is allocated by PyTorch, and 189.38 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# Save the model
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(training_args.output_dir, save_function=accelerator.save)
tokenizer.save_pretrained(training_args.output_dir)

### Define the Training Configuration

In the final step, we define all the parameters related to training. Here, you can set the `max_steps` to train for longer. For more detail on the training arguments, refer to the Seq2SeqTrainingArguments [docs](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments).

In [31]:
!pip install tensorboard

Collecting tensorboard
  Downloading tensorboard-2.18.0-py3-none-any.whl.metadata (1.6 kB)
Collecting absl-py>=0.4 (from tensorboard)
  Downloading absl_py-2.1.0-py3-none-any.whl.metadata (2.3 kB)
Collecting grpcio>=1.48.2 (from tensorboard)
  Downloading grpcio-1.66.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.9 kB)
Collecting markdown>=2.6.8 (from tensorboard)
  Downloading Markdown-3.7-py3-none-any.whl.metadata (7.0 kB)
Collecting protobuf!=4.24.0,>=3.19.6 (from tensorboard)
  Downloading protobuf-5.28.2-cp38-abi3-manylinux2014_x86_64.whl.metadata (592 bytes)
Collecting tensorboard-data-server<0.8.0,>=0.7.0 (from tensorboard)
  Downloading tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl.metadata (1.1 kB)
Collecting werkzeug>=1.0.1 (from tensorboard)
  Downloading werkzeug-3.0.4-py3-none-any.whl.metadata (3.7 kB)
Downloading tensorboard-2.18.0-py3-none-any.whl (5.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 M

In [43]:
import os
os.environ["NCCL_P2P_DISABLE"] = "0"
os.environ["NCCL_IB_DISABLE"] = "0"


In [35]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-medium-voa-v1",  # your repo name
    per_device_train_batch_size=64,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=5000,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=True,
)



**Note**: if one does not want to upload the model checkpoints to the Hub, 
set `push_to_hub=False`.

We then define a custom [Callback](https://huggingface.co/docs/transformers/main_classes/callback) that is called by the 🤗 Trainer on the end of each epoch. The Callback reinitialises and reshuffles the streaming dataset at the beginning of each new epoch - this gives different shuffling across our subsets for every epoch.

In [36]:
from transformers import TrainerCallback
from transformers.trainer_pt_utils import IterableDatasetShard
from torch.utils.data import IterableDataset

# trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
class ShuffleCallback(TrainerCallback):
    def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
        if isinstance(train_dataloader.dataset, IterableDatasetShard):
            pass  # set_epoch() is handled by the Trainer
        elif isinstance(train_dataloader.dataset, IterableDataset):
            train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)

We can forward the training arguments to the 🤗 Trainer along with our model,
dataset, data collator, `compute_metrics` function and custom callback:

In [37]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor,
    callbacks=[ShuffleCallback()],
)

max_steps is given, it will override any value given in num_train_epochs


We'll save the model and processor to the output directory before training:

In [38]:
model.save_pretrained(training_args.output_dir)
processor.save_pretrained(training_args.output_dir)



[]

### Training

Training will take approximately 5-10 hours depending on the GPU
allocated to this Google Colab. If using this Google Colab directly to 
fine-tune a Whisper model, you should make sure that training isn't 
interrupted due to inactivity. A simple workaround to prevent this is 
to paste the following code into the console of this tab (_right mouse click_ 
-> _inspect_ -> _Console tab_ -> _insert code_).

```javascript
function ConnectButton(){
    console.log("Connect pushed"); 
    document.querySelector("#top-toolbar > colab-connect-button").shadowRoot.querySelector("#connect").click() 
}
setInterval(ConnectButton, 60000);
```

The peak GPU memory for the given training configuration is approximately 36GB. 
Depending on your GPU, it is possible that you will encounter a CUDA `"out-of-memory"` error when you launch training. 
In this case, you can reduce the `per_device_train_batch_size` incrementally by factors of 2 
and employ [`gradient_accumulation_steps`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments.gradient_accumulation_steps)
to compensate.

To launch training, simply execute:

In [39]:
trainer.train()

KeyboardInterrupt: 

(note that training may take some time to commence as we load the first training data samples with streaming mode)

We can label our checkpoint with the `whisper-event` tag on push by setting the appropriate key-word arguments (kwargs):

In [None]:
kwargs = {
    "dataset_tags": "voa-engines/features_dataset_v1",
    "dataset": "Common Voice 11.0",  # a 'pretty' name for the training dataset
    "language": "pt",
    "model_name": "Whisper Medium Portuguese - Voa Health",  # a 'pretty' name for your model
    "finetuned_from": "pierreguillou/whisper-medium-portuguese",
    "tasks": "automatic-speech-recognition",
    "tags": "whisper-event",
}

The training results can now be uploaded to the Hub. To do so, execute the `push_to_hub` command: