# Finetuning OpenAI's Whisper for Singaporean speech
This fine tune of OpenAI's state of the art Automatic Speech Recognition (ASR) model aims to tackle a few naunces in the way Singaporeans speak that the original model struggles to recognise. More specifically the...
1. Borrowing of words from other languages within the same sentence. Examples include the use of words like 'makan', 'kiasu', 'paiseh'.
    - 'I understand that this math topic can be tough, but don't feel paiseh to ask questions if you're confused'.
2. Use of particles such as 'lor', 'leh' and 'lah' that can alter the meaning or perceived tone of sentences.
    - 'You can redo the assignment lor'. This sounds as if the teacher is allowing the redo of the assignment but not fully enthusiastic about the redo.
    - 'You can redo the assignment lah'. This sounds like a straightforward suggestion or gentle encouragement.



## Use Case
For some context, the finetuned ASR model will play a crucial role in enhancing students' study efficiency through its transcription capabilities in our 'iOrganise' application. This model will handle transcriptions for both standalone audio files and audio within video files, making it easier for users to review audio-based study materials. By providing text transcripts, we aim to streamline the review process, as reading typically takes less time than listening to the original audio. Additionally, the transcribed text will feed into the subject classification pipeline, a feature that another group member of mine would be developing.

![appdiagram.png](diagrams/appdiagram.png)

## Getting Started
I will begin with importing the necessary libraries. These libraries provides us with functions and tools that we will be using in order to finetune our model.

In [1]:
# import libraries
import asyncio
import os
import re
import threading
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, List, Union

import chardet
import evaluate
import numpy as np
import torch
from datasets import IterableDatasetDict, load_dataset
from pydub import AudioSegment
from transformers import WhisperForConditionalGeneration, WhisperProcessor, Seq2SeqTrainingArguments, Seq2SeqTrainer, GenerationConfig

from local_loadingScript import *

As I will be training our model using a GPU, I will check if the necessary libraries are correctly installed.

In [2]:
torch.device("cuda" if torch.cuda.is_available() else "cpu")

device(type='cuda')

## Data Source
Collecting and preparing data for ASR can be challenging and tedious process. However, thankfully we have a useful dataset that we can use to finetune our model, the [National Speech Corpus (NSC)](https://www.imda.gov.sg/about-imda/emerging-technologies-and-research/artificial-intelligence/national-speech-corpus).

The corpus is shared via Dropbox and consists of 6 different parts , each part containing roughly 1000 hours of audio and their respective transcripts:
1. Prompted recordings of phonetically-balanced scripts.
2. Prompted recordings of sentences randomly generated from words based on people, food, location, brands, etc.
3. General conversational data on topics covering daily life and playing games recorded in both seperate room and same room environments.
4. Conversational data that includes code-switching such as from Singaporean English to other languages such as Mandarin, Malay, etc.
5. Conversational data on debate, finance topics, with positive and negative emotion.
6. Conversational data in 3 styles (holiday/restaurant/hotel, bank/telephone/insurance, HDB/MOE/MSF).

Our application is designed for O, N, and A level students in Singapore, with expected audio files primarily coming from classroom settings or video lectures. The speech in these recordings may vary between structured and conversational formats. To address challenges the pre-trained Whisper model encounters, I decided to focus on Part 3 of the NSC for finetuning our model, due to its conversational nature which should be able to better represent the kind of speech typically found in classrooms. 

The NSC Part 3 recordings are divided into two environments, each recorded with two different types of microphones. In the first environment, where speakers were in the same room, we selected recordings captured with the close-talk microphone to better isolate the main speaker’s voice. In the second environment, where speakers were in separate rooms, we chose recordings from the standing microphone instead of the telephone recordings.

I downloaded the 4 folders and organised them in the following format:
```
dataset
│
├── audio
│   ├── 3000-1.wav                              # Audio Same CloseMic
│   ├── 3000-2.wav                              # Audio Same CloseMic
│   ├── conf_2500_2500_00862025.wav             # Audio Seperate StandingMic
│   └── conf_2500_2500_00862177.wav             # Audio Seperate StandingMic
│
└── transcripts
    ├── 3000-1.TextGrid                         # Scripts Same
    ├── 3000-2.TextGrid                         # Scripts Same
    ├── conf_2500_2500_00862025.TextGrid        # Scripts Seperate
    └── conf_2500_2500_00862177.TextGrid        # Scripts Seperate
```

## Raw Data
Taking a look at our raw data, we should be expecting 1 transcript to each audio file. However, that is not the case as we have 492 audio files in the 'Audio Same CloseMic' folder but 494 transcripts in the 'Scripts Same' folder. Similarly, we have 431 audio files in the 'Audio Seperate StandingMic' folder and only 418 transcripts in the 'Scripts Seperate' folder. As such, we need to find the outliers and remove the appropriate files so we have the correct amount of audio files to transcript files before we can proceed.

In [3]:
def remove_unmatched_files(audio_dir, transcript_dir):
    audio_files = os.listdir(audio_dir)
    transcript_files = os.listdir(transcript_dir)

    # remove the file extensions for comparison
    audio_files_base = [os.path.splitext(f)[0] for f in audio_files]
    transcript_files_base = [os.path.splitext(f)[0] for f in transcript_files]

    unmatched_audio_files = []
    unmatched_transcript_files = []

    # check for every audio file if there's a corresponding transcript
    for audio_base in audio_files_base:
        if audio_base not in transcript_files_base:
            unmatched_audio_files.append(audio_base + '.wav')

    # check for every transcript file if there's a corresponding audio
    for transcript_base in transcript_files_base:
        if transcript_base not in audio_files_base:
            unmatched_transcript_files.append(transcript_base + '.TextGrid')

    if unmatched_audio_files:
        for file in unmatched_audio_files:
            # Delete unmatched audio files
            os.remove(os.path.join(audio_dir, file))

    if unmatched_transcript_files:
        for file in unmatched_transcript_files:
            # delete unmatched transcript files
            os.remove(os.path.join(transcript_dir, file))

    return unmatched_audio_files, unmatched_transcript_files

# remove unmatched files
unmatched_audio, unmatched_transcripts = remove_unmatched_files('dataset/audio', 'dataset/transcripts')

print('\nRemaining files in the audio directory:', len(os.listdir('dataset/audio')))
print('Remaining files in the transcript directory:', len(os.listdir('dataset/transcripts')))



Remaining files in the audio directory: 902
Remaining files in the transcript directory: 902


## Data Pre-Processing
Whisper works by converting features within a 30s chunk of audio, represented as a log-Mel spectogram, into text. In order for us to train the model, we need to first process our raw dataset. We will do this by creating functions that:
1. Clean transcripts by removing annotations for paralinguistic phenomena, fillers, unknown words, unclear words, short pauses, and removing annotations for particles to make it part of the text.
2. Normalise transcripts by removing punctuation and converting text to lowercase.
4. Splice out utterances from the audio using timestamps provided in transcripts and combine shorter consecutive segments such that each segment is no longer than 30s at a sampling rate of 16kHz.

I will start by defining a 'clean_transcript' function that will do the following:
- Remove annotations for breathing, coughing and laughing represented in the transcript by annotations such as (ppo), (ppb), (ppl), etc.
- Remove fillers or unknown words '\<FIL/>', unclear words '\<UNK>', short pauses '\<S>' etc.
- Remove annotations for particles like '[lah]' and keep it as part of the text.

In [4]:
def clean_transcript(lines):
    # pattern to match text inside quotes after 'text ='
    pattern = r'(?<=text = ")(.*?)(?=")'

    cleaned_lines = []
        
    for line in lines:
        # find and clean the 'text' field (inside the quotes after 'text = ')
        match = re.search(pattern, line)
        if match:
            original_text = match.group(0)  # get the 'text' content

            # handle edge case where parentheses/angled brackets may appear at the end of the string or before a punctuation 
            if bool(re.search(r'(\(.*?\)|\<.*?\>)(?=\s*[\.,!?;]|\s*$)', original_text)):
                # remove parentheses and their contents along with the space before them
                # remove angled brackets and their contents along with the space before them
                # remove square brackets but keep contents
                cleaned_text = re.sub(r'\s?(\(.*?\)|\<.*?\>)|\[|\]', '', original_text)
            else:
                # remove parentheses and their contents along with the space after them
                # remove angled brackets and their contents along with the space after them
                # remove square brackets but keep contents
                cleaned_text = re.sub(r'\(.*?\)\s?|\<.*?\>\s?|\[|\]', '', original_text)

            # replace the original text with the cleaned version
            cleaned_line = line.replace(original_text, cleaned_text)
            
        else:
            cleaned_line = line

        cleaned_lines.append(cleaned_line)

    return cleaned_lines

Moving on, I will define another function to help normalise the transcript by removing punctuation and converting the text to lower case to avoid inconsistent casings and punctuation after finetuning our model.

In [5]:
def normalise_transcript(lines):
    # pattern to match text inside quotes after 'text ='
    pattern = r'(?<=text = ")(.*?)(?=")'
    
    normalised_lines = []
    
    for line in lines:
        match = re.search(pattern, line)
        if match:
            original_text = match.group(0)
            
            # remove unwanted punctuation, if needed, and convert to lowercase
            normalised_text = re.sub(r'[^\w\s]', '', original_text).lower()

            # Replace the original text with the normalized version
            normalised_line = line.replace(original_text, normalised_text)

        else:
            normalised_line = line

        # Add the normalized line to the list
        normalised_lines.append(normalised_line)
    
    return normalised_lines

Here, I will define a function that takes in the path to a directory to process all the TextGrid transcripts inside it using the previously defined data pre-processing functions. All the processed transcripts would be stored in the 'transcripts_processed' folder, shown in the diagram below:
```
dataset
│
├── audio
│   ├── 3000-1.wav                              # from 'NSC Part 3/Audio Same CloseMic'
│   ├── 3000-2.wav                              # from 'NSC Part 3/Audio Same CloseMic'
│   └── conf_2500_2500_00862025.wav             # from 'NSC Part 3/Audio Seperate StandingMic'
│
├── transcripts
│   ├── 3000-1.TextGrid                         # from 'NSC Part 3/Scripts Same'
│   ├── 3000-2.TextGrid                         # from 'NSC Part 3/Scripts Same'
│   └── conf_2500_2500_00862025.TextGrid        # from 'NSC Part 3/Scripts Seperate'
│
└── transcripts_processed                       # (Newly created folder)
    ├── 3000-1_processed.TextGrid               # (processed transcript)
    └── 3000-2_processed.TextGrid               # (processed transcript)
```

In [6]:
async def detect_encoding(file_path):
    return await asyncio.to_thread(read_file_detect_encoding, file_path)

def read_file_detect_encoding(file_path):
    with open(file_path, 'rb') as file:
        raw_data = file.read()
        encoding = chardet.detect(raw_data)['encoding']
    return encoding

# helper function to read file content asynchronously
async def read_file(file_path, encoding):
    return await asyncio.to_thread(read_file_content, file_path, encoding)

def read_file_content(file_path, encoding):
    with open(file_path, 'r', encoding=encoding) as file:
        return file.readlines()

# helper function to write file content asynchronously
async def write_file(file_path, data):
    await asyncio.to_thread(write_file_content, file_path, data)

def write_file_content(file_path, data):
    with open(file_path, 'w', encoding='utf-8') as file:
        file.writelines(data)

async def process_transcript(lines):
    return await asyncio.to_thread(normalise_and_clean, lines)

def normalise_and_clean(lines):
    cleaned = clean_transcript(lines)
    return normalise_transcript(cleaned)

async def process_transcripts(input_folder, output_folder='dataset/transcripts_processed'):
    if not os.path.exists(input_folder):
        print(f"Error: The folder {input_folder} does not exist.")
        return

    # create the output folder if it doesn't exist
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    tasks = []  # List to hold all the tasks for concurrency

    # iterate over all files in the input folder
    for file_name in os.listdir(input_folder):
        file_path = os.path.join(input_folder, file_name)

        if os.path.isfile(file_path):
            # detect encoding asynchronously
            encoding = await detect_encoding(file_path)

            # read the file asynchronously
            lines = await read_file(file_path, encoding)

            # process the transcript (normalizing and cleaning) concurrently
            processed_transcript = await process_transcript(lines)

            # get the processed file path
            processed_path = os.path.join(output_folder, os.path.basename(file_path).replace(".TextGrid", "_processed.TextGrid"))

            # write the processed transcript to a file asynchronously
            tasks.append(write_file(processed_path, processed_transcript))

    # wait for all tasks to complete asynchronously
    await asyncio.gather(*tasks)
    print(f"All files processed and saved.")

Let us begin with the file pre-processing!

In [7]:
await process_transcripts('dataset/transcripts')

All files processed and saved.


The transcripts that are stored in TextGrid format are structured in such a way where for each utterance, there is a timestamp associated with it. I will be defining a function that uses the timestamps to calculate the duration of all the utterances, extract said utterances from the audio and combine consecutive utterances from the audio to fit inside 1 sub 30s audio file. Similar to the audio files, the text portion in the transcript for consecutive utterances will also be combined and saved in a .txt file. The result of this step would be compact sub 30s audio files along with exactly what was said in them stored in the 'audio_segment_output_folder' and 'transcript_segment_output_folder'. For some context this is how the TextGrid file is formatted:
```
File type = "ooTextFile"
Object class = "TextGrid"

xmin = 0 
xmax = 30.0
tiers? <exists> 
size = 2
item []:
    item [1]:
        class = "IntervalTier"
        name = "audio-1"
        xmin = 0 
        xmax = 30.0
        intervals: size = 3
        intervals [1]:
            xmin = 0 
            xmax = 5.0 
            text = "Hello"
        intervals [2]:
            xmin = 5.0 
            xmax = 10.0 
            text = "how are you?"
```

In [7]:
def splice_audio(
        audio_path, 
        textgrid_path, 
        audio_segment_output_folder, 
        transcript_segment_output_folder, 
        max_duration=30
    ):       
    try:
        original_audio = AudioSegment.from_file(audio_path)
    except Exception as e:
        print(f"Error loading audio file {audio_path}: {e}")
        return

    # patterns for extracting intervals, xmin, xmax, and text
    interval_pattern = re.compile(r'intervals \[\d+\]:\s*(.*?)\s*(?=intervals|\Z)', re.DOTALL)
    xmin_pattern = re.compile(r'xmin = ([\d.]+)')
    xmax_pattern = re.compile(r'xmax = ([\d.]+)')
    text_pattern = re.compile(r'text = "(.*?)"')

    try:
        with open(textgrid_path, 'r', encoding='utf-8') as file:
            textgrid_content = file.read()
    except Exception as e:
        print(f"Error reading TextGrid file {textgrid_path}: {e}")
        return

    # find all interval blocks in the TextGrid file
    intervals = interval_pattern.findall(textgrid_content)

    # initialize lists for audio and transcript segments
    combined_audio_segments = []
    combined_transcript_segments = []

    current_audio_segment = AudioSegment.empty()
    current_transcript_segment = ''
    total_duration = 0

    for interval in intervals:
        try:
            # Extract xmin, xmax, and text values
            xmin = float(xmin_pattern.search(interval).group(1))
            xmax = float(xmax_pattern.search(interval).group(1))
            text_match = text_pattern.search(interval)
            text = text_match.group(1) if text_match else ""
        except AttributeError as e:
            print(f"Error extracting interval data from TextGrid: {e}")
            continue

        # exclude utterances that are empty
        if text == "":
            continue

        # calculate the duration of the interval
        interval_duration = xmax - xmin

        try:
            # check if adding the current interval exceeds max duration
            if total_duration + interval_duration > max_duration:
                # if the current segment has data, save it
                if current_audio_segment and current_transcript_segment != '':
                    combined_audio_segments.append(current_audio_segment)
                    combined_transcript_segments.append(current_transcript_segment.rstrip())

                # start a new segment
                current_audio_segment = original_audio[xmin * 1000 : xmax * 1000]
                current_transcript_segment = text + ' '
                total_duration = interval_duration
            else:
                # add current interval to the segment
                current_audio_segment += original_audio[xmin * 1000 : xmax * 1000]
                current_transcript_segment += text + ' '
                total_duration += interval_duration
        except Exception as e:
            print(f"Error processing interval at {xmin}-{xmax} seconds: {e}")
            continue

    # add the last segment if necessary
    if current_audio_segment and current_transcript_segment != '':
        combined_audio_segments.append(current_audio_segment)
        combined_transcript_segments.append(current_transcript_segment.rstrip())

    # export audio and transcript segments
    try:
        for id_, (audio_segment, transcript_segment) in enumerate(zip(combined_audio_segments, combined_transcript_segments)):
            # save the audio segment as a .wav file
            audio_segment.export(os.path.join(audio_segment_output_folder, f"{os.path.splitext(os.path.basename(audio_path))[0]}_segment_{id_ + 1}.wav"), format="wav")
        
            # save the transcript segment as a .txt file
            with open(os.path.join(transcript_segment_output_folder, f"{os.path.splitext(os.path.basename(textgrid_path))[0]}_segment_{id_ + 1}.txt"), 'w', encoding='utf-8') as transcript_file:
                transcript_file.write(transcript_segment)
    except Exception as e:
        print(f"Error exporting audio or transcript segments: {e}")
        return

I will make use of the function defined to splice all audio files inside my 'audio' and 'transcripts_processed' folder with the transcript and audio segments going into 'the 'transcript_segments' and 'audio_segments' folders. Multithredding is used due to the large number of files that needs to be spliced.

In [7]:
def process_audio_files_in_directory(audio_dir, transcript_dir, audio_segment_output_folder, transcript_segment_output_folder):
    # ThreadPoolExecutor to process files concurrently
    with ThreadPoolExecutor() as executor:
        futures = []
        for audio_filename in os.listdir(audio_dir):
            if audio_filename.endswith('.wav'):
                audio_path = os.path.join(audio_dir, audio_filename)
                textgrid_filename = audio_filename.replace('.wav', '_processed.TextGrid')
                textgrid_path = os.path.join(transcript_dir, textgrid_filename)

                if os.path.exists(textgrid_path):
                    # submit the task to the executor for processing
                    futures.append(executor.submit(splice_audio, audio_path, textgrid_path, audio_segment_output_folder, transcript_segment_output_folder))
                else:
                    print(f"Warning: No corresponding TextGrid file for {audio_filename}. Skipping.")
        
        # wait for all futures to complete
        for future in futures:
            future.result()  # this will raise exceptions if there were any during execution

    print(f"All audio files and their corresponding transcript have been spliced and saved.")


audio_dir = 'dataset/audio'
transcript_dir = 'dataset/transcripts_processed'
audio_segment_output_folder='dataset/audio_segments' 
transcript_segment_output_folder='dataset/transcript_segments'

# create the output folders if it doesn't exist
if not os.path.exists(audio_segment_output_folder):
    os.makedirs(audio_segment_output_folder)

if not os.path.exists(transcript_segment_output_folder):
    os.makedirs(transcript_segment_output_folder)
    

process_audio_files_in_directory(audio_dir, transcript_dir, audio_segment_output_folder, transcript_segment_output_folder)

All audio files and their corresponding transcript have been spliced and saved.


## Loading Data
The whole of NSC parts 3 and 4 consists of roughly 1000 hours of speech each which is a large amount of data that I will need to load into memory. In order to reduce memory usage, I made use of the IterableDatasetDict and implemented a data loading script which helps lazily load my dataset as needed. 

Whisper works by converting features within a 30s chunk of audio, represented as a log-Mel spectogram, into text. In order for us to train the model, we also need to convert each audio chunk into a log-Mel spectogram and tokenise the corresponding transcript. Here, I defined a function to do just that to process the data in the iterable dataset before using it for training.

In [8]:
processor = WhisperProcessor.from_pretrained('openai/whisper-small', task='transcribe')

def prepare_dataset(batch):
    audio_path = batch['audio_path']
    transcript_path = batch['transcript_path']

    audio = AudioSegment.from_file(audio_path)
    audio = audio.set_frame_rate(16000)
    audio_array = np.array(audio.get_array_of_samples())

    with open(transcript_path, 'r', encoding='utf-8') as f:
        transcript = f.read().strip()
    
    # perform feature extraction by computing log-Mel input features from input audio
    batch['input_features'] = torch.tensor(processor.feature_extractor(audio_array, sampling_rate=16000).input_features[0])
    # tokenize the transcripts
    batch['labels'] = processor.tokenizer(transcript).input_ids

    return batch

dataset_train = load_dataset('local_loadingScript.py','all', split='train', streaming=True)
dataset_val = load_dataset('local_loadingScript.py','all', split='validation', streaming=True)

dataset = IterableDatasetDict()
dataset['train'] = dataset_train
dataset['validation'] = dataset_val

dataset_processed = dataset.map(prepare_dataset)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


## Evaluation Metrics
Before we begin with the process of finetuning our model, let's first define the metrics we would be using for comparison purposes. One of the most commonly adopted metric for ASR tasks would be Word Error Rate (WER), which can be calculated as follows:

$WER = \frac{N}{S + D + I}$

Where:
- **S** is the number of substitutions
- **D** is the number of deletions
- **I** is the number of insertions
- **N** is the number of words in the reference

In this case, instead of writing my own function to calculate the WER, I will simply use the evaluate library to do it for me.

In [9]:
# calculate WER using predicted token array and ground truth token array
metric = evaluate.load('wer')

def compute_metrics(pred) -> Dict[str, float]:
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    
    # we do not want to group tokens when computing the metrics
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)
    
    return {"wer": wer}

## Loading Model
As this model is being trained locally using an RTX 4070 Ti Super, I will be making use of a technique called gradient checkpointing, which reduces memory usage but increases computation time to manage system resource requirements. However, the model originally comes with key Value (KV) caching enabled which poses a problem. While KV caching is used to help speed up transformer decoding during inference, it is incompatible with gradient checkpointing. As such, I will modify the model's configuration to disable KV caching along with setting up other model configurations.

In [10]:
# load the pretrained model
model = WhisperForConditionalGeneration.from_pretrained('openai/whisper-small')

# set up generation config
generation_config = GenerationConfig.from_pretrained('openai/whisper-small')
generation_config.task = 'transcribe'
generation_config.language = 'en'

# disable KV caching on model
model.config.use_cache = False
# modify the behaviour of the original 'generate' function to enable KV caching
model.generate = partial(model.generate, generation_config=generation_config, use_cache=True)

## Data Collator
As we convert our 30s audio chunks and corresponding transcripts into audio features and tokenized labels, we need to ensure that they are of the same lengths. This custom data collator pads audio features and tokenized labels to the appropriate max length.adding tokens are then replaced with -100, so that they are subsequently ignored in the loss calculations during finetuning.

In [11]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int
    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{'input_features': feature['input_features']} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors='pt')

        # get the tokenized label sequences
        label_features = [{'input_ids': feature['labels']} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors='pt')

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch['input_ids'].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's appended later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch['labels'] = labels
        batch['attention_mask'] = labels_batch['attention_mask']

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor, decoder_start_token_id=model.config.decoder_start_token_id)

## Model Training
I will be making use of Seq2SeqTrainer from Hugging Face's Transformers library  to finetune the Whisper model on my dataset. In total, we will be training our model on roughly 750 hours of audio from NSC Part 3 for 10000 steps at a learning rate of 40x smaller than what was used for pre-training using a training batch size of 32 with the rest of the audio going to the validation and test set (~95 hours each).

In [13]:
training_args = Seq2SeqTrainingArguments(
    output_dir='./whisper-finetune',
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    gradient_accumulation_steps=1,
    learning_rate=1e-5, #40 times smaller than what was used for pre-training
    warmup_steps=500,
    max_steps=8000,
    gradient_checkpointing=True,
    fp16=True,
    eval_strategy='steps',
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=250,
    eval_steps=250,
    logging_steps=20,
    load_best_model_at_end=True,
    metric_for_best_model='wer',
    greater_is_better=False,
    push_to_hub=False,
    max_grad_norm=1.0,
    logging_dir='./logs',
    logging_strategy='steps',
)

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=dataset_processed['train'],
    eval_dataset=dataset_processed['validation'],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

trainer.train()
# save the tokeniser which includes the list of vocabulary
processor.save_pretrained('whisper-finetune/processor')

max_steps is given, it will override any value given in num_train_epochs


Step,Training Loss,Validation Loss,Wer
250,0.6083,0.622784,24.665307
500,0.3935,0.521898,21.434722
750,0.2776,0.476975,17.635924
1000,0.379,0.469981,18.228255
1250,0.4161,0.476523,18.135082
1500,0.4203,0.450703,17.28829
1750,0.5677,0.438338,16.878659
2000,0.3942,0.427253,16.195979
2250,0.5179,0.417884,15.925871
2500,0.3808,0.403043,15.115784


You have passed task=transcribe, but also have set `forced_decoder_ids` to [[1, None], [2, 50359]] which creates a conflict. `forced_decoder_ids` will be ignored in favor of task=transcribe.


KeyboardInterrupt: 

As I saw that the model's validation loss and WER stops improving, I decided to stop the training early as it is a sign that the model has converged. We can visualise the training process using tensorboard. Make sure to clear the temporary files under '.tensorboard-info' if it refuses to connect.

In [1]:
%load_ext tensorboard
%tensorboard --logdir "logs"

Now that I am done with finetuning the model, I will login to HuggingFace via the CLI before saving the model to my HuggingFace repository. In this case, I will be taking checkpoint 3000 because we can see that the validation loss and WER stops improving after that point and it is at it's lowest.

In [4]:
repo_name = 'Xycone/whisper-small-SGspeech-finetune'

model = WhisperForConditionalGeneration.from_pretrained('whisper-finetune/checkpoint-3000')
processor = WhisperProcessor.from_pretrained('whisper-finetune/processor')

model.push_to_hub(repo_name)
processor.push_to_hub(repo_name)

README.md:   0%|          | 0.00/455 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


README.md:   0%|          | 0.00/455 [00:00<?, ?B/s]

No files have been modified since last commit. Skipping to prevent empty commit.


CommitInfo(commit_url='https://huggingface.co/Xycone/whisper-small-SGspeech-finetune/commit/351b8f244d88e845687ef351d0311128a6e23c23', commit_message='Upload processor', commit_description='', oid='351b8f244d88e845687ef351d0311128a6e23c23', pr_url=None, pr_revision=None, pr_num=None)

## Model Evaluation: Pre-trained Vs Fine-tuned
Now that I have trained the Whisper model, I will be comparing the trained model against the pre-trained model with some random examples.

In [13]:
from transformers import pipeline
transcript_path='dataset/transcript_segments/conf_2726_2726_00862618_processed_segment_50.txt'
audio_path='dataset/audio_segments/conf_2726_2726_00862618_segment_50.wav'

with open(transcript_path, 'r', encoding='utf-8') as f:
    transcript = f.read().strip()

pretrained_whisper = pipeline('automatic-speech-recognition', 'openai/whisper-small', return_timestamps=True, torch_dtype=torch.float16, device='cuda:0', generate_kwargs = {"task":"transcribe", "language":"<|en|>"})
pretrained_output = pretrained_whisper(audio_path)

finetuned_whisper = pipeline('automatic-speech-recognition', 'Xycone/whisper-small-SGspeech-finetune', return_timestamps=True, torch_dtype=torch.float16, device='cuda:0', generate_kwargs = {"task":"transcribe", "language":"<|en|>"})
finetuned_output = finetuned_whisper(audio_path)

print(f"Ground Truth:\n{transcript}\n")
print(f"Pretrained Output:\n{pretrained_output['text']}\n")
print(f"Finetuned Output:\n{finetuned_output['text']}\n")

You have passed task=transcribe, but also have set `forced_decoder_ids` to [[1, None], [2, 50359]] which creates a conflict. `forced_decoder_ids` will be ignored in favor of task=transcribe.


Ground Truth:
okay ya what is the oh okay self disappointed that means if they find out something about me then they will be disappointed that perhaps im that im selfish that i want to have my own space and not having to share with a in laws ya then like feeling that the the in laws staying with us is a intrusion to my family life yup ya ya is normal lor but is yes ya lor correct ya lor

Pretrained Output:
 Okay, yeah, what is the oh, okay self disappointed. That means if they find out something about me Then they'll be disappointed that perhaps I'm that I'm selfish that I want to have my own space and not having to share with a in-laws yeah, then like Feeling that the the in-laws thing with us is intrusion to my family life. Yeah, yeah, it's normal

Finetuned Output:
 okay yup what is the oh okay self disappointed that means if they find out something about me then theyll be disappointed that perhaps im that im selfish that i want to have my own space and not having to share with a in

## Acknowledgements
This notebook follows the methods outlined in the Medium post ["Finetuning Whisper for the Singaporean Home Team Context"](https://medium.com/htx-dsai/finetuning-whisper-for-the-singaporean-home-team-context-a3ae1a6ae809) by Rachel LW Tan and a blog post titled ["Singlish-Whisper: Finetuning ASR for Singapore's Unique English"](https://www.jensenlwt.com/blog/singlish-whisper-finetuning-asr-for-singapore-unique-english) by Jenson Low. Both author’s approach to finetuning and training OpenAI's Whisper model was applied in this notebook with some minor modifications tailored to my use case.