In [None]:
!pip install datasets >> /dev/null
!pip install transformers >> /dev/null
!pip install jiwer  >> /dev/null
!pip install evaluate >> /dev/null
!pip install accelerate -U >> /dev/null
!pip install wandb >> /dev/null

In [None]:
# Standard Libraries
import glob
import json
import logging
import os
import re
import string
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

# External Libraries
import datasets
import gdown
import librosa
import numpy as np
import pandas as pd
import torch
import torchaudio
import transformers
from datasets import Audio, Dataset, DatasetDict, load_dataset
from packaging import version
from torch import nn
from transformers import (
    AutoFeatureExtractor,
    AutoModelForCTC,
    AutoProcessor,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    Wav2Vec2CTCTokenizer,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2ForCTC,
    Wav2Vec2Processor,
    is_apex_available,
    set_seed,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import WAV2VEC2_ADAPTER_SAFE_FILE
from transformers.trainer_utils import get_last_checkpoint, is_main_process

# Local Modules
from evaluate import load
from safetensors.torch import save_file as safe_save_file

In [None]:

class SaltSpeechDataset:

    # Define constants
    BASE_DRIVE_URL = "https://drive.google.com/uc?id="

    FILE_IDS = {
        "acholi": ("1FBwyASu5WaMK1P9ZSitpZuePWhT3QgGI", "1ZY7eYECAQK8PT_3U0MYIsqmYpLswJHs6"),
        "ateso": ("1vfUPJ9tqPnepp4lvt6J7QnwDo0G8wWBj", "1lWcOB-gaFDs6bdlNcfI6hKzbNkz5JeB1"),
        "luganda": ("1TK_Y7bdT9vtp3NygdbAXv_jdEoq0tl4T", "1IfHycK9ERbCMkaKutmQpkA2kXsi-JQGh"),
        "lugbara": ("1eRLQYuee8bVmnp1lm3ZMxVsmKJDYd9Qn", "1ncAccGnHP2AqoOg4RPIy6RZhf79SsRB6"),
        "runyankole": ("1d_4qLBN0RwZm1A1MELZ7L1qvjgIbgMi0", "1FzPLvhh2Aw_Uu-ELXOpFLgq5tzzrhXty"),
        "english": ("1QCVxoZvOxWSEED3NkaahQK5fVVxUZZdY", "17Cxz36qmYHVKgni2xOnhrnrtaEojICwQ")
    }

    def __init__(self):
        self.datasets = {}

    def download_files(self):
        if not os.path.exists('data'):
            os.makedirs('data')

        for lang, (audio_id, transcript_id) in self.FILE_IDS.items():
            audio_file_path = f"data/{lang}-validated.zip"
            transcript_file_path = f"data/Prompt-{lang.capitalize()}.csv"

            if not os.path.exists(audio_file_path):
                os.system(f"gdown {self.BASE_DRIVE_URL}{audio_id} -O {audio_file_path}")
            if transcript_id and not os.path.exists(transcript_file_path):
                os.system(f"gdown {self.BASE_DRIVE_URL}{transcript_id} -O {transcript_file_path}")

    def extract_audio_files(self):
        print("Exracting audio ..")
        os.system("unzip 'data/*.zip' -d data")

    def load_and_split_dataset(self, audio_dir, csv_path):
        print("Loading and splitting dataset")

        df = pd.read_csv(csv_path)

        def get_audio_paths(row):
            key = os.path.join(audio_dir, str(row['Key']))
            return glob.glob(os.path.join(key + "/*.ogg"))

        df['audio_paths'] = df.apply(get_audio_paths, axis=1)
        dataset = Dataset.from_pandas(df)

        return {
            "train": dataset.filter(lambda example: example['split'] == 'train'),
            "test": dataset.filter(lambda example: example['split'] == 'test'),
            "validation": dataset.filter(lambda example: example['split'] == 'dev')
        }

    def prepare_datasets(self):
        print("preparing dataset")
        self.download_files()
        self.extract_audio_files()

        for lang in self.FILE_IDS:
            audio_dir = f"data/{lang}-validated"
            csv_path = f"data/Prompt-{lang.capitalize()}.csv"
            self.datasets[lang] = self.load_and_split_dataset(audio_dir, csv_path)


In [None]:
@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [
            {"input_values": feature["input_values"]} for feature in features
        ]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )
        labels_batch = self.processor.pad(
            labels=label_features,
            padding=self.padding,
            return_tensors="pt",
        )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        batch["labels"] = labels

        return batch

In [None]:
dataset = SaltSpeechDataset()
dataset.prepare_datasets()

preparing dataset
Exracting audio ..
Loading and splitting dataset


Filter:   0%|          | 0/25006 [00:00<?, ? examples/s]

Filter:   0%|          | 0/25006 [00:00<?, ? examples/s]

Filter:   0%|          | 0/25006 [00:00<?, ? examples/s]

Loading and splitting dataset


Filter:   0%|          | 0/24708 [00:00<?, ? examples/s]

Filter:   0%|          | 0/24708 [00:00<?, ? examples/s]

Filter:   0%|          | 0/24708 [00:00<?, ? examples/s]

Loading and splitting dataset


Filter:   0%|          | 0/25006 [00:00<?, ? examples/s]

Filter:   0%|          | 0/25006 [00:00<?, ? examples/s]

Filter:   0%|          | 0/25006 [00:00<?, ? examples/s]

Loading and splitting dataset


Filter:   0%|          | 0/25006 [00:00<?, ? examples/s]

Filter:   0%|          | 0/25006 [00:00<?, ? examples/s]

Filter:   0%|          | 0/25006 [00:00<?, ? examples/s]

Loading and splitting dataset


Filter:   0%|          | 0/25006 [00:00<?, ? examples/s]

Filter:   0%|          | 0/25006 [00:00<?, ? examples/s]

Filter:   0%|          | 0/25006 [00:00<?, ? examples/s]

Loading and splitting dataset


Filter:   0%|          | 0/25006 [00:00<?, ? examples/s]

Filter:   0%|          | 0/25006 [00:00<?, ? examples/s]

Filter:   0%|          | 0/25006 [00:00<?, ? examples/s]

In [None]:
target_language = "lugbara"

In [None]:
target_lang_dataset = dataset.datasets[target_language]

In [None]:
train_dataset = target_lang_dataset["train"]
test_dataset = target_lang_dataset["test"]
eval_dataset = target_lang_dataset["validation"]

In [None]:
chars_to_ignore = [
    ",",
    "?",
    ".",
    "!",
    "-",
    ";",
    ":",
    '""',
    "%",
    "'",
    '"',
    "�",
    "'",
    "\u2018",
    "\u2019",
]

chars_to_ignore_regex = f'[{"".join(chars_to_ignore)}]'

resampler = torchaudio.transforms.Resample(48_000, 16_000)


def remove_special_characters(batch):
    batch["text"] = re.sub(chars_to_ignore_regex, "", batch["Text"]).lower() + " "
    return batch


train_dataset = train_dataset.map(
    remove_special_characters, remove_columns=["Text"]
)
eval_dataset = eval_dataset.map(remove_special_characters, remove_columns=["Text"])


def extract_all_chars(batch):
    all_text = " ".join(batch["text"])
    vocab = list(set(all_text))
    return {"vocab": [vocab], "all_text": [all_text]}


vocab_train = train_dataset.map(
    extract_all_chars,
    batched=True,
    batch_size=-1,
    keep_in_memory=True,
    remove_columns=train_dataset.column_names,
)
vocab_test = train_dataset.map(
    extract_all_chars,
    batched=True,
    batch_size=-1,
    keep_in_memory=True,
    remove_columns=eval_dataset.column_names,
)

# The vocab list should include all ascii charachters
vocab_list = sorted(
    list(
        set(vocab_train["vocab"][0])
        | set(vocab_test["vocab"][0])
        | set(string.ascii_lowercase)
    )
)
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)

target_lang = "ach"
new_vocab_dict = {target_lang: vocab_dict}

with open("vocab.json", "w") as vocab_file:
    json.dump(new_vocab_dict, vocab_file)


Map:   0%|          | 0/23947 [00:00<?, ? examples/s]

Map:   0%|          | 0/496 [00:00<?, ? examples/s]

Map:   0%|          | 0/23947 [00:00<?, ? examples/s]

Map:   0%|          | 0/23947 [00:00<?, ? examples/s]

In [None]:
huggingface_dir = os.path.expanduser("~/.huggingface/")
os.makedirs(huggingface_dir, exist_ok=True)

token_path = os.path.join(huggingface_dir, "token")
with open(token_path, "w") as f:
    f.write("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")


In [None]:
!huggingface-cli whoami

akera
[1morgs: [0m Sunbird


In [None]:
# !huggingface-cli login

In [None]:


repo_name = "akera/mms-lgg"
tokenizer.push_to_hub(repo_name)

print("pushing tokenizer to hub")

from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


pushing tokenizer to hub


In [None]:
def speech_file_to_array_fn(batch):
    try:
        speech_array, _ = torchaudio.load(batch["audio_paths"][0])
        batch["speech"] = resampler(speech_array).squeeze().numpy()
        batch["sampling_rate"] = 16_000
        batch["target_text"] = batch["text"]

        return batch
    except Exception as e:
        print(f"Could not process file {batch['audio_paths']}. Error: {str(e)}")
        return None

def check_audio_file(batch):
    audio_path = batch.get("audio_paths")

    if not audio_path or audio_path[0] is None:
        return {"is_audio_ok": False}

    try:
        speech_array, sampling_rate = torchaudio.load(audio_path[0])
        return {"is_audio_ok": True}
    except Exception as e:
        print(f"Could not process file {audio_path[0]}. Error: {str(e)}")
        return {"is_audio_ok": False}

def prepare_dataset(batch):
    # check that all files have the correct sampling rate
    assert (
        len(set(batch["sampling_rate"])) == 1
    ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
    batch["input_values"] = processor(
        batch["speech"], sampling_rate=batch["sampling_rate"][0]
    ).input_values
    # Setup the processor for targets
    # with processor.as_target_processor():
        # batch["labels"] = processor(batch["target_text"]).input_ids
    batch["labels"] = processor(text=batch["target_text"]).input_ids

    return batch


In [None]:
train_dataset

Dataset({
    features: ['Key', 'split', 'audio_paths', 'text'],
    num_rows: 23947
})

In [None]:
# Create a new dataset with the 'is_audio_ok' column
train_dataset_with_check = train_dataset.map(check_audio_file)
eval_dataset_with_check = eval_dataset.map(check_audio_file)

Map:   0%|          | 0/23947 [00:00<?, ? examples/s]

Could not process file data/lugbara-validated/682/69.ogg. Error: Failed to open the input "data/lugbara-validated/682/69.ogg" (End of file).
Could not process file data/lugbara-validated/697/69.ogg. Error: Failed to open the input "data/lugbara-validated/697/69.ogg" (End of file).
Could not process file data/lugbara-validated/739/82.ogg. Error: Failed to open the input "data/lugbara-validated/739/82.ogg" (End of file).
Could not process file data/lugbara-validated/744/82.ogg. Error: Failed to open the input "data/lugbara-validated/744/82.ogg" (End of file).
Could not process file data/lugbara-validated/1007/70.ogg. Error: Failed to open the input "data/lugbara-validated/1007/70.ogg" (End of file).
Could not process file data/lugbara-validated/1135/15b9ba08-07a1-4e4c-b42e-b11bd778afc8.ogg. Error: Failed to open the input "data/lugbara-validated/1135/15b9ba08-07a1-4e4c-b42e-b11bd778afc8.ogg" (End of file).
Could not process file data/lugbara-validated/1346/99a9302b-9462-4589-aef0-30c47c0

Map:   0%|          | 0/496 [00:00<?, ? examples/s]

Could not process file data/lugbara-validated/1426/70.ogg. Error: Failed to open the input "data/lugbara-validated/1426/70.ogg" (End of file).
Could not process file data/lugbara-validated/4405/47.ogg. Error: Failed to open the input "data/lugbara-validated/4405/47.ogg" (End of file).
Could not process file data/lugbara-validated/4436/47.ogg. Error: Failed to open the input "data/lugbara-validated/4436/47.ogg" (End of file).


In [None]:
# Filter the dataset to only include rows where 'is_audio_ok' is True
train_dataset = train_dataset_with_check.filter(lambda x: x["is_audio_ok"])
eval_dataset = eval_dataset_with_check.filter(lambda x: x["is_audio_ok"])

Filter:   0%|          | 0/23947 [00:00<?, ? examples/s]

Filter:   0%|          | 0/496 [00:00<?, ? examples/s]

In [None]:
train_dataset1 = train_dataset.map(
    speech_file_to_array_fn,
    remove_columns=train_dataset.column_names,
    num_proc=1,
)

Map:   0%|          | 0/4771 [00:00<?, ? examples/s]

In [None]:
eval_dataset1 = eval_dataset.map(
        speech_file_to_array_fn,
        remove_columns=eval_dataset.column_names,
        num_proc=1,
    )

Map:   0%|          | 0/98 [00:00<?, ? examples/s]

In [None]:
final_train_dataset = train_dataset1.map(
    prepare_dataset,
    # remove_columns=train_dataset.column_names,
    batch_size=4,
    batched=True,
    num_proc=1,
)

Map:   0%|          | 0/4771 [00:00<?, ? examples/s]

In [None]:
final_eval_dataset = eval_dataset1.map(
    prepare_dataset,
    # remove_columns=train_dataset.column_names,
    batch_size=4,
    batched=True,
    num_proc=1,
)

In [None]:
wer_metric = datasets.load_metric("wer")

In [None]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

wer_metric = load("wer")

In [None]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}

In [None]:
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/mms-1b-all",
    attention_dropout=0.0,
    hidden_dropout=0.0,
    feat_proj_dropout=0.0,
    layerdrop=0.0,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
    ignore_mismatched_sizes=True,
)

In [None]:
model.gradient_checkpointing_enable()
model.init_adapter_layers()
model.freeze_base_model()

adapter_weights = model._get_adapters()
for param in adapter_weights.values():
    param.requires_grad = True

In [None]:
training_args = TrainingArguments(
    output_dir="output/mms-ach",
    group_by_length=True,
    per_device_train_batch_size=2,
    evaluation_strategy="steps",
    num_train_epochs=5,
    gradient_checkpointing=True,
    fp16=True,
    save_steps=100,
    eval_steps=20,
    logging_steps=100,
    learning_rate=1e-3,
    warmup_steps=100,
    save_total_limit=2,
    push_to_hub=True,
    report_to="wandb",
    run_name="mms-ach",
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
)

In [None]:
trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=final_train_dataset,
    eval_dataset=final_eval_dataset,
    tokenizer=processor.feature_extractor,
)

In [None]:
# a4c9538ad4597230cbe337502b0cd81bc04bd6d0

In [None]:
trainer.train()

In [None]:
adapter_file = WAV2VEC2_ADAPTER_SAFE_FILE.format(target_lang)
adapter_file = os.path.join(training_args.output_dir, adapter_file)
safe_save_file(model._get_adapters(), adapter_file, metadata={"format": "pt"})

trainer.push_to_hub()