In [None]:
!pip install -q jiwer
!pip install -q evaluate
!pip install -qU accelerate
!pip install -q transformers[torch]
!git clone https://github.com/sunbirdai/salt.git
!pip install -qr salt/requirements.txt

In [None]:
use_wandb = False
use_mlflow = True

if use_wandb:
  !pip install -q wandb
  import wandb
  %set_env WANDB_LOG_MODEL=True
  %set_env WANDB_WATCH=all
  %set_env WANDB_NOTEBOOK_NAME=whisper_base_en_sb.ipynb
  wandb.login()

if use_mlflow:
  !pip install -q mlflow
  ## requirements to log system/GPU metrics in mlflow
  !pip install -q psutil
  !pip install -q pynvml
  import os
  from getpass import getpass
  import mlflow
  import mlflow.pytorch
  from mlflow import MlflowClient

  # Set MLflow tracking credentials
  MLFLOW_TRACKING_USERNAME = getpass('Enter the MLFLOW_TRACKING_USERNAME: ')
  os.environ['MLFLOW_TRACKING_USERNAME'] = MLFLOW_TRACKING_USERNAME

  MLFLOW_TRACKING_PASSWORD = getpass('Enter the MLFLOW_TRACKING_PASSWORD: ')
  os.environ['MLFLOW_TRACKING_PASSWORD'] = MLFLOW_TRACKING_PASSWORD

  # Set the MLflow tracking URI
  mlflow.set_tracking_uri('https://mlflow-sunbird-ce0ecfc14244.herokuapp.com/')

In [None]:
from torch import nn
import torch
import transformers
from dataclasses import dataclass, field
from typing import Union, List, Dict, Any
import string
import os
import json
import datasets
import numpy as np
import yaml
import evaluate
import mlflow
import salt.dataset
import salt.metrics
from salt.utils import DataCollatorCTCWithPadding as dcwp
import huggingface_hub

In [None]:
huggingface_hub.notebook_login()

In [None]:
# The following code prepares datasets with one local language plus English
language = 'lug'

In [None]:
yaml_config = f'''
pretrained_model: openai/whisper-medium
mlflow_experiment_name : stt-whisper-callcentre
mlflow_run_name: {language}_eng_from_pretrained

training_args:
    output_dir: ./whisper-medium-sb-{language}-eng
    per_device_train_batch_size: 12
    per_device_eval_batch_size: 4
    gradient_accumulation_steps: 1  # increase by 2x for every 2x decrease in batch size
    learning_rate: 1.0e-4
    warmup_steps: 100
    max_steps: 20000
    gradient_checkpointing: True
    gradient_checkpointing_kwargs:
      use_reentrant: True
    fp16: True
    eval_strategy: steps
    predict_with_generate: True
    generation_max_length: 100
    save_steps: 200
    eval_steps: 100
    logging_steps: 25
    load_best_model_at_end: True
    metric_for_best_model: loss
    greater_is_better: False
    push_to_hub: True
    hub_model_id: jq/whisper-medium-sb-lug-eng
    save_total_limit: 2

Wav2Vec2ForCTC_args:
    attention_dropout: 0.0
    hidden_dropout: 0.0
    feat_proj_dropout: 0.0
    layerdrop: 0.0
    ctc_loss_reduction: mean
    ignore_mismatched_sizes: True

train:
    huggingface_load:
        - path: Sunbird/salt
          name: multispeaker-{language}
          split: train
        - path: Sunbird/salt
          name: multispeaker-eng
          split: train
        # Common Voice training data can optionally be added
        - path: mozilla-foundation/common_voice_13_0
          split: train
          name: lg
          trust_remote_code: True
        # And Google FLEURS
        - path: google/fleurs
          split: train
          name: lg_ug
          trust_remote_code: True
    source:
      type: speech
      language: [{language},eng]
      preprocessing:
        # Downsample some examples to 8KHz (to simulate phone audio) 
        - set_sample_rate:
            rate: 8_000
            p: 0.5
        # Then upsample again
        - set_sample_rate:
            rate: 16_000
        - augment_audio_noise:
            max_relative_amplitude: 0.5
    target:
      type: text
      language: [{language},eng]
    shuffle: True
validation:
    huggingface_load:
        # Evaluate on call center data
        - path: Sunbird/salt-practical-eval
          name: ucfd_eng
          split: test
        - path: Sunbird/salt-practical-eval
          name: ucfd_lug
          split: test
    source:
      type: speech
      language: [{language},eng]
      preprocessing:
        - set_sample_rate:
            rate: 16_000
    target:
      type: text
      language: [{language},eng]
'''

config = yaml.safe_load(yaml_config)
train_ds = salt.dataset.create(config['train'])
valid_ds = salt.dataset.create(config['validation'])

In [None]:
salt.utils.show_dataset(train_ds, audio_features=['source'], N=10)

In [None]:
feature_extractor = transformers.WhisperFeatureExtractor.from_pretrained(
    config['pretrained_model'])
processor = transformers.WhisperProcessor.from_pretrained(
    config['pretrained_model'], language=None, task="transcribe")
model = transformers.WhisperForConditionalGeneration.from_pretrained(
    config['pretrained_model'])

In [None]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor, decoder_start_token_id=model.config.decoder_start_token_id)

In [None]:
view_example = False

if view_example:
  train_iterator = iter(valid_ds)
  example = next(train_iterator)

  example.keys()
  input_str = example["target"]
  labels = processor.tokenizer(input_str).input_ids
  decoded_with_special = processor.tokenizer.decode(labels, skip_special_tokens=False)
  decoded_str = processor.tokenizer.decode(labels, skip_special_tokens=True)

  print(f"Input:                 {input_str}")
  print(f"Decoded w/ special:    {decoded_with_special}")
  print(f"Decoded w/out special: {decoded_str}")
  print(f"Are equal:             {input_str == decoded_str}")

In [None]:
# Mapping from SALT languages to Whisper language tokens
language_id_tokens = {
    'eng': 50259,
    'ach': 50357,
    'lgg': 50356,
    'lug': 50355,
    'nyn': 50354,
    'teo': 50353,
}

def prepare_dataset(example):
    # Extract the audio data from the 'source' key
    audio = example["source"]

    # Compute log-Mel input features from the audio array
    input_features = feature_extractor(audio, sampling_rate=16000).input_features[0]

    # Encode target text to label ids
    labels = processor.tokenizer(example["target"]).input_ids

    # Insert the language ID token into the second position of the sequence.
    labels.insert(1, language_id_tokens[example["target.language"]])

    # Create a new dictionary with the processed data
    processed_example = {
        "input_features": input_features,
        "labels": labels,
        "source.language": example["source.language"],
        "target.language": example["target.language"]
    }

    return processed_example

In [None]:
train_data = train_ds.map(prepare_dataset, remove_columns=["source", "target"])
val_data = valid_ds.map(prepare_dataset, remove_columns=["source", "target"])

In [None]:
compute_metrics = salt.metrics.multilingual_eval_fn(
      valid_ds, [evaluate.load('wer'), evaluate.load('cer')],
      processor.tokenizer, log_first_N_predictions=5,
      speech_processor=processor)

In [None]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

In [None]:
report_to = []
if use_wandb:
    report_to.append('wandb')
if use_mlflow:
    report_to.append('mlflow')
    experiment_name = config['mlflow_experiment_name']
    if not mlflow.get_experiment_by_name(experiment_name):
      mlflow.create_experiment(experiment_name)
    mlflow.set_experiment(experiment_name)

training_args = transformers.Seq2SeqTrainingArguments(
  **config["training_args"],
  report_to=report_to
)

In [None]:
trainer = transformers.Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

In [None]:
trainer.train()

In [None]:
model.push_to_hub("whisper-medium-lug-eng")
tokenizer.push_to_hub("whisper-medium-lug-eng")

Try running the model on the first test example

In [None]:
example = next(iter(valid_ds))
input_features = processor(example["source"], sampling_rate=16000, return_tensors="pt").input_features
with torch.no_grad():
    predicted_ids = model.generate(input_features.to("cuda"))[0]
transcription = processor.decode(predicted_ids)
print(transcription)