In [2]:
from google.colab import drive
drive.mount('/content/drive',force_remount=True)

Mounted at /content/drive


In [None]:
!git clone https://github.com/speechbrain/speechbrain/
%cd ./speechbrain
!pip install -r requirements.txt
!pip install -e .

In [4]:
!pip install speechbrain
!pip install transformers



In [5]:
!pip install pyctcdecode



In [6]:
!pip install https://github.com/kpu/kenlm/archive/master.zip

Collecting https://github.com/kpu/kenlm/archive/master.zip
  Using cached https://github.com/kpu/kenlm/archive/master.zip (553 kB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [7]:
import os
import sys
import torch
import logging
import speechbrain as sb
from speechbrain.utils.distributed import run_on_main
from hyperpyyaml import load_hyperpyyaml
from pathlib import Path
import torchaudio.transforms as T
import torchaudio
import numpy as np
import kenlm
from pyctcdecode import build_ctcdecoder
import re

In [8]:
%cd /content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm

/content/drive/.shortcut-targets-by-id/1osdzH_xbf3qWTH3TWpiTi_dARUqT-KAm/tunisian_corpora/tunisian_without_wavlm


In [9]:
hparams_file, run_opts, overrides = sb.parse_arguments(["./hparams/train_tunisian_withwavlm.yaml"])

# If distributed_launch=True then
# create ddp_group with the right communication protocol
sb.utils.distributed.ddp_init_group(run_opts)

with open(hparams_file) as fin:
    hparams = load_hyperpyyaml(fin, overrides)

# Create experiment directory
sb.create_experiment_directory(
    experiment_directory=hparams["output_folder"],
    hyperparams_to_save=hparams_file,
    overrides=overrides,
)
def read_labels_file(labels_file):
    with open(labels_file, "r",encoding="utf-8") as lf:
        lines = lf.read().splitlines()
        division = "==="
        numbers = {}
        for line in lines :
            if division in line :
                break
            string, number = line.split("=>")
            number = int(number)
            string = string[1:-2]
            numbers[number] = string
        return [numbers[x] for x in range(len(numbers))]
labels = read_labels_file(os.path.join(hparams["save_folder"], "label_encoder.txt"))
labels = [""] + labels[1:-1] + ["1"]

# Dataset prep (parsing Librispeech)

AttributeError: ignored

In [None]:
def dataio_prepare(hparams):
    """This function prepares the datasets to be used in the brain class.
    It also defines the data processing pipeline through user-defined functions."""

    # 1. Define datasets
    data_folder = hparams["data_folder"]

    train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
        csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
    )

    if hparams["sorting"] == "ascending":
        # we sort training data to speed up training and get better results.
        train_data = train_data.filtered_sorted(
            sort_key="duration",
            key_max_value={"duration": hparams["avoid_if_longer_than"]},
        )
        # when sorting do not shuffle in dataloader ! otherwise is pointless
        hparams["dataloader_options"]["shuffle"] = False

    elif hparams["sorting"] == "descending":
        train_data = train_data.filtered_sorted(
            sort_key="duration",
            reverse=True,
            key_max_value={"duration": hparams["avoid_if_longer_than"]},
        )
        # when sorting do not shuffle in dataloader ! otherwise is pointless
        hparams["dataloader_options"]["shuffle"] = False

    elif hparams["sorting"] == "random":
        pass

    else:
        raise NotImplementedError(
            "sorting must be random, ascending or descending"
        )

    valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
        csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
    )
    # We also sort the validation data so it is faster to validate
    valid_data = valid_data.filtered_sorted(sort_key="duration")
    test_datasets = {}
    for csv_file in hparams["test_csv"]:
        name = Path(csv_file).stem
        test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
            csv_path=csv_file, replacements={"data_root": data_folder}
        )
        test_datasets[name] = test_datasets[name].filtered_sorted(
            sort_key="duration"
        )

    datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]


    # 2. Define audio pipeline:
    @sb.utils.data_pipeline.takes("wav")
    @sb.utils.data_pipeline.provides("sig")
    def audio_pipeline(wav):
        info = torchaudio.info(wav)
        sig = sb.dataio.dataio.read_audio(wav)
        resampled = torchaudio.transforms.Resample(
            info.sample_rate, hparams["sample_rate"],
        )(sig)
        return resampled

    sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
    label_encoder = sb.dataio.encoder.CTCTextEncoder()

    # 3. Define text pipeline:
    @sb.utils.data_pipeline.takes("wrd")
    @sb.utils.data_pipeline.provides(
        "wrd", "char_list", "tokens_list", "tokens"
    )
    def text_pipeline(wrd):
        yield wrd
        char_list = list(wrd)
        yield char_list
        tokens_list = label_encoder.encode_sequence(char_list)
        yield tokens_list
        tokens = torch.LongTensor(tokens_list)
        yield tokens

    sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
    lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
    special_labels = {
        "blank_label": hparams["blank_index"],
        "unk_label": hparams["unk_index"]
    }
    label_encoder.load_or_create(
        path=lab_enc_file,
        from_didatasets=[train_data],
        output_key="char_list",
        special_labels=special_labels,
        sequence_input=True,
    )

    # 4. Set output:
    sb.dataio.dataset.set_output_keys(
        datasets, ["id", "sig", "wrd", "char_list", "tokens"],
    )
    return train_data, valid_data,test_datasets, label_encoder

class ASR(sb.core.Brain):
    def compute_forward(self, batch, stage):
        """Forward computations from the waveform batches to the output probabilities."""

        batch = batch.to(self.device)
        wavs, wav_lens = batch.sig
        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)

        if stage == sb.Stage.TRAIN:
            if hasattr(self.hparams, "augmentation"):
                wavs = self.hparams.augmentation(wavs, wav_lens)

        # Forward pass
        feats = self.modules.wav2vec2(wavs, wav_lens)
        x = self.modules.enc(feats)
        logits = self.modules.ctc_lin(x)
        p_ctc = self.hparams.log_softmax(logits)

        return p_ctc, wav_lens

    def custom_encode(self,wavs,wav_lens) :
        wavs = wavs.to(self.device)
        if(wav_lens is not None): wav_lens.to(self.device)

        feats = self.modules.wav2vec2(wavs, wav_lens)
        x = self.modules.enc(feats)
        logits = self.modules.ctc_lin(x)
        p_ctc = self.hparams.log_softmax(logits)

        return feats,p_ctc



    def compute_objectives(self, predictions, batch, stage):
        """Computes the loss (CTC) given predictions and targets."""

        p_ctc, wav_lens = predictions

        ids = batch.id
        tokens, tokens_lens = batch.tokens

        loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)

        if stage != sb.Stage.TRAIN:
            predicted_tokens = sb.decoders.ctc_greedy_decode(
                p_ctc, wav_lens, blank_id=self.hparams.blank_index
            )
            # Decode token terms to words
            if self.hparams.use_language_modelling:
                predicted_words = []
                for logs in p_ctc:
                    text = decoder.decode(logs.detach().cpu().numpy())
                    predicted_words.append(text.split(" "))
            else:
                predicted_words = [
                    "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
                    for utt_seq in predicted_tokens
                ]
            # Convert indices to words
            target_words = [wrd.split(" ") for wrd in batch.wrd]

            self.wer_metric.append(ids, predicted_words, target_words)
            self.cer_metric.append(ids, predicted_words, target_words)

        return loss

    def fit_batch(self, batch):
        """Train the parameters given a single batch in input"""
        should_step = self.step % self.grad_accumulation_factor == 0
        # Managing automatic mixed precision
        # TOFIX: CTC fine-tuning currently is unstable
        # This is certainly due to CTC being done in fp16 instead of fp32
        if self.auto_mix_prec:
            with torch.cuda.amp.autocast():
                with self.no_sync():
                    outputs = self.compute_forward(batch, sb.Stage.TRAIN)
                loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
            with self.no_sync(not should_step):
                self.scaler.scale(
                    loss / self.grad_accumulation_factor
                ).backward()
            if should_step:

                if not self.hparams.wav2vec2.freeze:
                    self.scaler.unscale_(self.wav2vec_optimizer)
                self.scaler.unscale_(self.model_optimizer)
                if self.check_gradients(loss):
                    if not self.hparams.wav2vec2.freeze:
                        if self.optimizer_step >= self.hparams.warmup_steps:
                            self.scaler.step(self.wav2vec_optimizer)
                    self.scaler.step(self.model_optimizer)
                self.scaler.update()
                self.zero_grad()
                self.optimizer_step += 1
        else:
            # This is mandatory because HF models have a weird behavior with DDP
            # on the forward pass
            with self.no_sync():
                outputs = self.compute_forward(batch, sb.Stage.TRAIN)

            loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)

            with self.no_sync(not should_step):
                (loss / self.grad_accumulation_factor).backward()
            if should_step:
                if self.check_gradients(loss):
                    if not self.hparams.wav2vec2.freeze:
                        if self.optimizer_step >= self.hparams.warmup_steps:
                            self.wav2vec_optimizer.step()
                    self.model_optimizer.step()
                self.zero_grad()
                self.optimizer_step += 1

        self.on_fit_batch_end(batch, outputs, loss, should_step)
        return loss.detach().cpu()

    def evaluate_batch(self, batch, stage):
        """Computations needed for validation/test batches"""
        predictions = self.compute_forward(batch, stage=stage)
        with torch.no_grad():
            loss = self.compute_objectives(predictions, batch, stage=stage)
        return loss.detach()

    def on_stage_start(self, stage, epoch):
        """Gets called at the beginning of each epoch"""
        if stage != sb.Stage.TRAIN:
            self.cer_metric = self.hparams.cer_computer()
            self.wer_metric = self.hparams.error_rate_computer()

    def on_stage_end(self, stage, stage_loss, epoch):
        """Gets called at the end of an epoch."""
        # Compute/store important stats
        stage_stats = {"loss": stage_loss}
        if stage == sb.Stage.TRAIN:
            self.train_stats = stage_stats
        else:
            stage_stats["CER"] = self.cer_metric.summarize("error_rate")
            stage_stats["WER"] = self.wer_metric.summarize("error_rate")

        # Perform end-of-iteration things, like annealing, logging, etc.
        if stage == sb.Stage.VALID:
            old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
                stage_stats["loss"]
            )
            old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
                stage_stats["loss"]
            )
            sb.nnet.schedulers.update_learning_rate(
                self.model_optimizer, new_lr_model
            )
            if not self.hparams.wav2vec2.freeze:
                sb.nnet.schedulers.update_learning_rate(
                    self.wav2vec_optimizer, new_lr_wav2vec
                )
            self.hparams.train_logger.log_stats(
                stats_meta={
                    "epoch": epoch,
                    "lr_model": old_lr_model,
                    "lr_wav2vec": old_lr_wav2vec,
                },
                train_stats=self.train_stats,
                valid_stats=stage_stats,
            )
            self.checkpointer.save_and_keep_only(
                meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
            )
        elif stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=stage_stats,
            )
            with open(self.hparams.wer_file, "w") as w:
                self.wer_metric.write_stats(w)

    def init_optimizers(self):
        "Initializes the wav2vec2 optimizer and model optimizer"

        # If the wav2vec encoder is unfrozen, we create the optimizer
        if not self.hparams.wav2vec2.freeze:
            self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
                self.modules.wav2vec2.parameters()
            )
            if self.checkpointer is not None:
                self.checkpointer.add_recoverable(
                    "wav2vec_opt", self.wav2vec_optimizer
                )

        self.model_optimizer = self.hparams.model_opt_class(
            self.hparams.model.parameters()
        )

        if self.checkpointer is not None:
            self.checkpointer.add_recoverable("modelopt", self.model_optimizer)

    def zero_grad(self, set_to_none=False):
        if not self.hparams.wav2vec2.freeze:
            self.wav2vec_optimizer.zero_grad(set_to_none)
        self.model_optimizer.zero_grad(set_to_none)



label_encoder = sb.dataio.encoder.CTCTextEncoder()

train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
        hparams
    )


# We dynamicaly add the tokenizer to our brain class.
# NB: This tokenizer corresponds to the one used for the LM!!
decoder = build_ctcdecoder(
    labels,
    kenlm_model_path="/content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm/lm_data/arpas/indomain.arpa",  # either .arpa or .bin file
    alpha=0.5,  # tuned on a val set
    beta=1,  # tuned on a val set
)

asr_brain = ASR(
    modules=hparams["modules"],
    hparams=hparams,
    run_opts=run_opts,
    checkpointer=hparams["checkpointer"],
)

asr_brain.tokenizer = label_encoder
"""
# Testing
real = True
if real :
    for k in test_datasets.keys():  # keys are test_clean, test_other etc
        asr_brain.hparams.wer_file = os.path.join(
            hparams["output_folder"], "wer_{}.txt".format(k)
        )
        asr_brain.evaluate(
            test_datasets[k], test_loader_kwargs=hparams["dataloader_options"]
        )
"""

In [None]:
from torch.nn.utils.rnn import pad_sequence
def load_paths(wavs_path):
    waveforms = []
    for path in wavs_path :
        waveform, _ = torchaudio.load(path)
        waveforms.append(waveform.squeeze(0))
    # normalize array length to the bigger arrays by pading with 0's
    padded_arrays = pad_sequence(waveforms, batch_first=True)
    return torch.tensor(padded_arrays)

waveform = load_paths(["/content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm/samples/Salah10.wav","/content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm/samples/Salah10.wav"])
embeddings, posteriogram = asr_brain.custom_encode(waveform,None)
print(embeddings.shape)
print(posteriogram.shape)


In [None]:
from speechbrain.pretrained import EncoderASR,EncoderDecoderASR
import torchaudio
import  speechbrain as sb
import torch
from torch.nn.utils.rnn import pad_sequence
import torch
import speechbrain as sb
import numpy as np
import torch.optim as optim
import torch.nn as nn


In [None]:
%ls

In [None]:
french_asr_model = EncoderASR.from_hparams(source="speechbrain/asr-wav2vec2-commonvoice-fr", savedir="pretrained_models/asr-wav2vec2-commonvoice-fr")
english_asr_model = EncoderASR.from_hparams(source="speechbrain/asr-wav2vec2-commonvoice-en", savedir="pretrained_models/asr-wav2vec2-commonvoice-en")


In [None]:
#UTILS FUNCTIOJNS
def get_size_dimensions(arr):
    size_dimensions = []
    while isinstance(arr, list):
        size_dimensions.append(len(arr))
        arr = arr[0]
    return size_dimensions

def scale_array(batch,n):
    scaled_batch = []

    for array in batch:
        if(n < len(array)): raise ValueError("Cannot scale Array down")

        repeat = round(n/len(array))+1
        scaled_length_array= []

        for i in array:
            for j in range(repeat) :
                if(len(scaled_length_array) == n): break
                scaled_length_array.append(i)

        scaled_batch.append(scaled_length_array)

    return torch.tensor(scaled_batch)


def load_paths(wavs_path):
    waveforms = []
    for path in wavs_path :
        waveform, _ = torchaudio.load(path)
        waveforms.append(waveform.squeeze(0))
    # normalize array length to the bigger arrays by pading with 0's
    padded_arrays = pad_sequence(waveforms, batch_first=True)
    return torch.tensor(padded_arrays)



def word_to_vec(input_string):
    mapping= {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, 'ا': 27, 'ب': 28, 'ت': 29, 'ث': 30, 'ج': 31, 'ح': 32, 'خ': 33, 'د': 34, 'ذ': 35, 'ر': 36, 'ز': 37, 'س': 38, 'ش': 39, 'ص': 40, 'ض': 41, 'ط': 42, 'ظ': 43, 'ع': 44, 'غ': 45, 'ف': 46, 'ق': 47, 'ك': 48, 'ل': 49, 'م': 50, 'ن': 51, 'ه': 52, 'و': 53, 'ي': 54,' ':55}

    numbers = [mapping[word] for word in input_string if word in mapping]
    return numbers


In [None]:
device = 'cuda'
verbose = 0
#FLOW LEVEL FUNCTIONS
def merge_strategy(embeddings1, embeddings2, embeddings3,post1, post2,post3):


    post1 = post1.to(device)
    post2 = post2.to(device)
    post3 = post3.to(device)
    embeddings1 = embeddings1.to(device)
    embeddings2 = embeddings2.to(device)
    embeddings3 = embeddings3.to(device)

    posteriograms_merged = torch.cat((post1,post2,post3),dim=2)
    embeddings_merged = torch.cat((embeddings1,embeddings2,embeddings3),dim=2)

    if(verbose !=0):
      print('MERGED POST ',posteriograms_merged.shape)
      print('MERGED emb ',embeddings_merged.shape)

    return torch.cat((posteriograms_merged,embeddings_merged),dim=2).to(device)

def decode(model,wavs,wav_lens):

    with torch.no_grad():
        wav_lens = wav_lens.to(model.device)
        encoder_out = model.encode_batch(wavs, wav_lens)
        predictions = model.decoding_function(encoder_out, wav_lens)
        return predictions

def middle_layer(batch):
    rel_length = torch.tensor([1.0 for x in batch])

    tn_embeddings, tn_posteriogram = asr_brain.custom_encode(batch,None)

    fr_embeddings = french_asr_model.mods.encoder.wav2vec2(batch)
    fr_posteriogram =french_asr_model.encode_batch(batch,rel_length)

    en_embeddings = english_asr_model.mods.encoder(batch)
    en_posteriogram = english_asr_model.encode_batch(batch,rel_length)

    if(verbose !=0):
      print('[EMBEDDINGS] FR:',fr_embeddings.shape, "EN:",en_embeddings.shape, "TN:", tn_embeddings.shape)
      print('[POSTERIOGRAM] FR:',fr_posteriogram.shape, "EN:",en_posteriogram.shape,"TN:",tn_posteriogram.shape)


    bilangual_sample = merge_strategy(fr_embeddings,en_embeddings,tn_embeddings,fr_posteriogram,en_posteriogram,tn_posteriogram)
    return bilangual_sample



In [None]:

class Mixer(sb.core.Brain):

    def compute_forward(self, batch, stage):
        """Forward computations from the waveform batches to the output probabilities."""

        wavs, wav_lens = batch.sig
        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)

        if stage == sb.Stage.TRAIN:
            if hasattr(self.hparams, "augmentation"):
                wavs = self.hparams.augmentation(wavs, wav_lens)

        multi_langual_feats = middle_layer(wavs)
        multi_langual_feats= multi_langual_feats.to(device)
        feats, _ = self.modules.enc(multi_langual_feats)
        logits = self.modules.ctc_lin(multi_langual_feats)
        p_ctc = self.hparams.log_softmax(logits)

        return p_ctc, wav_lens

    def compute_objectives(self, predictions, batch, stage):
        """Computes the loss (CTC) given predictions and targets."""

        p_ctc, wav_lens = predictions

        ids = batch.id
        tokens, tokens_lens = batch.tokens

        loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)

        if stage != sb.Stage.TRAIN:
            predicted_tokens = sb.decoders.ctc_greedy_decode(
                p_ctc, wav_lens, blank_id=self.hparams.blank_index
            )
            # Decode token terms to words
            if self.hparams.use_language_modelling:
                predicted_words = []
                for logs in p_ctc:
                    text = decoder.decode(logs.detach().cpu().numpy())
                    predicted_words.append(text.split(" "))
            else:
                predicted_words = [
                    "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
                    for utt_seq in predicted_tokens
                ]
            # Convert indices to words
            target_words = [wrd.split(" ") for wrd in batch.wrd]

            self.wer_metric.append(ids, predicted_words, target_words)
            self.cer_metric.append(ids, predicted_words, target_words)

        return loss

    def fit_batch(self, batch):
        """Train the parameters given a single batch in input"""
        should_step = self.step % self.grad_accumulation_factor == 0
        # Managing automatic mixed precision
        # TOFIX: CTC fine-tuning currently is unstable
        # This is certainly due to CTC being done in fp16 instead of fp32
        if self.auto_mix_prec:
            with torch.cuda.amp.autocast():
                with self.no_sync():
                    outputs = self.compute_forward(batch, sb.Stage.TRAIN)
                loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
            with self.no_sync(not should_step):
                self.scaler.scale(
                    loss / self.grad_accumulation_factor
                ).backward()
            if should_step:


                self.scaler.unscale_(self.model_optimizer)
                if self.check_gradients(loss):
                    self.scaler.step(self.model_optimizer)
                self.scaler.update()
                self.zero_grad()
                self.optimizer_step += 1
        else:
            # This is mandatory because HF models have a weird behavior with DDP
            # on the forward pass
            with self.no_sync():
                outputs = self.compute_forward(batch, sb.Stage.TRAIN)

            loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)

            with self.no_sync(not should_step):
                (loss / self.grad_accumulation_factor).backward()
            if should_step:
                if self.check_gradients(loss):
                    self.model_optimizer.step()
                self.zero_grad()
                self.optimizer_step += 1

        self.on_fit_batch_end(batch, outputs, loss, should_step)
        return loss.detach().cpu()

    def evaluate_batch(self, batch, stage):
        """Computations needed for validation/test batches"""
        predictions = self.compute_forward(batch, stage=stage)
        with torch.no_grad():
            loss = self.compute_objectives(predictions, batch, stage=stage)
        return loss.detach()

    def on_stage_start(self, stage, epoch):
        """Gets called at the beginning of each epoch"""
        if stage != sb.Stage.TRAIN:
            self.cer_metric = self.hparams.cer_computer()
            self.wer_metric = self.hparams.error_rate_computer()

    def on_stage_end(self, stage, stage_loss, epoch):
        """Gets called at the end of an epoch."""
        # Compute/store important stats
        stage_stats = {"loss": stage_loss}
        if stage == sb.Stage.TRAIN:
            self.train_stats = stage_stats
        else:
            stage_stats["CER"] = self.cer_metric.summarize("error_rate")
            stage_stats["WER"] = self.wer_metric.summarize("error_rate")

        # Perform end-of-iteration things, like annealing, logging, etc.
        if stage == sb.Stage.VALID:
            old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
                stage_stats["loss"]
            )
            old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
                stage_stats["loss"]
            )
            sb.nnet.schedulers.update_learning_rate(
                self.model_optimizer, new_lr_model
            )
            if not self.hparams.wav2vec2.freeze:
                sb.nnet.schedulers.update_learning_rate(
                    self.wav2vec_optimizer, new_lr_wav2vec
                )
            self.hparams.train_logger.log_stats(
                stats_meta={
                    "epoch": epoch,
                    "lr_model": old_lr_model,
                    "lr_wav2vec": old_lr_wav2vec,
                },
                train_stats=self.train_stats,
                valid_stats=stage_stats,
            )
            self.checkpointer.save_and_keep_only(
                meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
            )
        elif stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=stage_stats,
            )
            with open(self.hparams.wer_file, "w") as w:
                self.wer_metric.write_stats(w)

    def init_optimizers(self):
        "Initializes the wav2vec2 optimizer and model optimizer"

        self.model_optimizer = self.hparams.model_opt_class(
            self.hparams.model.parameters()
        )

        if self.checkpointer is not None:
            self.checkpointer.add_recoverable("modelopt", self.model_optimizer)

    def zero_grad(self, set_to_none=False):

        self.model_optimizer.zero_grad(set_to_none)






In [None]:

%pip install PyArabic

In [None]:
# ONLY RUN ONCE TO CLEAN UP THE dataset
# SHOULD RUN AGAIN FOR DEV ? TEST ? TRAIN
"""
import csv
import pyarabic.araby as araby

# Open the input CSV file for reading and the output CSV file for writing
input_file_path = '/content/drive/MyDrive/tunisian_corpora/code_switched/dev.csv'
output_file_path = '/content/drive/MyDrive/tunisian_corpora/code_switched/dev_processed.csv'

column_name = 'wrd'
chars = set()
# Define your processing function
def process_data(data):
    pattern = r"<(?:en|fr)>|<\/(?:en|fr)>|\\(?:en|fr)>|\\(?:en|fr)>|,|\n|-|\*|\\|'|>|\b\d+\b|\/|\|"
    data= re.sub(pattern,"",data)
    data = re.sub('\ufeff|"|<',"",data)
    data = re.sub("ڨ" , "ق",data)
    data = araby.strip_diacritics(data)
    return data.lower()

# Open input CSV file for reading
with open(input_file_path, 'r', newline='') as input_file:
    csv_reader = csv.DictReader(input_file)
    fieldnames = csv_reader.fieldnames

    # Create a list to store processed rows
    processed_rows = []

    for row in csv_reader:
        # Process the data in the input column
        processed_data = process_data(row[column_name])

        # Update the row with the processed data in the output column
        row[column_name] = processed_data

        for char in row[column_name]:
          chars.add(char)

        processed_rows.append(row)

print(chars)
print(len(chars))

# Write the processed data to the output CSV file
with open(output_file_path, 'w', newline='') as output_file:
    csv_writer = csv.DictWriter(output_file, fieldnames=fieldnames)
    csv_writer.writeheader()
    csv_writer.writerows(processed_rows)
"""

In [None]:
hparams_file, run_opts, overrides = sb.parse_arguments(["/content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm//hparams/mixer.yaml"])

# If distributed_launch=True then
# create ddp_group with the right communication protocol
sb.utils.distributed.ddp_init_group(run_opts)

with open(hparams_file) as fin:
    hparams = load_hyperpyyaml(fin, overrides)

# Create experiment directory
sb.create_experiment_directory(
    experiment_directory=hparams["output_folder"],
    hyperparams_to_save=hparams_file,
    overrides=overrides,
)


label_encoder = sb.dataio.encoder.CTCTextEncoder()

train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
        hparams
    )




# We dynamicaly add the tokenizer to our brain class.
# NB: This tokenizer corresponds to the one used for the LM!!

decoder = build_ctcdecoder(
    labels,
    kenlm_model_path="/content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm/lm_data/arpas/indomain.arpa",  # either .arpa or .bin file
    alpha=0.5,  # tuned on a val set
    beta=1,  # tuned on a val set
)

mixer = Mixer(
    modules=hparams["modules"],
    hparams=hparams,
    run_opts=run_opts,
    checkpointer=hparams["checkpointer"],
)

mixer.device = 'cpu'
asr_brain.tokenizer = label_encoder



mixer.fit(
    mixer.hparams.epoch_counter,
    train_data,
    valid_data,
    train_loader_kwargs=hparams["dataloader_options"],
    valid_loader_kwargs=hparams["test_dataloader_options"],
)

