In [1]:
%%capture
!pip install jiwer
!pip install evaluate
!pip install accelerate -U
!pip install transformers[torch]
!git clone https://github.com/sunbirdai/leb.git
!pip install -r leb/requirements.txt

In [3]:
from torch import nn
import torch
from transformers import (
    AutoFeatureExtractor,
    AutoModelForCTC,
    AutoProcessor,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    Wav2Vec2CTCTokenizer,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2ForCTC,
    Wav2Vec2Processor,
    is_apex_available,
    set_seed,
)
from dataclasses import dataclass, field
from typing import Union, List, Dict
import string
import datasets
import numpy as np
import leb.dataset
import yaml
import evaluate
import mlflow
from getpass import getpass
import os
from leb.utils import DataCollatorCTCWithPadding as dcwp
from leb.utils import MlflowExtendedLoggingCallback

# ASR data example

In [4]:
# languages currently available in SALT multispeaker STT dataset
languages = {
    "acholi": "ach",
    "lugbara": "lgg",
    "luganda": "lug",
    "ateso": "teo",
    "runyankole": "nyn",
    "english": "eng"

}

# choose a language of interest
language = languages["acholi"]

# define the directory for output files
project_folder = "./stt"

if not os.path.exists(project_folder):
  %mkdir $project_folder
%cd $project_folder

yaml_config = '''
common_source: &common_source
  type: speech
  language: "{language}"
  preprocessing:
    - set_sample_rate:
        rate: 16_000

common_target: &common_target
  type: text
  language: "{language}"
  preprocessing:
    - lower_case
    - clean_and_remove_punctuation

training_args:
    output_dir: "{project_folder}"
    per_device_train_batch_size: 2
    evaluation_strategy: steps
    num_train_epochs: 5
    max_steps: 10000
    gradient_checkpointing: True
    fp16: True
    save_steps: 1000
    eval_steps: 1000
    logging_steps: 1000
    learning_rate: 0.003
    warmup_steps: 100
    save_total_limit: 2
    # push_to_hub: True
    load_best_model_at_end: True
    metric_for_best_model: wer
    greater_is_better: False
    weight_decay: 0.01

Wav2Vec2ForCTC_args:
    attention_dropout: 0.0
    hidden_dropout: 0.0
    feat_proj_dropout: 0.0
    layerdrop: 0.0
    ctc_loss_reduction: mean
    ignore_mismatched_sizes: True

train:
    huggingface_load:
        - path: Sunbird/salt
          name: multispeaker-{language}
          split: train

    source: *common_source
    target: *common_target
validation:
    huggingface_load:
        - path: Sunbird/salt
          name: multispeaker-{language}
          split: dev

    source: *common_source
    target: *common_target

'''

yaml_config = yaml_config.format(
    project_folder=project_folder,
    language=language
)

config = yaml.safe_load(yaml_config)
train_ds = leb.dataset.create(config['train'])
valid_ds = leb.dataset.create(config['validation'])

/content/stt


In [5]:
## HELPER FUNCTIONS

# Create dict for vocabulary
def extract_all_chars(batch):
    all_text = " ".join(batch["target"])
    vocab = list(set(all_text))
    return {"vocab": vocab, "all_text": [all_text]}

def prepare_dataset(batch):
    # check that all files have the correct sampling rate
    batch["input_values"] = processor(
        batch["source"], sampling_rate=16000
    ).input_values
    # Setup the processor for targets
    batch["labels"] = processor(text=batch["target"]).input_ids

    return batch

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}

In [None]:
vocab_dict = {}

for item in train_ds:
    result = extract_all_chars(item)
    for char in result["vocab"]:
        vocab_dict[char] = 1

vocab_list = list(vocab_dict.keys())
vocab_dict = {v: k for k, v in enumerate(vocab_list)}

In [7]:
vocab_dict["|"] = vocab_dict[" "]
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)

In [8]:
new_vocab_dict = {language: vocab_dict}

In [9]:
new_vocab_dict

{'ach': {'e': 0,
  'd': 1,
  'y': 2,
  'o': 3,
  'b': 4,
  'a': 5,
  ' ': 6,
  't': 7,
  'k': 8,
  'm': 9,
  'n': 10,
  'l': 11,
  'r': 12,
  'p': 13,
  'i': 14,
  'g': 15,
  'u': 16,
  'w': 17,
  'j': 18,
  'c': 19,
  's': 20,
  'f': 21,
  'z': 22,
  'v': 23,
  'x': 24,
  'h': 25,
  'q': 26,
  '|': 6,
  '[UNK]': 28,
  '[PAD]': 29}}

In [10]:
import json
with open("vocab.json", "w") as vocab_file:
    json.dump(new_vocab_dict, vocab_file)

In [11]:
final_train_dataset = train_ds.map(
    prepare_dataset,
    batch_size=4,
    batched=True,
)

In [12]:
final_val_dataset = valid_ds.map(
    prepare_dataset,
    batch_size=4,
    batched=True,
)

In [13]:
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|", target_lang=language)
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [14]:
wer_metric = evaluate.load('wer')

Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

In [15]:
data_collator = dcwp(processor=processor, padding=True)

In [None]:
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/mms-1b-all",
  **config["Wav2Vec2ForCTC_args"],
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)

In [17]:
model.gradient_checkpointing_enable()
model.init_adapter_layers()
model.freeze_base_model()

adapter_weights = model._get_adapters()
for param in adapter_weights.values():
    param.requires_grad = True

In [18]:
training_args = TrainingArguments(
  **config["training_args"],
    report_to="none"
)

In [19]:
# Initialize trainer with the custom callback
trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=final_train_dataset,
    eval_dataset=final_val_dataset,
    tokenizer=processor.feature_extractor,
    callbacks=[MlflowExtendedLoggingCallback()]
)


__Set up MLflow with access to Sunbird Server__

In [22]:
service_account = "mlflow-server@sb-gcp-project-01.iam.gserviceaccount.com"
json_key_name = "./sb-gcp-project-01-de9e848cf5a8.json" # replace with the path to your generated json key

!gcloud auth activate-service-account $service_account --key-file=$json_key_name

# Set the Google Cloud credentials, with storage access
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = json_key_name

# Set MLflow tracking credentials
MLFLOW_TRACKING_USERNAME = getpass('Enter the MLFLOW_TRACKING_USERNAME: ')
os.environ['MLFLOW_TRACKING_USERNAME'] = MLFLOW_TRACKING_USERNAME

MLFLOW_TRACKING_PASSWORD = getpass('Enter the MLFLOW_TRACKING_PASSWORD: ')
os.environ['MLFLOW_TRACKING_PASSWORD'] = MLFLOW_TRACKING_PASSWORD

# Set the MLflow tracking URI
mlflow.set_tracking_uri('https://mlflow-sunbird-ce0ecfc14244.herokuapp.com/')

Activated service account credentials for: [mlflow-server@sb-gcp-project-01.iam.gserviceaccount.com]


__[ALTERNATIVE]__

__Set up MLflow without access to Sunbird Server__

_Uncomment the code below that sets up mlflow locally and generates a link with ngrok._

In [None]:
# !pip install pyngrok --quiet

In [None]:
# from pyngrok import ngrok

# # run mlflow tracking UI in the background
# get_ipython().system_raw("mlflow ui --port 5000 &")

# # Terminate open tunnels if exist
# ngrok.kill()

# # Click this URL: https://dashboard.ngrok.com/get-started/your-authtoken to get your authtoken
# NGROK_AUTH_TOKEN = getpass('Enter the ngrok authtoken: ') # paste the token you copied
# ngrok.set_auth_token(NGROK_AUTH_TOKEN)

# # Open an HTTPs tunnel on port 5000 for http://localhost:5000
# ngrok_tunnel = ngrok.connect(addr="5000", proto="http", bind_tls=True)

# print("MLflow Tracking UI:", ngrok_tunnel.public_url)

# # older experiments can always be accessed via the generated ngrok public url

In [23]:
# provide an experiment name for mlflow
experiment_name = "stt"
try:
    mlflow.create_experiment(experiment_name)
    mlflow.set_experiment(experiment_name)
except Exception:
    mlflow.set_experiment(experiment_name)

In [None]:
with mlflow.start_run(run_name=f"leb-{language}-stt") as run:

    mlflow.set_tag("developer", "sharon") # replace with your name


    train_output = trainer.train()

    # evaluate the model to get the latest metrics including WER
    eval_metrics = trainer.evaluate()

    mlflow.log_params({
"training_args": config["training_args"],
        "Wav2Vec2ForCTC_args": config["Wav2Vec2ForCTC_args"],
        "pretrained_model":"facebook/mms-1b-all",
        "train_data": config["train"],
        "eval_data": config["validation"]
    })

    # Save and log the model
    trainer.save_model()

    artifact_path = "model_artifacts"
    mlflow.log_artifact(f"{experiment_name}/config.json", artifact_path)
    mlflow.log_artifact(f"{experiment_name}/preprocessor_config.json", artifact_path)
    mlflow.log_artifact(f"{experiment_name}/training_args.bin", artifact_path)
    mlflow.log_artifact(f"{experiment_name}/model.safetensors", artifact_path)
    mlflow.log_artifact("vocab.json", ".")




Step,Training Loss,Validation Loss,Wer
1000,0.8085,0.514606,0.475921


Logged Training Loss: 0.8085 at step: 1000
Logged Evaluation Metrics: {'eval_loss': 0.5146058201789856, 'eval_wer': 0.47592067988668557} at step: 1000


