In [None]:
!pip install -q jiwer
!pip install -q evaluate
!pip install -qU accelerate
!pip install -q transformers[torch]
!git clone https://github.com/sunbirdai/salt.git
!pip install -qr salt/requirements.txt

In [None]:
use_wandb = False
use_mlflow = True

if use_wandb:
  !pip install -q wandb
  import wandb
  %set_env WANDB_LOG_MODEL=True
  %set_env WANDB_WATCH=all
  %set_env WANDB_NOTEBOOK_NAME=whisper_base_en_sb.ipynb
  wandb.login()

if use_mlflow:
  !pip install -q mlflow
  ## requirements to log system/GPU metrics in mlflow
  !pip install -q psutil
  !pip install -q pynvml
  import os
  from getpass import getpass
  import mlflow
  import mlflow.pytorch
  from mlflow import MlflowClient

  # Set MLflow tracking credentials
  MLFLOW_TRACKING_USERNAME = getpass('Enter the MLFLOW_TRACKING_USERNAME: ')
  os.environ['MLFLOW_TRACKING_USERNAME'] = MLFLOW_TRACKING_USERNAME

  MLFLOW_TRACKING_PASSWORD = getpass('Enter the MLFLOW_TRACKING_PASSWORD: ')
  os.environ['MLFLOW_TRACKING_PASSWORD'] = MLFLOW_TRACKING_PASSWORD

  # Set the MLflow tracking URI
  mlflow.set_tracking_uri('https://mlflow-sunbird-ce0ecfc14244.herokuapp.com/')
  mlflow.system_metrics.enable_system_metrics_logging()

In [None]:
import torch
import transformers
from dataclasses import dataclass, field
from typing import Union, List, Dict, Any
import string
import os
import json
import datasets
import numpy as np
import yaml
import evaluate
import salt.dataset
import salt.metrics
from salt.utils import DataCollatorCTCWithPadding as dcwp
import huggingface_hub
import peft

In [None]:
huggingface_hub.notebook_login()

In [18]:

yaml_config = f'''
pretrained_model: openai/whisper-large-v2
mlflow_experiment_name : stt-whisper-lug-eng
mlflow_run_name: whisper-large-lug-eng

use_peft: True
lora_config:
    r: 32
    lora_alpha: 64
    target_modules: ["q_proj", "v_proj"]
    lora_dropout: 0.05
    bias: "none"

training_args:
    output_dir: ./artifacts
    per_device_train_batch_size: 16
    per_device_eval_batch_size: 8
    gradient_accumulation_steps: 16  # increase by 2x for every 2x decrease in batch size
    learning_rate: 1.0e-3
    warmup_steps: 50
    max_steps: 2000
    gradient_checkpointing: True
    gradient_checkpointing_kwargs:
      use_reentrant: True
    fp16: True
    eval_strategy: steps
    predict_with_generate: True
    generation_max_length: 100
    save_steps: 50
    eval_steps: 50
    logging_steps: 50
    load_best_model_at_end: True
    metric_for_best_model: loss
    greater_is_better: False
    push_to_hub: True
    hub_model_id: whisper-large-multilingual-adapter
    save_total_limit: 3

train:
    huggingface_load:
        # Call centre data
        - path: Sunbird/salt-ucfd
          name: ucfd_eng
          split: train
        - path: Sunbird/salt-ucfd
          name: ucfd_lug
          split: train    
        # Main SALT ASR training data
        - path: Sunbird/salt
          name: multispeaker-lug
          split: train
        - path: Sunbird/salt
          name: multispeaker-eng
          split: train
        - path: Sunbird/salt
          name: multispeaker-ach
          split: train
        - path: Sunbird/salt
          name: multispeaker-lgg
          split: train
        - path: Sunbird/salt
          name: multispeaker-teo
          split: train
        - path: Sunbird/salt
          name: multispeaker-nyn
          split: train
        # Common Voice
        - path: mozilla-foundation/common_voice_13_0
          split: train
          name: lg
          trust_remote_code: True
        # Google FLEURS
        - path: google/fleurs
          split: train
          name: lg_ug
          trust_remote_code: True
    source:
      type: speech
      language: [lug,eng,ach,lgg,teo,nyn]
      preprocessing:
        # Downsample some examples to 8KHz (to simulate phone audio) 
        - set_sample_rate:
            rate: 8_000
            p: 0.2
        # Then upsample again
        - set_sample_rate:
            rate: 16_000
        - augment_audio_noise:
            max_relative_amplitude: 0.5
    target:
      type: text
      language: [lug,eng,ach,lgg,teo,nyn]
    shuffle: True
validation:
    huggingface_load:
        - path: Sunbird/salt
          name: multispeaker-eng
          split: dev
        - path: Sunbird/salt
          name: multispeaker-lug
          split: dev
        - path: Sunbird/salt
          name: multispeaker-ach
          split: dev
        - path: Sunbird/salt
          name: multispeaker-lgg
          split: dev
        - path: Sunbird/salt
          name: multispeaker-teo
          split: dev
        - path: Sunbird/salt
          name: multispeaker-nyn
          split: dev
    source:
      type: speech
      language: [lug,eng,ach,lgg,teo,nyn]
      preprocessing:
        - set_sample_rate:
            rate: 16_000
    target:
      type: text
      language: [lug,eng,ach,lgg,teo,nyn]
'''

config = yaml.safe_load(yaml_config)
train_ds = salt.dataset.create(config['train'])
valid_ds = salt.dataset.create(config['validation'])

In [None]:
salt.utils.show_dataset(train_ds, audio_features=['source'], N=5)

In [None]:
feature_extractor = transformers.WhisperFeatureExtractor.from_pretrained(
    config['pretrained_model'])
processor = transformers.WhisperProcessor.from_pretrained(
    config['pretrained_model'], language=None, task="transcribe")
model = transformers.WhisperForConditionalGeneration.from_pretrained(
    config['pretrained_model'])

In [None]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]    
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor, decoder_start_token_id=model.config.decoder_start_token_id)

In [None]:
# Mapping from SALT languages to Whisper language tokens
language_id_tokens = {
    'eng': 50259,
    'ach': 50357,
    'lgg': 50356,
    'lug': 50355,
    'nyn': 50354,
    'teo': 50353,
}

def prepare_dataset(example):
    # Extract the audio data from the 'source' key
    audio = example["source"]

    # Compute log-Mel input features from the audio array
    input_features = feature_extractor(
        audio, sampling_rate=16000, device='cuda').input_features[0]

    # Encode target text to label ids
    labels = processor.tokenizer(example["target"]).input_ids

    # Insert the language ID token into the second position of the sequence.
    labels.insert(1, language_id_tokens[example["target.language"]])

    # Create a new dictionary with the processed data
    processed_example = {
        "input_features": input_features,
        "labels": labels,
        "source.language": example["source.language"],
        "target.language": example["target.language"]
    }

    return processed_example

In [None]:
train_data = train_ds.map(prepare_dataset, remove_columns=["source", "target"])
val_data = valid_ds.map(prepare_dataset, remove_columns=["source", "target"])

In [None]:
compute_metrics = salt.metrics.multilingual_eval_fn(
      valid_ds, [evaluate.load('wer'), evaluate.load('cer')],
      processor.tokenizer, log_first_N_predictions=5,
      speech_processor=processor)

In [None]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []


if config['use_peft']:
    model = peft.prepare_model_for_kbit_training(model)
    lora_config = peft.LoraConfig(**config['lora_config'])
    model.enable_input_require_grads()
    model = peft.get_peft_model(model, lora_config)
    model.config.use_cache = False
    model.print_trainable_parameters()

Launch the training

In [19]:
training_args = transformers.Seq2SeqTrainingArguments(
  **config["training_args"],
  report_to= [
      platform for platform, use in [("wandb", use_wandb), ("mlflow", use_mlflow)] if use]
)

trainer = transformers.Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

trainer.train()

max_steps is given, it will override any value given in num_train_epochs
Too many dataloader workers: 4 (max is dataset.n_shards=1). Stopping 3 dataloader workers.


RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 32, in fetch
    data.append(next(self.dataset_iter))
  File "/opt/conda/lib/python3.10/site-packages/datasets/iterable_dataset.py", line 2011, in __iter__
    yield from self._iter_pytorch()
  File "/opt/conda/lib/python3.10/site-packages/datasets/iterable_dataset.py", line 1936, in _iter_pytorch
    for key, example in ex_iterable:
  File "/opt/conda/lib/python3.10/site-packages/datasets/iterable_dataset.py", line 954, in __iter__
    yield from self._iter()
  File "/opt/conda/lib/python3.10/site-packages/datasets/iterable_dataset.py", line 1045, in _iter
    transformed_example.update(self.function(*function_args, **self.fn_kwargs))
  File "/tmp/ipykernel_10238/1754533562.py", line 16, in prepare_dataset
    input_features = feature_extractor(
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/whisper/feature_extraction_whisper.py", line 306, in __call__
    input_features = extract_fbank_features(input_features[0], device)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/whisper/feature_extraction_whisper.py", line 136, in _torch_extract_fbank_features
    waveform = waveform.to(device)
  File "/opt/conda/lib/python3.10/site-packages/torch/cuda/__init__.py", line 288, in _lazy_init
    raise RuntimeError(
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method


Log the config settings for reference

In [None]:
if use_mlflow:
    last_run_id = mlflow.last_active_run().info.run_id
    with mlflow.start_run(run_id=last_run_id):
        mlflow.log_params(config)

Save the full model (not just the adapter weights)

In [None]:
model = model.merge_and_unload()
model.push_to_hub(config['training_args']['hub_model_id'] + '-merged')
processor.push_to_hub(config['training_args']['hub_model_id'] + '-merged')

Try running the model on the first test example

In [None]:
example = next(iter(valid_ds))
input_features = processor(example["source"], sampling_rate=16000, return_tensors="pt").input_features
with torch.no_grad():
    predicted_ids = model.generate(input_features.to("cuda"))[0]
transcription = processor.decode(predicted_ids)
print(transcription)