<a href="https://colab.research.google.com/github/Ayman-Mansour/SDN-Dialect-Automatic-Speech-Recognition/blob/main/SDN_Dialect_XLSR_Wav2Vec2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Fine-tuning XLSR-Wav2Vec2 on Sudanese Dialect** #

This a follow up to the first Sudanese dialect ASR model using CNN and CTC based on DeepSpeech design [A Proposed Automatic Speech Recognition model for the Sudanese Dialect 	](http://repository.sustech.edu/handle/123456789/25521?show=full), in this notebook we investigate the use of Wav2Vec2 pretrained model for ASR [September 2020](https://ai.facebook.com/blog/wav2vec-20-learning-the-structure-of-speech-from-raw-audio/) by Alexei Baevski, Michael Auli, and Alex Conneau.
As Wav2Vec2 showed powerful performance on LibriSpeech Englis ASR dataset (click [here](https://arxiv.org/abs/2006.13979)).
XLSR *cross-lingual  speech representations* XLSR-Wav2Vec2 is able to learn speech representations that are going to be useful across multiple languages in this case the Sudanese Dialect.


Wav2Vec2 is a pretrained model for Automatic Speech Recognition (ASR) and was released in [September 2020](https://ai.facebook.com/blog/wav2vec-20-learning-the-structure-of-speech-from-raw-audio/) by Alexei Baevski, Michael Auli, and Alex Conneau.  Soon after the superior performance of Wav2Vec2 was demonstrated on the English ASR dataset LibriSpeech, *Facebook AI* presented XLSR-Wav2Vec2 (click [here](https://arxiv.org/abs/2006.13979)). XLSR stands for *cross-lingual  speech representations* and refers to XLSR-Wav2Vec2's ability to learn speech representations that are useful across multiple languages.

Similar to Wav2Vec2, XLSR-Wav2Vec2 learns powerful speech representations from hundreds of thousands of hours of speech in more than 50 languages of unlabeled speech. Similar, to [BERT's masked language modeling](http://jalammar.github.io/illustrated-bert/), the model learns contextualized speech representations by randomly masking feature vectors before passing them to a transformer network.

![wav2vec2_structure](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/xlsr_wav2vec2.png)

The authors show for the first time that massively pretraining an ASR model on cross-lingual unlabeled speech data, followed by language-specific fine-tuning on very little labeled data achieves state-of-the-art results. See Table 1-5 of the official [paper](https://arxiv.org/pdf/2006.13979.pdf).

As showwn in the previous reserch the Sudanese dialect concidered a low-resourced language, because of its dataset merky over 3 hours of labeled speech, hence by using XLSR-Wav2Vec2 pretrained model we are going to take advantage of speech representations.
Both researches share the use of Connectionist Tempoaral Classification (CTC) which is used to train neural networks  for sequence-to-sequence problems and mainly in Automatic Speech Recognition and handwriting recognition.

##Installation##
Before we start, let's install both `datasets` and `transformers` from master. Also, we need the `torchaudio` and `librosa` package to load audio files and the `jiwer` to evaluate our fine-tuned model using the [word error rate (WER)](https://huggingface.co/metrics/wer) metric ${}^1$.

In [1]:
%%capture
!pip install datasets
!pip install transformers
!pip install torchaudio
!pip install librosa
!pip install jiwer
!pip install lang-trans

##Loading SDN dialect dataset##

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!cp "/content/drive/MyDrive/SDN Dialect Corpus v1.0.zip" "/content/"

In [4]:
!unzip SDN\ Dialect\ Corpus\ v1.0.zip

Archive:  SDN Dialect Corpus v1.0.zip
   creating: SDN Dialect Corpus v1.0/
   creating: SDN Dialect Corpus v1.0/transcripts/
  inflating: SDN Dialect Corpus v1.0/transcripts/HM.csv  
  inflating: SDN Dialect Corpus v1.0/transcripts/WB.csv  
   creating: SDN Dialect Corpus v1.0/wav/
  inflating: SDN Dialect Corpus v1.0/wav/hm_01_0001.wav  
  inflating: SDN Dialect Corpus v1.0/wav/hm_01_0002.wav  
  inflating: SDN Dialect Corpus v1.0/wav/hm_01_0003.wav  
  inflating: SDN Dialect Corpus v1.0/wav/hm_01_0004.wav  
  inflating: SDN Dialect Corpus v1.0/wav/hm_01_0005.wav  
  inflating: SDN Dialect Corpus v1.0/wav/hm_01_0006.wav  
  inflating: SDN Dialect Corpus v1.0/wav/hm_01_0007.wav  
  inflating: SDN Dialect Corpus v1.0/wav/hm_01_0008.wav  
  inflating: SDN Dialect Corpus v1.0/wav/hm_01_0009.wav  
  inflating: SDN Dialect Corpus v1.0/wav/hm_01_0010.wav  
  inflating: SDN Dialect Corpus v1.0/wav/hm_01_0011.wav  
  inflating: SDN Dialect Corpus v1.0/wav/hm_01_0012.wav  
  inflating: SDN Dia

## Preparing the dataset  ##

First merging HM.csv and WB.csv files into one file sdn.csv, 
Second split the data to train and validation sets

In [5]:
import pandas as pd
from glob import glob

interesting_files = glob("/content/SDN Dialect Corpus v1.0/transcripts/*.csv") # it grabs all the csv files from the directory you mention here

df_list = []
for filename in sorted(interesting_files):

    df_list.append(pd.read_csv(filename))
    full_df = pd.concat(df_list, sort=True)

# save the final file in same/different directory:
    full_df.to_csv("/content/SDN Dialect Corpus v1.0/transcripts/sdn.csv", index=False)

In [6]:
# -todo- apply Buckwalter or CODA (Conventional Orthography for Dialectal Arabic) transliteration for te dataset
from lang_trans.arabic import buckwalter
# converting buckwalter format to arabic letters
# text = buckwalter.untransliterate(text)
# text = re.sub("\s+", " ", text) # remove multiple spaces
sdn_full_dataset = pd.read_csv('/content/SDN Dialect Corpus v1.0/transcripts/sdn.csv')
sdn_full_dataset['text'] = sdn_full_dataset['text'].apply(lambda t: buckwalter.transliterate(t))
sdn_full_dataset['filename'] = sdn_full_dataset['filename'].apply(lambda f:'/content/SDN Dialect Corpus v1.0/wav/{}.wav'.format(f))
sdn_full_dataset

Unnamed: 0,filename,text
0,/content/SDN Dialect Corpus v1.0/wav/hm_01_000...,tEbAn tEbAn xAlS tEbAn
1,/content/SDN Dialect Corpus v1.0/wav/hm_01_000...,tEbAn tEbAn xAlS tEbAn
2,/content/SDN Dialect Corpus v1.0/wav/hm_01_000...,nHwlk AlEnAyp Almkvfp
3,/content/SDN Dialect Corpus v1.0/wav/hm_01_000...,Al<sm Alkrym
4,/content/SDN Dialect Corpus v1.0/wav/hm_01_000...,mHswbk AlnEysAn Alfy Al$Er HsAn
...,...,...
3544,/content/SDN Dialect Corpus v1.0/wav/wb_03_024...,>wwwww >hlA >hlA >hlA yA wd AlHsyn
3545,/content/SDN Dialect Corpus v1.0/wav/wb_03_024...,<zykn
3546,/content/SDN Dialect Corpus v1.0/wav/wb_03_024...,mbswT kdy mbswT kdy mAlk yA wd AlHsyn
3547,/content/SDN Dialect Corpus v1.0/wav/wb_03_024...,>nA mA jbtA lkm Allylp xbr


In [7]:
# from pathlib import Path
# import librosa
# import IPython.display as ipd
# SAMPLING_RATE = 16000
# # for i in range(len(sdn_full_dataset)):
# #   sdn_full_dataset.iloc[i]['audio'] = librosa.load(sdn_full_dataset.iloc[i]['filename'], sr=SAMPLING_RATE)

# # def audio_resampling(tr_dv_dst):
# audio = []
# trdst = sdn_full_dataset
# for i in range(len(sdn_full_dataset.index)):
#   # wav = librosa.load(sdn_full_dataset.iloc[i]['filename'], sr=SAMPLING_RATE)
#   audio.append(librosa.load(trdst.iloc[i]['filename'])[0])
#   # trdst.iloc[i]['audio'] = wav[0]
#   # trdst.iloc[i]['audio_sr'] = wav[1]
# # print(sdn_full_dataset['audio'])
# # wav
# len(audio)
import torchaudio
import librosa
import IPython.display as ipd
import numpy as np

sample = sdn_full_dataset.iloc[np.random.randint(0, len(sdn_full_dataset))]

path = sample["filename"]
print(sample["text"])
speech = torchaudio.load(path)
print(buckwalter.untrans(sample["text"]))
speech = speech[0].numpy().squeeze()

speech = librosa.resample(np.asarray(speech), 48_000, 16_000)
sp, sr = librosa.load(path)

# speech_d = np.array(sp_ar,  dtype=np.float32)
ipd.Audio(data=sp, autoplay=True, rate=16000)
# ipd.Audio(data=speech_d, autoplay=True, rate=16_000)
# from IPython.display import Audio
# wav = librosa.to


  wrAk HtY nrAk 
  وراك حتى نراك 


In [8]:
# # for i in range(len(audio)):
# trdst['audio'] = audio
#   # trdst['sr_audio'] = audio[i][1]
# # trdst
# # trdst['audio'] = trdst['audio'].apply(lambda f:"'{array({})'}'".format(f))
# type(audio[0])
# trdst

In [9]:
import os
train_eval_data = sdn_full_dataset
if not os.path.isfile('./content/SDN Dialect Corpus v1.0/transcripts/train.csv'):
    eval_data = train_eval_data.sample(n=int(len(train_eval_data) * 0.15 ))
    train_data = train_eval_data[~train_eval_data.isin(eval_data)]
    train_data = train_data[train_data.filename.notnull()]
    
    train_data.to_csv('/content/SDN Dialect Corpus v1.0/transcripts/train.csv')
    eval_data.to_csv('/content/SDN Dialect Corpus v1.0/transcripts/eval.csv')
else:
    train_data = pd.read_csv('/content/SDN Dialect Corpus v1.0/transcripts/train.csv')
    eval_data = pd.read_csv('/content/SDN Dialect Corpus v1.0/transcripts/eval.csv')

In [10]:
from datasets import load_dataset, load_metric


data_files = {
    "train": "/content/SDN Dialect Corpus v1.0/transcripts/train.csv", 
    "validation": "/content/SDN Dialect Corpus v1.0/transcripts/eval.csv",
}

dataset = load_dataset("csv", data_files=data_files)
train_dataset = dataset["train"]
eval_dataset = dataset["validation"]

print(train_dataset)
print(eval_dataset)

Using custom data configuration default-b4733e748ab5dfeb


Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-b4733e748ab5dfeb/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-b4733e748ab5dfeb/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

Dataset({
    features: ['Unnamed: 0', 'filename', 'text'],
    num_rows: 3017
})
Dataset({
    features: ['Unnamed: 0', 'filename', 'text'],
    num_rows: 532
})


In [11]:
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

In [12]:
show_random_elements(train_dataset)

Unnamed: 0.1,Unnamed: 0,filename,text
0,634,/content/SDN Dialect Corpus v1.0/wav/hm_09_0058.wav,hdyp hdyp mny lykm
1,940,/content/SDN Dialect Corpus v1.0/wav/hm_12_0057.wav,vqAftkm Alkrwyp EAlyp jdA lkn wrwyny HA tSrfw mEAw kyf
2,217,/content/SDN Dialect Corpus v1.0/wav/hm_04_0053.wav,xAf AlTryq ll>wrAq Almlwnp yEny zy AlTryq lmTEm AlqrASp
3,2099,/content/SDN Dialect Corpus v1.0/wav/hm_23_0009.wav,<jAzyk Allh Emlty lyk jdwl kmAn
4,2891,/content/SDN Dialect Corpus v1.0/wav/wb_01_0093.wav,>ywA >ywA bAlDbT kdA yA wd Alfky dA AlHSl w TbEA yAxwAnA bdl mA AlnAs ttfrq w txAf yAdwbk sxnw w qAmw yEny qlbw ETbrp
5,3182,/content/SDN Dialect Corpus v1.0/wav/wb_02_0137.wav,>yA >yA
6,860,/content/SDN Dialect Corpus v1.0/wav/hm_11_0087.wav,bEdyn yEny xlyk dymA m$rq Hdd >fkArk bwDwH AlgAyp tbrr Alwsylp zy mA qAl mykAfly
7,1592,/content/SDN Dialect Corpus v1.0/wav/hm_17_0072.wav,>Hsn w Allh fY EzAkm
8,1317,/content/SDN Dialect Corpus v1.0/wav/hm_14_0242.wav,q$rtA yA mnwfly q$rp $Aftny lyk HAjp hnAy dyk qAlt ly tbArk Allh EAynw lyhw zy Aljk AlmlmE
9,2996,/content/SDN Dialect Corpus v1.0/wav/wb_01_0198.wav,>hA yA mky >hA fY rAyk AlqAEd yfrkA mnw


In [13]:
# This step may not be applicable to the transliterated dataset for Buckealker depends on upper cases and $ symbol

# import re
# chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�]'

# def remove_special_characters(batch):
#     batch["text"] = re.sub(chars_to_ignore_regex, '', batch["text"]).lower() + " "
#     return batch

In [14]:
# This step may not be applicable to the transliterated dataset for Buckealker depends on upper cases and $ symbol

# train_dataset = train_dataset.map(remove_special_characters)
# eval_dataset = eval_dataset.map(remove_special_characters)

In [15]:
def extract_all_chars(batch):
  all_text = " ".join(batch["text"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

In [16]:
vocab_train = train_dataset.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=train_dataset.column_names)
vocab_test = eval_dataset.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=eval_dataset.column_names)

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [17]:
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))

In [18]:
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_dict

{' ': 23,
 '!': 7,
 '$': 14,
 '&': 36,
 "'": 22,
 '*': 19,
 '<': 34,
 '>': 18,
 'A': 38,
 'D': 27,
 'E': 21,
 'H': 4,
 'S': 32,
 'T': 3,
 'Y': 16,
 'Z': 5,
 'b': 25,
 'd': 17,
 'f': 2,
 'g': 31,
 'h': 29,
 'j': 26,
 'k': 24,
 'l': 30,
 'm': 20,
 'n': 37,
 'p': 8,
 'q': 13,
 'r': 9,
 's': 28,
 't': 6,
 'v': 10,
 'w': 35,
 'x': 33,
 'y': 12,
 'z': 0,
 '|': 1,
 '}': 11,
 '؟': 39,
 'پ': 15}

In [19]:
# delete پ for its typo from wb_01_0100 text
# del vocab_dict['پ'] 
#del vocab_dict['؟']

# put + instead of space ' ' to make it visiable 
vocab_dict['+'] = vocab_dict[' '] 
del vocab_dict[' '] 

In [20]:
# add a padding token that corresponds to CTC's "blank token".
# The "blank token" is a core component of the CTC algorithm

vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)
vocab_dict

{'!': 7,
 '$': 14,
 '&': 36,
 "'": 22,
 '*': 19,
 '+': 23,
 '<': 34,
 '>': 18,
 'A': 38,
 'D': 27,
 'E': 21,
 'H': 4,
 'S': 32,
 'T': 3,
 'Y': 16,
 'Z': 5,
 '[PAD]': 41,
 '[UNK]': 40,
 'b': 25,
 'd': 17,
 'f': 2,
 'g': 31,
 'h': 29,
 'j': 26,
 'k': 24,
 'l': 30,
 'm': 20,
 'n': 37,
 'p': 8,
 'q': 13,
 'r': 9,
 's': 28,
 't': 6,
 'v': 10,
 'w': 35,
 'x': 33,
 'y': 12,
 'z': 0,
 '|': 1,
 '}': 11,
 '؟': 39,
 'پ': 15}

In [21]:
import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

In [22]:
from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="+")

##Create XLSR-Wav2Vec2 Feature Extractor##

In [23]:
from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)

In [24]:
from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
processor.save_pretrained("/content/drive/MyDrive/wav2vec2-large-xlsr-sdn-dialect")

##Pre-process Data (audio files)##

In [25]:
# import seaborn as sb
# import librosa

# lens = []
# count = 0

# for _, row in sdn_full_dataset.iterrows():
#    wave, _ = librosa.load(row['filename'])
#    if wave is not None:
#        count += 1
#        lens.append(len(wave))
# secs = [x / 10000 for x in lens]       
# print(count)
# sb.distplot(secs, kde = False)

In [26]:
train_dataset[100]
# train_dataset[0]["audio"][7:-8].replace('\n', '')

{'Unnamed: 0': 118,
 'filename': '/content/SDN Dialect Corpus v1.0/wav/hm_02_0088.wav',
 'text': '  <tfDlw <tfDlw >hA yAhA dy Alxmsyn >zydkm ؟'}

In [27]:
# import IPython.display as ipd
# import numpy as np
# import random

# rand_int = random.randint(0, len(train_dataset)-1)
# array = train_dataset[rand_int]["audio"]
# ipd.Audio(data=array, autoplay=True, rate=16000)

In [28]:
import librosa
SAMPLING_RATE = 16_000

def prepare_dataset(batch):
    # audio = batch["filename"]

    # batched output is "un-batched"

   speech_array, sampling_rate = torchaudio.load(batch["filename"])
  #  print(speech_array)
   batch["speech"] = speech_array[0].numpy()
   batch["sampling_rate"] = sampling_rate
   batch["target_text"] = batch["text"]
    
   batch["speech"] = librosa.resample(np.asarray(batch["speech"]), 44_100, 16_000)
   batch["sampling_rate"] = 16_000

    # sp, sr = torchaudio.load(batch["filename"])
   batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"], padding=True).input_values[0]
    # batch["input_length"] = len(batch["input_values"])

    
   with processor.as_target_processor():
      batch["labels"] = processor(batch["target_text"]).input_ids
   return batch

In [29]:
train_dataset = train_dataset.map(prepare_dataset, remove_columns=train_dataset.column_names, num_proc=4)
eval_dataset = eval_dataset.map(prepare_dataset, remove_columns=eval_dataset.column_names, num_proc=4)

        

#0:   0%|          | 0/755 [00:00<?, ?ex/s]

#2:   0%|          | 0/754 [00:00<?, ?ex/s]

#1:   0%|          | 0/754 [00:00<?, ?ex/s]

#3:   0%|          | 0/754 [00:00<?, ?ex/s]

        

#1:   0%|          | 0/133 [00:00<?, ?ex/s]

#0:   0%|          | 0/133 [00:00<?, ?ex/s]

#2:   0%|          | 0/133 [00:00<?, ?ex/s]

#3:   0%|          | 0/133 [00:00<?, ?ex/s]

In [30]:
train_dataset[0]

{'input_values': [0.0069758049212396145,
  0.004911125171929598,
  0.004701679572463036,
  0.007324120495468378,
  -0.0016757575795054436,
  0.004652222152799368,
  -0.001986367627978325,
  -0.003172280266880989,
  -0.003281251061707735,
  -0.0037646705750375986,
  -0.0020856244955211878,
  0.005566761363297701,
  0.006770045030862093,
  -0.002125762403011322,
  -0.005799188278615475,
  -0.01304696686565876,
  -0.014825845137238503,
  -0.008388742804527283,
  -0.009839488193392754,
  -0.011095701716840267,
  -0.004993046168237925,
  -0.005414673127233982,
  -0.0036753483582288027,
  -0.015837304294109344,
  -0.012843452394008636,
  -0.014499572105705738,
  -0.011713088490068913,
  -0.010701647028326988,
  -0.02204842120409012,
  -0.009961970150470734,
  -0.0019101573852822185,
  -0.012268764898180962,
  -0.016552980989217758,
  -0.00662066088989377,
  -0.010700688697397709,
  -0.01598590612411499,
  -0.014264884404838085,
  -0.016226358711719513,
  -0.01348090823739767,
  -0.0094119580

In [31]:
# max_input_length_in_sec = 5.0
# train_dataset = train_dataset.filter(lambda x: x < max_input_length_in_sec * processor.feature_extractor.sampling_rate, input_columns=["input_length"])
# eval_dataset = eval_dataset.filter(lambda x: x < max_input_length_in_sec * processor.feature_extractor.sampling_rate, input_columns=["input_length"])

In [32]:
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
        max_length (:obj:`int`, `optional`):
            Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
        max_length_labels (:obj:`int`, `optional`):
            Maximum length of the ``labels`` returned list and optionally padding length (see above).
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        # input_features, label_features = input_features.to(device), label_features.to(device)
        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [33]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [34]:
wer_metric = load_metric("wer")

Downloading builder script:   0%|          | 0.00/1.90k [00:00<?, ?B/s]

In [35]:
import numpy as np

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [36]:
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-xlsr-53",
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.1,
    gradient_checkpointing=True, 
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=processor.tokenizer.vocab_size
)

Downloading:   0%|          | 0.00/1.73k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.18G [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/wav2vec2-large-xlsr-53 were not used when initializing Wav2Vec2ForCTC: ['project_q.bias', 'project_hid.bias', 'quantizer.weight_proj.weight', 'project_hid.weight', 'quantizer.weight_proj.bias', 'quantizer.codevectors', 'project_q.weight']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-xlsr-53 and are newly initialized: ['lm_head.weight', 'lm_head.bias']
You should probably TRAIN this model on a down-stream task to be able to u

In [37]:
model.freeze_feature_extractor()



In [38]:
model.gradient_checkpointing_enable()

In [39]:
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# model.to(device)
# data = data.to(device)

In [40]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  # output_dir="/content/drive/MyDrive/wav2vec2-large-xlsr-sdn-dialect",
  output_dir="./wav2vec2-large-xlsr-sdn-dialect",
  group_by_length=True,
  per_device_train_batch_size=2,
  gradient_accumulation_steps=2,
  evaluation_strategy="steps",
  num_train_epochs=10,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=3e-4,
  warmup_steps=500,
  save_total_limit=2,
)

In [41]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=processor.feature_extractor,
)

Using amp half precision backend


In [42]:
torch.cuda.empty_cache()

In [None]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: speech, target_text, sampling_rate. If speech, target_text, sampling_rate are not expected by `Wav2Vec2ForCTC.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 3017
  Num Epochs = 10
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 2
  Total optimization steps = 7540


Step,Training Loss,Validation Loss,Wer
100,4.7225,inf,1.0
200,2.9297,inf,1.0
300,2.9488,inf,1.0
400,2.9898,inf,1.0
500,2.791,inf,1.0
600,2.8943,inf,1.0
700,2.7725,inf,1.0
800,2.6735,inf,1.0
900,2.7711,inf,1.0
1000,2.7425,inf,1.0


The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: speech, target_text, sampling_rate. If speech, target_text, sampling_rate are not expected by `Wav2Vec2ForCTC.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 532
  Batch size = 8
Saving model checkpoint to ./wav2vec2-large-xlsr-sdn-dialect/checkpoint-100
Configuration saved in ./wav2vec2-large-xlsr-sdn-dialect/checkpoint-100/config.json
Model weights saved in ./wav2vec2-large-xlsr-sdn-dialect/checkpoint-100/pytorch_model.bin
Feature extractor saved in ./wav2vec2-large-xlsr-sdn-dialect/checkpoint-100/preprocessor_config.json
The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: speech, target_text, sampling_rate. If speech, target_text, sampling_rate are not expected by `Wav2Vec2ForCTC.forward`,  you can safely ignore this message.
****