In [1]:
%%capture
!pip install -q jiwer
!pip install -q evaluate
!pip install -qU accelerate
!pip install -q transformers[torch]
!git clone https://github.com/sunbirdai/salt.git
!pip install -qr salt/requirements.txt

In [2]:
use_wandb = False
use_mlflow = True

if use_wandb:
  !pip install -q wandb
  import wandb
  %set_env WANDB_LOG_MODEL=True
  %set_env WANDB_WATCH=all
  %set_env WANDB_NOTEBOOK_NAME=whisper_base_en_sb.ipynb
  wandb.login()

if use_mlflow:
  !pip install -q mlflow
  ## requirements to log system/GPU metrics in mlflow
  !pip install -q psutil
  !pip install -q pynvml
  import os
  from getpass import getpass
  import mlflow
  import mlflow.pytorch
  from mlflow import MlflowClient

  # Set MLflow tracking credentials
  MLFLOW_TRACKING_USERNAME = getpass('Enter the MLFLOW_TRACKING_USERNAME: ')
  os.environ['MLFLOW_TRACKING_USERNAME'] = MLFLOW_TRACKING_USERNAME

  MLFLOW_TRACKING_PASSWORD = getpass('Enter the MLFLOW_TRACKING_PASSWORD: ')
  os.environ['MLFLOW_TRACKING_PASSWORD'] = MLFLOW_TRACKING_PASSWORD

  # Set the MLflow tracking URI
  mlflow.set_tracking_uri('https://mlflow-sunbird-ce0ecfc14244.herokuapp.com/')

[0m

Enter the MLFLOW_TRACKING_USERNAME:  ········
Enter the MLFLOW_TRACKING_PASSWORD:  ········


In [3]:
from torch import nn
import torch
from transformers import (
    AutoFeatureExtractor,
    AutoModelForCTC,
    AutoProcessor,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    Wav2Vec2CTCTokenizer,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2ForCTC,
    Wav2Vec2Processor,
    is_apex_available,
    set_seed,
    WhisperForConditionalGeneration,
    WhisperProcessor,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    WhisperTokenizer,
    WhisperFeatureExtractor,
)
from dataclasses import dataclass, field
from typing import Union, List, Dict, Any
import string
import os
import json
import datasets
import numpy as np
import yaml
import evaluate
import mlflow
import salt.dataset
import salt.metrics
from salt.utils import DataCollatorCTCWithPadding as dcwp
import huggingface_hub

In [4]:
huggingface_hub.notebook_login()
#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

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [9]:
# The following code prepares datasets with one local language plus English
language = 'lug'
hub_model_id = 'whisper-medium-sb-lug-eng'

In [37]:
yaml_config = f'''
pretrained_model: openai/whisper-medium
mlflow_experiment_name : stt-whisper-{language}-eng
mlflow_run_name: {language}_eng_from_pretrained-[30k]

training_args:
    output_dir: ./whisper-medium-sb-{language}-eng
    per_device_train_batch_size: 16
    gradient_accumulation_steps: 1  # increase by 2x for every 2x decrease in batch size
    learning_rate: 1.0e-5
    warmup_steps: 500
    max_steps: 30000
    gradient_checkpointing: True
    gradient_checkpointing_kwargs:
      use_reentrant: True
    fp16: True
    eval_strategy: steps
    per_device_eval_batch_size: 8
    predict_with_generate: True
    generation_max_length: 100
    save_steps: 500
    eval_steps: 500
    logging_steps: 25
    load_best_model_at_end: True
    metric_for_best_model: loss
    greater_is_better: False
    push_to_hub: True
    hub_model_id: {hub_model_id}
    save_total_limit: 2

train:
    huggingface_load:
        - path: Sunbird/salt
          name: multispeaker-{language}
          split: train
        - path: Sunbird/salt
          name: multispeaker-eng
          split: train
        - path: mozilla-foundation/common_voice_13_0
          split: train
          name: lg
          trust_remote_code: True
    source:
      type: speech
      language: [{language},eng]
      preprocessing:
        - set_sample_rate:
            rate: 16_000
        - augment_audio_noise:
            max_relative_amplitude: 0.5
    target:
      type: text
      language: [{language},eng]
    shuffle: True
validation:
    huggingface_load:
        - path: Sunbird/salt
          name: multispeaker-{language}
          split: dev
        - path: Sunbird/salt
          name: multispeaker-eng
          split: dev
    source:
      type: speech
      language: [{language},eng]
      preprocessing:
        - set_sample_rate:
            rate: 16_000
    target:
      type: text
      language: [{language},eng]
'''

config = yaml.safe_load(yaml_config)
train_ds = salt.dataset.create(config['train'])
valid_ds = salt.dataset.create(config['validation'])

In [11]:
salt.utils.show_dataset(train_ds, audio_features=['source'], N=10)

Downloading readme:   0%|          | 0.00/9.98k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/194M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/186M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/234M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/175M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/183M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/19.1M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/19.6M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/5016 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/103 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/97 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/240M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/228M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/248M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/15.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/15.8M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/4797 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/96 [00:00<?, ? examples/s]

Downloading builder script:   0%|          | 0.00/8.18k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/14.7k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.65k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/65.4k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/13.2k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.47G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/462M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/488M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.29G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.44G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/17.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.16M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.14M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/8.52M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/9.51M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]


Reading metadata...: 0it [00:00, ?it/s][A
Reading metadata...: 15693it [00:00, 156913.69it/s][A
Reading metadata...: 34172it [00:00, 173304.84it/s][A
Reading metadata...: 70813it [00:00, 180922.95it/s][A


Generating validation split: 0 examples [00:00, ? examples/s]


Reading metadata...: 13389it [00:00, 204948.53it/s]


Generating test split: 0 examples [00:00, ? examples/s]


Reading metadata...: 13420it [00:00, 205712.09it/s]


Generating other split: 0 examples [00:00, ? examples/s]


Reading metadata...: 0it [00:00, ?it/s][A
Reading metadata...: 36922it [00:00, 205368.32it/s][A


Generating invalidated split: 0 examples [00:00, ? examples/s]


Reading metadata...: 0it [00:00, ?it/s][A
Reading metadata...: 39159it [00:00, 202778.06it/s][A


Unnamed: 0,source,target,source.language,target.language
0,Your browser does not support the audio element.,"He informed people that he is an advocate of stability, social order and unity.",eng,eng
1,Your browser does not support the audio element.,Musigale awaka okuziyiza okusaasaana kw'akawuka ka kolona.,lug,lug
2,Your browser does not support the audio element.,Nga oluguudo luzimbibwa ebitongole bya gavumenti bingi ebikibaamu.,lug,lug
3,Your browser does not support the audio element.,Kale nno ebirowoozo ebiwabya ku nnimi zaffe enzaaliranwa tubyeggyemu.,lug,lug
4,Your browser does not support the audio element.,Omusumba Bumanye yeebazizza Katonda olw’obulamu bw'omugenzi.,lug,lug
5,Your browser does not support the audio element.,Okkiririza mu katonda?,lug,lug
6,Your browser does not support the audio element.,Engoye zaakaze dda era genda oziggyeyo.,lug,lug
7,Your browser does not support the audio element.,Who is a colleague?,eng,eng
8,Your browser does not support the audio element.,"N'endowooza embi eri ekintu, oyinza okukaluubirizibwa okubaako ky'otuukako.",lug,lug
9,Your browser does not support the audio element.,Abakulembeze baasomeseddwa ku ngeri y'okukozesaamu enkola empya.,lug,lug


In [38]:
new_languages = ['lug', 'ach', 'lgg', 'nyn', 'teo']
tokenizer = WhisperTokenizer.from_pretrained(config['pretrained_model'], language=None, task="transcribe")
# Add new language tokens to the tokenizer
new_tokens = [f"<|{lang}|>" for lang in new_languages]
num_added_tokens = tokenizer.add_tokens(new_tokens, special_tokens=True)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [39]:
feature_extractor = WhisperFeatureExtractor.from_pretrained(config['pretrained_model'])
processor = WhisperProcessor(feature_extractor, tokenizer)

In [40]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [41]:
view_example = True

if view_example:
  train_iterator = iter(train_ds)
  example = next(train_iterator)

  example.keys()
  input_str = example["target"]
  labels = tokenizer(input_str).input_ids
  decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
  decoded_str = tokenizer.decode(labels, skip_special_tokens=True)

  print(f"Input:                 {input_str}")
  print(f"Decoded w/ special:    {decoded_with_special}")
  print(f"Decoded w/out special: {decoded_str}")
  print(f"Are equal:             {input_str == decoded_str}")

Input:                 Waliwo eby'okujaguza bingi ebikolebwa mu bwakabaka.
Decoded w/ special:    <|startoftranscript|><|transcribe|><|notimestamps|>Waliwo eby'okujaguza bingi ebikolebwa mu bwakabaka.<|endoftext|>
Decoded w/out special: Waliwo eby'okujaguza bingi ebikolebwa mu bwakabaka.
Are equal:             True


In [42]:
def prepare_dataset(example):
    audio = example["source"]
    input_features = feature_extractor(audio, sampling_rate=16000).input_features[0]

    # Get the language token
    lang = example['target.language']
    lang_token = "<|en|>" if lang == "eng" else f"<|{lang}|>"

    # Tokenize with the language token
    target_text = f"<|startoftranscript|>{lang_token}<|transcribe|>{example['target']}<|endoftext|>"
    labels = tokenizer(target_text, add_special_tokens=False).input_ids

    processed_example = {
        "input_features": input_features,
        "labels": labels,
        "source.language": example["source.language"],
        "target.language": example["target.language"]
    }
    return processed_example

In [43]:
example_dict = {"source": np.zeros(10), "target":"I am going to school", "source.language": "eng", "target.language":"eng"}

In [44]:
prepare_dataset(example_dict)

{'input_features': array([[-1.5, -1.5, -1.5, ..., -1.5, -1.5, -1.5],
        [-1.5, -1.5, -1.5, ..., -1.5, -1.5, -1.5],
        [-1.5, -1.5, -1.5, ..., -1.5, -1.5, -1.5],
        ...,
        [-1.5, -1.5, -1.5, ..., -1.5, -1.5, -1.5],
        [-1.5, -1.5, -1.5, ..., -1.5, -1.5, -1.5],
        [-1.5, -1.5, -1.5, ..., -1.5, -1.5, -1.5]], dtype=float32),
 'labels': [50258, 50259, 50359, 40, 669, 516, 281, 1395, 50257],
 'source.language': 'eng',
 'target.language': 'eng'}

In [45]:
train_data = train_ds.map(prepare_dataset, remove_columns=["source", "target"])
val_data = valid_ds.map(prepare_dataset, remove_columns=["source", "target"])

In [46]:
compute_metrics = salt.metrics.multilingual_eval_fn(
      valid_ds, [evaluate.load('wer'), evaluate.load('cer')],
      tokenizer, log_first_N_predictions=10,
      speech_processor=processor)

In [47]:
## Load the model and resize embeddings to the size of the tokenizer + new languages 
model = WhisperForConditionalGeneration.from_pretrained(config['pretrained_model'])
model.resize_token_embeddings(len(tokenizer))

Embedding(51870, 1024)

In [49]:
model.freeze_encoder()

In [50]:
# Update the model config
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

def custom_generate(self, input_features, language=None, task="transcribe", **kwargs):
    return super(WhisperForConditionalGeneration, self).generate(
        input_features,
        forced_decoder_ids=None,
        repetition_penalty=1.1,
        **kwargs
    )

model.generate = custom_generate.__get__(model, WhisperForConditionalGeneration)

In [51]:
report_to = []
if use_wandb:
  report_to.append('wandb')
if use_mlflow:
  report_to.append('mlflow')

training_args = Seq2SeqTrainingArguments(
  **config["training_args"],
  report_to=report_to
)

In [52]:
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

max_steps is given, it will override any value given in num_train_epochs


In [53]:
processor.save_pretrained(training_args.output_dir)

[]

In [28]:
# trainer.evaluate()

In [None]:
with mlflow.start_run(run_name=config['mlflow_run_name'], log_system_metrics=True) as run:
    mlflow.log_params(config)
    trainer.train()

In [40]:
trainer.push_to_hub(hub_model_id)
processor.push_to_hub(hub_model_id)

Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [], 'begin_suppress_tokens': [220, 50257]}


CommitInfo(commit_url='https://huggingface.co/akera/whisper-medium-sb-lug-eng/commit/464203b0cfe7564c07ca1d1592f67a86c89f686d', commit_message='akera/whisper-medium-sb-lug-eng', commit_description='', oid='464203b0cfe7564c07ca1d1592f67a86c89f686d', pr_url=None, pr_revision=None, pr_num=None)

In [83]:
## Backup Uganda File
import librosa
from transformers import pipeline

In [142]:
audio, sr  = librosa.load("SIMBA 5.4.mp3", sr=16000)

In [143]:
# audio

In [144]:
asr_pipeline = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=tokenizer,
    feature_extractor=feature_extractor,
    device="cuda",
    generate_kwargs={'num_beams': 5,
                     'temperature': 1},
)

In [145]:
# whisper_base = pipeline(
#     "automatic-speech-recognition",
#     model="openai/whisper-base",
#     device="cuda"
# )

In [146]:
# whisper_base(audio)

In [147]:
input_features = processor(
    audio, sampling_rate=sr, return_tensors="pt", device="cuda"
).input_features

In [148]:
input_features = input_features.to("cuda")

In [149]:
predicted_ids = model.generate(input_features)

In [150]:
predicted_ids

tensor([[50258, 50318, 50359, 50363,   426,    73, 12257, 29758,  2887, 51867,
          4013,   371,  7371,  8384,  1877, 10449,   275,    74,  6440,  6799,
           457, 10449,   297,  1275,    64,   417,  8707,  5406,   297,   432,
           339,   453,  3780,   371,  7371,  8384,    13,  3301,   339,  5159,
         46930,   619, 14610,    64,  1038, 10121,  3406,    84,  5406,  2752,
            13,  6056,   308,  2887,  1049,    89, 13275,   297,  5509,  1111,
          3274,   275, 10730,    84,   350,   418,   695,    13,   426,   271,
           332,   345,   388,   271,  4151,    13, 51867,    77,   619, 14610,
            64,  1038, 10121,  2478,   275,   716,    13,   426,  9994,  2562,
            64,   275, 15615,  3547,   410,    64,   417,   304,   473,  1667,
           417,   270,   901,    64,  6493,  5159,    13,   591,  2860,   320,
          9384,  4773, 14610,    64,   308,  2887,    13,  6777,   350,  3780,
           297,   455,   443, 12716,   288,  3548,  

In [151]:
tokenizer.decode(50259)

'<|en|>'

In [152]:
tokenizer.decode(7497)

' Could'

In [155]:
processor.batch_decode(predicted_ids, skip_special_tokens=True)

[' Njaza shamidaji vila ko futawa mkoma ka butawa ngana chiti wa ngechokula vila ko. Umchala wanake vela musimi omu wa mi. Na eida genzoku naba obata mdemu kore gu. Nisimadiliswa.nake vela musimi ya mne. Nasa anga mwen numbers aba chalate na chituala bulala. Kumbayoli ba vela eida. Ba kula nabemidi yomba ga natate.']