Working on RTX 6000Ada 48GB (per-device batch size 2) and H100 80GB (per-device batch size 16)

In [None]:
!pip install -q jiwer
!pip install -q evaluate
!pip install -qU accelerate
!pip install -Uq torch
!pip install -Uq torchvision
!pip install -q transformers[torch]
!pip install -q soundfile
!git clone https://github.com/jqug/salt.git
!pip install -qr salt/requirements.txt
!pip install -q peft

In [None]:
use_wandb = False
use_mlflow = True

import importlib.metadata
installed = [
    dist.metadata['Name']
    for dist in importlib.metadata.distributions()
]

if use_wandb:
  !pip install -q wandb
  import wandb
  %set_env WANDB_LOG_MODEL=True
  %set_env WANDB_WATCH=all
  %set_env WANDB_NOTEBOOK_NAME=whisper_base_en_sb.ipynb
  wandb.login()

if use_mlflow:
  if 'mlflow' not in installed:
      !pip install -q mlflow
      ## requirements to log system/GPU metrics in mlflow
  !pip install -q psutil
  !pip install -q pynvml
  import os
  from getpass import getpass
  import mlflow
  import mlflow.pytorch
  from mlflow import MlflowClient

  # Set MLflow tracking credentials
  MLFLOW_TRACKING_USERNAME = getpass('Enter the MLFLOW_TRACKING_USERNAME: ')
  os.environ['MLFLOW_TRACKING_USERNAME'] = MLFLOW_TRACKING_USERNAME

  MLFLOW_TRACKING_PASSWORD = getpass('Enter the MLFLOW_TRACKING_PASSWORD: ')
  os.environ['MLFLOW_TRACKING_PASSWORD'] = MLFLOW_TRACKING_PASSWORD

  # Set the MLflow tracking URI
  mlflow.set_tracking_uri('https://mlflow-sunbird-ce0ecfc14244.herokuapp.com/')
  mlflow.system_metrics.enable_system_metrics_logging()

In [None]:
import torch
import transformers
from dataclasses import dataclass, field
from typing import Union, List, Dict, Any
import string
import os
import json
import datasets
import numpy as np
import yaml
import evaluate
import salt.dataset
import salt.metrics
import salt.constants
from salt.utils import DataCollatorCTCWithPadding as dcwp
import huggingface_hub
import peft
import pandas as pd

In [None]:
huggingface_hub.notebook_login()

In [None]:
# !rm -rf salt
# !git clone https://github.com/jqug/salt.git
from importlib import reload
reload(salt.dataset)

In [None]:
yaml_config = f'''
pretrained_model: openai/whisper-large-v3
mlflow_experiment_name : stt-whisper

use_peft: False
lora_config:
    r: 32
    lora_alpha: 64
    target_modules: ["q_proj", "v_proj"]
    lora_dropout: 0.05
    bias: "none"

training_args:
    output_dir: whisper-large-v3-multilingual
    per_device_train_batch_size: 2
    per_device_eval_batch_size: 2
    gradient_accumulation_steps: 32  # increase by 2x for every 2x decrease in batch size
    learning_rate: 1.0e-5
    warmup_steps: 500
    max_steps: 7500
    gradient_checkpointing: True
    gradient_checkpointing_kwargs:
      use_reentrant: True
    fp16: True
    eval_strategy: steps
    predict_with_generate: True
    generation_max_length: 100
    save_steps: 250
    eval_steps: 250
    logging_steps: 250
    load_best_model_at_end: True
    metric_for_best_model: loss
    greater_is_better: False
    push_to_hub: True
    hub_model_id: jq/whisper-large-v3-salt-plus-xog-myx-kin-swa
    save_total_limit: 2
    
train:
    download_datasets_in_parallel: True
    huggingface_load:
        - path: Sunbird/external-speech-data
          name: common-voice-sample-packed-lug
        - path: Sunbird/external-speech-data
          name: common-voice-sample-packed-swa
          split: train[:-25]
        - path: Sunbird/external-speech-data
          name: common-voice-sample-packed-kin
          split: train[:-25]
        - path: Sunbird/external-speech-data
          name: makerere-radio-speech
        - path: Sunbird/external-speech-data
          name: makerere-yogera-ach
        - path: Sunbird/external-speech-data
          name: makerere-yogera-lug
        - path: Sunbird/external-speech-data
          name: makerere-yogera-nyn
        # Save some myx and xog data for validation
        - path: Sunbird/external-speech-data
          name: makerere-yogera-myx
          split: train[:-100]
        - path: Sunbird/external-speech-data
          name: makerere-yogera-xog
          split: train[:-100]
        # Real-world data
        - path: Sunbird/salt-ucfd
          name: eng
          split: train
        - path: Sunbird/salt-ucfd
          name: lug
          split: train    
        - path: Sunbird/salt-ucfd
          name: numbers-eng
          split: train
        - path: Sunbird/salt-ucfd
          name: numbers-lug
          split: train  
        - path: Sunbird/salt-tracfm
          name: lug
          split: train
        # Main SALT ASR training data
        - path: Sunbird/salt
          name: multispeaker-lug
          split: train
        - path: Sunbird/salt
          name: multispeaker-eng
          split: train
        - path: Sunbird/salt
          name: multispeaker-ach
          split: train
        - path: Sunbird/salt
          name: multispeaker-lgg
          split: train
        - path: Sunbird/salt
          name: multispeaker-teo
          split: train
        - path: Sunbird/salt
          name: multispeaker-nyn
          split: train
        # Google FLEURS
        - path: google/fleurs
          split: train
          name: lg_ug
          trust_remote_code: True
        - path: google/fleurs
          split: train
          name: sw_ke
          trust_remote_code: True
    source:
      type: speech
      language: [lug,eng,ach,lgg,teo,nyn,myx,xog,swa,kin]
      preprocessing:
        # Downsample some examples to 8KHz (to simulate phone audio) 
        - set_sample_rate:
            rate: 8_000
            p: 0.2
        # Then upsample again
        - set_sample_rate:
            rate: 16_000
        - normalize_audio
        - augment_audio_speed:
            low: 0.95
            high: 1.15
        - augment_audio_noise:
            max_relative_amplitude: 0.5
            noise_audio_repo:
                path: Sunbird/urban-noise
                name: small
                split: train       
    target:
      type: text
      preprocessing:
        - ensure_text_ends_with_punctuation
      language: [lug,eng,ach,lgg,teo,nyn,myx,xog,swa,kin]
    shuffle: True
validation:
    huggingface_load:
        - path: Sunbird/salt
          name: multispeaker-eng
          split: dev
        - path: Sunbird/salt
          name: multispeaker-lug
          split: dev
        - path: Sunbird/salt
          name: multispeaker-ach
          split: dev
        - path: Sunbird/salt
          name: multispeaker-lgg
          split: dev
        - path: Sunbird/salt
          name: multispeaker-teo
          split: dev
        - path: Sunbird/salt
          name: multispeaker-nyn
          split: dev
        - path: Sunbird/external-speech-data
          name: makerere-yogera-myx
          split: train[-100:]
        - path: Sunbird/external-speech-data
          name: makerere-yogera-xog
          split: train[-100:]
        - path: Sunbird/external-speech-data
          name: common-voice-sample-packed-swa
          split: train[-25:]
        - path: Sunbird/external-speech-data
          name: common-voice-sample-packed-kin
          split: train[-25:]
    source:
      type: speech
      language: [lug,eng,ach,lgg,teo,nyn,myx,xog,swa,kin]
      preprocessing:
        - set_sample_rate:
            rate: 16_000
    target:
      type: text
      language: [lug,eng,ach,lgg,teo,nyn,myx,xog,swa,kin]
'''

config = yaml.safe_load(yaml_config)
train_ds = salt.dataset.create(config['train'], verbose=True)
valid_ds = salt.dataset.create(config['validation'])

In [None]:
salt.utils.show_dataset(train_ds, audio_features=['source'], N=5)

In [None]:
feature_extractor = transformers.WhisperFeatureExtractor.from_pretrained(
    config['pretrained_model'])
processor = transformers.WhisperProcessor.from_pretrained(
    config['pretrained_model'], language=None, task="transcribe")
model = transformers.WhisperForConditionalGeneration.from_pretrained(
    config['pretrained_model'])

In [None]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]    
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor, decoder_start_token_id=model.config.decoder_start_token_id)

Read in prompts: preceding text which is used to guide the model.

In [None]:
sentences = datasets.load_dataset(
    'Sunbird/salt', 'text-all', split='train').to_pandas()
prompts = datasets.load_dataset(
    'Sunbird/prompts', split='train').to_pandas()
joined = pd.merge(sentences, prompts, on='id', how='inner')
SALT_PROMPT_LANGUAGES = ['eng', 'ach', 'lgg', 'lug', 'nyn', 'teo']
sentence_to_prompt = {}
for language in SALT_PROMPT_LANGUAGES:
    sentence_to_prompt[language] = dict(
        zip(joined[f'{language}_text'], joined[f'{language}_prompt']))

In [None]:
language_id_tokens = salt.constants.SALT_LANGUAGE_TOKENS_WHISPER

def prepare_dataset(example, p_prompt = 0.5):    
    audio = example["source"]
    input_features = feature_extractor(
        audio, sampling_rate=16000, device='cuda',
        do_normalize=True).input_features[0]

    # Encode target text to label ids
    labels = processor.tokenizer(str(example["target"])).input_ids

    # Insert the language ID token into the second position of the sequence.
    labels.insert(1, language_id_tokens[example["target.language"]])

    # If a prompt is known for a particular sentence, add it to the
    # training example with probability `p_prompt`.
    if example["target.language"] in sentence_to_prompt:
        prompt = sentence_to_prompt[example["target.language"]].get(example["target"], None)
        if prompt:
            if np.random.random() < p_prompt:
                prompt_ids = list(processor.get_prompt_ids(prompt))
                labels = prompt_ids + labels  

    # Create a new dictionary with the processed data
    processed_example = {
        "input_features": input_features,
        "labels": np.array(labels),
        "source.language": example["source.language"],
        "target.language": example["target.language"]
    }

    return processed_example

In [None]:
train_data = train_ds.map(prepare_dataset, remove_columns=["source", "target"])
val_data = valid_ds.map(prepare_dataset, remove_columns=["source", "target"])

In [None]:
compute_metrics = salt.metrics.multilingual_eval_fn(
      valid_ds, [evaluate.load('wer'), evaluate.load('cer')],
      processor.tokenizer, log_first_N_predictions=5,
      speech_processor=processor)

In [None]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

if config['use_peft']:
    model = peft.prepare_model_for_kbit_training(model)
    lora_config = peft.LoraConfig(**config['lora_config'])
    model.enable_input_require_grads()
    model = peft.get_peft_model(model, lora_config)
    model.config.use_cache = False
    model.print_trainable_parameters()

Launch the training

In [None]:
training_args = transformers.Seq2SeqTrainingArguments(
  **config["training_args"],
  report_to= [
      platform for platform, use in [("wandb", use_wandb), ("mlflow", use_mlflow)] if use]
)

trainer = transformers.Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=data_collator,
    # compute_metrics=compute_metrics,
    processing_class=processor,
)

trainer.train()

Log the config settings for reference

In [None]:
if use_mlflow:
    mlflow.log_params(config)

Save the full model (not just the adapter weights)

In [None]:
processor.push_to_hub(config['training_args']['hub_model_id'])
model.push_to_hub(config['training_args']['hub_model_id'])

Try running the model on the first test example

In [None]:
example = next(iter(valid_ds))
input_features = processor(example["source"], sampling_rate=16000, return_tensors="pt").input_features
with torch.no_grad():
    predicted_ids = model.generate(input_features.to("cuda"))[0]
transcription = processor.decode(predicted_ids)
print(transcription)