# Custom Dataset Creation using Wav2Vec2

### Author: RENAUD Thomas Marc Frederic
### Last Modified: Dec. 3, 2023

**Abstract:**
This project draws inspiration from Samrat Dutta's paper on "Error Correction in ASR using Sequence-to-Sequence Models." [1]. The primary objective of this notebook is to fine-tune a Bidirectional Auto-Regressive Transformer (BART) to rectify errors introduced by an ASR model. The absence of a dedicated error dataset containing plausible ASR errors presents a challenge. To overcome this hurdle, we propose an innovative solution: leveraging a pre-existing ASR dataset to simulate errors through naive predictions based on audio records.

**Acknowledgment:**
This notebook builds upon the work presented by Patrick von Platen, accessible [here](https://huggingface.co/blog/fine-tune-wav2vec2-english). The foundational code has undergone modifications to align with the specific goals of this project. Debugging comments and dataset generation-specific lines of code have been added. Changes from the original code are denoted with "TR."

**Objective:**
The code provided in this notebook is crucial for the fine-tuning of BART, transforming it into a powerfull post-editing tool.

**Input:**
- Timit dataset [2]
- Wav2Vec2 pre-trained checkpoint

**Output:**
The result is a CSV file containing the generated dataset, ready for subsequent steps in the error correction process.

## **References**

[1] Samrat Dutta, Shreyansh Jain, Ayush Maheshwari, Souvik Pal, Ganesh Ramakrishnan, and
Preethi Jyothi. Error correction in asr using sequence-to-sequence models. 2022

[2] William M. Fisher Jonathan G. Fiscus David S. Pallett Nancy L. Dahlgren Victor Zue John
S. Garofolo, Lori F. Lamel. Timit acoustic-phonetic continuous speech corpus. 1993.

In [10]:
!pip install datasets==1.18.3
!pip install transformers==4.24.0
!pip install huggingface_hub==0.11
!pip install jiwer

import torch
import numpy as np

from transformers import AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ForCTC
from datasets import load_dataset, load_metric
from IPython.display import clear_output

clear_output(wait=False)

In [11]:
processor = Wav2Vec2Processor.from_pretrained("patrickvonplaten/wav2vec2-base-timit-demo-google-colab")
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-timit-demo-google-colab").cuda()
clear_output(wait=False) # TR

In [4]:
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

In [12]:
timit = load_dataset("timit_asr", split="test")
timit = timit.remove_columns(["phonetic_detail", "word_detail", "dialect_region", "id", "sentence_type", "speaker_id"])

def prepare_dataset(batch):
    audio = batch["audio"]

    # batched output is "un-batched" to ensure mapping is correct
    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    batch["input_length"] = len(batch["input_values"])

    # The term "target processor" implies that the processor is being used to handle the target labels or
    # outputs of a model.
    with processor.as_target_processor():
        #batch["labels"] = processor(batch["text"]).input_ids
        batch["labels"] = batch["text"] # TR
    return batch

timit_preprocess = timit.map(prepare_dataset, remove_columns=timit.column_names, num_proc=4)
clear_output(wait=False) # TR
print(timit_preprocess) # TR
show_random_elements(timit_preprocess, 1) # TR

Dataset({
    features: ['input_values', 'input_length', 'labels'],
    num_rows: 1680
})


Unnamed: 0,input_values,input_length,labels
0,"[-0.029056508094072342, -0.0010434247087687254, -0.011548331938683987, -0.0010434247087687254, 0.00420902855694294, 0.012963117100298405, 0.01996638812124729, 0.01996638812124729, 0.018215570598840714, 0.018215570598840714, 0.012963117100298405, 0.002458210801705718, 0.005959846079349518, -0.002794242464005947, -0.002794242464005947, -0.0010434247087687254, -0.0010434247087687254, -0.0010434247087687254, -0.0010434247087687254, -0.002794242464005947, 0.002458210801705718, -0.002794242464005947, 0.0007073930464684963, 0.00420902855694294, -0.002794242464005947, -0.0010434247087687254, 0.014713934622704983, 0.011212298646569252, 0.012963117100298405, 0.014713934622704983, 0.012963117100298405, 0.014713934622704983, 0.014713934622704983, 0.005959846079349518, 0.0007073930464684963, 0.00420902855694294, -0.002794242464005947, -0.008046695962548256, -0.008046695962548256, -0.009797513484954834, -0.0062958779744803905, -0.0010434247087687254, 0.0007073930464684963, 0.007710663601756096, 0.00420902855694294, 0.007710663601756096, 0.009461481124162674, 0.0007073930464684963, -0.009797513484954834, -0.013299149461090565, -0.011548331938683987, -0.013299149461090565, -0.009797513484954834, -0.011548331938683987, -0.0010434247087687254, 0.016464753076434135, 0.018215570598840714, 0.018215570598840714, 0.018215570598840714, 0.018215570598840714, 0.018215570598840714, 0.002458210801705718, -0.0010434247087687254, -0.008046695962548256, -0.009797513484954834, -0.009797513484954834, -0.0062958779744803905, -0.002794242464005947, 0.0007073930464684963, 0.018215570598840714, 0.018215570598840714, 0.01996638812124729, 0.018215570598840714, 0.018215570598840714, 0.018215570598840714, 0.011212298646569252, -0.002794242464005947, -0.009797513484954834, -0.009797513484954834, -0.008046695962548256, -0.008046695962548256, -0.002794242464005947, -0.002794242464005947, 0.002458210801705718, 0.016464753076434135, 0.014713934622704983, 0.018215570598840714, 0.018215570598840714, 0.01996638812124729, 0.016464753076434135, 0.007710663601756096, 0.002458210801705718, -0.0010434247087687254, -0.0062958779744803905, -0.008046695962548256, -0.009797513484954834, -0.008046695962548256, -0.009797513484954834, -0.0062958779744803905, -0.0010434247087687254, ...]",41677,Weatherproof galoshes are very useful in Seattle.


In [6]:
def map_to_result(batch):
  with torch.no_grad():
    input_values = torch.tensor(batch["input_values"], device="cuda").unsqueeze(0)
    logits = model(input_values).logits
    # print(logits.size()) # TR

  pred_ids = torch.argmax(logits, dim=-1)
  batch["pred_str"] = processor.batch_decode(pred_ids)[0]
  batch["text"] = batch["labels"] # TR
  # batch["text"] = processor.decode(batch["labels"], group_tokens=False)

  return batch
results = timit_preprocess.map(map_to_result, remove_columns=timit_preprocess.column_names)

print(results) # TR

0ex [00:00, ?ex/s]

Dataset({
    features: ['pred_str', 'text'],
    num_rows: 1680
})


In [7]:
wer_metric = load_metric("wer")

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["text"])))

Downloading:   0%|          | 0.00/1.90k [00:00<?, ?B/s]

Test WER: 0.408


In [8]:
# Convert results in Dataframe
df_results=results.to_pandas() # TR
df_results # TR

Unnamed: 0,pred_str,text
0,the bunglo was plesntly situated near the shor,The bungalow was pleasantly situated near the ...
1,don't ask me to carry an oily rag like that,Don't ask me to carry an oily rag like that.
2,are you looking for employment,Are you looking for employment?
3,she had your dark suit in greasy wash water al...,She had your dark suit in greasy wash water al...
4,at twilight on the twelfth day we'll have shible,At twilight on the twelfth day we'll have Chab...
...,...,...
1675,pam gives driving lessons on thorse days,Pam gives driving lessons on Thursdays.
1676,he rubbed his eyes sleepily with one huge paw,He rubbed his eyes sleepily with one huge paw.
1677,ait felguns were captured in position,Eight field guns were captured in position.
1678,alloinam silveoware can often be flimsy,Aluminum silverware can often be flimsy.


In [9]:
# Export Dataframe to CSV file
from google.colab import drive
import pandas as pd

drive.mount('/content/gdrive') # TR

file_path = '/content/gdrive/MyDrive/ENSC/3A/CI/wav2vec2-timit-asr-error.csv' # TR

# Save the DataFrame to a CSV file
df_results.to_csv(file_path, index=False) # TR

# Print a message indicating the file has been saved
print(f'DataFrame has been saved to: {file_path}') # TR

Mounted at /content/gdrive
DataFrame has been saved to: /content/gdrive/MyDrive/ENSC/3A/CI/wav2vec2-timit-asr-error.csv
