# Custom Dataset Creation using Wav2Vec2

### Author: RENAUD Thomas Marc Frederic
### Last Modified: Dec. 3, 2023

**Abstract:**
This project draws inspiration from Samrat Dutta's paper on "Error Correction in ASR using Sequence-to-Sequence Models." [1]. The primary objective of this notebook is to fine-tune a Bidirectional Auto-Regressive Transformer (BART) [2] to rectify errors introduced by an ASR model. By introducing a model with linguistic intuition, the aim is to provide a supplementary layer of comprehension that can navigate and rectify the nuanced errors arising in ASR transcriptions.

**Acknowledgment:**
This notebook builds upon the documentation created by Hugging Face, accessible [here](https://huggingface.co/docs/transformers/model_doc/bart).

**Objective:**
The code provided in this notebook is a comprehensive overview of the fine-tuning of BART, transforming it into a powerfull post-editing tool.

**Input:**
- Personal custom dataset
- BART pre-trained checkpoint

**Output:**
The result is overall word error rate of the model.

## **References**

[1] Samrat Dutta, Shreyansh Jain, Ayush Maheshwari, Souvik Pal, Ganesh Ramakrishnan, and
Preethi Jyothi. Error correction in asr using sequence-to-sequence models. 2022

[2] Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer
Levy, Ves Stoyanov, and Luke Zettlemoyer. Bart: Denoising sequence-to-sequence pre-training
for natural language generation, translation, and comprehension. 2019.

In [2]:
!pip install transformers==4.24.0

import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BartForConditionalGeneration, BartTokenizer, AdamW, get_cosine_schedule_with_warmup
from sklearn.model_selection import train_test_split
import random
from google.colab import drive
import pandas as pd
from IPython.display import clear_output

clear_output(wait=False)

In [3]:
drive.mount('/content/gdrive')

# Replace 'your_file.csv' with the actual filename
file_path = '/content/gdrive/MyDrive/ENSC/3A/CI/wav2vec2-timit-asr-error.csv'

# Read the CSV file into a Pandas DataFrame
df_dataset = pd.read_csv(file_path, delimiter=',')
df_dataset

Mounted at /content/gdrive


Unnamed: 0,pred_str,text
0,the bunglo was plesntly situated near the shor,The bungalow was pleasantly situated near the ...
1,don't ask me to carry an oily rag like that,Don't ask me to carry an oily rag like that.
2,are you looking for employment,Are you looking for employment?
3,she had your dark suit in greasy wash water al...,She had your dark suit in greasy wash water al...
4,at twilight on the twelfth day we'll have shible,At twilight on the twelfth day we'll have Chab...
...,...,...
1675,pam gives driving lessons on thorse days,Pam gives driving lessons on Thursdays.
1676,he rubbed his eyes sleepily with one huge paw,He rubbed his eyes sleepily with one huge paw.
1677,ait felguns were captured in position,Eight field guns were captured in position.
1678,alloinam silveoware can often be flimsy,Aluminum silverware can often be flimsy.


In [4]:
import re

chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"]'

def format_output(text):
    # function to substitute all occurrences of the characters specified by
    # chars_to_ignore_regex with an empty string ('') in the text
    return re.sub(chars_to_ignore_regex, '', text).lower()

df_dataset["text"] = df_dataset["text"].apply(format_output)

In [5]:
X = df_dataset['pred_str']
y = df_dataset['text']

# Split the data into training and testing sets "random_state" is set for reproducibility.
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Create separate DataFrames for training and testing sets
train_df = pd.DataFrame({'pred_str': X_train, 'text': y_train})
test_df = pd.DataFrame({'pred_str': X_test, 'text': y_test})

test_df

Unnamed: 0,pred_str,text
1603,destroy every file related to my audice,destroy every file related to my audits
482,are your grades higher or lower than nansies,are your grades higher or lower than nancy's
203,she had your dark suit in greasy wash water al...,she had your dark suit in greasy wash water al...
49,dissymber and jenuaerry are nice months to spi...,december and january are nice months to spend ...
937,the misquoat was retracted with an appoligy,the misquote was retracted with an apology
...,...,...
226,how way and freeway mean the same thing,highway and freeway mean the same thing
231,he remembered them well from his happy period,he remembered them well from his happy period
650,i'd rather not buy these shoes than be overcha...,i'd rather not buy these shoes than be overcha...
1511,heve on those roapes the boats come on stuck,heave on those ropes the boat's come unstuck


In [6]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

# Function resposible for the tokenization of the dataset
def preprocess_function(df):
    model_inputs = tokenizer(df["pred_str"].to_list(),
                             padding=True,
                             truncation=True,
                             return_tensors='pt')

    # Setup the tokenizer for targets
    labels = tokenizer(text_target=df["text"].to_list(),
                      padding=True,
                      truncation=True,
                      return_tensors='pt')

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_train_dataset = preprocess_function(train_df)
tokenized_test_dataset = preprocess_function(test_df)

clear_output(wait=False)
tokenized_train_dataset.keys()

dict_keys(['input_ids', 'attention_mask', 'labels'])

In [7]:
# Creation of Pytorch dataset
class PostEditingDataset(Dataset):
    def __init__(self, input_ids, attention_mask, labels):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.labels = labels

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_mask[idx],
            'labels': self.labels[idx]
        }

# Create datasets
train_dataset = PostEditingDataset(
    input_ids=tokenized_train_dataset['input_ids'],
    attention_mask=tokenized_train_dataset['attention_mask'],
    labels=tokenized_train_dataset['labels']
)

test_dataset = PostEditingDataset(
    input_ids=tokenized_test_dataset['input_ids'],
    attention_mask=tokenized_test_dataset['attention_mask'],
    labels=tokenized_test_dataset['input_ids']
)

# Set batch size
batch_size = 12

# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [8]:
bart_model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
# Training hyperparameters
num_epochs = 10
optimizer = torch.optim.AdamW(bart_model.parameters(), lr=3e-5, weight_decay=0.1)
total_steps = len(train_dataloader) * num_epochs
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps)

In [10]:
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
bart_model.to(device)
bart_model.train()

for epoch in range(num_epochs):
    start_time = time.time()
    train_loss = 0.0
    for batch in train_dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = bart_model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        train_loss += loss.item()

    end_time = time.time()
    epoch_time = end_time - start_time

    avg_loss = train_loss / len(train_dataloader)
    print(f'Epoch: {epoch + 1}, Loss: {avg_loss}, Time: {epoch_time} seconds')


Epoch: 1, Loss: 5.200909559215818, Time: 18.676990270614624 seconds
Epoch: 2, Loss: 0.7315891428983637, Time: 15.404480457305908 seconds
Epoch: 3, Loss: 0.3424471190997532, Time: 13.8198401927948 seconds
Epoch: 4, Loss: 0.23476180134873306, Time: 13.635308504104614 seconds
Epoch: 5, Loss: 0.17101185276572192, Time: 15.968344926834106 seconds
Epoch: 6, Loss: 0.12839822926824646, Time: 15.627381086349487 seconds
Epoch: 7, Loss: 0.10880954136207167, Time: 14.00828766822815 seconds
Epoch: 8, Loss: 0.09404908478193517, Time: 15.180054187774658 seconds
Epoch: 9, Loss: 0.08544212452501856, Time: 15.144948244094849 seconds
Epoch: 10, Loss: 0.08447590550141675, Time: 15.431470155715942 seconds


In [11]:
!pip install jiwer
from torch.utils.data import DataLoader
from jiwer import wer

clear_output(wait=False)

In [12]:
def test_model(model, dataloader, device):
    model.eval()
    total_predictions = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            model.to(device)

            # Generate predictions
            generated_ids = bart_model.generate(input_ids, attention_mask=attention_mask, max_length=50)

            # Decode generated sequences
            predictions = [tokenizer.decode(gen_id, skip_special_tokens=True) for gen_id in generated_ids]

            total_predictions.extend(predictions)

    return total_predictions

def calculate_word_error_rate(predictions, labels):
    # Calculate Word Error Rate (WER) using jiwer library
    error_rate = wer(predictions, labels)
    return error_rate

# Test the model
predictions = test_model(bart_model, test_dataloader, device)

# Calculate global error rate
word_error_rate = calculate_word_error_rate(predictions, test_df["text"].to_list())

print(f'Word Error Rate (WER): {word_error_rate}')

Word Error Rate (WER): 0.09959072305593451


In [13]:
data = {'Predictions': predictions, 'Ground Truth': test_df["text"].to_list()}
df_predictions_ground_truth = pd.DataFrame(data)

df_predictions_ground_truth

Unnamed: 0,Predictions,Ground Truth
0,destroy every file related to my audits,destroy every file related to my audits
1,are your grades higher or lower than nancy's,are your grades higher or lower than nancy's
2,she had your dark suit in greasy wash water al...,she had your dark suit in greasy wash water al...
3,december and january are nice months to spend ...,december and january are nice months to spend ...
4,the misquote was retracted with an apology,the misquote was retracted with an apology
...,...,...
331,how way and freeway mean the same thing,highway and freeway mean the same thing
332,he remembered them well from his happy period,he remembered them well from his happy period
333,i'd rather not buy these shoes than be overcha...,i'd rather not buy these shoes than be overcha...
334,heaven on those nights the boats come on stuck,heave on those ropes the boat's come unstuck


In [14]:
data = {'Original predictions': test_df["pred_str"], 'Post-editing': predictions}
df_post_editing_results = pd.DataFrame(data)

df_post_editing_results

Unnamed: 0,Original predictions,Post-editing
1603,destroy every file related to my audice,destroy every file related to my audits
482,are your grades higher or lower than nansies,are your grades higher or lower than nancy's
203,she had your dark suit in greasy wash water al...,she had your dark suit in greasy wash water al...
49,dissymber and jenuaerry are nice months to spi...,december and january are nice months to spend ...
937,the misquoat was retracted with an appoligy,the misquote was retracted with an apology
...,...,...
226,how way and freeway mean the same thing,how way and freeway mean the same thing
231,he remembered them well from his happy period,he remembered them well from his happy period
650,i'd rather not buy these shoes than be overcha...,i'd rather not buy these shoes than be overcha...
1511,heve on those roapes the boats come on stuck,heaven on those nights the boats come on stuck
