# Wav2vec model using mozilla commonvoice [EN] dataset

### Loading data

In [1]:
import os
import re
import random
import json
import pandas as pd
from datasets import load_dataset
import IPython.display as ipd
from IPython.display import HTML

#### Load audio dataset

In [2]:
cv_corpus_dataset = load_dataset("csv", data_files="../../data/cv-corpus-17.0-2024-03-15/vi/validated.tsv", sep="\t", header=0)
ds_train_test = cv_corpus_dataset["train"].train_test_split(test_size=0.2)
ds_train_test = ds_train_test.remove_columns(
    ["client_id" ,"up_votes", "down_votes", "age",
     "gender", "accents", "variant", "locale", "segment",
     "sentence_id", "sentence_domain"])

In [3]:
ds_train_test

DatasetDict({
    train: Dataset({
        features: ['path', 'sentence'],
        num_rows: 4108
    })
    test: Dataset({
        features: ['path', 'sentence'],
        num_rows: 1027
    })
})

#### Convert text to lowercase

In [4]:
def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    ipd.display(HTML(df.to_html()))

show_random_elements(ds_train_test["train"].remove_columns(["path"]))

Unnamed: 0,sentence
0,quân khốn đáng ghét
1,Tháng mười một lòng chợt thấy chênh chao
2,Sau khi tất cả những căn nhà sàn được dựng lên
3,"mà con bé nó lại khỏe mạnh, đi lại bình thường"
4,Để cho anh tới xóa buồn buồn đi
5,Tháng mười một đẹp lại duyên
6,Có hai hang cây
7,Đi lấy hàng mấy hôm cho nên đâu có biết đâu
8,cháu còn đến nhà chép bài vở cho thảo
9,Hương thơm mát rượi


#### Remove special characters

In [5]:
chars_to_ignore_regex = r'[()“”‘’",.;:?!‑–—_\-\t\n]'

def remove_special_characters(batch):
    batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower()
    return batch

ds_train_test = ds_train_test.map(remove_special_characters)

Map:   0%|          | 0/4108 [00:00<?, ? examples/s]

Map:   0%|          | 0/1027 [00:00<?, ? examples/s]

In [6]:
show_random_elements(ds_train_test["train"].remove_columns(["path"]))

Unnamed: 0,sentence
0,ngày nay người người nhà nhà đem kuman về thờ
1,ngoài con tất cả
2,dứt lời quân vào trong xe rồi lái đi mất
3,đưa tình về cuối sông
4,vẫn thấy thật vui
5,đèn nến và tất cả mọi thứ
6,con đã thêm cao
7,tất cả đánh xe vào trong làng
8,mong con lớn khôn
9,mòn mỏi hao gầy


#### Extract all characters

In [7]:
def extract_all_chars(batch):
    all_text = " ".join(batch["sentence"])
    vocab = list(set(all_text))
    return {"vocab": [vocab], "all_text": [all_text]}


def extract_all_chars_to_json(dataset):
    vocabs = dataset.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=dataset.column_names["train"])
    vocab_list = list(set(vocabs["train"]["vocab"][0]) | set(vocabs["test"]["vocab"][0]))
    vocab_list.sort()
    vocab_dict = {v: k for k, v in enumerate(vocab_list)}
    vocab_dict["|"] = vocab_dict[" "]
    del vocab_dict[" "]
    vocab_dict["[UNK]"] = len(vocab_dict)
    vocab_dict["[PAD]"] = len(vocab_dict)
    print("Number of characters:", len(vocab_dict))
    print("Characters:", vocab_dict.keys())
    with open("../../modules/models/vocabs/wav2vec_cvcorpus_vn.json", 'w') as vocab_file:
        json.dump(vocab_dict, vocab_file)

extract_all_chars_to_json(ds_train_test)

Map:   0%|          | 0/4108 [00:00<?, ? examples/s]

Map:   0%|          | 0/1027 [00:00<?, ? examples/s]

Number of characters: 97
Characters: dict_keys(["'", 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'à', 'á', 'â', 'ã', 'è', 'é', 'ê', 'ì', 'í', 'ò', 'ó', 'ô', 'õ', 'ù', 'ú', 'ý', 'ă', 'đ', 'ĩ', 'ũ', 'ơ', 'ư', 'ạ', 'ả', 'ấ', 'ầ', 'ẩ', 'ẫ', 'ậ', 'ắ', 'ằ', 'ẳ', 'ẵ', 'ặ', 'ẹ', 'ẻ', 'ẽ', 'ế', 'ề', 'ể', 'ễ', 'ệ', 'ỉ', 'ị', 'ọ', 'ỏ', 'ố', 'ồ', 'ổ', 'ỗ', 'ộ', 'ớ', 'ờ', 'ở', 'ỡ', 'ợ', 'ụ', 'ủ', 'ứ', 'ừ', 'ử', 'ữ', 'ự', 'ỳ', 'ỵ', 'ỷ', 'ỹ', '|', '[UNK]', '[PAD]'])


#### Convert file path to local absolute path

In [8]:
path_to_audio = os.path.abspath("../../data/cv-corpus-17.0-2024-03-15/vi/clips")
def abs_path_to_file(batch):
    batch["path"] = os.path.join(path_to_audio,batch["path"])
    return batch

ds_train_test = ds_train_test.map(abs_path_to_file)

Map:   0%|          | 0/4108 [00:00<?, ? examples/s]

Map:   0%|          | 0/1027 [00:00<?, ? examples/s]

### Set up processor

In [9]:
from transformers import Wav2Vec2CTCTokenizer
from transformers import Wav2Vec2FeatureExtractor
from transformers import Wav2Vec2Processor

#### Tokenizer

In [10]:
tokenizer = Wav2Vec2CTCTokenizer(
    "../../modules/models/vocabs/wav2vec_cvcorpus_vn.json",
    unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

#### Feature extractor

In [11]:
feature_extractor = Wav2Vec2FeatureExtractor(
    feature_size=1, sampling_rate=16_000, 
    padding_value=0.0, do_normalize=True, 
    return_attention_mask=False)

#### Processor

In [12]:
processor = Wav2Vec2Processor(
    feature_extractor=feature_extractor, 
    tokenizer=tokenizer)

### Audio preview

In [13]:
import torch
import torchaudio
from torchaudio.functional import resample

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_sample_rate = 16_000

#### Load random audio file from dataset

In [14]:
random_index = random.randint(0, len(ds_train_test["train"]) - 1)
speech_file = ds_train_test["train"]["path"][random_index]
print("Sentence:", ds_train_test["train"]["sentence"][random_index])
ipd.Audio(speech_file)

Sentence: quân mở mắt ra mới biết mình vừa nằm mơ


#### Resampling to match model sample rate

In [15]:
waveform, sample_rate = torchaudio.load(speech_file)
waveform = waveform.to(device)

if sample_rate != model_sample_rate :
    waveform = torchaudio.functional.resample(waveform, sample_rate, model_sample_rate)
ipd.Audio(data=waveform[0].tolist(), rate=model_sample_rate)

### Preprocessing

In [16]:
def prepare_dataset(batch):
    # Get path to audio file
    audio_path = batch["path"]

    # Convert audio to waveform
    waveform, sample_rate = torchaudio.load(audio_path)

    # Resample audio to match model's
    if sample_rate != model_sample_rate :
        waveform = torchaudio.functional.resample(waveform, sample_rate, model_sample_rate)
   
    # Pass audio data to preprocessor
    batch["input_values"] = processor(waveform[0].tolist(), sampling_rate=model_sample_rate).input_values[0]
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["sentence"]).input_ids
    return batch

In [17]:
ds_train_test = ds_train_test.map(prepare_dataset, remove_columns=ds_train_test.column_names["train"])

Map:   0%|          | 0/4108 [00:00<?, ? examples/s]



Map:   0%|          | 0/1027 [00:00<?, ? examples/s]

### Train the model

#### Data Collator CTC

In [18]:
from dataclasses import dataclass
from typing import Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [19]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

#### Word Error Rate metric

In [20]:
import numpy as np
import evaluate

wer_metric = evaluate.load("wer", trust_remote_code=True)

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)
    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
    pred_str = processor.batch_decode(pred_ids)

    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}

#### Import pretrained model

In [21]:
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-base", 
    ctc_loss_reduction="mean",
    vocab_size=processor.tokenizer.vocab_size,
    pad_token_id=processor.tokenizer.pad_token_id,
)
model.freeze_feature_encoder()

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['lm_head.bias', 'lm_head.weight', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


#### Setting up Trainer

In [22]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="../models",
    group_by_length=True,
    per_device_train_batch_size=16,
    eval_strategy="steps",
    num_train_epochs=60,
    fp16=True,
    gradient_checkpointing=True, 
    save_steps=500,
    eval_steps=500,
    logging_steps=500,
    learning_rate=1e-4,
    weight_decay=0.005,
    warmup_steps=1000,
    save_total_limit=2,
)

In [23]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=ds_train_test["train"],
    eval_dataset=ds_train_test["test"],
    tokenizer=processor.feature_extractor,
)

In [24]:
trainer.train()



Step,Training Loss,Validation Loss,Wer
500,7.9265,3.547183,1.0
1000,2.1806,0.93757,0.634318
1500,0.7743,0.551285,0.373269
2000,0.5086,0.474308,0.318024
2500,0.3856,0.433668,0.28155
3000,0.3091,0.442124,0.266241
3500,0.2602,0.465825,0.261049
4000,0.2326,0.45285,0.239483
4500,0.204,0.434571,0.228435
5000,0.1849,0.447168,0.227636




TrainOutput(global_step=15420, training_loss=0.48444980171094765, metrics={'train_runtime': 20432.0223, 'train_samples_per_second': 12.063, 'train_steps_per_second': 0.755, 'total_flos': 9.60312690107315e+18, 'train_loss': 0.48444980171094765, 'epoch': 60.0})

In [25]:
pd.DataFrame(trainer.state.log_history)

Unnamed: 0,loss,grad_norm,learning_rate,epoch,step,eval_loss,eval_wer,eval_runtime,eval_samples_per_second,eval_steps_per_second,train_runtime,train_samples_per_second,train_steps_per_second,total_flos,train_loss
0,7.9265,1.937705,0.000050,1.945525,500,,,,,,,,,,
1,,,,1.945525,500,3.547183,1.000000,45.6644,22.490,2.825,,,,,
2,2.1806,3.166314,0.000100,3.891051,1000,,,,,,,,,,
3,,,,3.891051,1000,0.937570,0.634318,45.0076,22.818,2.866,,,,,
4,0.7743,1.794879,0.000097,5.836576,1500,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
56,0.0571,1.270281,0.000006,56.420233,14500,,,,,,,,,,
57,,,,56.420233,14500,0.482328,0.184771,61.9935,16.566,2.081,,,,,
58,0.0573,1.830534,0.000003,58.365759,15000,,,,,,,,,,
59,,,,58.365759,15000,0.467350,0.183440,277.4892,3.701,0.465,,,,,


#### Save model

In [26]:
model_dir_path = "../models/wav2vec_cvcorpus_vn"
if not os.path.isdir(model_dir_path):
    os.mkdir(model_dir_path)
trainer.save_model(model_dir_path)
processor.save_pretrained(model_dir_path)

[]

### Load trained model

In [27]:
from transformers import AutoModelForCTC, Wav2Vec2Processor

model = AutoModelForCTC.from_pretrained("../models/wav2vec_cvcorpus_vn").to(device)
processor = Wav2Vec2Processor.from_pretrained("../models/wav2vec_cvcorpus_vn")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [28]:
def map_to_result(batch):
    with torch.no_grad():
        input_values = torch.tensor(batch["input_values"], device="cuda").unsqueeze(0)
        logits = model(input_values).logits
        pred_ids = torch.argmax(logits, dim=-1)
        batch["pred_str"] = processor.batch_decode(pred_ids)[0]
        batch["text"] = processor.decode(batch["labels"], group_tokens=False)
    
    return batch

results = ds_train_test["test"].map(map_to_result, remove_columns=ds_train_test["test"].column_names)



Map:   0%|          | 0/1027 [00:00<?, ? examples/s]

In [29]:
print(f"Test WER: {wer_metric.compute(predictions=results["pred_str"], references=results["text"]):.3f}")

Test WER: 0.175


In [30]:
show_random_elements(results)

Unnamed: 0,pred_str,text
0,thì hại chấc nhận anh ta và chăm sóc quận thận,thì hãy chấp nhận anh ta và chăm sóc cẩn thận
1,về nhà tời đó mẹ có kể lại công chuyện cho bố nghe,về nhà tối đó mẹ có kể lại câu chuyện cho bố nghe
2,đã ên ngày thác lú,đã yên ngày thác lũ
3,và tình ái là sợi dây vấn viết,và tình ái là sợi dây vấn vít
4,bà dã quỳ ngẩn ngơ một lúc,bà dã quỳ ngẩn ngơ một lúc
5,mày đừng về tìm tao nữa có được,mày đừng về tìm tao nữa có được
6,mồa hôi của a hội nhễ nhại ước đẫm cả áo,mồ hôi của a hội nhễ nhại ướt đẫm cả áo
7,hai thuốc rồi nhặt bong giúp gười ta,hái thuốc rồi nhặt bông giúp người ta
8,có hai chăng là con bẻ đó gây ra không,có hay chăng là con bé đó gây ra không
9,thì là cô ấy,thì là cô ấy
