## Transfer learning on a pre-trained Wav2Vec2 on small datasets for Urdu-ASR:

This notebook explores the practicality of transfer learning (fine-tuning) for Automatic Speech Recognition. Our code mostly follows: [Fine-tune Wav2vec for English ASR](https://huggingface.co/blog/fine-tune-wav2vec2-english) by Patrick von Platen. Details on the Wav2Vec2 model can be found [here](https://ai.meta.com/blog/wav2vec-20-learning-the-structure-of-speech-from-raw-audio/). 

The [Wav2Vec2 model](https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec#wav2vec-20) used was pre-trained on the Librispeech corpus.

In addition, we wanted to explore the usefulness of pre-trained models on resource-starved environments. In particular we wanted to test the following:
- The feasilibity of low end PCs for this task. Our machine had the GTX 1050M as the GPU with 4GB of available VRAM. This is a low-end GPU available with many low-end and affordable laptops. However, we had to adapt our code to be able to run locally and on a low-end device.

- The practicality of fine-tuning the model on different languages. For this, we fine-tuned the model on publicly availble datasets in Urdu.

- The possibility of expansion to different languages. While English and Urdu are vastly different languages, the north-western and the northern regions of the sub-continent present dialect continuums locally and areally. The distinction between language, dialect and accent is often fuzzy. While prestige dialects and languages have enough available resources to have state-of-the-art ASR, dialects with lower prestige do not. Future work could explore the effectiveness of ASR models trained on prestige dialects in the aformentioned regions on dialects with less prestige. For this English and Urdu, while not being entirely analogous, serve as a stand-in.


### Summarizing the Methodology and Results:
- Our methodology uses gradient accumulation and a batch size of 1 to allow for the model to run on such a low-end device.

- We obtained a WER of 0.225. Given our small dataset, this is a very good result. For comparison, state-of-the-art ASR models tend to have a WER of 0.150-0.200.

### Imports:

In [1]:
import librosa
import regex as re
import pandas as pd
import random
import pyarrow as pa
from datasets import Dataset, DatasetDict, load_metric
from IPython.display import display, HTML
import json
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor, Wav2Vec2ForCTC, Trainer
import torch
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import numpy as np
import os
import gc
import transformers
os.environ["TRANSFORMERS_OFFLINE"] = "1" # Allow transformers from huggingface to run offline.
num_speakers = 3

### Dataset Loading and Preprocessing:

#### Dataset:
Our dataset is composed of the recordings of the PRUS dataset and a compilation of news clippings. The speakers are undergraduates from LUMS.

- [The PRUS dataset](https://www.c-salt.org/downloads/prus): A compilation of 708 urdu sentences generated using a greedy approach. The goal was to obtain a dataset with all tri-phoneme combination (word-boundaries included) possible within Urdu, with each tri-phoneme equally likely. The dataset consists of naturally occuring words in Urdu, and is grammatically correct. However, the dataset is not representative of natural Urdu speech and the syntax, word-order, sentence length and the lexicon of many of the sentences is rarely or never part of natural Urdu speech. 

- Newspaper clippings: A compilation of 505 headlines and sentences from various news articles and newspapers within the last few years. Compared with the PRUS dataset, this dataset is far more representative of modern, albeit highly formal Urdu speech.

##### Transcriptions and preprocessing:

- The vocabulary consists of the Urdu alphabet plus a few other characters deemed necessary.

- The data has been pre-processed to remove any characters that do not represent a phoneme.

- Harakat (used to demarcate vowels) are not considered part of the vocabulary given that the Urdu script is a partial Abjad and the reader is expected to infer them.

- The tashdeed (used to represent gemination) is not considered part of the vocabulary given that gemination follows specific rules and is readily infered from the context.

- Digits are replaced with their word representations to facilitate the training of the model. This also means that for any new speech, any numbers spoken will be represented using words.

- In addition, a few special unicode characters are removed (\u200c and \u200e)

In [2]:
prus_transcriptions = open("./Dataset/Trasncriptions - PRUS.txt" , 'r', encoding = "UTF-8")
prus_transcriptions = prus_transcriptions.readlines()

news_transcriptions = open("./Dataset/Transcriptions - News clippings.txt", 'r', encoding = "UTF-8")
news_transcriptions = news_transcriptions.readlines()

to_ignore_regex = "[\,\?\.\!\-\;\:\"\\()'؟۔’‘،ًٌَُِّٰٔٓ]"

# PRUS Dataset
for i in range(len(prus_transcriptions)):
    prus_transcriptions[i] = prus_transcriptions[i].replace("\n","")
    prus_transcriptions[i] = re.sub(to_ignore_regex, "", prus_transcriptions[i])
    prus_transcriptions[i] = re.sub("0", "صفر", prus_transcriptions[i])
    prus_transcriptions[i] = re.sub("1", "ایک", prus_transcriptions[i])
    prus_transcriptions[i] = re.sub("2", "دو", prus_transcriptions[i])
    prus_transcriptions[i] = re.sub("3", "تین", prus_transcriptions[i])
    prus_transcriptions[i] = re.sub("4", "چار", prus_transcriptions[i])
    prus_transcriptions[i] = re.sub("5", "پانچ", prus_transcriptions[i])
    prus_transcriptions[i] = re.sub("6", "چھ", prus_transcriptions[i])
    prus_transcriptions[i] = re.sub("7", "سات", prus_transcriptions[i])
    prus_transcriptions[i] = re.sub("8", "اٹھ", prus_transcriptions[i])
    prus_transcriptions[i] = re.sub("9", "نو", prus_transcriptions[i])
    prus_transcriptions[i] = re.sub("\u200c", "", prus_transcriptions[i])
    prus_transcriptions[i] = re.sub("\u200e", "", prus_transcriptions[i])

# News Clippings Dataset
for i in range(len(news_transcriptions)):
    news_transcriptions[i] = news_transcriptions[i].replace("\n","")
    news_transcriptions[i] = re.sub(to_ignore_regex, "", news_transcriptions[i])
    news_transcriptions[i] = re.sub("0", "صفر", news_transcriptions[i])
    news_transcriptions[i] = re.sub("1", "ایک", news_transcriptions[i])
    news_transcriptions[i] = re.sub("2", "دو", news_transcriptions[i])
    news_transcriptions[i] = re.sub("3", "تین", news_transcriptions[i])
    news_transcriptions[i] = re.sub("4", "چار", news_transcriptions[i])
    news_transcriptions[i] = re.sub("5", "پانچ", news_transcriptions[i])
    news_transcriptions[i] = re.sub("6", "چھ", news_transcriptions[i])
    news_transcriptions[i] = re.sub("7", "سات", news_transcriptions[i])
    news_transcriptions[i] = re.sub("8", "اٹھ", news_transcriptions[i])
    news_transcriptions[i] = re.sub("9", "نو", news_transcriptions[i])
    news_transcriptions[i] = re.sub("\u200c", "", news_transcriptions[i])
    news_transcriptions[i] = re.sub("\u200e", "", news_transcriptions[i])

##### Audio data:

In [3]:
prus_path = "./Dataset/PRUS/"
news_path = "./Dataset/News_Clippings/"

# files = []
audios = []
transcriptions = []

# PRUS dataset
for i in range(num_speakers):
    for j in range(len(prus_transcriptions)):
        # File
        file_path = prus_path + str(j+1) + "_" + "speaker" + str(i+1) + "_prus.wav"
        # files.append(file_path)
        
        # Audio
        array, sampling_rate = librosa.load(file_path, sr=16000)
        audios.append({"array": array, "path": file_path, "sampling_rate": sampling_rate})

        # Transcriptions
        transcriptions.append(prus_transcriptions[j])

# News dataset
for i in range(num_speakers):
    for j in range(len(news_transcriptions)):
        # File
        file_path = news_path + str(j+1) + "_" + "speaker" + str(i+1) + "_news.wav"
        # files.append(file_path)
        
        # Audio
        array, sampling_rate = librosa.load(file_path, sr=16000)
        audios.append({"array": array, "path": file_path, "sampling_rate": sampling_rate})

        # Transcriptions
        transcriptions.append(news_transcriptions[j])


# Randomizing
# randomized_data = list(zip(files, audios, transcriptions))
randomized_data = list(zip(audios, transcriptions))
random.shuffle(randomized_data)

# files, audios, transcriptions = zip(*randomized_data)
audios, transcriptions = zip(*randomized_data)

randomized_data = 0

#### Train-Test Split:

The model is trained on 80% of the data and evaluated on the remaining 20%. The data is randomized before splitting.

In [7]:
dataset_size = len(audios)
train_split = 0.8
train_size = int(train_split * dataset_size)

train_files = []
train_audios = []
train_transcriptions = []

test_files = []
test_audios = []
test_transcriptions = []

# Training data
for i in range(0, train_size):
    #train_files.append(files[i])
    train_audios.append(audios[i])
    train_transcriptions.append(transcriptions[i])

# Test data
for i in range(train_size, dataset_size):
    #test_files.append(files[i])
    test_audios.append(audios[i])
    test_transcriptions.append(transcriptions[i])


# Creating the train and test datasets.
# df = pd.DataFrame({'file': train_files, 'audio': train_audios , 'text' : train_transcriptions})
df = pd.DataFrame({'audio': train_audios , 'transcription' : train_transcriptions})
train_dataset = Dataset(pa.Table.from_pandas(df))

# df = pd.DataFrame({'file': test_files, 'audio': test_audios , 'text' : test_transcriptions})
df = pd.DataFrame({'audio': test_audios , 'transcription' : test_transcriptions})
test_dataset = Dataset(pa.Table.from_pandas(df))

# Creating the overall urdu dataset used for fine-tuning the model.
ur_fine_dataset = DatasetDict({"train": train_dataset , "test": test_dataset})
ur_fine_dataset

train_dataset = 0
test_dataset = 0

#### Training dataset transcriptions after pre-processing:

In [5]:
def displayTranscriptions(dataset, samples=5):
    # Warning
    if samples >= len(dataset):
        print("Warning: can only display a maximum of", len(dataset), "samples")
        samples = len(dataset)
    # Creating indicies
    indices = []
    for _ in range(samples):
        i = random.randint(0, len(dataset)-1)
        
        while i in indices:
            i = random.randint(0, len(dataset)-1)
        
        indices.append(i)
    
    dataframe = pd.DataFrame(dataset[indices]["transcription"])
    dataframe.columns = ["Transcriptions"]
    display(HTML(dataframe.to_html())) 

In [6]:
displayTranscriptions(ur_fine_dataset["train"], 10)

Unnamed: 0,Transcriptions
0,خیال کیا جاتا ہے کہ اس طرح کی جاسوسی کا زیادہ تر نشانہ سماجی کارکن صحافی اور سیاستدان بنتے ہیں
1,بدصورت بانسری کی انوالومنٹ سے سیلرز کی ہیرو سے گراں تخریب نیچرل تھی
2,حروں نے شہابیے کو ساگر میں پھینکا اور پیداوار کے رزق کی ڈیل کرنے کے بعد گاوں سے اوجھل ہو گئے
3,سوم درجے کی جعلسازی کا بھیڑیا تھیلی اٹھایا اور اپنے پڑدادا کے جلو میں دب گیا ہے
4,غیر جانبدار افغان بھابی کے ساتھ نتھی گڑیا کا موڈ بدلیں تو خفت رفو ہو سکتی ہے
5,میلے کچیلے خام فوم کے سویٹر میں کلراٹھی چقندر باندھیں تو یہی شے خوبصورت ڈھال ہے
6,وائس آف امریکہ سے بات کرتے ہوئے ان کا کہنا تھا کہ وہ ڈرامے کے اختتام سے بالکل خوش نہیں
7,گاڑی کے نظام میں ایسی کوئی چیز نہیں ہے جو یہ تصدیق کر سکے کہ ویڈیو گیم کھیلنے والا ڈرائیور ہے یا اس کے ساتھ والی سیٹ پر بیٹھا ہوا مسافر
8,دنیا بھر میں پیٹرولیم مصنوعات کی قیمتوں میں مسلسل اضافے سے لوگ پریشان ہیں
9,گگلی کے تجربہ اور تشخیص پر ہدایت اور تعریفیں چاہتے ہیں


#### Generating the vocabulary:

The initial vocabulary consists of the Urdu alphabet plus a few other characters deemed necessary.

In [8]:
def generateVocab(dataset):
    for i in range(len(dataset)):
        text = " ".join(dataset["transcription"])
    
    return list(set(text))


vocabulary = generateVocab(ur_fine_dataset["train"])
vocabulary.append("|")
vocabulary.remove(" ")
vocabulary.append("[UNK]")
vocabulary.append("[PAD]")
print(vocabulary)

# Creating a vocabulary dictionary
vocabulary_dictionary = {}

i = 0
for char in vocabulary:
    vocabulary_dictionary[char] = i
    i += 1

vocabulary_size = len(vocabulary_dictionary)

['چ', 'ٹ', 'گ', 'پ', 'ء', 'خ', 'م', 'ت', 'ص', 'ڑ', 'ق', 'ژ', 'ں', 'ئ', 'ر', 'ع', 'ے', 'ط', 'ا', 'ڈ', 'ذ', 'ک', 'غ', 'س', 'ف', 'ج', 'ش', 'ہ', 'ن', 'ؤ', 'ب', 'ض', 'ح', 'ز', 'ظ', 'آ', 'و', 'ل', 'ی', 'ث', 'د', 'ھ', '|', '[UNK]', '[PAD]']


#### Converting to json:
The vocabulary is extracted to a json file to allow it to be used as input to the Wav2Vec2Tokenizer

In [9]:
with open("vocabulary.json", "w") as json_file:
    json.dump(vocabulary_dictionary, json_file)

### Model preparing and training:

#### Tokenizer

Initializes the tokenizer, a huggingface made class which converts the input data into tokens.

In [10]:
tokenizer = Wav2Vec2CTCTokenizer("./vocabulary.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

In [11]:
tokenizer.save_pretrained("./Model/final_model/")

('./Model/final_model/tokenizer_config.json',
 './Model/final_model/special_tokens_map.json',
 './Model/final_model/vocab.json',
 './Model/final_model/added_tokens.json')

#### Feature extractor

Initializes the feature extractor, a huggingface made class which prepares audio data into features to use within in a model. More information can be found [here](https://huggingface.co/docs/transformers/main_classes/feature_extractor#:~:text=A%20feature%20extractor%20is%20in,for%20audio%20or%20vision%20models.).

In [12]:
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)

In [13]:
feature_extractor.save_pretrained("./Model/final_model/")

['./Model/final_model/preprocessor_config.json']

#### Processor:

Combines the tokenizer and feature extractor into one.

In [14]:
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

#### Pre-processing the audio data for training:

- The audio is loaded and resampled (not needed in our case)

- The input values are stored after processing. In this case, the audio data is only normalized.

- The transcriptions are encoded to label IDs.

In [15]:
def preprocessingAudio(batch):
    audio = batch["audio"]

    # batched output is "un-batched" to ensure mapping is correct
    batch["input_values"] = processor(audio["array"], sampling_rate = audio["sampling_rate"]).input_values[0]
    batch["input_length"] = len(batch["input_values"])
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["transcription"]).input_ids
    return batch

In [16]:
ur_fine_dataset = ur_fine_dataset.map(preprocessingAudio, remove_columns = ur_fine_dataset.column_names["train"], num_proc=1)

Map:   0%|          | 0/2911 [00:00<?, ? examples/s]



Map:   0%|          | 0/728 [00:00<?, ? examples/s]

#### Initializing the model:

In [36]:
@dataclass
class DataCollatorCTCWithPadding:
    
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        labels = labels.cuda()

        input_values = batch['input_values']
        input_values = input_values.cuda()

      
        batch["labels"] = labels
        batch["input_values"] = input_values

        return batch

In [37]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

# The metric used is Word Error Rate (WER) as it the most commonly used method for ASR models.
wer_metric = load_metric("wer")

#### Computing the WER on the training data:

In [38]:
def computeWER(predicted):
    predicted_logits = predicted.predictions
    predicted_ids = np.argmax(predicted_logits, axis=-1)

    predicted.label_ids[predicted.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(predicted_ids)
    
    # We do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(predicted.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

#### Loading the Wav2Vec2pretrained checkpoint:

In [39]:
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-base", 
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size = vocabulary_size
)

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2ForCTC: ['quantizer.weight_proj.weight', 'project_hid.bias', 'project_hid.weight', 'quantizer.weight_proj.bias', 'quantizer.codevectors', 'project_q.bias', 'project_q.weight']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['lm_head.weight', 'lm_head.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predicti

#### Fine-tuning training arguments:

In [40]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir=".\Model",
  group_by_length=True,
  per_device_train_batch_size=1,
  per_device_eval_batch_size=1,
  evaluation_strategy="steps",
  num_train_epochs=30,
  fp16=False, # Set to True for better performance
  gradient_checkpointing=True,
  save_steps=50,
  eval_steps=50,
  logging_steps=50,
  learning_rate=1e-4,
  weight_decay=0.005,
  warmup_steps=1000,
  save_total_limit=2,
  push_to_hub=False, # Set to false if running offline
  dataloader_pin_memory=False,
  gradient_accumulation_steps=30 # To reduce memory usage
)

#### Trainer:

In [41]:
trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=computeWER,
    train_dataset = ur_fine_dataset["train"],
    eval_dataset = ur_fine_dataset["test"],
    tokenizer=processor.feature_extractor
)

In [42]:
torch.cuda.empty_cache()
gc.collect()
torch.cuda.memory_summary(device=None, abbreviated=False)

trainer.train()



  0%|          | 0/2910 [00:00<?, ?it/s]



{'loss': 12.1758, 'learning_rate': 5e-06, 'epoch': 0.52}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 8.392762184143066, 'eval_wer': 1.0, 'eval_runtime': 265.0949, 'eval_samples_per_second': 2.746, 'eval_steps_per_second': 2.746, 'epoch': 0.52}




{'loss': 4.9365, 'learning_rate': 1e-05, 'epoch': 1.03}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 3.467226028442383, 'eval_wer': 1.0, 'eval_runtime': 261.0537, 'eval_samples_per_second': 2.789, 'eval_steps_per_second': 2.789, 'epoch': 1.03}




{'loss': 3.4088, 'learning_rate': 1.5e-05, 'epoch': 1.55}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 3.271307945251465, 'eval_wer': 1.0, 'eval_runtime': 263.1092, 'eval_samples_per_second': 2.767, 'eval_steps_per_second': 2.767, 'epoch': 1.55}




{'loss': 3.2356, 'learning_rate': 2e-05, 'epoch': 2.06}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 3.1610107421875, 'eval_wer': 1.0, 'eval_runtime': 269.128, 'eval_samples_per_second': 2.705, 'eval_steps_per_second': 2.705, 'epoch': 2.06}




{'loss': 3.1421, 'learning_rate': 2.5e-05, 'epoch': 2.58}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 3.11612606048584, 'eval_wer': 1.0, 'eval_runtime': 269.4766, 'eval_samples_per_second': 2.702, 'eval_steps_per_second': 2.702, 'epoch': 2.58}




{'loss': 3.1074, 'learning_rate': 3e-05, 'epoch': 3.09}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 3.098893404006958, 'eval_wer': 1.0, 'eval_runtime': 270.9497, 'eval_samples_per_second': 2.687, 'eval_steps_per_second': 2.687, 'epoch': 3.09}




{'loss': 3.0951, 'learning_rate': 3.5e-05, 'epoch': 3.61}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 3.0826168060302734, 'eval_wer': 1.0, 'eval_runtime': 268.0129, 'eval_samples_per_second': 2.716, 'eval_steps_per_second': 2.716, 'epoch': 3.61}




{'loss': 3.0873, 'learning_rate': 4e-05, 'epoch': 4.12}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 3.0749130249023438, 'eval_wer': 1.0, 'eval_runtime': 267.5487, 'eval_samples_per_second': 2.721, 'eval_steps_per_second': 2.721, 'epoch': 4.12}




{'loss': 3.0761, 'learning_rate': 4.5e-05, 'epoch': 4.64}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 3.0334696769714355, 'eval_wer': 1.0, 'eval_runtime': 267.2168, 'eval_samples_per_second': 2.724, 'eval_steps_per_second': 2.724, 'epoch': 4.64}




{'loss': 2.8768, 'learning_rate': 5e-05, 'epoch': 5.15}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 2.4343087673187256, 'eval_wer': 0.9988654413433175, 'eval_runtime': 267.0482, 'eval_samples_per_second': 2.726, 'eval_steps_per_second': 2.726, 'epoch': 5.15}




{'loss': 2.0997, 'learning_rate': 5.500000000000001e-05, 'epoch': 5.67}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 1.5502562522888184, 'eval_wer': 0.8225550260948491, 'eval_runtime': 268.1847, 'eval_samples_per_second': 2.715, 'eval_steps_per_second': 2.715, 'epoch': 5.67}




{'loss': 1.4734, 'learning_rate': 6e-05, 'epoch': 6.18}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 1.1019192934036255, 'eval_wer': 0.6912487708947886, 'eval_runtime': 268.1733, 'eval_samples_per_second': 2.715, 'eval_steps_per_second': 2.715, 'epoch': 6.18}




{'loss': 1.1732, 'learning_rate': 6.500000000000001e-05, 'epoch': 6.7}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.9048839211463928, 'eval_wer': 0.6446562287270252, 'eval_runtime': 215.4731, 'eval_samples_per_second': 3.379, 'eval_steps_per_second': 3.379, 'epoch': 6.7}




{'loss': 0.9855, 'learning_rate': 7e-05, 'epoch': 7.21}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.7670953869819641, 'eval_wer': 0.5870962862113305, 'eval_runtime': 153.257, 'eval_samples_per_second': 4.75, 'eval_steps_per_second': 4.75, 'epoch': 7.21}




{'loss': 0.8512, 'learning_rate': 7.500000000000001e-05, 'epoch': 7.73}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.6952930092811584, 'eval_wer': 0.545117615914076, 'eval_runtime': 150.6908, 'eval_samples_per_second': 4.831, 'eval_steps_per_second': 4.831, 'epoch': 7.73}




{'loss': 0.777, 'learning_rate': 8e-05, 'epoch': 8.24}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.6270283460617065, 'eval_wer': 0.5287043340140686, 'eval_runtime': 155.7298, 'eval_samples_per_second': 4.675, 'eval_steps_per_second': 4.675, 'epoch': 8.24}




{'loss': 0.7094, 'learning_rate': 8.5e-05, 'epoch': 8.76}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.5807090997695923, 'eval_wer': 0.5037440435670524, 'eval_runtime': 154.276, 'eval_samples_per_second': 4.719, 'eval_steps_per_second': 4.719, 'epoch': 8.76}




{'loss': 0.6405, 'learning_rate': 9e-05, 'epoch': 9.28}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.5506355166435242, 'eval_wer': 0.4925497314877846, 'eval_runtime': 151.9131, 'eval_samples_per_second': 4.792, 'eval_steps_per_second': 4.792, 'epoch': 9.28}




{'loss': 0.6126, 'learning_rate': 9.5e-05, 'epoch': 9.79}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.5124320983886719, 'eval_wer': 0.4785568413886998, 'eval_runtime': 150.7415, 'eval_samples_per_second': 4.829, 'eval_steps_per_second': 4.829, 'epoch': 9.79}




{'loss': 0.5562, 'learning_rate': 0.0001, 'epoch': 10.31}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.47998449206352234, 'eval_wer': 0.4591937069813176, 'eval_runtime': 151.6243, 'eval_samples_per_second': 4.801, 'eval_steps_per_second': 4.801, 'epoch': 10.31}




{'loss': 0.5187, 'learning_rate': 9.738219895287959e-05, 'epoch': 10.82}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.4567178189754486, 'eval_wer': 0.4325693971711671, 'eval_runtime': 152.7046, 'eval_samples_per_second': 4.767, 'eval_steps_per_second': 4.767, 'epoch': 10.82}




{'loss': 0.4883, 'learning_rate': 9.476439790575917e-05, 'epoch': 11.34}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.43991777300834656, 'eval_wer': 0.42114817336056276, 'eval_runtime': 153.5424, 'eval_samples_per_second': 4.741, 'eval_steps_per_second': 4.741, 'epoch': 11.34}




{'loss': 0.4674, 'learning_rate': 9.214659685863875e-05, 'epoch': 11.85}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.43530893325805664, 'eval_wer': 0.4116935178882081, 'eval_runtime': 153.0964, 'eval_samples_per_second': 4.755, 'eval_steps_per_second': 4.755, 'epoch': 11.85}




{'loss': 0.4187, 'learning_rate': 8.952879581151833e-05, 'epoch': 12.37}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.4159089922904968, 'eval_wer': 0.39891082368958475, 'eval_runtime': 154.9479, 'eval_samples_per_second': 4.698, 'eval_steps_per_second': 4.698, 'epoch': 12.37}




{'loss': 0.4096, 'learning_rate': 8.691099476439791e-05, 'epoch': 12.88}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.4110824167728424, 'eval_wer': 0.3916496482868164, 'eval_runtime': 154.4314, 'eval_samples_per_second': 4.714, 'eval_steps_per_second': 4.714, 'epoch': 12.88}




{'loss': 0.3838, 'learning_rate': 8.429319371727749e-05, 'epoch': 13.4}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.38556620478630066, 'eval_wer': 0.3812873458891158, 'eval_runtime': 150.5022, 'eval_samples_per_second': 4.837, 'eval_steps_per_second': 4.837, 'epoch': 13.4}




{'loss': 0.3626, 'learning_rate': 8.167539267015707e-05, 'epoch': 13.91}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.39861273765563965, 'eval_wer': 0.3750850918992512, 'eval_runtime': 151.393, 'eval_samples_per_second': 4.809, 'eval_steps_per_second': 4.809, 'epoch': 13.91}




{'loss': 0.3464, 'learning_rate': 7.905759162303665e-05, 'epoch': 14.43}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.37521669268608093, 'eval_wer': 0.3690341123969443, 'eval_runtime': 151.5491, 'eval_samples_per_second': 4.804, 'eval_steps_per_second': 4.804, 'epoch': 14.43}




{'loss': 0.3389, 'learning_rate': 7.643979057591623e-05, 'epoch': 14.94}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.3773742914199829, 'eval_wer': 0.36283185840707965, 'eval_runtime': 152.9948, 'eval_samples_per_second': 4.758, 'eval_steps_per_second': 4.758, 'epoch': 14.94}




{'loss': 0.3126, 'learning_rate': 7.382198952879581e-05, 'epoch': 15.46}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.36431893706321716, 'eval_wer': 0.35746161409878224, 'eval_runtime': 156.2885, 'eval_samples_per_second': 4.658, 'eval_steps_per_second': 4.658, 'epoch': 15.46}




{'loss': 0.2971, 'learning_rate': 7.12041884816754e-05, 'epoch': 15.97}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.3718809187412262, 'eval_wer': 0.3567808789047727, 'eval_runtime': 153.104, 'eval_samples_per_second': 4.755, 'eval_steps_per_second': 4.755, 'epoch': 15.97}




{'loss': 0.2799, 'learning_rate': 6.858638743455498e-05, 'epoch': 16.49}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.3621373474597931, 'eval_wer': 0.3489902427955525, 'eval_runtime': 150.5702, 'eval_samples_per_second': 4.835, 'eval_steps_per_second': 4.835, 'epoch': 16.49}




{'loss': 0.2785, 'learning_rate': 6.596858638743456e-05, 'epoch': 17.0}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.3542162775993347, 'eval_wer': 0.33809847969140006, 'eval_runtime': 154.2163, 'eval_samples_per_second': 4.721, 'eval_steps_per_second': 4.721, 'epoch': 17.0}




{'loss': 0.2475, 'learning_rate': 6.335078534031414e-05, 'epoch': 17.52}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.34813857078552246, 'eval_wer': 0.3377202934725059, 'eval_runtime': 150.4849, 'eval_samples_per_second': 4.838, 'eval_steps_per_second': 4.838, 'epoch': 17.52}




{'loss': 0.254, 'learning_rate': 6.073298429319372e-05, 'epoch': 18.04}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.3441419005393982, 'eval_wer': 0.3365100975720445, 'eval_runtime': 154.0655, 'eval_samples_per_second': 4.725, 'eval_steps_per_second': 4.725, 'epoch': 18.04}




{'loss': 0.2343, 'learning_rate': 5.81151832460733e-05, 'epoch': 18.55}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.3361930549144745, 'eval_wer': 0.3324256864079873, 'eval_runtime': 148.5492, 'eval_samples_per_second': 4.901, 'eval_steps_per_second': 4.901, 'epoch': 18.55}




{'loss': 0.2303, 'learning_rate': 5.5497382198952887e-05, 'epoch': 19.07}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.3295890986919403, 'eval_wer': 0.321912109522729, 'eval_runtime': 153.3698, 'eval_samples_per_second': 4.747, 'eval_steps_per_second': 4.747, 'epoch': 19.07}




{'loss': 0.2168, 'learning_rate': 5.287958115183246e-05, 'epoch': 19.58}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.3416559398174286, 'eval_wer': 0.3278874517812571, 'eval_runtime': 157.1729, 'eval_samples_per_second': 4.632, 'eval_steps_per_second': 4.632, 'epoch': 19.58}




{'loss': 0.2198, 'learning_rate': 5.026178010471204e-05, 'epoch': 20.1}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.3367469906806946, 'eval_wer': 0.32229029574162316, 'eval_runtime': 151.4035, 'eval_samples_per_second': 4.808, 'eval_steps_per_second': 4.808, 'epoch': 20.1}




{'loss': 0.2001, 'learning_rate': 4.764397905759162e-05, 'epoch': 20.61}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.33665335178375244, 'eval_wer': 0.30716284698585583, 'eval_runtime': 151.167, 'eval_samples_per_second': 4.816, 'eval_steps_per_second': 4.816, 'epoch': 20.61}




{'loss': 0.2033, 'learning_rate': 4.50261780104712e-05, 'epoch': 21.13}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.3217971920967102, 'eval_wer': 0.30882686634899026, 'eval_runtime': 151.1469, 'eval_samples_per_second': 4.817, 'eval_steps_per_second': 4.817, 'epoch': 21.13}




{'loss': 0.1923, 'learning_rate': 4.240837696335079e-05, 'epoch': 21.64}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.3307955265045166, 'eval_wer': 0.306179562816731, 'eval_runtime': 149.2358, 'eval_samples_per_second': 4.878, 'eval_steps_per_second': 4.878, 'epoch': 21.64}




{'loss': 0.1951, 'learning_rate': 3.9790575916230365e-05, 'epoch': 22.16}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.33941909670829773, 'eval_wer': 0.2948339762499054, 'eval_runtime': 156.5059, 'eval_samples_per_second': 4.652, 'eval_steps_per_second': 4.652, 'epoch': 22.16}




{'loss': 0.1802, 'learning_rate': 3.717277486910995e-05, 'epoch': 22.67}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.31174254417419434, 'eval_wer': 0.29733000529460707, 'eval_runtime': 159.961, 'eval_samples_per_second': 4.551, 'eval_steps_per_second': 4.551, 'epoch': 22.67}




{'loss': 0.1837, 'learning_rate': 3.455497382198953e-05, 'epoch': 23.19}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.32396313548088074, 'eval_wer': 0.2909764768171848, 'eval_runtime': 157.0471, 'eval_samples_per_second': 4.636, 'eval_steps_per_second': 4.636, 'epoch': 23.19}




{'loss': 0.1698, 'learning_rate': 3.1937172774869115e-05, 'epoch': 23.7}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.3133009076118469, 'eval_wer': 0.2884048105287043, 'eval_runtime': 160.4001, 'eval_samples_per_second': 4.539, 'eval_steps_per_second': 4.539, 'epoch': 23.7}




{'loss': 0.1728, 'learning_rate': 2.931937172774869e-05, 'epoch': 24.22}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.30827319622039795, 'eval_wer': 0.2865138794342334, 'eval_runtime': 161.7534, 'eval_samples_per_second': 4.501, 'eval_steps_per_second': 4.501, 'epoch': 24.22}




{'loss': 0.1651, 'learning_rate': 2.6701570680628273e-05, 'epoch': 24.73}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.3188498914241791, 'eval_wer': 0.2874971636033583, 'eval_runtime': 165.5186, 'eval_samples_per_second': 4.398, 'eval_steps_per_second': 4.398, 'epoch': 24.73}




{'loss': 0.1609, 'learning_rate': 2.4083769633507854e-05, 'epoch': 25.25}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.3089900612831116, 'eval_wer': 0.28182437031994556, 'eval_runtime': 161.2452, 'eval_samples_per_second': 4.515, 'eval_steps_per_second': 4.515, 'epoch': 25.25}




{'loss': 0.1588, 'learning_rate': 2.1465968586387435e-05, 'epoch': 25.76}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.3090582489967346, 'eval_wer': 0.28099236063837835, 'eval_runtime': 172.1621, 'eval_samples_per_second': 4.229, 'eval_steps_per_second': 4.229, 'epoch': 25.76}




{'loss': 0.1525, 'learning_rate': 1.8848167539267016e-05, 'epoch': 26.28}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.30835938453674316, 'eval_wer': 0.27834505710611906, 'eval_runtime': 159.5166, 'eval_samples_per_second': 4.564, 'eval_steps_per_second': 4.564, 'epoch': 26.28}




{'loss': 0.1501, 'learning_rate': 1.6230366492146596e-05, 'epoch': 26.79}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.31023889780044556, 'eval_wer': 0.27978216473791695, 'eval_runtime': 161.9051, 'eval_samples_per_second': 4.496, 'eval_steps_per_second': 4.496, 'epoch': 26.79}




{'loss': 0.1477, 'learning_rate': 1.3612565445026179e-05, 'epoch': 27.31}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.30927330255508423, 'eval_wer': 0.2781181453747825, 'eval_runtime': 164.079, 'eval_samples_per_second': 4.437, 'eval_steps_per_second': 4.437, 'epoch': 27.31}




{'loss': 0.1443, 'learning_rate': 1.099476439790576e-05, 'epoch': 27.83}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.3016777038574219, 'eval_wer': 0.2767566749867635, 'eval_runtime': 172.2781, 'eval_samples_per_second': 4.226, 'eval_steps_per_second': 4.226, 'epoch': 27.83}




{'loss': 0.1436, 'learning_rate': 8.37696335078534e-06, 'epoch': 28.34}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.3099626898765564, 'eval_wer': 0.27683231223054233, 'eval_runtime': 166.2358, 'eval_samples_per_second': 4.379, 'eval_steps_per_second': 4.379, 'epoch': 28.34}




{'loss': 0.14, 'learning_rate': 5.759162303664922e-06, 'epoch': 28.86}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.3092484176158905, 'eval_wer': 0.27607593979275397, 'eval_runtime': 179.6336, 'eval_samples_per_second': 4.053, 'eval_steps_per_second': 4.053, 'epoch': 28.86}




{'loss': 0.141, 'learning_rate': 3.1413612565445026e-06, 'epoch': 29.37}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.3074033856391907, 'eval_wer': 0.27456319491717723, 'eval_runtime': 167.2455, 'eval_samples_per_second': 4.353, 'eval_steps_per_second': 4.353, 'epoch': 29.37}




{'loss': 0.1358, 'learning_rate': 5.235602094240838e-07, 'epoch': 29.89}


  0%|          | 0/728 [00:00<?, ?it/s]

{'eval_loss': 0.3078209161758423, 'eval_wer': 0.2758490280614174, 'eval_runtime': 170.7182, 'eval_samples_per_second': 4.264, 'eval_steps_per_second': 4.264, 'epoch': 29.89}




{'train_runtime': 70994.548, 'train_samples_per_second': 1.23, 'train_steps_per_second': 0.041, 'train_loss': 1.0566150097502875, 'epoch': 29.99}


TrainOutput(global_step=2910, training_loss=1.0566150097502875, metrics={'train_runtime': 70994.548, 'train_samples_per_second': 1.23, 'train_steps_per_second': 0.041, 'train_loss': 1.0566150097502875, 'epoch': 29.99})

#### Saving the processor and the model to a local directory:

This saves the processor and the model for future use. The model can be further fine-tuned on additional data.

In [43]:
model.save_pretrained("./Model/final_model")

### Evaluating on the test data:

#### Reloading the model from the saved directory:

In [44]:
model = model.from_pretrained("./Model/final_model")

#### Predicting the transcriptions:

In [45]:
def testTranscriptions(test_data):
    with torch.no_grad():
        input_values = torch.tensor(test_data["input_values"]).unsqueeze(0)
        logits = model(input_values).logits

    pred_ids = torch.argmax(logits, dim=-1)
    test_data["pred"] = processor.batch_decode(pred_ids)[0]
    test_data["gold"] = processor.decode(test_data["labels"], group_tokens=False)

    return test_data

In [46]:
output = ur_fine_dataset["test"].map(testTranscriptions, remove_columns=ur_fine_dataset["test"].column_names)

Map:   0%|          | 0/728 [00:00<?, ? examples/s]

#### Computing Test WER:

In [47]:
print("Test WER: {:.3f}".format(wer_metric.compute(predictions = output["pred"], references = output["gold"])))

Test WER: 0.225


#### Saving the predicted and original transcriptions to text files:

In [48]:
def saveToText(output):
    with open("test_gold.txt", "w", encoding = "UTF-8") as file:
        for i in range(len(output["gold"])):
            file.write(output["gold"][i] + "\n")
        file.close()
    
    with open("test_pred.txt", "w", encoding = "UTF-8") as file:
        for i in range(len(output["pred"])):
            file.write(output["pred"][i] + "\n")
        file.close()


saveToText(output)