# **Fine-tuning XLS-R for Multi-Lingual ASR with 🤗 Transformers**

***New (11/2021)***: *This blog post has been updated to feature XLSR's successor, called [XLS-R](https://huggingface.co/models?other=xls_r)*.

Before we start, let's install `datasets` and `transformers`. Also, we need the `torchaudio` to load audio files and `jiwer` to evaluate our fine-tuned model using the [word error rate (WER)](https://huggingface.co/metrics/wer) metric ${}^1$.

In [None]:
%%capture
!pip install datasets==1.18.3
!pip install transformers==4.11.3
!pip install torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
!pip install jiwer

In [None]:
# huggingFace My token to access the repo for writing in my huggingface model:
# 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

In [None]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
%%capture
!apt install git-lfs

Connecting Google Drive:


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
# # REMOVE PREVIOUS non empty DATA FOLDER
# import shutil
# shutil.rmtree('/content/My_Mer_final_DS')

In [None]:
# !unzip drive/MyDrive/wavs.zip
import zipfile
with zipfile.ZipFile('/content/gdrive/MyDrive/Mer_Final_DS.zip', 'r') as zip_ref:
    zip_ref.extractall('My_Mer_final_DS')

In [None]:
# Loading dataset using pandas intead of hugging face load_dataset as it cant be used for my local dataset format.
from sklearn.model_selection import train_test_split
import pandas as pd
from datasets import load_dataset

colnames=['path','sentence']
df  = pd.read_csv('/content/My_Mer_final_DS/NewMergedRiaSwaShiOpenSLR_SUPERFINALL/FInal_All_MER_spk_text.tsv',sep='\t',header=None,names = colnames)
df['path'] = '/content/My_Mer_final_DS/NewMergedRiaSwaShiOpenSLR_SUPERFINALL/Final_All_merged_wavs/'+df['path'] +'.wav'

train, test = train_test_split(df, test_size=0.1)

train.to_csv('/content/My_Mer_final_DS/NewMergedRiaSwaShiOpenSLR_SUPERFINALL/FInal_All_MER_spk_text_train.csv')
test.to_csv('/content/My_Mer_final_DS/NewMergedRiaSwaShiOpenSLR_SUPERFINALL/FInal_All_MER_spk_text_test.csv')
common_voice_train = load_dataset('csv', data_files='/content/My_Mer_final_DS/NewMergedRiaSwaShiOpenSLR_SUPERFINALL/FInal_All_MER_spk_text_train.csv',split= 'train')
common_voice_test = load_dataset('csv', data_files='/content/My_Mer_final_DS/NewMergedRiaSwaShiOpenSLR_SUPERFINALL/FInal_All_MER_spk_text_test.csv',split = 'train')
df.head()



Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-7866ab0210cc70cc/0.0.0/6b9057d9e23d9d8a2f05b985917a0da84d70c5dae3d22ddd8a3f22fb01c69d9e...


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.read_csv_kwargs)


Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-7866ab0210cc70cc/0.0.0/6b9057d9e23d9d8a2f05b985917a0da84d70c5dae3d22ddd8a3f22fb01c69d9e. Subsequent calls will reuse this data.




Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-92496c63860e871f/0.0.0/6b9057d9e23d9d8a2f05b985917a0da84d70c5dae3d22ddd8a3f22fb01c69d9e...


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.read_csv_kwargs)


Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-92496c63860e871f/0.0.0/6b9057d9e23d9d8a2f05b985917a0da84d70c5dae3d22ddd8a3f22fb01c69d9e. Subsequent calls will reuse this data.


Unnamed: 0,path,sentence
0,/content/My_Mer_final_DS/NewMergedRiaSwaShiOpe...,००७ मिलको दूरीमा
1,/content/My_Mer_final_DS/NewMergedRiaSwaShiOpe...,००७ मिलको दूरीमा
2,/content/My_Mer_final_DS/NewMergedRiaSwaShiOpe...,००७ मिलको दूरीमा
3,/content/My_Mer_final_DS/NewMergedRiaSwaShiOpe...,०११ देखि काङ्ग्रेसको
4,/content/My_Mer_final_DS/NewMergedRiaSwaShiOpe...,०१ सौर द्रव्यमान


In [None]:
print(len(common_voice_train))
len(common_voice_test)

149607


16624

In [None]:
# common_voice_train = common_voice_train.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
# common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])

common_voice_train = common_voice_train.remove_columns(["Unnamed: 0"])
common_voice_test = common_voice_test.remove_columns(["Unnamed: 0"])


In [None]:
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

In [None]:
show_random_elements(common_voice_train.remove_columns(["path"]), num_examples=10)

Unnamed: 0,sentence
0,यो टेम्प्लेटले काम
1,बैंक्सका मार्गदर्शनमा उनले
2,प्रेरित गर्‍यो
3,चिन्तित छन्
4,जसको उपयोगले बाणलाई
5,अमेरिकीहरूको पैसामा ४० वर्षयताकै उच्च मूल्यवृद्धि छ
6,अधिकारी चिफ् डिस्ट्रिक
7,केट जनको जान
8,भेषराज शर्माद्वारा लेखिएको
9,बचाउने छन्


In [None]:
import re
chars_to_remove_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\']'

def remove_special_characters(batch):
    batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower()
    return batch

In [None]:
common_voice_train = common_voice_train.map(remove_special_characters)
common_voice_test = common_voice_test.map(remove_special_characters)



0ex [00:00, ?ex/s]

0ex [00:00, ?ex/s]

In [None]:
show_random_elements(common_voice_train.remove_columns(["path"]))

Unnamed: 0,sentence
0,र निबन्धले मागेका निर्धारक
1,१९९० पछि पछि
2,भदै र लायोनेल
3,हुन्छ त्यो विचार
4,राष्ट्रिय कृषि कार्यक्रमहरूले
5,यसरी सिन्धुलीगढीको क्षेत्र
6,उठ्ने सम्भावना
7,अहिले पनि सम्झन्छु
8,उपत्यकासँग जोड्दछ
9,उनको लागि इतिहास


In [None]:
# def replace_hatted_characters(batch):
#     batch["sentence"] = re.sub('[â]', 'a', batch["sentence"])
#     batch["sentence"] = re.sub('[î]', 'i', batch["sentence"])
#     batch["sentence"] = re.sub('[ô]', 'o', batch["sentence"])
#     batch["sentence"] = re.sub('[û]', 'u', batch["sentence"])
#     return batch

In [None]:
# common_voice_train = common_voice_train.map(replace_hatted_characters)
# common_voice_test = common_voice_test.map(replace_hatted_characters)

In [None]:
def extract_all_chars(batch):
  all_text = " ".join(batch["sentence"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

In [None]:
vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [None]:
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))

In [None]:
vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
vocab_dict

{' ': 0,
 '\\': 1,
 'a': 2,
 'b': 3,
 'c': 4,
 'e': 5,
 'f': 6,
 'k': 7,
 'o': 8,
 'ँ': 9,
 'ं': 10,
 'ः': 11,
 'अ': 12,
 'आ': 13,
 'इ': 14,
 'ई': 15,
 'उ': 16,
 'ऊ': 17,
 'ऋ': 18,
 'ए': 19,
 'ऐ': 20,
 'ऑ': 21,
 'ओ': 22,
 'औ': 23,
 'क': 24,
 'ख': 25,
 'ग': 26,
 'घ': 27,
 'ङ': 28,
 'च': 29,
 'छ': 30,
 'ज': 31,
 'झ': 32,
 'ञ': 33,
 'ट': 34,
 'ठ': 35,
 'ड': 36,
 'ढ': 37,
 'ण': 38,
 'त': 39,
 'थ': 40,
 'द': 41,
 'ध': 42,
 'न': 43,
 'प': 44,
 'फ': 45,
 'ब': 46,
 'भ': 47,
 'म': 48,
 'य': 49,
 'र': 50,
 'ऱ': 51,
 'ल': 52,
 'व': 53,
 'श': 54,
 'ष': 55,
 'स': 56,
 'ह': 57,
 '़': 58,
 'ा': 59,
 'ि': 60,
 'ी': 61,
 'ु': 62,
 'ू': 63,
 'ृ': 64,
 'े': 65,
 'ै': 66,
 'ॉ': 67,
 'ॊ': 68,
 'ो': 69,
 'ौ': 70,
 '्': 71,
 'ॐ': 72,
 '॑': 73,
 'ॠ': 74,
 '।': 75,
 '०': 76,
 '१': 77,
 '२': 78,
 '३': 79,
 '४': 80,
 '५': 81,
 '६': 82,
 '७': 83,
 '८': 84,
 '९': 85,
 '॰': 86,
 '\u200b': 87,
 '\u200c': 88,
 '\u200d': 89,
 '\u200e': 90,
 '\u200f': 91,
 '\u202f': 92}

In [None]:
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

In [None]:
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)

95

In [None]:
import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

In a final step, we use the json file to load the vocabulary into an instance of the `Wav2Vec2CTCTokenizer` class.

In [None]:
!pip install transformers
import transformers
from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

Collecting transformers
  Downloading transformers-4.31.0-py3-none-any.whl (7.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m43.5 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m64.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m56.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, safetensors, transformers
Successfully installed safetensors-0.3.1 tokenizers-0.13.3 transformers-4.31.0


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
!python -m pip install huggingface_hub
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|
    
    A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.
    Setting a new token will erase the existing one.
    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Token: 
Add token as git credential? (Y/n) y
Token is valid (permission: write).
Your token has been saved in your

If one wants to re-use the just created tokenizer with the fine-tuned model of this notebook, it is strongly advised to upload the `tokenizer` to the [🤗 Hub](https://huggingface.co/). Let's call the repo to which we will upload the files
`"wav2vec2-large-xlsr-turkish-demo-colab"`:

In [None]:
repo_name = "wav2vec2-large-xls-r-300m-Nepali_ShishirAI"

and upload the tokenizer to the [🤗 Hub](https://huggingface.co/).

In [None]:
tokenizer.push_to_hub(repo_name)

CommitInfo(commit_url='https://huggingface.co/shishirml/wav2vec2-large-xls-r-300m-Nepali_ShishirAI/commit/02d9a1212f51c6d4ad82776ca6ebc32dce8ae92b', commit_message='Upload tokenizer', commit_description='', oid='02d9a1212f51c6d4ad82776ca6ebc32dce8ae92b', pr_url=None, pr_revision=None, pr_num=None)

Great, you can see the just created repository under `https://huggingface.co/<your-username>/wav2vec2-large-xls-r-300m-tr-colab`

### Create `Wav2Vec2FeatureExtractor`

In [None]:
from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)

In [None]:
from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [None]:
processor.save_pretrained("/content/gdrive/MyDrive/wav2vec2-large-xls-r-nepali")

Next, we can prepare the dataset.

### Preprocess Data

So far, we have not looked at the actual values of the speech signal but just the transcription. In addition to `sentence`, our datasets include two more column names `path` and `audio`. `path` states the absolute path of the audio file. Let's take a look.


XLS-R expects the input in the format of a 1-dimensional array of 16 kHz. This means that the audio file has to be loaded and resampled.

 Thankfully, `datasets` does this automatically by calling the other column `audio`. Let try it out.

In [None]:
# common_voice_train[0]["audio"]
common_voice_train[0]

{'path': '/content/My_Mer_final_DS/NewMergedRiaSwaShiOpenSLR_SUPERFINALL/Final_All_merged_wavs/c588f74d17.wav',
 'sentence': 'धनकुमार चेम्जोङ हुन्'}

In [None]:
common_voice_train[0]["path"]

'/content/My_Mer_final_DS/NewMergedRiaSwaShiOpenSLR_SUPERFINALL/Final_All_merged_wavs/c588f74d17.wav'

In [None]:
common_voice_train[0]

{'path': '/content/My_Mer_final_DS/NewMergedRiaSwaShiOpenSLR_SUPERFINALL/Final_All_merged_wavs/c588f74d17.wav',
 'sentence': 'धनकुमार चेम्जोङ हुन्'}

In [None]:
import torchaudio

def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    batch["speech"] = speech_array[0].numpy()
    batch["sampling_rate"] = sampling_rate
    batch["target_text"] = batch["sentence"]
    return batch

In [None]:
common_voice_train = common_voice_train.map(speech_file_to_array_fn, remove_columns=common_voice_train.column_names)
common_voice_test = common_voice_test.map(speech_file_to_array_fn, remove_columns=common_voice_test.column_names)

0ex [00:00, ?ex/s]

0ex [00:00, ?ex/s]

In [None]:
# REMOVE PREVIOUS non empty DATA FOLDER
import shutil
shutil.rmtree('/content/My_Mer_final_DS/NewMergedRiaSwaShiOpenSLR_SUPERFINALL/Final_All_merged_wavs')

In [None]:
import librosa
import numpy as np

def resample(batch):
    batch["speech"] = librosa.resample(np.asarray(batch["speech"]), 48_000, 16_000)
    batch["sampling_rate"] = 16_000
    return batch

In [None]:
common_voice_train = common_voice_train.map(resample, num_proc=4)
common_voice_test = common_voice_test.map(resample, num_proc=4)

TypeError: ignored

SSHISHIR SKIP FROM HERE FOR 2 cells


Great, we can see that the audio file has automatically been loaded. This is thanks to the new [`"Audio"` feature](https://huggingface.co/docs/datasets/package_reference/main_classes.html?highlight=audio#datasets.Audio) introduced in `datasets == 4.13.3`, which loads and resamples audio files on-the-fly upon calling.

In the example above we can see that the audio data is loaded with a sampling rate of 48kHz whereas 16kHz are expected by the model. We can set the audio feature to the correct sampling rate by making use of [`cast_column`](https://huggingface.co/docs/datasets/package_reference/main_classes.html?highlight=cast_column#datasets.DatasetDict.cast_column):

In [None]:
# common_voice_train = common_voice_train.cast_column("audio", Audio(sampling_rate=16_000))
# common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16_000))

Let's take a look at `"audio"` again.

In [None]:
# common_voice_train[0]["audio"]

{'array': array([ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
        -7.4556941e-05, -1.4621433e-05, -5.7861507e-05], dtype=float32),
 'path': '/root/.cache/huggingface/datasets/downloads/extracted/05be0c29807a73c9b099873d2f5975dae6d05e9f7d577458a2466ecb9a2b0c6b/cv-corpus-6.1-2020-12-11/tr/clips/common_voice_tr_21921195.mp3',
 'sampling_rate': 16000}

This seemed to have worked! Let's listen to a couple of audio files to better understand the dataset and verify that the audio was correctly loaded.

**Note**: *You can click the following cell a couple of times to listen to different speech samples.*

In [None]:
import IPython.display as ipd
import numpy as np
import random

rand_int = random.randint(0, len(common_voice_train)-1)

print(common_voice_train[rand_int]["target_text"])
ipd.Audio(data=common_voice_train[rand_int]["speech"], autoplay=True, rate=16000)

चौडाइ ७५ सम्म हुन्छ


It seems like the data is now correctly loaded and resampled.

It can be heard, that the speakers change along with their speaking rate, accent, and background environment, etc. Overall, the recordings sound acceptably clear though, which is to be expected from a crowd-sourced read speech corpus.

Let's do a final check that the data is correctly prepared, by printing the shape of the speech input, its transcription, and the corresponding sampling rate.

**Note**: *You can click the following cell a couple of times to verify multiple samples.*

In [None]:
rand_int = random.randint(0, len(common_voice_train)-1)

# print("Target text:", common_voice_train[rand_int]["sentence"])
# print("Input array shape:", common_voice_train[rand_int]["audio"]["array"].shape)
# print("Sampling rate:", common_voice_train[rand_int]["audio"]["sampling_rate"])

rand_int = random.randint(0, len(common_voice_train)-1)

print("Target text:", common_voice_train[rand_int]["target_text"])
print("Input array shape:", np.asarray(common_voice_train[rand_int]["speech"]).shape)
print("Sampling rate:", common_voice_train[rand_int]["sampling_rate"])

Target text: लमजुङ जिल्लामा अवस्थित
Input array shape: (30400,)
Sampling rate: 16000


Good! Everything looks fine - the data is a 1-dimensional array, the sampling rate always corresponds to 16kHz, and the target text is normalized.

Finally, we can leverage `Wav2Vec2Processor` to process the data to the format expected by `Wav2Vec2ForCTC` for training. To do so let's make use of Dataset's [`map(...)`](https://huggingface.co/docs/datasets/package_reference/main_classes.html?highlight=map#datasets.DatasetDict.map) function.

First, we load and resample the audio data, simply by calling `batch["audio"]`.
Second, we extract the `input_values` from the loaded audio file. In our case, the `Wav2Vec2Processor` only normalizes the data. For other speech models, however, this step can include more complex feature extraction, such as [Log-Mel feature extraction](https://en.wikipedia.org/wiki/Mel-frequency_cepstrum).
Third, we encode the transcriptions to label ids.

**Note**: This mapping function is a good example of how the `Wav2Vec2Processor` class should be used. In "normal" context, calling `processor(...)` is redirected to `Wav2Vec2FeatureExtractor`'s call method. When wrapping the processor into the `as_target_processor` context, however, the same method is redirected to `Wav2Vec2CTCTokenizer`'s call method.
For more information please check the [docs](https://huggingface.co/transformers/master/model_doc/wav2vec2.html#transformers.Wav2Vec2Processor.__call__).

In [None]:
def prepare_dataset(batch):
    # check that all files have the correct sampling rate
    assert (
        len(set(batch["sampling_rate"])) == 1
    ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."

    batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values

    with processor.as_target_processor():
        batch["labels"] = processor(batch["target_text"]).input_ids
    return batch

Let's apply the data preparation function to all examples.

In [None]:
# common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names)
# common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names)

common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names, batch_size=8, num_proc=4, batched=True)
common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=4, batched=True)



KeyboardInterrupt: ignored

**Note**: Currently `datasets` make use of [`torchaudio`](https://pytorch.org/audio/stable/index.html) and [`librosa`](https://librosa.org/doc/latest/index.html) for audio loading and resampling. If you wish to implement your own costumized data loading/sampling, feel free to just make use of the `"path"` column instead and disregard the `"audio"` column.

Long input sequences require a lot of memory. XLS-R is based on `self-attention` the memory requirement scales quadratically with the input length for long input sequences (*cf.* with [this](https://www.reddit.com/r/MachineLearning/comments/genjvb/d_why_is_the_maximum_input_sequence_length_of/) reddit post). In case this demo crashes with an "Out-of-memory" error for you, you might want to uncomment the following lines to filter all sequences that are longer than 5 seconds for training.

In [None]:
#max_input_length_in_sec = 5.0
#common_voice_train = common_voice_train.filter(lambda x: x < max_input_length_in_sec * processor.feature_extractor.sampling_rate, input_columns=["input_length"])

Awesome, now we are ready to start training!

## Training

The data is processed so that we are ready to start setting up the training pipeline. We will make use of 🤗's [Trainer](https://huggingface.co/transformers/master/main_classes/trainer.html?highlight=trainer) for which we essentially need to do the following:

- Define a data collator. In contrast to most NLP models, XLS-R has a much larger input length than output length. *E.g.*, a sample of input length 50000 has an output length of no more than 100. Given the large input sizes, it is much more efficient to pad the training batches dynamically meaning that all training samples should only be padded to the longest sample in their batch and not the overall longest sample. Therefore, fine-tuning XLS-R requires a special padding data collator, which we will define below

- Evaluation metric. During training, the model should be evaluated on the word error rate. We should define a `compute_metrics` function accordingly

- Load a pretrained checkpoint. We need to load a pretrained checkpoint and configure it correctly for training.

- Define the training configuration.

After having fine-tuned the model, we will correctly evaluate it on the test data and verify that it has indeed learned to correctly transcribe speech.

### Set-up Trainer

Let's start by defining the data collator. The code for the data collator was copied from [this example](https://github.com/huggingface/transformers/blob/7e61d56a45c19284cfda0cee8995fb552f6b1f4e/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py#L219).

Without going into too many details, in contrast to the common data collators, this data collator treats the `input_values` and `labels` differently and thus applies to separate padding functions on them (again making use of XLS-R processor's context manager). This is necessary because in speech input and output are of different modalities meaning that they should not be treated by the same padding function.
Analogous to the common data collators, the padding tokens in the labels with `-100` so that those tokens are **not** taken into account when computing the loss.

In [None]:
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [None]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

Next, the evaluation metric is defined. As mentioned earlier, the
predominant metric in ASR is the word error rate (WER), hence we will use it in this notebook as well.

In [None]:
wer_metric = load_metric("wer")

Downloading:   0%|          | 0.00/1.95k [00:00<?, ?B/s]

The model will return a sequence of logit vectors:
$\mathbf{y}_1, \ldots, \mathbf{y}_m$ with $\mathbf{y}_1 = f_{\theta}(x_1, \ldots, x_n)[0]$ and $n >> m$.

A logit vector $\mathbf{y}_1$ contains the log-odds for each word in the vocabulary we defined earlier, thus $\text{len}(\mathbf{y}_i) =$ `config.vocab_size`. We are interested in the most likely prediction of the model and thus take the `argmax(...)` of the logits. Also, we transform the encoded labels back to the original string by replacing `-100` with the `pad_token_id` and decoding the ids while making sure that consecutive tokens are **not** grouped to the same token in CTC style ${}^1$.

In [None]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

Now, we can load the pretrained checkpoint of [Wav2Vec2-XLS-R-300M](https://huggingface.co/facebook/wav2vec2-xls-r-300m). The tokenizer's `pad_token_id` must be to define the model's `pad_token_id` or in the case of `Wav2Vec2ForCTC` also CTC's *blank token* ${}^2$. To save GPU memory, we enable PyTorch's [gradient checkpointing](https://pytorch.org/docs/stable/checkpoint.html) and also set the loss reduction to "*mean*".

Because the dataset is quite small (~6h of training data) and because Common Voice is quite noisy, fine-tuning Facebook's [wav2vec2-xls-r-300m checkpoint](https://huggingface.co/facebook/wav2vec2-xls-r-300m) seems to require some hyper-parameter tuning. Therefore, I had to play around a bit with different values for dropout, [SpecAugment](https://arxiv.org/abs/1904.08779)'s masking dropout rate, layer dropout, and the learning rate until training seemed to be stable enough.

**Note**: When using this notebook to train XLS-R on another language of Common Voice those hyper-parameter settings might not work very well. Feel free to adapt those depending on your use case.

In [None]:
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-xls-r-300m",
    attention_dropout=0.0,
    hidden_dropout=0.0,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.0,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
)

Downloading:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.18G [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/wav2vec2-xls-r-300m were not used when initializing Wav2Vec2ForCTC: ['project_hid.weight', 'project_q.bias', 'quantizer.weight_proj.weight', 'quantizer.weight_proj.bias', 'project_q.weight', 'project_hid.bias', 'quantizer.codevectors']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-xls-r-300m and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it 

The first component of XLS-R consists of a stack of CNN layers that are used to extract acoustically meaningful - but contextually independent - features from the raw speech signal. This part of the model has already been sufficiently trained during pretraining and as stated in the [paper](https://arxiv.org/pdf/2006.13979.pdf) does not need to be fine-tuned anymore.
Thus, we can set the `requires_grad` to `False` for all parameters of the *feature extraction* part.

In [None]:
model.freeze_feature_extractor()

In a final step, we define all parameters related to training.
To give more explanation on some of the parameters:
- `group_by_length` makes training more efficient by grouping training samples of similar input length into one batch. This can significantly speed up training time by heavily reducing the overall number of useless padding tokens that are passed through the model
- `learning_rate` and `weight_decay` were heuristically tuned until fine-tuning has become stable. Note that those parameters strongly depend on the Common Voice dataset and might be suboptimal for other speech datasets.

For more explanations on other parameters, one can take a look at the [docs](https://huggingface.co/transformers/master/main_classes/trainer.html?highlight=trainer#trainingarguments).

During training, a checkpoint will be uploaded asynchronously to the hub every 400 training steps. It allows you to also play around with the demo widget even while your model is still training.

**Note**: If one does not want to upload the model checkpoints to the hub, simply set `push_to_hub=False`.

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir=repo_name,
  group_by_length=True,
  per_device_train_batch_size=16,
  gradient_accumulation_steps=2,
  evaluation_strategy="steps",
  num_train_epochs=30,
  gradient_checkpointing=True,
  fp16=True,
  save_steps=400,
  eval_steps=400,
  logging_steps=400,
  learning_rate=3e-4,
  warmup_steps=500,
  save_total_limit=2,
  push_to_hub=True,
)

Now, all instances can be passed to Trainer and we are ready to start training!

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor,
)

/content/wav2vec2-large-xls-r-300m-turkish-colab is already a clone of https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-300m-turkish-colab. Make sure you pull the latest changes with `repo.git_pull()`.
Using amp fp16 backend




---

${}^1$ To allow models to become independent of the speaker rate, in CTC, consecutive tokens that are identical are simply grouped as a single token. However, the encoded labels should not be grouped when decoding since they don't correspond to the predicted tokens of the model, which is why the `group_tokens=False` parameter has to be passed. If we wouldn't pass this parameter a word like `"hello"` would incorrectly be encoded, and decoded as `"helo"`.

${}^2$ The blank token allows the model to predict a word, such as `"hello"` by forcing it to insert the blank token between the two l's. A CTC-conform prediction of `"hello"` of our model would be `[PAD] [PAD] "h" "e" "e" "l" "l" [PAD] "l" "o" "o" [PAD]`.

### Training

Training will take multiple hours depending on the GPU allocated to this notebook. While the trained model yields somewhat satisfying results on *Common Voice*'s test data of Turkish, it is by no means an optimally fine-tuned model. The purpose of this notebook is just to demonstrate how to fine-tune XLS-R on an ASR dataset.

In case you want to use this google colab to fine-tune your model, you should make sure that your training doesn't stop due to inactivity. A simple hack to prevent this is to paste the following code into the console of this tab (*right mouse click -> inspect -> Console tab and insert code*).

```javascript
function ConnectButton(){
    console.log("Connect pushed");
    document.querySelector("#top-toolbar > colab-connect-button").shadowRoot.querySelector("#connect").click()
}
setInterval(ConnectButton,60000);
```

Depending on what GPU was allocated to your google colab it might be possible that you are seeing an `"out-of-memory"` error here. In this case, it's probably best to reduce `per_device_train_batch_size` to 8 or even less and increase [`gradient_accumulation`](https://huggingface.co/transformers/master/main_classes/trainer.html#trainingarguments).

In [None]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.
***** Running training *****
  Num examples = 3478
  Num Epochs = 30
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 2
  Total optimization steps = 3270
  tensor = as_tensor(value)
  return (input_length - kernel_size) // stride + 1


Step,Training Loss,Validation Loss,Wer
400,3.7652,0.625533,0.69329
800,0.3716,0.405785,0.456337
1200,0.1779,0.412796,0.408028
1600,0.1226,0.399832,0.394137
2000,0.0927,0.408412,0.371259
2400,0.0737,0.394867,0.351956
2800,0.0581,0.36983,0.33112
3200,0.0453,0.359626,0.324175


The following columns in the evaluation set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.
***** Running Evaluation *****
  Num examples = 1647
  Batch size = 8
Saving model checkpoint to wav2vec2-large-xls-r-300m-turkish-colab/checkpoint-400
Configuration saved in wav2vec2-large-xls-r-300m-turkish-colab/checkpoint-400/config.json
Model weights saved in wav2vec2-large-xls-r-300m-turkish-colab/checkpoint-400/pytorch_model.bin
Configuration saved in wav2vec2-large-xls-r-300m-turkish-colab/checkpoint-400/preprocessor_config.json
Configuration saved in wav2vec2-large-xls-r-300m-turkish-colab/preprocessor_config.json
  return (input_length - kernel_size) // stride + 1
The following columns in the evaluation set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.
***** Running Evaluation *****
  Num examples = 1647
  Batch size = 8
Saving model checkpoint to wav2vec2-large-xls-r-300m-turk

TrainOutput(global_step=3270, training_loss=0.5767540113641582, metrics={'train_runtime': 16918.4457, 'train_samples_per_second': 6.167, 'train_steps_per_second': 0.193, 'total_flos': 1.2873093995838829e+19, 'train_loss': 0.5767540113641582, 'epoch': 30.0})

The training loss and validation WER go down nicely.

You can now upload the result of the training to the 🤗 Hub, just execute this instruction:

In [None]:
trainer.push_to_hub()

Saving model checkpoint to wav2vec2-large-xls-r-300m-turkish-colab
Configuration saved in wav2vec2-large-xls-r-300m-turkish-colab/config.json
Model weights saved in wav2vec2-large-xls-r-300m-turkish-colab/pytorch_model.bin
Configuration saved in wav2vec2-large-xls-r-300m-turkish-colab/preprocessor_config.json
Several commits (2) will be pushed upstream.
The progress bars may be unreliable.


Upload file pytorch_model.bin:   0%|          | 3.35k/1.18G [00:00<?, ?B/s]

Upload file runs/Nov12_14-33-48_c2d4142f9305/events.out.tfevents.1636727730.c2d4142f9305.1468.0:  41%|####1   …

To https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-300m-turkish-colab
   56ebe74..fe76946  main -> main

Dropping the following result as it does not have all the necessary field:
{'dataset': {'name': 'common_voice', 'type': 'common_voice', 'args': 'tr'}}
To https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-300m-turkish-colab
   fe76946..5f0d67b  main -> main



'https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-300m-turkish-colab/commit/fe769461e4e2fb9534740e6c278a0cfabf268474'

You can now share this model with all your friends, family, favorite pets: they can all load it with the identifier "your-username/the-name-you-picked" so for instance:

```python
from transformers import AutoModelForCTC, Wav2Vec2Processor

model = AutoModelForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xls-r-300m-tr-colab")
processor = Wav2Vec2Processor.from_pretrained("patrickvonplaten/wav2vec2-large-xls-r-300m-tr-colab")
```

For more examples of how XLS-R can be fine-tuned, please take a look at the [official speech recognition examples](https://github.com/huggingface/transformers/tree/master/examples/pytorch/speech-recognition#examples).

### Evaluation

As a final check, let's load the model and verify that it indeed has learned to transcribe Turkish speech.

Let's first load the pretrained checkpoint.

In [None]:
model = Wav2Vec2ForCTC.from_pretrained(repo_name).to("cuda")
processor = Wav2Vec2Processor.from_pretrained(repo_name)

loading configuration file wav2vec2-large-xls-r-300m-turkish-colab/config.json
Model config Wav2Vec2Config {
  "_name_or_path": "facebook/wav2vec2-xls-r-300m",
  "activation_dropout": 0.0,
  "apply_spec_augment": true,
  "architectures": [
    "Wav2Vec2ForCTC"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "classifier_proj_size": 256,
  "codevector_dim": 768,
  "contrastive_logits_temperature": 0.1,
  "conv_bias": true,
  "conv_dim": [
    512,
    512,
    512,
    512,
    512,
    512,
    512
  ],
  "conv_kernel": [
    10,
    3,
    3,
    3,
    3,
    2,
    2
  ],
  "conv_stride": [
    5,
    2,
    2,
    2,
    2,
    2,
    2
  ],
  "ctc_loss_reduction": "mean",
  "ctc_zero_infinity": false,
  "diversity_loss_weight": 0.1,
  "do_stable_layer_norm": true,
  "eos_token_id": 2,
  "feat_extract_activation": "gelu",
  "feat_extract_dropout": 0.0,
  "feat_extract_norm": "layer",
  "feat_proj_dropout": 0.0,
  "feat_quantizer_dropout": 0.0,
  "final_dropout": 0.0,
  "grad


Now, we will just take the first example of the test set, run it through the model and take the `argmax(...)` of the logits to retrieve the predicted token ids.

In [None]:
input_dict = processor(common_voice_test[0]["input_values"], return_tensors="pt", padding=True)

logits = model(input_dict.input_values.to("cuda")).logits

pred_ids = torch.argmax(logits, dim=-1)[0]

It is strongly recommended to pass the ``sampling_rate`` argument to this function.Failing to do so can result in silent errors that might be hard to debug.


We adapted `common_voice_test` quite a bit so that the dataset instance does not contain the original sentence label anymore. Thus, we re-use the original dataset to get the label of the first example.

In [None]:
common_voice_test_transcription = load_dataset("common_voice", "tr", data_dir="./cv-corpus-6.1-2020-12-11", split="test")

Using custom data configuration tr-ad9f7b76efa9f3a0


Downloading and preparing dataset common_voice/tr (download: 592.09 MiB, generated: 2.89 MiB, post-processed: Unknown size, total: 594.98 MiB) to /root/.cache/huggingface/datasets/common_voice/tr-ad9f7b76efa9f3a0/6.1.0/f7a9d973839b7706e9e281c19b7e512f31badf3c0fdbd21c671f3c4bf9acf3b9...


0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset common_voice downloaded and prepared to /root/.cache/huggingface/datasets/common_voice/tr-ad9f7b76efa9f3a0/6.1.0/f7a9d973839b7706e9e281c19b7e512f31badf3c0fdbd21c671f3c4bf9acf3b9. Subsequent calls will reuse this data.



Finally, we can decode the example.

In [None]:
print("Prediction:")
print(processor.decode(pred_ids))

print("\nReference:")
print(common_voice_test_transcription[0]["sentence"].lower())

Prediction:
ha ta küçük şeyleri için bir büyük biş şeylir koğoluyor ve yeneküçük şeyler için bir birmizi incilkiyoruz

Reference:
hayatta küçük şeyleri kovalıyor ve yine küçük şeyler için birbirimizi incitiyoruz.


Alright! The transcription can definitely be recognized from our prediction, but it is not perfect yet. Training the model a bit longer, spending more time on the data preprocessing, and especially using a language model for decoding would certainly improve the model's overall performance.

For a demonstration model on a low-resource language, the results are quite acceptable however 🤗.