<a href="https://colab.research.google.com/github/SunbirdAI/salt/blob/main/notebooks/ASR_correction_training_FLAN_T5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
!pip install -U transformers
!pip install -U datasets
!pip install accelerate
!pip install sentencepiece
!pip install sacremoses
!pip install -q mlflow
!pip install psutil
!pip install pynvml

!git clone https://github.com/sunbirdai/salt.git
!pip install -r salt/requirements.txt

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import transformers
import datasets
import evaluate
import tqdm
import salt.dataset
import salt.utils
import salt.metrics
import yaml
from IPython import display
import getpass
import mlflow

In [None]:
# Set MLflow tracking credentials
MLFLOW_TRACKING_USERNAME = getpass.getpass('Enter the MLFLOW_TRACKING_USERNAME: ') # enter your provided username
os.environ['MLFLOW_TRACKING_USERNAME'] = MLFLOW_TRACKING_USERNAME

MLFLOW_TRACKING_PASSWORD = getpass.getpass('Enter the MLFLOW_TRACKING_PASSWORD: ') # enter your provided password
os.environ['MLFLOW_TRACKING_PASSWORD'] = MLFLOW_TRACKING_PASSWORD

# Set the MLflow tracking URI
mlflow.set_tracking_uri('https://mlflow-sunbird-ce0ecfc14244.herokuapp.com/')

In [None]:
if torch.cuda.is_available():
  !nvidia-smi

In [None]:
# define the artifacts directory for output files
drive_folder = "./artifacts"

if not os.path.exists(drive_folder):
  %mkdir $drive_folder

effective_train_batch_size = 3000
train_batch_size = 15
eval_batch_size = train_batch_size

gradient_accumulation_steps = int(effective_train_batch_size / train_batch_size)

# Everything in one yaml string, so that it can all be logged to MLFlow
yaml_config = '''
training_args:
  output_dir: "{drive_folder}"
  evaluation_strategy: steps
  eval_steps: 10
  save_steps: 40
  warmup_steps: 10
  num_train_epochs: 3
  gradient_accumulation_steps: {gradient_accumulation_steps}
  learning_rate: 3.0e-4  # Include decimal point to parse as float
  per_device_train_batch_size: {train_batch_size}
  per_device_eval_batch_size: {eval_batch_size}
  weight_decay: 0.01
  save_total_limit: 3
  predict_with_generate: True
  fp16: False
  logging_dir: "{drive_folder}"
  load_best_model_at_end: True
  metric_for_best_model: loss
  seed: 42
  hub_model_id: asr-correction-flan-t5
  push_to_hub: True

mlflow_run_name: correction-with-ambiguity
mlflow_experiment_name : asr-correction

max_input_length: 224
max_output_length: 224
eval_pretrained_model: False
early_stopping_patience: 4
data_dir: .
model_checkpoint: google/flan-t5-base
'''

yaml_config = yaml_config.format(
    drive_folder=drive_folder,
    train_batch_size=train_batch_size,
    eval_batch_size=eval_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
)

config = yaml.safe_load(yaml_config)

training_settings = transformers.Seq2SeqTrainingArguments(
    **config["training_args"])

In [None]:
tokenizer = transformers.T5Tokenizer.from_pretrained(config['model_checkpoint'])
model = transformers.T5ForConditionalGeneration.from_pretrained(config['model_checkpoint')

In [None]:
label_pad_token_id = -100
data_collator = transformers.DataCollatorForSeq2Seq(
    tokenizer,
    model = model,
    label_pad_token_id=label_pad_token_id,
)

In [None]:
def preprocess(examples):
    model_inputs = tokenizer(
        examples['source'],
        text_target=examples['target'],
        max_length=config['max_input_length'],
        truncation=True)
    return model_inputs

train_dataset = datasets.load_dataset(
    'jq/salt-asr-correction', split='train[:-100]')
train_dataset = train_dataset.shuffle()

eval_dataset = datasets.load_dataset(
    'jq/salt-asr-correction', split='train[-100:]')

salt.utils.show_dataset(train_dataset, N=10)

train_dataset = train_dataset.map(
    preprocess,
    batched=True,
    num_proc=6,
    remove_columns=['source', 'source.language', 'target', 'target.language'])
eval_dataset = eval_dataset.map(
    preprocess,
    batched=True)

compute_metrics = salt.metrics.multilingual_eval_fn(
      eval_dataset, [evaluate.load('cer')],
      tokenizer, log_first_N_predictions=10)


In [None]:
gen_cfg = transformers.GenerationConfig.from_pretrained(config['model_checkpoint')
gen_cfg.max_new_tokens = config['max_output_length']
training_settings.generation_config = gen_cfg

In [None]:
trainer = transformers.Seq2SeqTrainer(
  model,
  training_settings,
  train_dataset = train_dataset,
  eval_dataset = eval_dataset,
  data_collator = data_collator,
  tokenizer = tokenizer,
  compute_metrics = compute_metrics,
  callbacks = [
      salt.utils.MlflowExtendedLoggingCallback(),
      transformers.EarlyStoppingCallback(
          early_stopping_patience = (config
           ['early_stopping_patience']))],
)

experiment_name = config['mlflow_experiment_name']

if not mlflow.get_experiment_by_name(experiment_name):
  mlflow.create_experiment(experiment_name)
mlflow.set_experiment(experiment_name)

with mlflow.start_run(run_name=config['mlflow_run_name'], log_system_metrics=True) as run:

    mlflow.set_tag("developer", os.environ['MLFLOW_TRACKING_USERNAME'])
    mlflow.log_params(config)

    trainer.train()

In [None]:
tokenizer.push_to_hub(config['training_args']['hub_model_id'])
model.push_to_hub(config['training_args']['hub_model_id'])