<a href="https://colab.research.google.com/github/SunbirdAI/leb/blob/main/notebooks/Multilingual_ASR_training_single_language_%2B_English.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

- MMS ASR training should take around 2 hours on a single RTX 6000 Ada GPU.
- We fine-tune from the existing MMS adapters (which are available for all the SALT languages). This means that we have to reuse the tokenizers for each language. Note that in some cases the tokenizer vocabulary is incomplete, e.g. might be missing 'q' or 'z' characters for some languages.
- This notebook mixes in some Ugandan English to the training samples, so that the resulting model has some multi-lingual capability -- at least to identify English phrases in case a speaker does code switching.

In [None]:
%%capture
!pip install -q jiwer
!pip install -q evaluate
!pip install -qU accelerate
!pip install -q transformers[torch]
!git clone https://github.com/sunbirdai/leb.git
!pip install -qr leb/requirements.txt
!pip install -q mlflow

## requirements to log system/GPU metrics in mlflow
!pip install psutil
!pip install pynvml

In [None]:
from torch import nn
import torch
from transformers import (
    AutoFeatureExtractor,
    AutoModelForCTC,
    AutoProcessor,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    Wav2Vec2CTCTokenizer,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2ForCTC,
    Wav2Vec2Processor,
    is_apex_available,
    set_seed,
)
from dataclasses import dataclass, field
from typing import Union, List, Dict
import string
import os
import json
import datasets
import numpy as np
import yaml
import evaluate
import mlflow
from getpass import getpass
import leb.dataset
import leb.metrics
from leb.utils import DataCollatorCTCWithPadding as dcwp
import mlflow.pytorch
from mlflow import MlflowClient
import huggingface_hub
from transformers.models.wav2vec2.modeling_wav2vec2 import WAV2VEC2_ADAPTER_SAFE_FILE
from safetensors.torch import save_file as safe_save_file

# import wandb # If using weights and biases

In [None]:
# Set MLflow tracking credentials
MLFLOW_TRACKING_USERNAME = getpass('Enter the MLFLOW_TRACKING_USERNAME: ') # enter your provided username
os.environ['MLFLOW_TRACKING_USERNAME'] = MLFLOW_TRACKING_USERNAME

MLFLOW_TRACKING_PASSWORD = getpass('Enter the MLFLOW_TRACKING_PASSWORD: ') # enter your provided password
os.environ['MLFLOW_TRACKING_PASSWORD'] = MLFLOW_TRACKING_PASSWORD

# Set the MLflow tracking URI
mlflow.set_tracking_uri('https://mlflow-sunbird-ce0ecfc14244.herokuapp.com/')

In [None]:
# Train a model with a single local language plus English (Uganda accent) support
language = 'lgg'

In [None]:
# languages currently available in SALT multispeaker STT dataset
languages = {
    "acholi": "ach",
    "lugbara": "lgg",
    "luganda": "lug",
    "ateso": "teo",
    "runyankole": "nyn",
    "english": "eng"
}

yaml_config = f'''
pretrained_model: facebook/mms-1b-all
pretrained_adapter: {language}
mlflow_experiment_name : stt-multilingual
mlflow_run_name: {language}_from_pretrained
adapter_save_id: {language}+eng

training_args:
    output_dir: stt
    per_device_train_batch_size: 24
    gradient_accumulation_steps: 2
    evaluation_strategy: steps
    max_steps: 1200
    gradient_checkpointing: True
    gradient_checkpointing_kwargs:
      use_reentrant: True
    fp16: True
    save_steps: 100
    eval_steps: 100
    logging_steps: 100
    learning_rate: 3.0e-4
    warmup_steps: 100
    save_total_limit: 2
    # push_to_hub: True
    load_best_model_at_end: True
    metric_for_best_model: loss
    greater_is_better: False
    weight_decay: 0.01

Wav2Vec2ForCTC_args:
    attention_dropout: 0.0
    hidden_dropout: 0.0
    feat_proj_dropout: 0.0
    layerdrop: 0.0
    ctc_loss_reduction: mean
    ignore_mismatched_sizes: True

train:
    huggingface_load:
        - path: Sunbird/salt
          name: multispeaker-{language}
          split: train
    source:
      type: speech
      language: [{language},eng]
      preprocessing:
        - set_sample_rate:
            rate: 16_000
        - augment_audio_noise:
            max_relative_amplitude: 0.5
    target:
      type: text
      language: [{language},eng]
      preprocessing:
        - lower_case
        - clean_and_remove_punctuation:
            allowed_punctuation: "'"
    shuffle: True
validation:
    huggingface_load:
        - path: Sunbird/salt
          name: multispeaker-{language}
          split: dev
        - path: Sunbird/salt
          name: multispeaker-eng
          split: dev
    source:
      type: speech
      language: [{language},eng]
      preprocessing:
        - set_sample_rate:
            rate: 16_000
    target:
      type: text
      language: [{language},eng]
      preprocessing:
        - lower_case
        - clean_and_remove_punctuation:
            allowed_punctuation: "'"
'''

config = yaml.safe_load(yaml_config)
train_ds = leb.dataset.create(config['train'])
valid_ds = leb.dataset.create(config['validation'])

In [None]:
leb.utils.show_dataset(train_ds, audio_features=['source'], N=5)

In [None]:
if config.get('pretrained_adapter'):
  # If fine-tuning from an existing adapter, we have to use the matching
  # vocabulary.
  tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(config['pretrained_model'])
  tokenizer.set_target_lang(config['pretrained_adapter'])
else:
  # Otherwise, create a new vocabulary. Assume that the preprocessing leaves
  # only lower case characters, digits and specific special characters.
  language = '-'.join(config['train']['source']['language'])
  vocab = list(string.ascii_lowercase)
  vocab += ['[UNK]', '[PAD]', '|', "'"]
  vocab_dict = {
      language: {v: i for i, v in enumerate(vocab)}
  }
  # vocab_dict[language]['|'] = vocab_dict[language][' ']
  with open("vocab.json", "w") as vocab_file:
      json.dump(vocab_dict, vocab_file)
  tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
      "./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|",
      target_lang=language)

feature_extractor = Wav2Vec2FeatureExtractor(
    feature_size=1, sampling_rate=16000, padding_value=0.0,
    do_normalize=True, return_attention_mask=True)
processor = Wav2Vec2Processor(
    feature_extractor=feature_extractor, tokenizer=tokenizer)
data_collator = dcwp(processor=processor, padding=True)

In [None]:
def prepare_dataset(batch):
    # check that all files have the correct sampling rate
    batch["input_values"] = processor(
        batch["source"], sampling_rate=16000
    ).input_values
    # Setup the processor for targets
    batch["labels"] = processor(text=batch["target"]).input_ids

    return batch

final_train_dataset = train_ds.map(
    prepare_dataset,
    batch_size=4,
    batched=True,
)

final_val_dataset = valid_ds.map(
    prepare_dataset,
    batch_size=4,
    batched=True,
)

In [None]:
model = Wav2Vec2ForCTC.from_pretrained(
    config['pretrained_model'],
    **config["Wav2Vec2ForCTC_args"],
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)

if config.get('pretrained_adapter'):
  model.load_adapter(config['pretrained_adapter'])
else:
  model.init_adapter_layers()

model.freeze_base_model()
adapter_weights = model._get_adapters()
for param in adapter_weights.values():
    param.requires_grad = True

In [None]:
compute_metrics = leb.metrics.multilingual_eval_fn(
      valid_ds, [evaluate.load('wer'), evaluate.load('cer')],
      tokenizer, log_first_N_predictions=2,
      speech_processor=processor)

In [None]:
training_args = TrainingArguments(
  **config["training_args"],
    report_to="none"
)

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=final_train_dataset,
    eval_dataset=final_val_dataset,
    tokenizer=processor.feature_extractor,
    callbacks=[leb.utils.MlflowExtendedLoggingCallback()]
)

experiment_name = config['mlflow_experiment_name']

if not mlflow.get_experiment_by_name(experiment_name):
  mlflow.create_experiment(experiment_name)
mlflow.set_experiment(experiment_name)

with mlflow.start_run(run_name=config['mlflow_run_name'], log_system_metrics=True) as run:

    mlflow.set_tag("developer", os.environ['MLFLOW_TRACKING_USERNAME'])

    mlflow.log_params(config)

    train_output = trainer.train()

    # evaluate the model to get the latest metrics including WER
    eval_metrics = trainer.evaluate()

    # Save and log the model
    trainer.save_model()

    adapter_file = WAV2VEC2_ADAPTER_SAFE_FILE.format(config['adapter_save_id'])
    adapter_file = os.path.join(training_args.output_dir, adapter_file)
    safe_save_file(model._get_adapters(), adapter_file, metadata={"format": "pt"})

    artifact_path = "model_artifacts"
    # mlflow.log_artifact(f"{experiment_name}/config.json", artifact_path)
    mlflow.log_artifact(
        f"{config['training_args']['output_dir']}/preprocessor_config.json",
        artifact_path)
    mlflow.log_artifact(
        f"{config['training_args']['output_dir']}/training_args.bin",
        artifact_path)
    mlflow.log_artifact(
        f"{config['training_args']['output_dir']}/adapter.{config['adapter_save_id']}.safetensors",
        artifact_path)
