In [1]:
!nvidia-smi

Sat Mar 27 00:55:33 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.51.05    Driver Version: 450.51.05    CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100S-PCI...  Off  | 00000000:00:0A.0 Off |                    0 |
| N/A   34C    P0    25W / 250W |      0MiB / 32510MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
%%capture
import sys
if 'google.colab' in sys.modules:
    !pip install datasets==1.4.1
    # !pip install transformers
    !pip install git+https://github.com/huggingface/transformers.git@refs/pull/10826/head
    !pip install torchaudio
    !pip install librosa soundfile pyloudnorm
    !pip install jiwer
    !pip install wandb
    from google.colab import drive
    drive.mount('/content/gdrive/')

In [3]:
# from datasets import set_caching_enabled
# set_caching_enabled(False)

In [4]:
import os
import wandb

# W&B company account
%env WANDB_ENTITY=arampacha
entity = os.environ["WANDB_ENTITY"]

# Choose the public W&B project
%env WANDB_PROJECT=xlsr-czech
project_name = os.environ["WANDB_PROJECT"]

# Log your trained model to W&B as an Artifact
%env WANDB_LOG_MODEL=false 

# # Disable logging of gradients to speed things up a little
%env WANDB_WATCH = false

env: WANDB_ENTITY=arampacha
env: WANDB_PROJECT=xlsr-czech
env: WANDB_LOG_MODEL=false
env: WANDB_WATCH=false


In [5]:
language_code = 'cs'
language_name = 'czech'
base_model = "facebook/wav2vec2-large-xlsr-53"
# pretrain_model = f"armpacha/wav2vec2-large-xlsr-{language_name}"

data_dir = f"/workspace/data/{language_code}"
cache_dir = "/workspace/.cache"
output_models_dir = f"/workspace/output_models/{language_code}/wav2vec2-large-xlsr-{language_name}"

In [6]:
from datasets import (load_dataset, load_metric, ClassLabel)
import random
import pandas as pd
import numpy as np
from IPython.display import display, HTML
import IPython.display as ipd
import re
import json
import torch
import torch.nn as nn
import torchaudio
import librosa
from torch_audiomentations import Compose, AddBackgroundNoise, Gain

  '"torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE" flag is deprecated and will be removed in 0.9.0. '


In [7]:
import transformers
from transformers import (
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    Wav2Vec2CTCTokenizer,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2ForCTC,
    Wav2Vec2Processor,
    is_apex_available,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint

## Prepare Data, Tokenizer, Feature Extractor

### Create Wav2Vec2CTCTokenizer

In [8]:
from datasets import load_dataset, load_metric

In [9]:
common_voice_train = load_dataset("common_voice", "cs", split="train+validation", cache_dir=cache_dir, ) # "train+validation"
common_voice_test = load_dataset("common_voice", "cs", split="test", cache_dir=cache_dir)

Reusing dataset common_voice (/workspace/.cache/common_voice/cs/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f)
Reusing dataset common_voice (/workspace/.cache/common_voice/cs/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f)


Many ASR datasets only provide the target text, `'sentence'` for each audio file `'path'`. Common Voice actually provides much more information about each audio file, such as the `'accent'`, etc. However, we want to keep the notebook as general as possible, so that we will only consider the transcribed text for fine-tuning.



In [10]:
common_voice_train = common_voice_train.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])

Let's write a short function to display some random samples of the dataset and run it a couple of times to get a feeling for the transcriptions.

In [11]:
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

In [12]:
show_random_elements(common_voice_train.remove_columns(["path"]), num_examples=20)

Unnamed: 0,sentence
0,"Ne, ne, mlčte!"
1,Studoval také na univerzitě Gregoriana v Římě.
2,Prý už žil ve vlastním světě.
3,Zastupitelstvo je sedmičlenné.
4,Weil se rozhodl odjet do Paříže.
5,Hoši si to naučení vzali k srdci a dali se zase do karet.
6,Podrobnosti stanoví zákon.
7,Vlastní ostrov je protáhlý z jihozápadu na severovýchod.
8,O tom jsem ale mluvit nechtěl.
9,Stojí v pokoji bez obrazů.


In [13]:
import re
chars_to_ignore_regex = '[,?.!-;:"“%‘”�«»—…\)\(*„]'

def remove_special_characters(batch):
    batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower() + " "
    batch["sentence"] = re.sub(re.compile('/'), ' ', batch['sentence'])
    batch["sentence"] = re.sub(re.compile('[äá]'), 'a', batch['sentence'])
    batch["sentence"] = re.sub(re.compile('[öó]'), 'o', batch['sentence'])
    batch["sentence"] = re.sub(re.compile('[èé]'), 'e', batch['sentence'])
    batch["sentence"] = re.sub(re.compile("[ïí]"), 'i', batch['sentence'])
    batch["sentence"] = re.sub(re.compile("[üů]"), 'u', batch['sentence'])
    batch["sentence"] = re.sub(re.compile("–"), '', batch['sentence'])
    return batch

In [14]:
common_voice_train = common_voice_train.map(remove_special_characters)
common_voice_test = common_voice_test.map(remove_special_characters)

Loading cached processed dataset at /workspace/.cache/common_voice/cs/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-36252a4846063ac7.arrow
Loading cached processed dataset at /workspace/.cache/common_voice/cs/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-c34a47ff2880983c.arrow


In [15]:
show_random_elements(common_voice_train.remove_columns(["path"]))

Unnamed: 0,sentence
0,člen davu tomuto chovani většinou nedokaže odolat
1,firma ktera ji stavěla prý zapomněla na videokabel pro připojeni notebooku
2,ty dale zvyšuji tlumeni a přibližuji potencial harmonických k úrovni země
3,dalšimi autory teto serie jsou jan malý a jiři polaček
4,tomuto problemu se však zelena kniha vubec nevěnuje
5,prukopnik se take nazýva pionýr
6,těchto cyklu muže být v molekule sloučeniny několik
7,vzhled celeho východni břehu se tak významně změnil
8,z francie napřiklad po francouzske revoluci přestali přichazet dary úplně
9,po zabrani rakouska hral za německo


In [16]:
def extract_all_chars(batch):
    all_text = " ".join(batch["sentence"])
    vocab = list(set(all_text))
    return {"vocab": [vocab], "all_text": [all_text]}

In [17]:
vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Now, we create the union of all distinct letters in the training dataset and test dataset and convert the resulting list into an enumerated dictionary.

In [18]:
vocab_list = ['[PAD]', '[UNK]'] + sorted(list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0])))

In [19]:
# from collections import Counter
# counts = Counter(vocab_train['all_text'][0]+vocab_test['all_text'][0])
# counts

In [20]:
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_dict

{'[PAD]': 0,
 '[UNK]': 1,
 ' ': 2,
 'a': 3,
 'b': 4,
 'c': 5,
 'd': 6,
 'e': 7,
 'f': 8,
 'g': 9,
 'h': 10,
 'i': 11,
 'j': 12,
 'k': 13,
 'l': 14,
 'm': 15,
 'n': 16,
 'o': 17,
 'p': 18,
 'q': 19,
 'r': 20,
 's': 21,
 't': 22,
 'u': 23,
 'v': 24,
 'w': 25,
 'x': 26,
 'y': 27,
 'z': 28,
 'ú': 29,
 'ý': 30,
 'č': 31,
 'ď': 32,
 'ě': 33,
 'ň': 34,
 'ř': 35,
 'š': 36,
 'ť': 37,
 'ž': 38}

In [21]:
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

In [22]:
len(vocab_dict)

39

In [23]:
with open(f'vocab_{language_code}.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

In [24]:
tokenizer = Wav2Vec2CTCTokenizer(f"./vocab_{language_code}.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

### Create XLSR-Wav2Vec2 Feature Extractor

A XLSR-Wav2Vec2 feature extractor object requires the following parameters to be instantiated:

- `feature_size`: Speech models take a sequence of feature vectors as an input. While the length of this sequence obviously varies, the feature size should not. In the case of Wav2Vec2, the feature size is 1 because the model was trained on the raw speech signal ${}^2$.
- `sampling_rate`: The sampling rate at which the model is trained on.
- `padding_value`: For batched inference, shorter inputs need to be padded with a specific value
- `do_normalize`: Whether the input should be *zero-mean-unit-variance* normalized or not. Usually, speech models perform better when normalizing the input
- `return_attention_mask`: Whether the model should make use of an `attention_mask` for batched inference. In general, XLSR-Wav2Vec2 models should **always** make use of the `attention_mask`.

In [25]:
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [26]:
processor.save_pretrained(output_models_dir)

In [27]:
import warnings
warnings.simplefilter('ignore')

## Augmented dataset

In [28]:
from audiomentations import Compose, AddGaussianNoise, Gain, PitchShift
import soundfile as sf
import librosa
augment = Compose([
    AddGaussianNoise(min_amplitude=0.0001, max_amplitude=0.001, p=0.8),
    PitchShift(min_semitones=-1, max_semitones=1, p=0.8),
    Gain(min_gain_in_db=-6, max_gain_in_db=6, p=0.8)
])
def augmented_speech_file_to_array_fn(batch):
    try:
        speech_array, sampling_rate = sf.read(batch["path"] + "augmented.wav")
    except:
        speech_array, sampling_rate = librosa.load(batch["path"])
        speech_array = augment(samples=speech_array, sample_rate=sampling_rate)
        sf.write(batch["path"] + "augmented.wav", speech_array, sampling_rate, subtype='PCM_24')
    batch["speech"] = librosa.resample(speech_array, sampling_rate, 16_000)
    batch["sampling_rate"] = 16_000
    batch["target_text"] = batch["sentence"]
    return batch


common_voice_train_augmented = common_voice_train.map(augmented_speech_file_to_array_fn, remove_columns=common_voice_train.column_names, num_proc=4)

    

HBox(children=(FloatProgress(value=0.0, description='#0', max=2444.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#1', max=2443.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#3', max=2443.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#2', max=2443.0, style=ProgressStyle(description_width='i…







### Preprocess Data

In [29]:
common_voice_train[0]

{'path': '/workspace/.cache/downloads/extracted/3249973b5a34ed3c73f74007bcfbf160c81a6fb30754454ad429beb17fe53d1b/cv-corpus-6.1-2020-12-11/cs/clips/common_voice_cs_20493005.mp3',
 'sentence': 'je mi jedno jak to zařidiš '}

In [30]:
def resample(batch):
    batch["speech"] = librosa.resample(np.asarray(batch["speech"]), 48_000, 16_000)
    batch["sampling_rate"] = 16_000
    return batch

In [31]:
def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = librosa.load(batch["path"])
    batch["speech"] = librosa.resample(speech_array, sampling_rate, 16_000)
    batch["sampling_rate"] = 16_000
    batch["target_text"] = batch["sentence"]
    return batch

In [32]:
import multiprocessing
multiprocessing.cpu_count()

56

In [33]:
common_voice_train = common_voice_train.map(speech_file_to_array_fn, remove_columns=common_voice_train.column_names, num_proc=8)
common_voice_test = common_voice_test.map(speech_file_to_array_fn, remove_columns=common_voice_test.column_names, num_proc=8)

   

Loading cached processed dataset at /workspace/.cache/common_voice/cs/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-9a2640d70e37112e.arrow
Loading cached processed dataset at /workspace/.cache/common_voice/cs/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-55cc798e5f8940f8.arrow


 

Loading cached processed dataset at /workspace/.cache/common_voice/cs/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-7a3f760f02d0ed98.arrow
Loading cached processed dataset at /workspace/.cache/common_voice/cs/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-06fe7a82114fed00.arrow


 

Loading cached processed dataset at /workspace/.cache/common_voice/cs/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-826789b76fe7b1bc.arrow


 

Loading cached processed dataset at /workspace/.cache/common_voice/cs/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-b3f4385a7476440f.arrow


  

Loading cached processed dataset at /workspace/.cache/common_voice/cs/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-c65c16041ae734b7.arrow
Loading cached processed dataset at /workspace/.cache/common_voice/cs/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-a94cd07945602542.arrow


        

Loading cached processed dataset at /workspace/.cache/common_voice/cs/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-75671a0240fe6367.arrow
Loading cached processed dataset at /workspace/.cache/common_voice/cs/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-f36604419097893e.arrow
Loading cached processed dataset at /workspace/.cache/common_voice/cs/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-d46b0f56c8ab18af.arrow
Loading cached processed dataset at /workspace/.cache/common_voice/cs/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-28ca6b649e227dbc.arrow
Loading cached processed dataset at /workspace/.cache/common_voice/cs/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-3fddaeac28e8ba59.arrow
Loading cached processed dataset at /workspace/.cache/common_voice/cs/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-991

## Add augmented data to training dataset

In [34]:
from datasets import concatenate_datasets
print("Merging original and augmented data...")
common_voice_train = concatenate_datasets([common_voice_train, common_voice_train_augmented])

Merging original and augmented data...


In [None]:
# import soundfile as sf
# import pyloudnorm as pyln

# def get_loudness_normalised(sa, sr):
#     # peak normalize audio to -1 dB
#     peak_normalized_audio = pyln.normalize.peak(sa, -1.0)

#     # measure the loudness first 
#     meter = pyln.Meter(sr) # create BS.1770 meter
#     loudness = meter.integrated_loudness(sa)

#     # loudness normalize audio to -12 dB LUFS
#     loudness_normalized_audio = pyln.normalize.loudness(sa, loudness, -12.0)

#     return loudness_normalized_audio

In [None]:
# def speech_file_to_array_loud_norm_fn(batch):
#     speech_array, sampling_rate = torchaudio.load(batch["path"])
    
#     # DO loudness normalisation
#     sa = get_loudness_normalised(speech_array[0].numpy(), sampling_rate)
    
#     batch["speech"] = sa
#     batch["sampling_rate"] = sampling_rate
#     batch["target_text"] = batch["sentence"]
#     return batch

In [None]:
# common_voice_train = common_voice_train.map(speech_file_to_array_loud_norm_fn)
# common_voice_test = common_voice_test.map(speech_file_to_array_fn)

In [35]:
import IPython.display as ipd
import numpy as np
import random

rand_int = random.randint(0, len(common_voice_train)-1)

print("Target text:", common_voice_train[rand_int]["target_text"])
print("Input array shape:", np.asarray(common_voice_train[rand_int]["speech"]).shape)
print("Sampling rate:", common_voice_train[rand_int]["sampling_rate"])

ipd.Audio(data=np.asarray(common_voice_train[rand_int]["speech"]), autoplay=True, rate=16000)

Target text: velmi úspěšna pro něj byla nasledujici sezona 
Input array shape: (112513,)
Sampling rate: 16000


In [36]:
def prepare_dataset(batch):
    # check that all files have the correct sampling rate
    assert (
        len(set(batch["sampling_rate"])) == 1
    ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."

    batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["target_text"]).input_ids
    return batch

In [37]:
common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names, batch_size=8, num_proc=4, batched=True)
common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=4, batched=True)

    

HBox(children=(FloatProgress(value=0.0, description='#1', max=611.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='#0', max=611.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='#3', max=611.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='#2', max=611.0, style=ProgressStyle(description_width='in…





  

HBox(children=(FloatProgress(value=0.0, description='#0', max=130.0, style=ProgressStyle(description_width='in…

 

HBox(children=(FloatProgress(value=0.0, description='#1', max=130.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='#2', max=130.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='#3', max=130.0, style=ProgressStyle(description_width='in…







## Training

The data is processed so that we are ready to start setting up the training pipeline. We will make use of 🤗's [Trainer](https://huggingface.co/transformers/master/main_classes/trainer.html?highlight=trainer) for which we essentially need to do the following:

- Define a data collator. In contrast to most NLP models, XLSR-Wav2Vec2 has a much larger input length than output length. *E.g.*, a sample of input length 50000 has an output length of no more than 100. Given the large input sizes, it is much more efficient to pad the training batches dynamically meaning that all training samples should only be padded to the longest sample in their batch and not the overall longest sample. Therefore, fine-tuning XLSR-Wav2Vec2 requires a special padding data collator, which we will define below

- Evaluation metric. During training, the model should be evaluated on the word error rate. We should define a `compute_metrics` function accordingly

- Load a pretrained checkpoint. We need to load a pretrained checkpoint and configure it correctly for training.

- Define the training configuration.

After having fine-tuned the model, we will correctly evaluate it on the test data and verify that it has indeed learned to correctly transcribe speech.

### Set-up Trainer

Let's start by defining the data collator. The code for the data collator was copied from [this example](https://github.com/huggingface/transformers/blob/9a06b6b11bdfc42eea08fa91d0c737d1863c99e3/examples/research_projects/wav2vec2/run_asr.py#L81).

Without going into too many details, in contrast to the common data collators, this data collator treats the `input_values` and `labels` differently and thus applies to separate padding functions on them (again making use of XLSR-Wav2Vec2's context manager). This is necessary because in speech input and output are of different modalities meaning that they should not be treated by the same padding function.
Analogous to the common data collators, the padding tokens in the labels with `-100` so that those tokens are **not** taken into account when computing the loss.

In [38]:
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
        max_length (:obj:`int`, `optional`):
            Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
        max_length_labels (:obj:`int`, `optional`):
            Maximum length of the ``labels`` returned list and optionally padding length (see above).
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [39]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

Next, the evaluation metric is defined. As mentioned earlier, the 
predominant metric in ASR is the word error rate (WER), hence we will use it in this notebook as well.

In [40]:
wer_metric = load_metric("wer")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1764.0, style=ProgressStyle(description…




In [41]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

Now, we can load the pretrained `XLSR-Wav2Vec2` checkpoint. The tokenizer's `pad_token_id` must be to define the model's `pad_token_id` or in the case of `Wav2Vec2ForCTC` also CTC's *blank token* ${}^2$. To save GPU memory, we enable PyTorch's [gradient checkpointing](https://pytorch.org/docs/stable/checkpoint.html) and also set the loss reduction to "*mean*".

Because the dataset is quite small (~6h of training data) and because Common Voice is quite noisy, fine-tuning Facebook's [wav2vec2-large-xlsr-53 checkpoint](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) seems to require some hyper-parameter tuning. Therefore, I had to play around a bit with different values for dropout, [SpecAugment](https://arxiv.org/abs/1904.08779)'s masking dropout rate, layer dropout, and the learning rate until training seemed to be stable enough. 

**Note**: When using this notebook to train XLSR-Wav2Vec2 on another language of Common Voice those hyper-parameter settings might not work very well. Feel free to adapt those depending on your use case. 

In [42]:
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-xlsr-53", 
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.04,
    mask_time_prob=0.05,
    layerdrop=0.1,
    gradient_checkpointing=True, 
    ctc_loss_reduction="mean",
    ctc_zero_infinity=True,
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)
model.freeze_feature_extractor()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1451.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1261920069.0, style=ProgressStyle(descr…




Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-xlsr-53 and are newly initialized: ['lm_head.weight', 'lm_head.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Augmentations on the fly

In [43]:
# class AudioAug(nn.Module):
#     def __init__(self, aug, sample_rate=16_000):
#         super().__init__()
#         self.aug = aug
#         self.sample_rate = sample_rate
#     def forward(self, x):
#         return self.aug(x, sample_rate=self.sample_rate)

In [44]:
# aug = Compose(transforms = [
#     AddBackgroundNoise(min_amplitude=0.0001, max_amplitude=0.001, p=0.5),
#     Gain(min_gain_in_db=-6, max_gain_in_db=6, p=0.5)
# ])
# aug_module = AudioAug(aug)

The first component of XLSR-Wav2Vec2 consists of a stack of CNN layers that are used to extract acoustically meaningful - but contextually independent - features from the raw speech signal. This part of the model has already been sufficiently trained during pretraining and as stated in the [paper](https://arxiv.org/pdf/2006.13979.pdf) does not need to be fine-tuned anymore. 
Thus, we can set the `requires_grad` to `False` for all parameters of the *feature extraction* part.

In a final step, we define all parameters related to training. 
To give more explanation on some of the parameters:
- `group_by_length` makes training more efficient by grouping training samples of similar input length into one batch. This can significantly speed up training time by heavily reducing the overall number of useless padding tokens that are passed through the model
- `learning_rate` and `weight_decay` were heuristically tuned until fine-tuning has become stable. Note that those parameters strongly depend on the Common Voice dataset and might be suboptimal for other speech datasets.

For more explanations on other parameters, one can take a look at the [docs](https://huggingface.co/transformers/master/main_classes/trainer.html?highlight=trainer#trainingarguments).

**Note**: If one wants to save the trained models in his/her google drive the commented-out `output_dir` can be used instead.

In [46]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir=output_models_dir,
    group_by_length=True,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    gradient_accumulation_steps=1,
    evaluation_strategy="steps",
    save_strategy="steps",
    max_steps=10000,
    fp16=True,
    save_steps=500,
    eval_steps=500,
    logging_steps=500,
    learning_rate=3e-4,
    lr_scheduler_type='cosine',
    warmup_steps=500,
    save_total_limit=2,
    dataloader_num_workers=8,
    report_to = 'wandb',
    run_name = 'cz-aug'
)

In [47]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor,
)

### Training

In [48]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mfastai_community[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.23 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Step,Training Loss,Validation Loss,Wer,Runtime,Samples Per Second
500,4.8286,2.668277,1.0,128.8732,32.156
1000,0.8452,0.340108,0.45283,131.9763,31.4
1500,0.2984,0.278121,0.365048,131.9763,31.4
2000,0.182,0.280746,0.324491,131.507,31.512
2500,0.1277,0.330042,0.313085,131.5382,31.504
3000,0.0956,0.350278,0.308567,131.7329,31.458
3500,0.0765,0.339232,0.283934,131.4376,31.528
4000,0.0626,0.339459,0.300988,131.4242,31.531
4500,0.0515,0.333291,0.281019,131.5284,31.506
5000,0.0445,0.366521,0.291659,131.9217,31.413


TrainOutput(global_step=10000, training_loss=0.3429147410392761, metrics={'train_runtime': 16662.5806, 'train_samples_per_second': 0.6, 'total_flos': 9.099791610614091e+19, 'epoch': 32.68, 'init_mem_cpu_alloc_delta': 74835, 'init_mem_gpu_alloc_delta': 1261915136, 'init_mem_cpu_peaked_delta': 18306, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 44827415, 'train_mem_gpu_alloc_delta': 3851747840, 'train_mem_cpu_peaked_delta': 371636900, 'train_mem_gpu_peaked_delta': 14896999424})

## TODO

In [None]:
# model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-turkish-demo").to("cuda")
# processor = Wav2Vec2Processor.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-turkish-demo")

Now, we will just take the first example of the test set, run it through the model and take the `argmax(...)` of the logits to retrieve the predicted token ids.

In [None]:
# input_dict = processor(common_voice_test["input_values"][0], return_tensors="pt", padding=True)

In [None]:
# logits = model(input_dict.input_values.to("cuda")).logits

# pred_ids = torch.argmax(logits, dim=-1)[0]

We adapted `common_voice_test` quite a bit so that the dataset instance does not contain the original sentence label anymore. Thus, we re-use the original dataset to get the label of the first example.

In [None]:
# common_voice_test_transcription = load_dataset("common_voice", "tr", data_dir="./cv-corpus-6.1-2020-12-11", split="test")

Finally, we can decode the example.

In [None]:
# print("Prediction:")
# print(processor.decode(pred_ids))

# print("\nReference:")
# print(common_voice_test_transcription["sentence"][0].lower())


Alright! The transcription can definitely be recognized from our prediction, but it is far from being perfect. Training the model a bit longer, spending more time on the data preprocessing, and especially using a language model for decoding would certainly improve the model's overall performance. 

For a demonstration model on a low-resource language, the results are acceptable, however 🤗.