# **Fine-tuning XLSR-Wav2Vec2 for Multi-Lingual ASR with 🤗 Transformers**

In [1]:
language_code = 'uk'
language_name = 'ukrainian'
base_model = "facebook/wav2vec2-large-xlsr-53"
# pretrain_model = f"armpacha/wav2vec2-large-xlsr-{language_name}"

data_dir = f"/workspace/data/{language_code}"
cache_dir = "/workspace/.cache"
output_models_dir = f"/workspace/output_models/{language_code}/wav2vec2-large-xlsr-{language_name}-expr3"

In [2]:
import os
import wandb

# W&B company account
%env WANDB_ENTITY = wandb
entity = os.environ["WANDB_ENTITY"]

# Choose the public W&B project
%env WANDB_PROJECT = xlsr-ukrainian
project_name = os.environ["WANDB_PROJECT"]

# Log your trained model to W&B as an Artifact
%env WANDB_LOG_MODEL = false
%env WANDB_WATCH = false

env: WANDB_ENTITY=wandb
env: WANDB_PROJECT=xlsr-ukrainian
env: WANDB_LOG_MODEL=false
env: WANDB_WATCH=false


## Imports

In [3]:
from datasets import (load_dataset, load_metric, ClassLabel, Dataset)
import random
import pandas as pd
import numpy as np
from IPython.display import display, HTML
import IPython.display as ipd
import re
import json
import torch
import torchaudio
import librosa

  '"sox" backend is being deprecated. '


In [4]:
import transformers
from transformers import (
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    Wav2Vec2CTCTokenizer,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2ForCTC,
    Wav2Vec2Processor,
    is_apex_available,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint

In [5]:
import warnings
warnings.simplefilter('ignore')

## Prepare Data, Tokenizer, Feature Extractor

In [6]:
common_voice_train = load_dataset("common_voice", language_code, split="train+validation", cache_dir=cache_dir)
common_voice_test = load_dataset("common_voice", language_code, split="test", cache_dir=cache_dir)

Reusing dataset common_voice (/workspace/.cache/common_voice/uk/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f)
Reusing dataset common_voice (/workspace/.cache/common_voice/uk/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f)


## Geneder-wise augmentations

In [7]:
from audiomentations import Compose, AddGaussianNoise, Gain, PitchShift
import soundfile as sf
import librosa

In [8]:
def is_male(s): return s['gender'].strip() == 'male'
male_ds = common_voice_train.filter(is_male)
len(male_ds)

Loading cached processed dataset at /workspace/.cache/common_voice/uk/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-2b7c319e1c78d934.arrow


4244

In [9]:
def is_female(s): return s['gender'].strip() == 'female'
female_ds = common_voice_train.filter(is_female)
len(female_ds)

Loading cached processed dataset at /workspace/.cache/common_voice/uk/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-c73501c5bb23df2f.arrow


1739

In [10]:
augment_male = Compose([
    AddGaussianNoise(min_amplitude=0.0001, max_amplitude=0.001, p=0.1),
    PitchShift(min_semitones=1, max_semitones=4, p=1),
    # Gain(min_gain_in_db=-6, max_gain_in_db=6, p=0.8)
])

def aug_male(batch):
#     try:
#         speech_array, sampling_rate = sf.read(batch["path"] + "augmented.wav")
#     except:
    speech_array, sampling_rate = torchaudio.load(batch["path"])
#         sf.write(batch["path"] + "augmented.wav", speech_array, sampling_rate, subtype='PCM_24')
    speech_array = torchaudio.transforms.Resample(sampling_rate, 16_000)(speech_array).squeeze().numpy() 
    speech_array = augment_male(samples=speech_array, sample_rate=sampling_rate)
    batch["speech"] = speech_array
    batch["sampling_rate"] = 16_000
    batch["target_text"] = batch["sentence"]
    return batch

In [11]:
i = random.randint(0, len(male_ds))
ipd.Audio(male_ds[i]["path"], autoplay=True, rate=16000)

In [12]:
res = aug_male(male_ds[i])
print("Target text:", res["target_text"])
print("Input array shape:", np.asarray(res["speech"]).shape)
print("Sampling rate:", res["sampling_rate"])
ipd.Audio(data=np.asarray(res["speech"]), autoplay=True, rate=16000)

Target text: Справа в тому що сучасна академія наук говорить що вона продовжує традиції тієї академії
Input array shape: (137472,)
Sampling rate: 16000


In [13]:
augment_female = Compose([
    AddGaussianNoise(min_amplitude=0.0001, max_amplitude=0.001, p=0.1),
    PitchShift(min_semitones=-4, max_semitones=-1, p=1),
    # Gain(min_gain_in_db=-6, max_gain_in_db=6, p=0.8)
])

def aug_female(batch):
#     try:
#         speech_array, sampling_rate = sf.read(batch["path"] + "augmented.wav")
#     except:
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    speech_array = torchaudio.transforms.Resample(sampling_rate, 16_000)(speech_array).squeeze().numpy() 
    speech_array = augment_female(samples=speech_array, sample_rate=sampling_rate)
#         sf.write(batch["path"] + "augmented.wav", speech_array, 16_000, subtype='PCM_24')
    batch["speech"] = speech_array
    batch["sampling_rate"] = 16_000
    batch["target_text"] = batch["sentence"]
    return batch

In [14]:
i = random.randint(0, len(female_ds))
ipd.Audio(female_ds[i]["path"], autoplay=True, rate=16000)

In [15]:
res = aug_female(female_ds[i])
print("Target text:", res["target_text"])
print("Input array shape:", np.asarray(res["speech"]).shape)
print("Sampling rate:", res["sampling_rate"])
ipd.Audio(data=np.asarray(res["speech"]), autoplay=True, rate=16000)

Target text: — Зате ж ти єси русин, полянин...
Input array shape: (62592,)
Sampling rate: 16000


In [None]:
# male_ds = male_ds.map(aug_male, remove_columns=male_ds.column_names, num_proc=4)
# female_ds = female_ds.map(aug_male, remove_columns=female_ds.column_names, num_proc=4)

In [None]:
# male_ds.save_to_disk(cache_dir+'/ukr_male_aug_dataset')
# female_ds.save_to_disk(cache_dir+'/ukr_female_aug_dataset')

In [None]:
!ls /workspace/.cache/ukr_male_aug_dataset/

In [17]:
male_ds = Dataset.load_from_disk(cache_dir+'/ukr_male_aug_dataset')
female_ds = Dataset.load_from_disk(cache_dir+'/ukr_female_aug_dataset')

## Process original data

In [18]:
common_voice_train = common_voice_train.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])

In [22]:
def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

In [None]:
# show_random_elements(common_voice_train.remove_columns(["path"]), num_examples=20)

In [19]:
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\«\»\—\…\(\)\*]'

def remove_special_characters(batch):
    batch["sentence"] = re.sub(re.compile("['`]"), '’', batch['sentence'])
    batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower() + " "
    batch["sentence"] = re.sub(re.compile('i'), 'і', batch['sentence'])
    batch["sentence"] = re.sub(re.compile('o'), 'о', batch['sentence'])
    batch["sentence"] = re.sub(re.compile('a'), 'а', batch['sentence'])
    batch["sentence"] = re.sub(re.compile('e'), 'а', batch['sentence'])
    batch["sentence"] = re.sub(re.compile('ы'), 'и', batch['sentence'])
    batch["sentence"] = re.sub(re.compile("–"), '', batch['sentence'])
    return batch

In [20]:
common_voice_train = common_voice_train.map(remove_special_characters)
common_voice_test = common_voice_test.map(remove_special_characters)

Loading cached processed dataset at /workspace/.cache/common_voice/uk/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-28baf52f898467bd.arrow
Loading cached processed dataset at /workspace/.cache/common_voice/uk/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-63c296103cd95bec.arrow


In [23]:
show_random_elements(common_voice_train.remove_columns(["path"]))

Unnamed: 0,sentence
0,батько пам’ятає і розповідає
1,о мово вкраїнська хто любить ії той любить мою україну
2,живе з сім’єю в будинку неподалік від києва
3,студенти що не здали математику будуть вивішені біля деканату
4,давай вип’ємо
5,туристам які планують сходження у високогір’я слід бути обережними
6,яка птиця така й пісня
7,готовеньке і кішка з’їсть
8,чорний кіт межи ними пробіг
9,але ми сьогодні не ведемо цієї програми


In [24]:
def extract_all_chars(batch):
    all_text = " ".join(batch["sentence"])
    vocab = list(set(all_text))
    return {"vocab": [vocab], "all_text": [all_text]}

In [25]:
vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




In [26]:
vocab_list = ["[PAD]", "[UNK]"] + sorted(list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0])))

In [27]:
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_dict

{'[PAD]': 0,
 '[UNK]': 1,
 ' ': 2,
 'c': 3,
 'j': 4,
 'k': 5,
 'l': 6,
 'm': 7,
 'n': 8,
 'p': 9,
 'u': 10,
 'x': 11,
 'y': 12,
 'а': 13,
 'б': 14,
 'в': 15,
 'г': 16,
 'д': 17,
 'е': 18,
 'ж': 19,
 'з': 20,
 'и': 21,
 'й': 22,
 'к': 23,
 'л': 24,
 'м': 25,
 'н': 26,
 'о': 27,
 'п': 28,
 'р': 29,
 'с': 30,
 'т': 31,
 'у': 32,
 'ф': 33,
 'х': 34,
 'ц': 35,
 'ч': 36,
 'ш': 37,
 'щ': 38,
 'ь': 39,
 'ю': 40,
 'я': 41,
 'є': 42,
 'і': 43,
 'ї': 44,
 'ґ': 45,
 '’': 46}

In [28]:
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

In [29]:
len(vocab_dict)

47

In [None]:
# with open(f'vocab_{language_code}.json', 'w') as vocab_file:
#     json.dump(vocab_dict, vocab_file)

In a final step, we use the json file to instantiate an object of the `Wav2Vec2CTCTokenizer` class.

In [30]:
from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer(f"./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

In [31]:
len(tokenizer)

47

## Randomly augmented dataset

In [None]:
from audiomentations import Compose, AddGaussianNoise, Gain, PitchShift
import soundfile as sf
import librosa
augment = Compose([
    AddGaussianNoise(min_amplitude=0.0001, max_amplitude=0.001, p=0.8),
    PitchShift(min_semitones=-1, max_semitones=1, p=0.8),
    Gain(min_gain_in_db=-6, max_gain_in_db=6, p=0.8)
])
def augmented_speech_file_to_array_fn(batch):
    try:
        speech_array, sampling_rate = sf.read(batch["path"] + "augmented.wav")
    except:
        speech_array, sampling_rate = librosa.load(batch["path"])
        speech_array = augment(samples=speech_array, sample_rate=sampling_rate)
        sf.write(batch["path"] + "augmented.wav", speech_array, sampling_rate, subtype='PCM_24')
    batch["speech"] = librosa.resample(speech_array, sampling_rate, 16_000)
    batch["sampling_rate"] = 16_000
    batch["target_text"] = batch["sentence"]
    return batch


#common_voice_train_augmented = common_voice_train.map(augmented_speech_file_to_array_fn, remove_columns=common_voice_train.column_names, num_proc=4)

### Create XLSR-Wav2Vec2 Feature Extractor

A XLSR-Wav2Vec2 feature extractor object requires the following parameters to be instantiated:

- `feature_size`: Speech models take a sequence of feature vectors as an input. While the length of this sequence obviously varies, the feature size should not. In the case of Wav2Vec2, the feature size is 1 because the model was trained on the raw speech signal ${}^2$.
- `sampling_rate`: The sampling rate at which the model is trained on.
- `padding_value`: For batched inference, shorter inputs need to be padded with a specific value
- `do_normalize`: Whether the input should be *zero-mean-unit-variance* normalized or not. Usually, speech models perform better when normalizing the input
- `return_attention_mask`: Whether the model should make use of an `attention_mask` for batched inference. In general, XLSR-Wav2Vec2 models should **always** make use of the `attention_mask`.

In [32]:
from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)

In [33]:
from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [34]:
processor.save_pretrained(output_models_dir)

### Preprocess Data

So far, we have not looked at the actual values of the speech signal but just kept the path to its file in the dataset. `XLSR-Wav2Vec2` expects the audio file in the format of a 1-dimensional array, so in the first step, let's load all audio files into the dataset object.

Let's first check the serialization format of the downloaded audio files by looking at the first training sample.

In [35]:
common_voice_train[0]

{'path': '/workspace/.cache/downloads/extracted/bd7977a935ef2cde87dfb944844830662a9cc8b14cd2421c9ac1113177c4f8a8/cv-corpus-6.1-2020-12-11/uk/clips/common_voice_uk_23566158.mp3',
 'sentence': 'московитам дозволено створити свою державу а татарам чеченцям  ні але це  расизм '}

In [36]:
import torchaudio

def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    batch["speech"] = torchaudio.transforms.Resample(sampling_rate, 16_000)(speech_array).squeeze().numpy()
    batch["sampling_rate"] = 16_000
    batch["target_text"] = batch["sentence"]
    return batch

In [37]:
common_voice_train = common_voice_train.map(speech_file_to_array_fn, remove_columns=common_voice_train.column_names, num_proc=4)
common_voice_test = common_voice_test.map(speech_file_to_array_fn, remove_columns=common_voice_test.column_names, num_proc=4)

    

Loading cached processed dataset at /workspace/.cache/common_voice/uk/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-dc1bbbe1fa6ff6de.arrow
Loading cached processed dataset at /workspace/.cache/common_voice/uk/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-e25d7c1e4bbcfec0.arrow
Loading cached processed dataset at /workspace/.cache/common_voice/uk/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-7e8e70544d87f35f.arrow
Loading cached processed dataset at /workspace/.cache/common_voice/uk/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-9a6b1051130bddeb.arrow


    

Loading cached processed dataset at /workspace/.cache/common_voice/uk/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-e977e7f9e4f8bcd9.arrow
Loading cached processed dataset at /workspace/.cache/common_voice/uk/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-f0deda4cecbf81b8.arrow
Loading cached processed dataset at /workspace/.cache/common_voice/uk/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-ff7c5812779c2ed8.arrow
Loading cached processed dataset at /workspace/.cache/common_voice/uk/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-d06e317300ef0a04.arrow


In [38]:
rand_int = random.randint(0, len(common_voice_train)-1)

print("Target text:", common_voice_train[rand_int]["target_text"])
print("Input array shape:", np.asarray(common_voice_train[rand_int]["speech"]).shape)
print("Sampling rate:", common_voice_train[rand_int]["sampling_rate"])
ipd.Audio(data=np.asarray(common_voice_train[rand_int]["speech"]), autoplay=False, rate=16000)

Target text: як світ складається з атомів так історія будується з цеглинок 
Input array shape: (78720,)
Sampling rate: 16000


In [39]:
def prepare_dataset(batch):
    # check that all files have the correct sampling rate
    assert (
        len(set(batch["sampling_rate"])) == 1
    ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."

    batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
    batch['length'] = [len(inp) for inp in batch['input_values']]
    with processor.as_target_processor():
        batch["labels"] = processor(batch["target_text"]).input_ids
    return batch

In [40]:
import datasets
train_ds = datasets.concatenate_datasets([common_voice_train, male_ds, female_ds])
test_ds = common_voice_test

In [41]:
train_ds = train_ds.map(prepare_dataset, remove_columns=common_voice_train.column_names, batch_size=8, num_proc=4, batched=True)
test_ds = test_ds.map(prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=4, batched=True)

    

Loading cached processed dataset at /workspace/.cache/common_voice/uk/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-1e699bc3eedfc39b.arrow
Loading cached processed dataset at /workspace/.cache/common_voice/uk/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-15895e79fdb8fa7f.arrow
Loading cached processed dataset at /workspace/.cache/common_voice/uk/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-57a36c0e40b5ae6d.arrow
Loading cached processed dataset at /workspace/.cache/common_voice/uk/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-17e27233a4a982da.arrow


  

Loading cached processed dataset at /workspace/.cache/common_voice/uk/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-89eb846c2c7ff9a7.arrow


 

Loading cached processed dataset at /workspace/.cache/common_voice/uk/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-1f5ce7fe67c0410b.arrow


 

Loading cached processed dataset at /workspace/.cache/common_voice/uk/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-44920240e41f8749.arrow
Loading cached processed dataset at /workspace/.cache/common_voice/uk/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-d4c8d6c54df213f2.arrow


In [None]:
# common_voice_train = common_voice_train.filter(lambda sample: len(sample['input_values'])/len(sample['labels']) < 5000, num_proc=4)

## Training

The data is processed so that we are ready to start setting up the training pipeline. We will make use of 🤗's [Trainer](https://huggingface.co/transformers/master/main_classes/trainer.html?highlight=trainer) for which we essentially need to do the following:

- Define a data collator. In contrast to most NLP models, XLSR-Wav2Vec2 has a much larger input length than output length. *E.g.*, a sample of input length 50000 has an output length of no more than 100. Given the large input sizes, it is much more efficient to pad the training batches dynamically meaning that all training samples should only be padded to the longest sample in their batch and not the overall longest sample. Therefore, fine-tuning XLSR-Wav2Vec2 requires a special padding data collator, which we will define below

- Evaluation metric. During training, the model should be evaluated on the word error rate. We should define a `compute_metrics` function accordingly

- Load a pretrained checkpoint. We need to load a pretrained checkpoint and configure it correctly for training.

- Define the training configuration.

After having fine-tuned the model, we will correctly evaluate it on the test data and verify that it has indeed learned to correctly transcribe speech.

### Set-up Trainer

Let's start by defining the data collator. The code for the data collator was copied from [this example](https://github.com/huggingface/transformers/blob/9a06b6b11bdfc42eea08fa91d0c737d1863c99e3/examples/research_projects/wav2vec2/run_asr.py#L81).

Without going into too many details, in contrast to the common data collators, this data collator treats the `input_values` and `labels` differently and thus applies to separate padding functions on them (again making use of XLSR-Wav2Vec2's context manager). This is necessary because in speech input and output are of different modalities meaning that they should not be treated by the same padding function.
Analogous to the common data collators, the padding tokens in the labels with `-100` so that those tokens are **not** taken into account when computing the loss.

In [42]:
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
        max_length (:obj:`int`, `optional`):
            Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
        max_length_labels (:obj:`int`, `optional`):
            Maximum length of the ``labels`` returned list and optionally padding length (see above).
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [43]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [44]:
wer_metric = load_metric("wer")

The model will return a sequence of logit vectors:
$\mathbf{y}_1, \ldots, \mathbf{y}_m$ with $\mathbf{y}_1 = f_{\theta}(x_1, \ldots, x_n)[0]$ and $n >> m$.

A logit vector $\mathbf{y}_1$ contains the log-odds for each word in the vocabulary we defined earlier, thus $\text{len}(\mathbf{y}_i) =$ `config.vocab_size`. We are interested in the most likely prediction of the model and thus take the `argmax(...)` of the logits. Also, we transform the encoded labels back to the original string by replacing `-100` with the `pad_token_id` and decoding the ids while making sure that consecutive tokens are **not** grouped to the same token in CTC style ${}^1$.

In [45]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

Now, we can load the pretrained `XLSR-Wav2Vec2` checkpoint. The tokenizer's `pad_token_id` must be to define the model's `pad_token_id` or in the case of `Wav2Vec2ForCTC` also CTC's *blank token* ${}^2$. To save GPU memory, we enable PyTorch's [gradient checkpointing](https://pytorch.org/docs/stable/checkpoint.html) and also set the loss reduction to "*mean*".

Because the dataset is quite small (~6h of training data) and because Common Voice is quite noisy, fine-tuning Facebook's [wav2vec2-large-xlsr-53 checkpoint](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) seems to require some hyper-parameter tuning. Therefore, I had to play around a bit with different values for dropout, [SpecAugment](https://arxiv.org/abs/1904.08779)'s masking dropout rate, layer dropout, and the learning rate until training seemed to be stable enough. 

**Note**: When using this notebook to train XLSR-Wav2Vec2 on another language of Common Voice those hyper-parameter settings might not work very well. Feel free to adapt those depending on your use case. 

In [46]:
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    base_model,
    attention_dropout=0.1,
    activation_dropout=0.05,
    hidden_dropout=0.1,
    feat_proj_dropout=0.008,
    mask_time_prob=0.05,
    layerdrop=0.05,
    final_dropout=0.1,
    gradient_checkpointing=True, 
    ctc_loss_reduction="mean",
    ctc_zero_infinity=True,
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)
model.freeze_feature_extractor()

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-xlsr-53 and are newly initialized: ['lm_head.weight', 'lm_head.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In a final step, we define all parameters related to training. 
To give more explanation on some of the parameters:
- `group_by_length` makes training more efficient by grouping training samples of similar input length into one batch. This can significantly speed up training time by heavily reducing the overall number of useless padding tokens that are passed through the model
- `learning_rate` and `weight_decay` were heuristically tuned until fine-tuning has become stable. Note that those parameters strongly depend on the Common Voice dataset and might be suboptimal for other speech datasets.

For more explanations on other parameters, one can take a look at the [docs](https://huggingface.co/transformers/master/main_classes/trainer.html?highlight=trainer#trainingarguments).

**Note**: If one wants to save the trained models in his/her google drive the commented-out `output_dir` can be used instead.

In [47]:
training_args = TrainingArguments(
    output_dir=output_models_dir,
    group_by_length=True,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    gradient_accumulation_steps=1,
    dataloader_num_workers=8,
    evaluation_strategy="steps",
    save_strategy="steps",
    lr_scheduler_type='cosine',
    max_steps=10000,
    num_train_epochs=1,
    fp16=True,
    save_steps=500,
    eval_steps=500,
    logging_steps=500,
    learning_rate=1e-4,
    warmup_steps=500,
    save_total_limit=2,
    report_to='wandb',
    run_name = 'ukr-gen-aug-3'
)

In [48]:
from transformers import Trainer
from transformers.trainer_pt_utils import LengthGroupedSampler, DistributedLengthGroupedSampler
from torch.utils.data import DataLoader
import collections

class GroupedLengthsTrainer(Trainer):
    # length_field_name should possibly be part of TrainingArguments instead
    def __init__(self, length_field_name='length', *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.length_field_name = length_field_name
    
    def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
        if isinstance(self.train_dataset, torch.utils.data.IterableDataset) or not isinstance(
            self.train_dataset, collections.abc.Sized
        ):
            return None

        # Build the sampler.
        if self.args.group_by_length:
            lengths = self.train_dataset[self.length_field_name] if self.length_field_name is not None else None
            model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
            if self.args.world_size <= 1:
                return LengthGroupedSampler(
                    self.train_dataset, self.args.train_batch_size, lengths=lengths, model_input_name=model_input_name
                )
            else:
                return DistributedLengthGroupedSampler(
                    self.train_dataset,
                    self.args.train_batch_size,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
                    lengths=lengths,
                    model_input_name=model_input_name,
                )

        else:
            return super()._get_train_sampler()

# Build trainer indicating the name of the field that contains the lengths
trainer = GroupedLengthsTrainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    tokenizer=processor.feature_extractor,
)

In [None]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mfastai_community[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.24 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Step,Training Loss,Validation Loss
