# **Fine-Tune W2V2-Bert for low-resource ASR with 🤗 Transformers**

## Introduction

On Dec 2023, MetaAI released [Wav2Vec2-BERT](https://huggingface.co/docs/transformers/main/en/model_doc/wav2vec2-bert), as a building block of their [Seamless Communication](https://ai.meta.com/research/seamless-communication/), a family of AI translation models.

[Wav2Vec2-BERT](https://huggingface.co/docs/transformers/main/en/model_doc/wav2vec2-bert) is the result of a series of improvements based on an original model: **Wav2Vec2**, a pre-trained model for Automatic Speech Recognition (ASR) released in [September 2020](https://ai.facebook.com/blog/wav2vec-20-learning-the-structure-of-speech-from-raw-audio/) by *Alexei Baevski, Michael Auli, and Alex Conneau*.  With as little as 10 minutes of labeled audio data, Wav2Vec2 could be fine-tuned to achieve 5% word-error rate performance on the [LibriSpeech](https://huggingface.co/datasets/librispeech_asr) dataset, demonstrating for the first time low-resource transfer learning for ASR.

Following a series of multilingual improvements ([XLSR](https://huggingface.co/docs/transformers/model_doc/xlsr_wav2vec2), [XLS-R](https://huggingface.co/docs/transformers/model_doc/xls_r) and [MMS](https://huggingface.co/docs/transformers/model_doc/mms)), Wav2Vec2-BERT is a 580M-parameters versatile audio model that has been pre-trained on **4.5M** hours of unlabeled audio data covering **more than 143 languages**. For comparison, **XLS-R** used almost **half a million** hours of audio data in **128 languages** and **MMS** checkpoints were pre-trained on more than **half a million hours of audio** in over **1,400 languages**. Boosting to millions of hours enables Wav2Vec2-BERT to achieve even more competitive results in speech-related tasks, whatever the language.

To use it for ASR, Wav2Vec2-BERT can be fine-tuned using Connectionist Temporal Classification (CTC), which is an algorithm that is used to train neural networks for sequence-to-sequence problems, such as ASR and handwriting recognition.

The aim of this notebook is to train Wav2Vec2-BERT model - more specifically the pre-trained checkpoint [**facebook/w2v-bert-2.0**](https://huggingface.co/facebook/w2v-bert-2.0) - on ASR tasks, using open-source tools and models. It first presents the complete pre-processing pipeline, then performs a little fine-tuning of the W2V2-BERT.

We fine-tune the model on the low resource Armenian ASR dataset of [Common Voice 17.0](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0) that contains 48h of validated training data.

## Notebook Setup

In [None]:
import torch

# Set device to GPU (CUDA) if available, otherwise fallback to CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


Before we start, let's install `datasets` and `transformers`. Also, we need `accelerate` for training, `torchaudio` to load audio files and `jiwer` to evaluate our fine-tuned model using the [word error rate (WER)](https://huggingface.co/metrics/wer) metric ${}^1$.

In [None]:
%%capture
!pip install datasets
!pip install --upgrade transformers
!pip install torchaudio
!pip install jiwer
!pip install accelerate -U


We will upload training checkpoints directly to the [🤗 Hub](https://huggingface.co/) while training.

To do so, we have to store our authentication token from the Hugging Face website. This is done by entering the Hub authentication token when prompted below. Find Hub authentication token [here](https://huggingface.co/settings/tokens):

In [None]:
from huggingface_hub import notebook_login

notebook_login()


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

# Load the Data

In [None]:
from datasets import load_dataset, load_metric, Audio

common_voice_train = load_dataset("mozilla-foundation/common_voice_17_0", "hy-AM", split="train+validation", use_auth_token=True)
common_voice_test = load_dataset("mozilla-foundation/common_voice_17_0", "hy-AM", split="test", use_auth_token=True)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/8.19k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/12.6k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.92k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/132k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/17.5k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/198M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/141M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/150M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/516M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/40.3M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/493M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.34M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.58M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.62M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.90M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/392k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.59M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]


Reading metadata...: 6180it [00:00, 69168.61it/s]


Generating validation split: 0 examples [00:00, ? examples/s]


Reading metadata...: 4214it [00:00, 89149.14it/s]


Generating test split: 0 examples [00:00, ? examples/s]


Reading metadata...: 4281it [00:00, 70604.63it/s]


Generating other split: 0 examples [00:00, ? examples/s]


Reading metadata...: 0it [00:00, ?it/s][A
Reading metadata...: 15672it [00:00, 91099.65it/s]


Generating invalidated split: 0 examples [00:00, ? examples/s]


Reading metadata...: 1017it [00:00, 62050.61it/s]


Generating validated split: 0 examples [00:00, ? examples/s]


Reading metadata...: 0it [00:00, ?it/s][A
Reading metadata...: 5401it [00:00, 54006.57it/s][A
Reading metadata...: 14813it [00:00, 60968.71it/s]
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


## Drop the unused columns

In [None]:
common_voice_train = common_voice_train.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])

## Display some random elements

In [None]:
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

show_random_elements(common_voice_train.remove_columns(["path", "audio"]), num_examples=10)


Unnamed: 0,sentence,variant
0,եղբորիցդ նամակ չունի՞ք այս մոտիկ ժամանակներս:,
1,Այս աշխարհի բոլոր սիրահարները կապված են միմյանց թելերով։,
2,Սինգլը դարձել է խմբի թողարկած վերջին երգը։,
3,Բարբառներում հանդիպում են նաև այլ ձայնավորներ (օրինակ՝ քմայնացած)։,
4,Սայուկիի մասին լուսաբանել են հետևյալ զանգվածային լրատվամիջոցները։,
5,Այժմ անցկացվում են սպորտային պարերի մրցույթներ։,
6,"Գետերը կարճ են, բայց՝ ջրառատ, նավարկելի, միմյանց միացած են ջրանցքներով։",
7,Հանդիսանում է ատլանտյան օվկիանոսի տարածքին պատկանող գետ։,
8,Մտնում է երկրի՝ ժամանակակից արվեստին նվիրված խոշորագույն թանգարանների հնգյակի մեջ։,
9,Անգլիայի գավաթը համարվում է աշխարհի ամենահին ֆուտբոլային առաջնությունը։,


## Perform Data Cleaning (remove special characters)

In [None]:
import re
chars_to_remove_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\»\«\u0589\u058A\u0559\u055A\u055B\u055C\u055D\u055E\u055F\u2019\u02BB\(\)\\`\´\…]'

def remove_special_characters(batch):
    # remove special characters
    batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower()

    return batch

common_voice_train = common_voice_train.map(remove_special_characters)
common_voice_test = common_voice_test.map(remove_special_characters)


Map:   0%|          | 0/10394 [00:00<?, ? examples/s]

Map:   0%|          | 0/4281 [00:00<?, ? examples/s]

In [None]:
show_random_elements(common_voice_train.remove_columns(["path","audio"]))


Unnamed: 0,sentence,variant
0,երգը նույն տեմպով հնչում է կրկներգում,
1,ամալֆի ափի երկայնքով երթևեկում են ավտոբուսներ և լաստանավեր,
2,իր գուշակությունները նա շարադրել է քառյակների կատրենների ձևով,
3,իր կառուցվածքով ու ֆունկցիայով ներզատիչ գեղձ է,
4,այն իրենից ներկայացնում է չափածո վեպ,
5,սեռը գոյականի հատկանիշներից է և յուրաքանչյուր սեռ ունի խոնարհվելու իր ձևը,
6,հեռուստաալիքը լիովին նվիրված է թենիսին,
7,բուսական աշխարհից զուրկ է,
8,խոշոր ազգային փոքրամասնություններն են լեզգիները ռուսները թալիշները ավարները թաթարները և վրացիները,
9,տիրապետել է բազում մարտարվեստների այդ թվում կարատե ուշու թեկվոնդո,


## Extract All the Unique Characters from the Dataset (create a vocab)

In [None]:
def extract_all_chars(batch):
  all_text = " ".join(batch["sentence"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)


Map:   0%|          | 0/10394 [00:00<?, ? examples/s]

Map:   0%|          | 0/4281 [00:00<?, ? examples/s]

In [None]:
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))

vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
vocab_dict

{' ': 0,
 'ա': 1,
 'բ': 2,
 'գ': 3,
 'դ': 4,
 'ե': 5,
 'զ': 6,
 'է': 7,
 'ը': 8,
 'թ': 9,
 'ժ': 10,
 'ի': 11,
 'լ': 12,
 'խ': 13,
 'ծ': 14,
 'կ': 15,
 'հ': 16,
 'ձ': 17,
 'ղ': 18,
 'ճ': 19,
 'մ': 20,
 'յ': 21,
 'ն': 22,
 'շ': 23,
 'ո': 24,
 'չ': 25,
 'պ': 26,
 'ջ': 27,
 'ռ': 28,
 'ս': 29,
 'վ': 30,
 'տ': 31,
 'ր': 32,
 'ց': 33,
 'ւ': 34,
 'փ': 35,
 'ք': 36,
 'օ': 37,
 'ֆ': 38,
 'և': 39}

To make it clearer that `" "` has its own token class, we give it a more visible character `|`. In addition, we also add an "unknown" token so that the model can later deal with characters not encountered in Common Voice's training set.

In [None]:
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

Finally, we also add a padding token that corresponds to CTC's "*blank token*".


The "blank token" is a core component of the CTC algorithm. For more information, please take a look at the "Alignment" section [here](https://distill.pub/2017/ctc/).

In [None]:
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)


42

Let's now save the vocabulary as a json file.

In [None]:
import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)


In a final step, we use the json file to load the vocabulary into an instance of the `Wav2Vec2CTCTokenizer` class.

Initialize the tokenizer

In [None]:
from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")


upload the tokenizer to the [🤗 Hub](https://huggingface.co/).

In [None]:
repo_name = "w2v-bert-2.0-armenian-CV17.0"

In [None]:
tokenizer.push_to_hub(repo_name)

CommitInfo(commit_url='https://huggingface.co/anah1tbaghdassarian/w2v-bert-2.0-armenian-CV17.0/commit/161f1415ca01b07f478ca47d4e573ce1ac305c67', commit_message='Upload tokenizer', commit_description='', oid='161f1415ca01b07f478ca47d4e573ce1ac305c67', pr_url=None, pr_revision=None, pr_num=None)

### Create `SeamlessM4TFeatureExtractor`

In [None]:
from transformers import SeamlessM4TFeatureExtractor

feature_extractor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")




preprocessor_config.json:   0%|          | 0.00/275 [00:00<?, ?B/s]

For improved user-friendliness, the feature extractor and tokenizer are *wrapped* into a single `Wav2Vec2BertProcessor` class so that one only needs a `model` and `processor` object.

In [None]:
from transformers import Wav2Vec2BertProcessor

processor = Wav2Vec2BertProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)
processor.push_to_hub(repo_name)

README.md:   0%|          | 0.00/5.17k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/anah1tbaghdassarian/w2v-bert-2.0-armenian-CV17.0/commit/ec095587549c09636df4e38c561c78d49c7ef810', commit_message='Upload processor', commit_description='', oid='ec095587549c09636df4e38c561c78d49c7ef810', pr_url=None, pr_revision=None, pr_num=None)

## Preprocess the data (resample to 16kHz)

In [None]:
common_voice_train = common_voice_train.cast_column("audio", Audio(sampling_rate=16_000))
common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16_000))

In [None]:
def prepare_dataset(batch):
    audio = batch["audio"]
    batch["input_features"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    batch["input_length"] = len(batch["input_features"])

    batch["labels"] = processor(text=batch["sentence"]).input_ids
    return batch

In [None]:
common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names)
common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names)

Map:   0%|          | 0/10394 [00:00<?, ? examples/s]

Map:   0%|          | 0/4281 [00:00<?, ? examples/s]

## Training

The data is processed so that we are ready to start setting up the training pipeline. We will make use of 🤗's [Trainer](https://huggingface.co/transformers/master/main_classes/trainer.html?highlight=trainer) for which we essentially need to do the following:

- Define a data collator. In contrast to most NLP models, W2V-BERT has a much larger input length than output length. Given the large input sizes, it is much more efficient to pad the training batches dynamically meaning that all training samples should only be padded to the longest sample in their batch and not the overall longest sample. Therefore, fine-tuning W2V-BERT requires a special padding data collator, which we will define below.

- Evaluation metric. During training, the model should be evaluated on the word error rate. We should define a `compute_metrics` function accordingly

- Load a pretrained checkpoint. We need to load a pretrained checkpoint and configure it correctly for training.

- Define the training configuration.

After having fine-tuned the model, we will correctly evaluate it on the test data and verify that it has indeed learned to correctly transcribe speech.

### Set-up Trainer

Let's start by defining the data collator. The code for the data collator was copied from [this example](https://github.com/huggingface/transformers/blob/7e61d56a45c19284cfda0cee8995fb552f6b1f4e/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py#L219).

Without going into too many details, in contrast to the common data collators, this data collator treats the `input_features` and `labels` differently and thus applies to separate padding functions on them. This is necessary because in speech input and output are of different modalities meaning that they should not be treated by the same padding function.
Analogous to the common data collators, the padding tokens in the labels with `-100` so that those tokens are **not** taken into account when computing the loss.

In [None]:
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:

    processor: Wav2Vec2BertProcessor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )

        labels_batch = self.processor.pad(
            labels=label_features,
            padding=self.padding,
            return_tensors="pt",
        )
        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)


Next, the evaluation metric is defined.

In [None]:
wer_metric = load_metric("wer")
cer_metric = load_metric("cer")

  wer_metric = load_metric("wer")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/1.90k [00:00<?, ?B/s]

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

The model will return a sequence of logit vectors:
$\mathbf{y}_1, \ldots, \mathbf{y}_m$ with $\mathbf{y}_1 = f_{\theta}(x_1, \ldots, x_n)[0]$ and $n >> m$.

A logit vector $\mathbf{y}_1$ contains the log-odds for each word in the vocabulary we defined earlier, thus $\text{len}(\mathbf{y}_i) =$ `config.vocab_size`. We are interested in the most likely prediction of the model and thus take the `argmax(...)` of the logits. Also, we transform the encoded labels back to the original string by replacing `-100` with the `pad_token_id` and decoding the ids while making sure that consecutive tokens are **not** grouped to the same token in CTC style ${}^1$.

In [None]:
import numpy as np

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer, "cer": cer}


Now, we can load the pretrained checkpoint of [Wav2Vec2-XLS-R-300M](https://huggingface.co/facebook/wav2vec2-xls-r-300m). The tokenizer's `pad_token_id` must be to define the model's `pad_token_id` or in the case of `Wav2Vec2BertForCTC` also CTC's *blank token* ${}^2$. To save GPU memory, we enable PyTorch's [gradient checkpointing](https://pytorch.org/docs/stable/checkpoint.html) and also set the loss reduction to "*mean*".

Since, we're only training a small subset of weights, the model is not prone to overfitting. Therefore, we make sure to disable all dropout layers.

**Note**: When using this notebook to train W2V-BERT on another language of Common Voice those hyper-parameter settings might not work very well. Feel free to adapt those depending on your use case.

In [None]:
from transformers import Wav2Vec2BertForCTC

model = Wav2Vec2BertForCTC.from_pretrained(
    "facebook/w2v-bert-2.0",
    attention_dropout=0.0,
    hidden_dropout=0.0,
    feat_proj_dropout=0.0,
    mask_time_prob=0.0,
    layerdrop=0.0,
    ctc_loss_reduction="mean",
    add_adapter=True,
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
)




config.json:   0%|          | 0.00/1.87k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.32G [00:00<?, ?B/s]

Some weights of Wav2Vec2BertForCTC were not initialized from the model checkpoint at facebook/w2v-bert-2.0 and are newly initialized: ['adapter.layers.0.ffn.intermediate_dense.bias', 'adapter.layers.0.ffn.intermediate_dense.weight', 'adapter.layers.0.ffn.output_dense.bias', 'adapter.layers.0.ffn.output_dense.weight', 'adapter.layers.0.ffn_layer_norm.bias', 'adapter.layers.0.ffn_layer_norm.weight', 'adapter.layers.0.residual_conv.bias', 'adapter.layers.0.residual_conv.weight', 'adapter.layers.0.residual_layer_norm.bias', 'adapter.layers.0.residual_layer_norm.weight', 'adapter.layers.0.self_attn.linear_k.bias', 'adapter.layers.0.self_attn.linear_k.weight', 'adapter.layers.0.self_attn.linear_out.bias', 'adapter.layers.0.self_attn.linear_out.weight', 'adapter.layers.0.self_attn.linear_q.bias', 'adapter.layers.0.self_attn.linear_q.weight', 'adapter.layers.0.self_attn.linear_v.bias', 'adapter.layers.0.self_attn.linear_v.weight', 'adapter.layers.0.self_attn_conv.bias', 'adapter.layers.0.self_

In a final step, we define all parameters related to training.
To give more explanation on some of the parameters:
- `group_by_length` makes training more efficient by grouping training samples of similar input length into one batch. This can significantly speed up training time by heavily reducing the overall number of useless padding tokens that are passed through the model
- `learning_rate` was heuristically tuned until fine-tuning has become stable. Note that those parameters strongly depend on the Common Voice dataset and might be suboptimal for other speech datasets.

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir=repo_name,
  group_by_length=True,
  per_device_train_batch_size=16,
  gradient_accumulation_steps=2,
  evaluation_strategy="epoch",
  save_strategy = "epoch",
  logging_strategy="epoch",
  num_train_epochs=5,
  gradient_checkpointing=True,
  fp16=True,
  learning_rate=5e-5,
  warmup_steps=500,
  push_to_hub=True,
  metric_for_best_model="wer",
  load_best_model_at_end=True,
)


Now, all instances can be passed to Trainer and we are ready to start training!

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor,
)


Train!

In [None]:
trainer.train()



Epoch,Training Loss,Validation Loss,Wer,Cer
1,1.6647,0.220714,0.260478,0.045169
2,0.1807,0.177397,0.218325,0.038203
3,0.111,0.144711,0.167138,0.029502
4,0.0672,0.13029,0.143864,0.025206
5,0.04,0.120222,0.128809,0.022677




TrainOutput(global_step=1625, training_loss=0.41273853067251354, metrics={'train_runtime': 10405.6579, 'train_samples_per_second': 4.994, 'train_steps_per_second': 0.156, 'total_flos': 8.023670067538299e+18, 'train_loss': 0.41273853067251354, 'epoch': 5.0})

Push to the Hub!

In [None]:
trainer.push_to_hub()

events.out.tfevents.1714772702.dc758b97f96f.2998.0:   0%|          | 0.00/9.21k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/anah1tbaghdassarian/w2v-bert-2.0-armenian-CV17.0/commit/a92fba04268826e11b8a3bb82459d81e4e281254', commit_message='End of training', commit_description='', oid='a92fba04268826e11b8a3bb82459d81e4e281254', pr_url=None, pr_revision=None, pr_num=None)