# Master Thesis 2nd Version of a STT Model with new State of the Art: Whisper

**Author**: Karin Thommen

**Date**: April 2023


---

**Content of the Notebook**:  Fine-tuning and Training of OpenAi Whisper ASR Model

---
**References**:
- https://huggingface.co/blog/fine-tune-whisper
- https://wandb.ai/parambharat/whisper_finetuning/reports/Fine-tuning-Whisper-ASR-models---VmlldzozMTEzNDE5
- https://github.com/vasistalodagala/whisper-finetune

## Step 1: Import and Setup

In [None]:
%%capture
!pip install datasets
!pip install transformers==4.28.0
!pip install librosa
!pip install evaluate>=0.30
!pip install jiwer
!pip install gradio
!pip install audio-metadata
!pip install "dill<0.3.5"
!pip install git-lfs

In [None]:
import pandas as pd
import os
import transformers

from datasets.fingerprint import Hasher
import pickle
import dill

from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML
import re
import json

import IPython.display as ipd
import numpy as np
import random

import torch
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

import audio_metadata

from datasets import load_dataset, Audio, load_metric, load_from_disk, DatasetDict, list_datasets
from datasets import Dataset, Sequence

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

from transformers import WhisperTokenizer
from transformers import WhisperTokenizerFast
from transformers import WhisperProcessor
from transformers import WhisperFeatureExtractor
from huggingface_hub import notebook_login

from google.colab import drive

In [None]:
transformers.__version__

'4.28.0'

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

## Step 2: Load Data

In [None]:
# login to huggingface account for data
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
print(list_datasets())

    list_datasets is deprecated and will be removed in the next major version of datasets. Use 'huggingface_hub.list_datasets' instead.


['acronym_identification', 'ade_corpus_v2', 'adversarial_qa', 'aeslc', 'afrikaans_ner_corpus', 'ag_news', 'ai2_arc', 'air_dialogue', 'ajgt_twitter_ar', 'allegro_reviews', 'allocine', 'alt', 'amazon_polarity', 'amazon_reviews_multi', 'amazon_us_reviews', 'ambig_qa', 'americas_nli', 'ami', 'amttl', 'anli', 'app_reviews', 'aqua_rat', 'aquamuse', 'ar_cov19', 'ar_res_reviews', 'ar_sarcasm', 'arabic_billion_words', 'arabic_pos_dialect', 'arabic_speech_corpus', 'arcd', 'arsentd_lev', 'art', 'arxiv_dataset', 'ascent_kb', 'aslg_pc12', 'asnq', 'asset', 'assin', 'assin2', 'atomic', 'autshumato', 'facebook/babi_qa', 'banking77', 'bbaw_egyptian', 'bbc_hindi_nli', 'bc2gm_corpus', 'beans', 'best2009', 'bianet', 'bible_para', 'big_patent', 'billsum', 'bing_coronavirus_query_set', 'biomrc', 'biosses', 'blbooks', 'blbooksgenre', 'blended_skill_talk', 'blimp', 'blog_authorship_corpus', 'bn_hate_speech', 'bnl_newspapers', 'bookcorpus', 'bookcorpusopen', 'boolq', 'bprec', 'break_data', 'brwac', 'bsd_ja_en'

In [None]:
# load dataset from huggingface (after uploading it via local machine to huggingface)
dataset = load_dataset("karinthommen/schawinski_V2")

Downloading readme:   0%|          | 0.00/585 [00:00<?, ?B/s]

Downloading and preparing dataset None/None to /root/.cache/huggingface/datasets/karinthommen___parquet/karinthommen--schawinski_V2-6a970eecd5fc90fe/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/363M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/403M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/419M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/399M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/330M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/366M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/3544 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/645 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/647 [00:00<?, ? examples/s]

Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/karinthommen___parquet/karinthommen--schawinski_V2-6a970eecd5fc90fe/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
# check if data loading worked
dataset["train"][0]

{'audio': {'path': 'Badran_Schawinski_13-05-2013_SPK0-Badran_Schawinski_13-05-2013-0001.wav',
  'array': array([ 0.00010681,  0.00018311,  0.00045776, ...,  0.00128174,
         -0.00854492, -0.01789856]),
  'sampling_rate': 44100},
 'transcription': '[music]',
 'duration': 26.31}

In [None]:
dataset.shape

{'train': (3544, 3), 'validation': (645, 3), 'test': (647, 3)}

In [None]:
chars_to_remove_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\$\(\)\'\$\*\$\_\'̈\’\•\‹\₂\›\–\²\½\‑\°\`\&\(\)\*\+\/\=\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\[\]]'

def preprocess(batch):
  batch["transcription"] = re.sub('\[music\]', '', batch["transcription"])
  batch["transcription"] = re.sub('\[noise\]', '', batch["transcription"])
  batch["transcription"] = re.sub('\[speech-in-noise\]', '', batch["transcription"])
  batch["transcription"] = re.sub('\[breath_mouth_noise\]', '', batch["transcription"])
  batch["transcription"] = re.sub('\[no_relevant_speech\]', '', batch["transcription"])
  batch["transcription"] = re.sub('\[no-relevant-speech\]', '', batch["transcription"])
  batch["transcription"] = re.sub('\[laughter\]', '', batch["transcription"])
  batch["transcription"] = re.sub('\[speech-in-speech\]', '', batch["transcription"])
  batch["transcription"] = re.sub(r"\\", '', batch["transcription"])
  batch["transcription"] = re.sub(r"/", '', batch["transcription"])
  batch["transcription"] = re.sub('\*', '', batch["transcription"])
  batch["transcription"] = re.sub(chars_to_remove_regex, '', batch["transcription"]).lower()
  batch["transcription"] = batch["transcription"].strip()
  return batch

In [None]:
# filter speech-in-speech tags out
dataset = dataset.filter(lambda example: not example["transcription"].startswith("[speech-in-speech]"))

Filter:   0%|          | 0/3544 [00:00<?, ? examples/s]

Filter:   0%|          | 0/645 [00:00<?, ? examples/s]

Filter:   0%|          | 0/647 [00:00<?, ? examples/s]

In [None]:
# apply preprocessing to dataset
dataset = dataset.map(preprocess, num_proc=1)

Map:   0%|          | 0/1962 [00:00<?, ? examples/s]

Map:   0%|          | 0/303 [00:00<?, ? examples/s]

Map:   0%|          | 0/393 [00:00<?, ? examples/s]

In [None]:
# filter data to delete empty samples
dataset["train"] = dataset["train"].filter(lambda example: len(example["transcription"])!=0)
dataset["validation"] = dataset["validation"].filter(lambda example: len(example["transcription"])!=0)
dataset["test"] = dataset["test"].filter(lambda example: len(example["transcription"])!=0)

Filter:   0%|          | 0/1962 [00:00<?, ? examples/s]

Filter:   0%|          | 0/303 [00:00<?, ? examples/s]

Filter:   0%|          | 0/393 [00:00<?, ? examples/s]

In [None]:
dataset.shape

{'train': (1857, 3), 'validation': (294, 3), 'test': (385, 3)}

In [None]:
# load tokenizer form Whisper Tokenizer
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", task="transcribe")
# load feature extractor
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
# load processor from Whisper Processor
processor = WhisperProcessor.from_pretrained("openai/whisper-small", task="transcribe")

Downloading (…)okenizer_config.json:   0%|          | 0.00/842 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.20M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

Downloading (…)main/normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/2.08k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.08k [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/185k [00:00<?, ?B/s]

## Step 3: Prepare Dataset and convert it into the correct Format

In [None]:
# downsample dataset to a sampling rate of 16kHz for the model
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

In [None]:
# Check if audio loading worked with a random audio and sentence
rand_int = random.randint(0, len(dataset["train"])-1)
print(dataset["train"]["transcription"][rand_int])
ipd.Audio(data=dataset["train"][rand_int]["audio"]["array"], autoplay=True, rate=16000)

und jez luege mer das es guet gaat ich säg jez los ämaal  tuu musch mir bewiise öb sich s überhaubt loont mit dir z schaffe


In [None]:
# Check sentence, input array shape and sampling rate
rand_int = random.randint(0, len(dataset["train"])-1)

print("Target text:", dataset["train"][rand_int]["transcription"])
print("Input array shape:", dataset["train"][rand_int]["audio"]["array"].shape)
print("Sampling rate:", dataset["train"][rand_int]["audio"]["sampling_rate"])

Target text: ganz wiit obe
Input array shape: (14641,)
Sampling rate: 16000


In [None]:
# show sentence decoded with the special characters ( in the format that is needed by whisper )
input_str = dataset["train"][0]["transcription"]
labels = tokenizer(input_str).input_ids
decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
decoded_str = tokenizer.decode(labels, skip_special_tokens=True)

print(f"Input:                 {input_str}")
print(f"Decoded w/ special:    {decoded_with_special}")
print(f"Decoded w/out special: {decoded_str}")
print(f"Are equal:             {input_str == decoded_str}")

Input:                 da isch di eerschti taakschou vo de wuche froit m mi dass si au hüt aabig debii sind bi miir isch d schagglin badraan ässphee nazionaalräätin
Decoded w/ special:    <|startoftranscript|><|transcribe|><|notimestamps|>da isch di eerschti taakschou vo de wuche froit m mi dass si au hüt aabig debii sind bi miir isch d schagglin badraan ässphee nazionaalräätin<|endoftext|>
Decoded w/out special: da isch di eerschti taakschou vo de wuche froit m mi dass si au hüt aabig debii sind bi miir isch d schagglin badraan ässphee nazionaalräätin
Are equal:             True


In [None]:
# show format of train dataset
dataset["train"][0]

{'audio': {'path': 'Badran_Schawinski_13-05-2013_SPK0-Badran_Schawinski_13-05-2013-0002.wav',
  'array': array([-0.00837064, -0.04421842, -0.02962748, ..., -0.00506399,
         -0.00173904,  0.00242807]),
  'sampling_rate': 16000},
 'transcription': 'da isch di eerschti taakschou vo de wuche froit m mi dass si au hüt aabig debii sind bi miir isch d schagglin badraan ässphee nazionaalräätin',
 'duration': 6.39}

In [None]:
def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    # encode target text to label ids
    batch["labels"] = tokenizer(batch["transcription"]).input_ids
    return batch

In [None]:
dataset = dataset.map(prepare_dataset, num_proc=2)

Map (num_proc=2):   0%|          | 0/1857 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/294 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/385 [00:00<?, ? examples/s]

In [None]:
dataset = dataset.shuffle(seed=500)

# make sure that there is no empty transcription in dataset
dataset = dataset.filter(lambda example: len(example["transcription"])!=0)

Filter:   0%|          | 0/1857 [00:00<?, ? examples/s]

Filter:   0%|          | 0/294 [00:00<?, ? examples/s]

Filter:   0%|          | 0/385 [00:00<?, ? examples/s]

## Fine-Tune & Train Model

In [None]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [None]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [None]:
import evaluate

metric = evaluate.load("wer")

Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

In [None]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [None]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to("cuda")

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.97k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/967M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/3.51k [00:00<?, ?B/s]

In [None]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

In [None]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./spont-whisper-default",  # change to a repo name of your choice
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=100,
    max_steps=500,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    #per_device_eval_batch_size=8,
    predict_with_generate=True,
    #generation_max_length=225,
    save_steps=100,
    eval_steps=100,
    logging_steps=100,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=True,
)

In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

Cloning https://huggingface.co/karinthommen/spont-whisper-default into local empty directory.


In [None]:
processor.save_pretrained(training_args.output_dir)

In [None]:
#torch.cuda.empty_cache()

### Version 2.2 using some default Whisper settings

In [None]:
trainer.train()

`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


Step,Training Loss,Validation Loss,Wer
100,2.6173,1.282082,61.190631
200,0.8919,0.993969,51.366298
300,0.5235,0.938091,48.503578
400,0.3118,0.947916,47.49512
500,0.2056,0.945045,46.942095


TrainOutput(global_step=500, training_loss=0.9100223541259765, metrics={'train_runtime': 4011.9463, 'train_samples_per_second': 1.994, 'train_steps_per_second': 0.125, 'total_flos': 2.2913680785408e+18, 'train_loss': 0.9100223541259765, 'epoch': 4.27})

In [None]:
trainer.push_to_hub()