Use this notebook to fine-tune and evaluate whisper model.

In [1]:
!pip install --upgrade pip
!pip install --upgrade datasets[audio] transformers accelerate evaluate jiwer tensorboard huggingface_hub

Collecting pip
  Downloading pip-25.0.1-py3-none-any.whl.metadata (3.7 kB)
Downloading pip-25.0.1-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 23.3.1
    Uninstalling pip-23.3.1:
      Successfully uninstalled pip-23.3.1
Successfully installed pip-25.0.1
[0mCollecting transformers
  Downloading transformers-4.51.3-py3-none-any.whl.metadata (38 kB)
Collecting accelerate
  Downloading accelerate-1.6.0-py3-none-any.whl.metadata (19 kB)
Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting jiwer
  Downloading jiwer-3.1.0-py3-none-any.whl.metadata (2.6 kB)
Collecting tensorboard
  Downloading tensorboard-2.19.0-py3-none-any.whl.metadata (1.8 kB)
Collecting huggingface_hub
  Downloading huggingface_hub-0.30.2-py3-none-any.whl.metadat

In [None]:
from huggingface_hub import login
# generate a new token from huggingface
TOKEN = "HF_TOKEN"
login(token = TOKEN)

In [None]:
from datasets import load_from_disk, load_dataset
import pandas as pd
import random
import numpy as np
import torch

from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate

In [3]:
def set_seed(seed_value=42):
    """Set seed for reproducibility."""

    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
set_seed()

In [None]:
import os
dataset_dir = "dataset/audio"

ds = load_dataset("parquet", 
                      data_files={'train': os.path.join(dataset_dir, 'train.parquet'), 
                                  'test': os.path.join(dataset_dir, 'test.parquet')})

In [6]:
dataset_train = ds['train']
dataset_test = ds['test']

In [7]:
feature_extractor = WhisperFeatureExtractor.from_pretrained("distil-whisper/distil-large-v3.5")

In [8]:
tokenizer = WhisperTokenizer.from_pretrained("distil-whisper/distil-large-v3.5", task="transcribe")

In [9]:
input_str = dataset_train["text"][0]
labels = tokenizer(input_str).input_ids
decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
decoded_str = tokenizer.decode(labels, skip_special_tokens=True)

print(f"Input:                 {input_str}")
print(f"Decoded w/ special:    {decoded_with_special}")
print(f"Decoded w/out special: {decoded_str}")
print(f"Are equal:             {input_str == decoded_str}")

Input:                 Tënk sama kàrt.
Decoded w/ special:    <|startoftranscript|><|transcribe|><|notimestamps|>Tënk sama kàrt.<|endoftext|>
Decoded w/out special: Tënk sama kàrt.
Are equal:             True


In [10]:
labels

[50258, 50360, 50364, 51, 15101, 77, 74, 17768, 350, 1467, 17721, 13, 50257]

In [11]:
processor = WhisperProcessor.from_pretrained("distil-whisper/distil-large-v3.5", task="transcribe")

In [12]:
def prepare_dataset(batch):
    # load audio data
    audio = batch["audio"]

    # compute log-Mel input features from input audio array
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids
    batch["labels"] = tokenizer(batch["text"]).input_ids
    return batch

In [13]:
dataset_train = dataset_train.map(prepare_dataset) #, num_proc=4)

In [14]:
dataset_test = dataset_test.map(prepare_dataset) #, num_proc=4)

In [15]:
model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v3.5")

In [16]:
model.generation_config.language = "english"
model.generation_config.task = "transcribe"

model.generation_config.forced_decoder_ids = None

In [17]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch


In [18]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

In [19]:
metric = evaluate.load("wer")

In [20]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [21]:
training_args = Seq2SeqTrainingArguments(
    output_dir="/workspace/distil_whisper_checkpoints/", # replace with correct path
    per_device_train_batch_size=8, # 16
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=500,
    # max_steps=2000,
    max_steps=1000,
    gradient_checkpointing=True,
    fp16=True,
    eval_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=500,
    eval_steps=500,
    logging_steps=100,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
)

In [22]:
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=dataset_train,
    eval_dataset=dataset_test,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

  trainer = Seq2SeqTrainer(


In [23]:
import time
start_time = time.time()
trainer.train()
training_time = time.time() - start_time
training_time = time.strftime("%H:%M:%S", time.gmtime(training_time))
print("Training time : {0}".format(training_time))

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


Step,Training Loss,Validation Loss,Wer
500,0.1558,0.138556,12.731682
1000,0.0306,0.047677,4.624978


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
There were missing keys in the checkpoint model loaded: ['proj_out.weight'].


Training time : 00:44:24


In [24]:
trainer.evaluate()

{'eval_loss': 0.04767697677016258,
 'eval_wer': 4.624978347479647,
 'eval_runtime': 244.2501,
 'eval_samples_per_second': 2.624,
 'eval_steps_per_second': 0.332,
 'epoch': 3.115264797507788}