In [1]:
%%capture
!pip install -q jiwer
!pip install -q evaluate
!pip install -qU accelerate
!pip install -q transformers[torch]
# Temporarily use JQ version with multilingual metric changes
!git clone https://github.com/jqug/leb.git
!pip install -qr leb/requirements.txt
!pip install -q mlflow

## requirements to log system metrics in mlflow
!pip install psutil
!pip install pynvml # useful if interested in logging GPU metrics too.

In [None]:
from torch import nn
import torch
from transformers import (
    AutoFeatureExtractor,
    AutoModelForCTC,
    AutoProcessor,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    Wav2Vec2CTCTokenizer,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2ForCTC,
    Wav2Vec2Processor,
    is_apex_available,
    set_seed,
)
from dataclasses import dataclass, field
from typing import Union, List, Dict
import string
import os
import json
import datasets
import numpy as np
import yaml
import evaluate
import mlflow
from getpass import getpass
import leb.dataset
import leb.metrics
from leb.utils import DataCollatorCTCWithPadding as dcwp
from leb.utils import MlflowExtendedLoggingCallback
import mlflow.pytorch
from mlflow import MlflowClient
import huggingface_hub
from transformers.models.wav2vec2.modeling_wav2vec2 import WAV2VEC2_ADAPTER_SAFE_FILE
from safetensors.torch import save_file as safe_save_file

In [None]:
huggingface_hub.notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [5]:
# Load the YAML configuration file.
with open('asr_config.yml', 'r') as file:
    config = yaml.safe_load(file)

# You may make edits to the config.yml if required.
train_ds = leb.dataset.create(config['train'])
valid_ds = leb.dataset.create(config['validation'])

In [6]:
train_ds

IterableDataset({
    features: Unknown,
    n_shards: 1
})

In [None]:
leb.utils.show_dataset(train_ds, audio_features=['source'], N=5)

In [None]:
if config.get('pretrained_adapter'):
  # If fine-tuning from an existing adapter, we have to use the matching
  # vocabulary.
  tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(config['pretrained_model'])
  tokenizer.set_target_lang(config['pretrained_adapter'])
else:
  # Otherwise, create a new vocabulary. Assume that the preprocessing leaves
  # only lower case characters, digits and specific special characters.
  language = '-'.join(config['train']['source']['language'])
  vocab = list(string.ascii_lowercase)
  vocab += ['[UNK]', '[PAD]', '|', "'"]
  vocab_dict = {
      language: {v: i for i, v in enumerate(vocab)}
  }
  # vocab_dict[language]['|'] = vocab_dict[language][' ']
  with open("vocab.json", "w") as vocab_file:
      json.dump(vocab_dict, vocab_file)
  tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
      "./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|",
      target_lang=language)

tokenizer_config.json:   0%|          | 0.00/397 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.34M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/96.0 [00:00<?, ?B/s]

In [None]:
feature_extractor = Wav2Vec2FeatureExtractor(
    feature_size=1, sampling_rate=16000, padding_value=0.0,
    do_normalize=True, return_attention_mask=True)
processor = Wav2Vec2Processor(
    feature_extractor=feature_extractor, tokenizer=tokenizer)
data_collator = dcwp(processor=processor, padding=True)

In [None]:
def prepare_dataset(batch):
    # check that all files have the correct sampling rate
    batch["input_values"] = processor(
        batch["source"], sampling_rate=16000
    ).input_values
    # Setup the processor for targets
    batch["labels"] = processor(text=batch["target"]).input_ids

    return batch

final_train_dataset = train_ds.map(
    prepare_dataset,
    batch_size=4,
    batched=True,
)

final_val_dataset = valid_ds.map(
    prepare_dataset,
    batch_size=4,
    batched=True,
)

In [None]:
model = Wav2Vec2ForCTC.from_pretrained(
    config['pretrained_model'],
    **config["Wav2Vec2ForCTC_args"],
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)

In [None]:
model.gradient_checkpointing_enable()
if config.get('pretrained_adapter'):
  model.load_adapter(config['pretrained_adapter'])
else:
  model.init_adapter_layers()

model.freeze_base_model()
adapter_weights = model._get_adapters()
for param in adapter_weights.values():
    param.requires_grad = True

In [None]:
compute_metrics = leb.metrics.multilingual_eval_fn(
      valid_ds, [evaluate.load('wer')],
      tokenizer, log_first_N_predictions=2,
      speech_processor=processor)

In [None]:
training_args = TrainingArguments(
  **config["training_args"],
    report_to="none"
)

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=final_train_dataset,
    eval_dataset=final_val_dataset,
    tokenizer=processor.feature_extractor,
    callbacks=[MlflowExtendedLoggingCallback()]
)

__Connect to MLflow Server Using Sunbird Server__

In [None]:
json_key_name = "path-to-your-serviceaccount.json"

!gcloud auth activate-service-account --key-file=$json_key_name

# Set the Google Cloud credentials, with storage access
GOOGLE_APPLICATION_CREDENTIALS = json_key_name
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = GOOGLE_APPLICATION_CREDENTIALS

# Set MLflow tracking credentials
MLFLOW_TRACKING_USERNAME = getpass('Enter the MLFLOW_TRACKING_USERNAME: ') # enter your provided username
os.environ['MLFLOW_TRACKING_USERNAME'] = MLFLOW_TRACKING_USERNAME

MLFLOW_TRACKING_PASSWORD = getpass('Enter the MLFLOW_TRACKING_PASSWORD: ') # enter your provided password
os.environ['MLFLOW_TRACKING_PASSWORD'] = MLFLOW_TRACKING_PASSWORD

# Set the MLflow tracking URI
mlflow.set_tracking_uri('https://mlflow-sunbird-ce0ecfc14244.herokuapp.com/')

In [None]:
def print_auto_logged_info(r):
    tags = {k: v for k, v in r.data.tags.items() if not k.startswith("mlflow.")}
    artifacts = [f.path for f in MlflowClient().list_artifacts(r.info.run_id, "model")]
    print(f"run_id: {r.info.run_id}")
    print(f"artifacts: {artifacts}")
    print(f"params: {r.data.params}")
    print(f"metrics: {r.data.metrics}")
    print(f"tags: {tags}")

In [None]:
# provide an experiment name.
experiment_name = "stt-multilingual"
try:
    mlflow.create_experiment(experiment_name)
    mlflow.set_experiment(experiment_name)
except Exception:
    mlflow.set_experiment(experiment_name)

In [None]:
with mlflow.start_run(run_name=f"lug-from-pretrained", log_system_metrics=True) as run:

    mlflow.set_tag("developer", os.environ['MLFLOW_TRACKING_USERNAME'])

    mlflow.log_params(config)

    train_output = trainer.train()

    # evaluate the model to get the latest metrics including WER
    eval_metrics = trainer.evaluate()

    # Save and log the model
    trainer.save_model()

    adapter_file = WAV2VEC2_ADAPTER_SAFE_FILE.format('lug')
    adapter_file = os.path.join(training_args.output_dir, adapter_file)
    safe_save_file(model._get_adapters(), adapter_file, metadata={"format": "pt"})

    artifact_path = "model_artifacts"
    mlflow.log_artifact(f"{experiment_name}/config.json", artifact_path)
    mlflow.log_artifact(f"{experiment_name}/preprocessor_config.json", artifact_path)
    mlflow.log_artifact(f"{experiment_name}/training_args.bin", artifact_path)
    mlflow.log_artifact(f"{experiment_name}/adapter.lug.safetensors", artifact_path)
    # Logging the adapter should be sufficient, as the rest is the same as the
    # base model.
    # mlflow.log_artifact(f"{experiment_name}/model.safetensors", artifact_path)
    mlflow.log_artifact("vocab.json", ".")
