# **Whisper Model fine-tuning for ASR**

In this file, you can easily fine-tune different variations of the Whisper model to your specific multilingual data based on a simple manifest. 

In [None]:
!pip install datasets
!pip install evaluate
!pip install --upgrade librosa

In [1]:
from datasets import Dataset
import pandas as pd
from datasets import Audio
import gc
import evaluate
import torch
import csv

If you want to run the code on googlecolab you can add the googledrive. If you want to run locally, you dont need to this cell.

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# **Data prepration**
In this section, you can create the train.csv and test.csv files based on Kaldi's file. If you do not want to use Kaldi's files, you can easily write your code to generate train.csv and test.csv in this format:<br>
(one audio file and its corresponding text in each row)

     Path 	                                                         Sentence
______________________________________________________________________
/home/rf/kaldi/egs/Aref_mini_librispeech/s5/corpus/010/s010u128w.wav&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;lots of foreign movies have subtitles<br>
/home/rf/kaldi/egs/Aref_mini_librispeech/s5/corpus/029/s009u423n.wav&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;smash lightbulbs and their cash value will diminish to nothing<br>
/home/rf/kaldi/egs/Aref_mini_librispeech/s5/corpus/011/s011u328w.wav&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;gregory and tom chose to watch cartoons in the afternoon<br>
/home/rf/kaldi/egs/Aref_mini_librispeech/s5/corpus/018/s018u445w.wav&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;what is this large thing by the ironing board<br>
/home/rf/kaldi/egs/Aref_mini_librispeech/s5/corpus/017/s017u372w.wav&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;that diagram makes sense only after much study<br>
/home/rf/kaldi/egs/Aref_mini_librispeech/s5/corpus/023/s003u207n.wav&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;do you hear the sleigh bells ringing<br> 
..............................

In [None]:
## Preparing train.csv and test.csv based on the Kaldi's files: "text" and "wav.scp"

# Read the first .txt file and store the second elements
file1_values = []
with open('wav_train.scp', 'r') as file1:
    for line in file1:
        values = line.strip().split() 
        if len(values) >= 2: 
            file1_values.append(values[1])  # piking up the address of each audio file

# Read the second .txt file and store all elements except the first
file2_values = []
with open('text_train', 'r') as file2:
    for line in file2:
        values = line.strip().split()  
        if len(values) > 1:  
            file2_values.append(' '.join(values[1:]))  # Join all elements except the first that was audio filename

# Write the values into a new CSV file
with open('train.csv', 'w', newline='') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(['Path', 'Sentence'])  # Write the header row with updated column names

    for value1, value2 in zip(file1_values, file2_values):
        writer.writerow([value1, value2])

In [None]:
## preparing test.csv and test.csv based on the Kaldi's files: "text" and "wav.scp"

# Read the first .txt file and store the second elements
file1_values = []
with open('wav_test.scp', 'r') as file1:
    for line in file1:
        values = line.strip().split()  # Assuming space-separated elements
        if len(values) >= 2:  # Check if there is a second element
            file1_values.append(values[1])  # piking up the address of each audio file

# Read the second .txt file and store all elements except the first
file2_values = []
with open('text_test', 'r') as file2:
    for line in file2:
        values = line.strip().split()  # Assuming space-separated elements
        if len(values) > 1:  # Check if there are more than one elements
            file2_values.append(' '.join(values[1:]))  # Join all elements except the first that was audio filename

# Write the values into a new CSV file
with open('test.csv', 'w', newline='') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(['Path', 'Sentence'])  # Write the header row with updated column names

    # Write the corresponding values from both files
    for value1, value2 in zip(file1_values, file2_values):
        writer.writerow([value1, value2])

# **Model Training**
You can easily train the Whisper model using the train.csv and test.csv files.
If you don't have enough GPU memory, you can use lighter versions of the Whisper model like "tiny" and "base" instead of the "small" version.
Please change the files' location to yours. In addition, you can switch the "language" parameter to your dataset's language.

In [2]:
## we will load the both of the data here.
train_df = pd.read_csv("/home/rf/makingdata_for_wav2vec/train.csv")
test_df = pd.read_csv("/home/rf/makingdata_for_wav2vec/test.csv")

## convert the pandas dataframes to dataset
train_dataset = Dataset.from_pandas(train_df)
test_dataset = Dataset.from_pandas(test_df)

## convert the sample rate of every audio files using cast_column function
train_dataset = train_dataset.cast_column("Path", Audio(sampling_rate=16000))
test_dataset = test_dataset.cast_column("Path", Audio(sampling_rate=16000))

In [3]:
## Import feature extractor
from transformers import WhisperFeatureExtractor
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small") # openai/whisper-base

## Load WhisperTokenizer
from transformers import WhisperTokenizer
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="English", task="transcribe")

## Combine To Create A WhisperProcessor
from transformers import WhisperProcessor
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="English", task="transcribe")

In [5]:
def prepare_dataset(examples):
    # compute log-Mel input features from input audio array
    audio = examples["Path"]
    examples["input_features"] = feature_extractor(
        audio["array"], sampling_rate=16000).input_features[0]
    del examples["Path"]
    sentences = examples["Sentence"]

    # encode target text to label ids
    examples["labels"] = tokenizer(sentences).input_ids
    del examples["Sentence"]
    return examples

In [None]:
train_dataset = train_dataset.map(prepare_dataset, num_proc=1)
test_dataset = test_dataset.map(prepare_dataset, num_proc=1)

In [7]:

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]
        batch["labels"] = labels
        return batch

## lets initiate the data collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [8]:
metric = evaluate.load("wer")

In [None]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    wer = 100 * metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}

In [None]:
# Load a Pre-Trained Checkpoint
from transformers import WhisperForConditionalGeneration
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

In [11]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

Based on your dataset and Whisper model size, you can change the training parameters.

In [16]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-small",  # change to a repo name of your choice
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=15000,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=500,
    eval_steps=500,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
)

from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)



In [None]:
print("step1 - training...")
train_result = trainer.train()

In [None]:
print("step2 - ")

metrics = train_result.metrics
print("step3")
max_train_samples = len(train_dataset)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
print("step4 - saving the model...")
trainer.save_model()
print("model created!")
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()