In [1]:
!pip install -q jiwer
!pip install -q evaluate
!pip install -qU accelerate
!pip install -q transformers[torch]
!git clone https://github.com/sunbirdai/salt.git
!pip install -qr salt/requirements.txt
!pip install -q peft

[0mCloning into 'salt'...
remote: Enumerating objects: 912, done.[K
remote: Counting objects: 100% (327/327), done.[K
remote: Compressing objects: 100% (157/157), done.[K
remote: Total 912 (delta 181), reused 243 (delta 148), pack-reused 585 (from 1)[K
Receiving objects: 100% (912/912), 20.83 MiB | 15.28 MiB/s, done.
Resolving deltas: 100% (469/469), done.
[0m

In [2]:
use_wandb = False
use_mlflow = True

if use_wandb:
  !pip install -q wandb
  import wandb
  %set_env WANDB_LOG_MODEL=True
  %set_env WANDB_WATCH=all
  %set_env WANDB_NOTEBOOK_NAME=whisper_base_en_sb.ipynb
  wandb.login()

if use_mlflow:
  !pip install -q mlflow
  ## requirements to log system/GPU metrics in mlflow
  !pip install -q psutil
  !pip install -q pynvml
  import os
  from getpass import getpass
  import mlflow
  import mlflow.pytorch
  from mlflow import MlflowClient

  # Set MLflow tracking credentials
  MLFLOW_TRACKING_USERNAME = getpass('Enter the MLFLOW_TRACKING_USERNAME: ')
  os.environ['MLFLOW_TRACKING_USERNAME'] = MLFLOW_TRACKING_USERNAME

  MLFLOW_TRACKING_PASSWORD = getpass('Enter the MLFLOW_TRACKING_PASSWORD: ')
  os.environ['MLFLOW_TRACKING_PASSWORD'] = MLFLOW_TRACKING_PASSWORD

  # Set the MLflow tracking URI
  mlflow.set_tracking_uri('https://mlflow-sunbird-ce0ecfc14244.herokuapp.com/')
  mlflow.system_metrics.enable_system_metrics_logging()

[0m

Enter the MLFLOW_TRACKING_USERNAME:  ········
Enter the MLFLOW_TRACKING_PASSWORD:  ········


In [5]:
import torch
import transformers
from dataclasses import dataclass, field
from typing import Union, List, Dict, Any
import string
import os
import json
import datasets
import numpy as np
import yaml
import evaluate
import salt.dataset
import salt.metrics
from salt.utils import DataCollatorCTCWithPadding as dcwp
import huggingface_hub
import peft
import pandas as pd

In [7]:
huggingface_hub.notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [11]:
yaml_config = f'''
pretrained_model: jq/whisper-large-v2-multilingual
mlflow_experiment_name : stt-whisper-lug-eng

use_peft: False
lora_config:
    r: 32
    lora_alpha: 64
    target_modules: ["q_proj", "v_proj"]
    lora_dropout: 0.05
    bias: "none"

training_args:
    output_dir: whisper-large-v2-multilingual
    per_device_train_batch_size: 8
    per_device_eval_batch_size: 8
    gradient_accumulation_steps: 32  # increase by 2x for every 2x decrease in batch size
    learning_rate: 5.0e-6
    warmup_steps: 50
    max_steps: 500
    gradient_checkpointing: True
    gradient_checkpointing_kwargs:
      use_reentrant: True
    fp16: True
    eval_strategy: steps
    predict_with_generate: True
    generation_max_length: 100
    save_steps: 100
    eval_steps: 100
    logging_steps: 100
    load_best_model_at_end: True
    metric_for_best_model: loss
    greater_is_better: False
    push_to_hub: True
    hub_model_id: whisper-large-v2-multilingual-prompts-corrected
    save_total_limit: 2

train:
    huggingface_load:
        # Call centre data
        - path: Sunbird/salt-ucfd
          name: eng
          split: train
        - path: Sunbird/salt-ucfd
          name: lug
          split: train    
        - path: Sunbird/salt-ucfd
          name: numbers-eng
          split: train
        - path: Sunbird/salt-ucfd
          name: numbers-lug
          split: train   
        # Main SALT ASR training data
        - path: Sunbird/salt-corrected-corrected
          name: corrected-lug
          split: train
        - path: Sunbird/salt-corrected-corrected
          name: corrected-eng
          split: train
        - path: Sunbird/salt-corrected-corrected
          name: corrected-ach
          split: train
        - path: Sunbird/salt-corrected-corrected
          name: corrected-lgg
          split: train
        - path: Sunbird/salt-corrected-corrected
          name: corrected-teo
          split: train
        - path: Sunbird/salt-corrected-corrected
          name: corrected-nyn
          split: train
        # # Common Voice
        # - path: mozilla-foundation/common_voice_13_0
        #   split: train
        #   name: lg
        #   trust_remote_code: True
        # # Google FLEURS
        # - path: google/fleurs
        #   split: train
        #   name: lg_ug
          trust_remote_code: True
    source:
      type: speech
      language: [lug,eng,ach,lgg,teo,nyn]
      preprocessing:
        # Downsample some examples to 8KHz (to simulate phone audio) 
        - set_sample_rate:
            rate: 8_000
            p: 0.2
        # Then upsample again
        - set_sample_rate:
            rate: 16_000
        - augment_audio_noise:
            max_relative_amplitude: 0.5
    target:
      type: text
      language: [lug,eng,ach,lgg,teo,nyn]
    shuffle: True
validation:
    huggingface_load:
        - path: Sunbird/salt-corrected-corrected
          name: corrected-eng
          split: dev
        - path: Sunbird/salt-corrected-corrected
          name: corrected-lug
          split: dev
        - path: Sunbird/salt-corrected-corrected
          name: corrected-ach
          split: dev
        - path: Sunbird/salt-corrected-corrected
          name: corrected-lgg
          split: dev
        - path: Sunbird/salt-corrected-corrected
          name: corrected-teo
          split: dev
        - path: Sunbird/salt-corrected-corrected
          name: corrected-nyn
          split: dev
    source:
      type: speech
      language: [lug,eng,ach,lgg,teo,nyn]
      preprocessing:
        - set_sample_rate:
            rate: 16_000
    target:
      type: text
      language: [lug,eng,ach,lgg,teo,nyn]
'''

config = yaml.safe_load(yaml_config)
train_ds = salt.dataset.create(config['train'])
valid_ds = salt.dataset.create(config['validation'])

In [9]:
salt.utils.show_dataset(train_ds, audio_features=['source'], N=5)

README.md:   0%|          | 0.00/2.05k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/449M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/198 [00:00<?, ? examples/s]

train-00000-of-00001.parquet:   0%|          | 0.00/99.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/47 [00:00<?, ? examples/s]

train-00000-of-00001.parquet:   0%|          | 0.00/75.4M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/328 [00:00<?, ? examples/s]

train-00000-of-00001.parquet:   0%|          | 0.00/7.33M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/27 [00:00<?, ? examples/s]

README.md:   0%|          | 0.00/4.47k [00:00<?, ?B/s]

train-00000-of-00002.parquet:   0%|          | 0.00/483M [00:00<?, ?B/s]

train-00001-of-00002.parquet:   0%|          | 0.00/491M [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/19.0M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/19.5M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/5005 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/103 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/97 [00:00<?, ? examples/s]

Map:   0%|          | 0/5005 [00:00<?, ? examples/s]

train-00000-of-00002.parquet:   0%|          | 0.00/362M [00:00<?, ?B/s]

train-00001-of-00002.parquet:   0%|          | 0.00/358M [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/14.8M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/15.6M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/4797 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/96 [00:00<?, ? examples/s]

Map:   0%|          | 0/4797 [00:00<?, ? examples/s]

train-00000-of-00002.parquet:   0%|          | 0.00/458M [00:00<?, ?B/s]

train-00001-of-00002.parquet:   0%|          | 0.00/373M [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/17.3M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/16.9M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/4776 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/101 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/94 [00:00<?, ? examples/s]

Map:   0%|          | 0/4776 [00:00<?, ? examples/s]

train-00000-of-00003.parquet:   0%|          | 0.00/382M [00:00<?, ?B/s]

train-00001-of-00003.parquet:   0%|          | 0.00/392M [00:00<?, ?B/s]

train-00002-of-00003.parquet:   0%|          | 0.00/368M [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/24.0M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/24.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/4732 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/96 [00:00<?, ? examples/s]

Map:   0%|          | 0/4732 [00:00<?, ? examples/s]

train-00000-of-00002.parquet:   0%|          | 0.00/466M [00:00<?, ?B/s]

train-00001-of-00002.parquet:   0%|          | 0.00/438M [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/18.1M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/19.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/4626 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/95 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/94 [00:00<?, ? examples/s]

Map:   0%|          | 0/4626 [00:00<?, ? examples/s]

train-00000-of-00003.parquet:   0%|          | 0.00/308M [00:00<?, ?B/s]

train-00001-of-00003.parquet:   0%|          | 0.00/261M [00:00<?, ?B/s]

train-00002-of-00003.parquet:   0%|          | 0.00/424M [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/20.2M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/4722 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/94 [00:00<?, ? examples/s]

Map:   0%|          | 0/4722 [00:00<?, ? examples/s]

Unnamed: 0,source,target,source.language,target.language
0,Your browser does not support the audio element.,Abantu bakahutuurwa kugira ngu bakore munonga kugira ngu babaase kureeberera gye bizinesi zaabo.,nyn,nyn
1,Your browser does not support the audio element.,Dayasesi katulika eyini Arua niri avi i ma kristiani azi okporu ni,lgg,lgg
2,Your browser does not support the audio element.,Land surveyors are very expensive.,eng,eng
3,Your browser does not support the audio element.,Lubanga dit,ach,ach
4,Your browser does not support the audio element.,0120021,eng,eng


In [10]:
feature_extractor = transformers.WhisperFeatureExtractor.from_pretrained(
    config['pretrained_model'])
processor = transformers.WhisperProcessor.from_pretrained(
    config['pretrained_model'], language=None, task="transcribe")
model = transformers.WhisperForConditionalGeneration.from_pretrained(
    config['pretrained_model'])

preprocessor_config.json:   0%|          | 0.00/339 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/283k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/34.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.32k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/112k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.18G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/4.29k [00:00<?, ?B/s]

In [12]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]    
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor, decoder_start_token_id=model.config.decoder_start_token_id)

Read in prompts: preceding text which is used to guide the model.

In [13]:
sentences = datasets.load_dataset(
    'Sunbird/salt', 'text-all', split='train').to_pandas()
prompts = datasets.load_dataset(
    'Sunbird/prompts', split='train').to_pandas()
joined = pd.merge(sentences, prompts, on='id', how='inner')
SALT_LANGUAGES = ['eng', 'ach', 'lgg', 'lug', 'nyn', 'teo']
sentence_to_prompt = {}
for language in SALT_LANGUAGES:
    sentence_to_prompt[language] = dict(
        zip(joined[f'{language}_text'], joined[f'{language}_prompt']))

README.md:   0%|          | 0.00/9.98k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/6.89M [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/163k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/171k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/23947 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/496 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/500 [00:00<?, ? examples/s]

README.md:   0%|          | 0.00/509 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/14.8M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/23947 [00:00<?, ? examples/s]

In [43]:
# Mapping from SALT languages to Whisper language tokens
language_id_tokens = {
    'eng': 50259,
    'ach': 50357,
    'lgg': 50356,
    'lug': 50355,
    'nyn': 50354,
    'teo': 50353,
}

def prepare_dataset(example, p_prompt = 0.5):    
    # Extract the audio data from the 'source' key
    audio = example["source"]

    # Compute log-Mel input features from the audio array
    input_features = feature_extractor(
        audio, sampling_rate=16000, device='cuda',
        do_normalize=True).input_features[0]

    # Encode target text to label ids
    labels = processor.tokenizer(str(example["target"])).input_ids

    # Insert the language ID token into the second position of the sequence.
    labels.insert(1, language_id_tokens[example["target.language"]])

    # If a prompt is known for a particular sentence, add it to the
    # training example with probability `p_prompt`.
    prompt = sentence_to_prompt[example["target.language"]].get(example["target"], None)
    if prompt:
        if np.random.random() < p_prompt:
            prompt_ids = list(processor.get_prompt_ids(prompt))
            labels = prompt_ids + labels  

    # Create a new dictionary with the processed data
    processed_example = {
        "input_features": input_features,
        "labels": np.array(labels),
        "source.language": example["source.language"],
        "target.language": example["target.language"]
    }

    return processed_example

In [44]:
train_data = train_ds.map(prepare_dataset, remove_columns=["source", "target"])
val_data = valid_ds.map(prepare_dataset, remove_columns=["source", "target"])

In [45]:
compute_metrics = salt.metrics.multilingual_eval_fn(
      valid_ds, [evaluate.load('wer'), evaluate.load('cer')],
      processor.tokenizer, log_first_N_predictions=5,
      speech_processor=processor)

Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/5.60k [00:00<?, ?B/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Map:   0%|          | 0/103 [00:00<?, ? examples/s]

Map:   0%|          | 0/101 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Map:   0%|          | 0/95 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [46]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

if config['use_peft']:
    model = peft.prepare_model_for_kbit_training(model)
    lora_config = peft.LoraConfig(**config['lora_config'])
    model.enable_input_require_grads()
    model = peft.get_peft_model(model, lora_config)
    model.config.use_cache = False
    model.print_trainable_parameters()

Launch the training

In [47]:
training_args = transformers.Seq2SeqTrainingArguments(
  **config["training_args"],
  report_to= [
      platform for platform, use in [("wandb", use_wandb), ("mlflow", use_mlflow)] if use]
)

trainer = transformers.Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

trainer.train()

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)
max_steps is given, it will override any value given in num_train_epochs
2024/09/17 15:58:01 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Step,Training Loss,Validation Loss,Wer Eng,Wer Lug,Wer Ach,Wer Lgg,Wer Teo,Wer Nyn,Wer Mean,Cer Eng,Cer Lug,Cer Ach,Cer Lgg,Cer Teo,Cer Nyn,Cer Mean
100,2.3622,0.564442,0.151,0.135,0.258,0.394,0.274,0.378,0.265,0.149,0.03,0.067,0.117,0.08,0.079,0.087
200,1.3837,0.522994,0.014,0.116,0.244,0.372,0.266,0.337,0.225,0.005,0.031,0.064,0.099,0.081,0.075,0.059
300,1.2559,0.48994,0.027,0.094,0.227,0.366,0.254,0.35,0.22,0.018,0.023,0.057,0.096,0.074,0.081,0.058
400,1.1864,0.480566,0.014,0.093,0.223,0.363,0.254,0.316,0.211,0.004,0.021,0.057,0.093,0.068,0.067,0.052
500,1.1603,0.484391,0.014,0.094,0.222,0.359,0.244,0.319,0.209,0.004,0.023,0.056,0.094,0.064,0.068,0.052


Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


First N predictions in eval set:
Prediction (eng to eng): "There are a number of wealth creation programs around agriculture.", True label: "There are a number of wealth creation programs around agriculture."
Prediction (eng to eng): "Nkubaayo abantu baingi, abakungu baingi, abakungu baingi baingi baingi baingi baingi baingi baingi baingi baingi baingi baingi baingi baingi baingi baingi baingi baingi baingi baingi baingi baingi baingi baingi baingi baing", True label: "They get information on, planting time, good crop agronomy and post-harvest practices."
Prediction (eng to eng): "Thermites have become a very big issue in this garden.", True label: "Termites have become a very big issue in this garden."
Prediction (eng to eng): "The leaves of the plant have been affected by the disease.", True label: "The leaves of the plant have been affected by the disease."
Prediction (eng to eng): "The whole world is in a pandemic.", True label: "The whole world is in a pandemic."


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [], 'begin_suppress_tokens': [220, 50257]}
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
HTTP Error 500 thrown while requesting PUT https://hf-hub-lfs-us-east-1.s3-accelerate.amazonaws.com/repos/0a/cc/0accaa0f874904f8d21297265b13c1ace11ad98e39a745633610d07baa31d2d1/76b7b5f8fec0fe54afc471d5d23b342afe4bd347bbe837b2a7c184fafa23e5b1?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=AKIA2JU7TKAQLC2QXPN7%2F20240917%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240917T170724Z&X-Amz-Expires=86400&X-Amz-Signature=39798f2c7fc75a2cff4e2535075af4f83cbdc2b7e2d74ba427d89f3b44d774d8&X-Amz-SignedHeaders=host&partNumber=101&uploadId=bBQ2gDuHsE3hfruMCBXY0qFbY9U7wXXWJT.dHLRPXYBQBDX_pyTLJR.IIgW3rKj93hl5LA5o3MdblT59JrtE9zExb1bEcfbmhuwADmxB6I17bgWxM16leYzJChK3FDFY&x-id=UploadPart
Retrying in 1s [Retry 1/5].
HTTP Err

Generating train split:   0%|          | 0/198 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/47 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/328 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/27 [00:00<?, ? examples/s]

First N predictions in eval set:
Prediction (eng to eng): "There are a number of wealth creation programs around agriculture.", True label: "There are a number of wealth creation programs around agriculture."
Prediction (eng to eng): "They get information on planting time, good crop agronomy and post-harvest practices.", True label: "They get information on, planting time, good crop agronomy and post-harvest practices."
Prediction (eng to eng): "Thermites have become a very big issue in this garden.", True label: "Termites have become a very big issue in this garden."
Prediction (eng to eng): "The leaves of the plant have been affected by the disease.", True label: "The leaves of the plant have been affected by the disease."
Prediction (eng to eng): "The whole world is in a pandemic.", True label: "The whole world is in a pandemic."


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [], 'begin_suppress_tokens': [220, 50257]}
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


First N predictions in eval set:
Prediction (eng to eng): "There are a number of wealth creation programs around agriculture.", True label: "There are a number of wealth creation programs around agriculture."
Prediction (eng to eng): "They get information on planting time, good crop agronomy, and post-harvest practices.", True label: "They get information on, planting time, good crop agronomy and post-harvest practices."
Prediction (eng to eng): "Thermites have become a very big issue in this garden.", True label: "Termites have become a very big issue in this garden."
Prediction (eng to eng): "The leaves of the plant have been affected by the disease.", True label: "The leaves of the plant have been affected by the disease."
Prediction (eng to eng): "The whole world is in a pandemic.", True label: "The whole world is in a pandemic."


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [], 'begin_suppress_tokens': [220, 50257]}
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


First N predictions in eval set:
Prediction (eng to eng): "There are a number of wealth creation programs around agriculture.", True label: "There are a number of wealth creation programs around agriculture."
Prediction (eng to eng): "They get information on planting time, good crop agronomy, and post-harvest practices.", True label: "They get information on, planting time, good crop agronomy and post-harvest practices."
Prediction (eng to eng): "Thermites have become a very big issue in this garden.", True label: "Termites have become a very big issue in this garden."
Prediction (eng to eng): "The leaves of the plant have been affected by the disease.", True label: "The leaves of the plant have been affected by the disease."
Prediction (eng to eng): "The whole world is in a pandemic.", True label: "The whole world is in a pandemic."


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [], 'begin_suppress_tokens': [220, 50257]}
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


First N predictions in eval set:
Prediction (eng to eng): "There are a number of wealth creation programs around agriculture.", True label: "There are a number of wealth creation programs around agriculture."
Prediction (eng to eng): "They get information on planting time, good crop agronomy, and post-harvest practices.", True label: "They get information on, planting time, good crop agronomy and post-harvest practices."
Prediction (eng to eng): "Thermites have become a very big issue in this garden.", True label: "Termites have become a very big issue in this garden."
Prediction (eng to eng): "The leaves of the plant have been affected by the disease.", True label: "The leaves of the plant have been affected by the disease."
Prediction (eng to eng): "The whole world is in a pandemic.", True label: "The whole world is in a pandemic."


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [], 'begin_suppress_tokens': [220, 50257]}
There were missing keys in the checkpoint model loaded: ['proj_out.weight'].
2024/09/17 21:42:11 INFO mlflow.tracking._tracking_service.client: 🏃 View run whisper-large-v2-multilingual at: https://mlflow-sunbird-ce0ecfc14244.herokuapp.com/#/experiments/0/runs/d79ffa675d12482ea3a5f3fe435a080e.
2024/09/17 21:42:11 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: https://mlflow-sunbird-ce0ecfc14244.herokuapp.com/#/experiments/0.
2024/09/17 21:42:12 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2024/09/17 21:42:12 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!


TrainOutput(global_step=500, training_loss=1.469713134765625, metrics={'train_runtime': 20649.1928, 'train_samples_per_second': 6.199, 'train_steps_per_second': 0.024, 'total_flos': 2.717149345579008e+20, 'train_loss': 1.469713134765625, 'epoch': 4.0855})

Log the config settings for reference

In [48]:
if use_mlflow:
    mlflow.log_params(config)

2024/09/17 22:04:06 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.


Save the full model (not just the adapter weights)

In [50]:
processor.push_to_hub(config['training_args']['hub_model_id'])
model.push_to_hub(config['training_args']['hub_model_id'])

README.md:   0%|          | 0.00/5.17k [00:00<?, ?B/s]

No files have been modified since last commit. Skipping to prevent empty commit.
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [], 'begin_suppress_tokens': [220, 50257]}


CommitInfo(commit_url='https://huggingface.co/jq/whisper-large-v2-multilingual-prompts-corrected/commit/4461b1268e28451d8056f92e6fb78231747cd8cf', commit_message='Upload WhisperForConditionalGeneration', commit_description='', oid='4461b1268e28451d8056f92e6fb78231747cd8cf', pr_url=None, pr_revision=None, pr_num=None)

Try running the model on the first test example

In [None]:
example = next(iter(valid_ds))
input_features = processor(example["source"], sampling_rate=16000, return_tensors="pt").input_features
with torch.no_grad():
    predicted_ids = model.generate(input_features.to("cuda"))[0]
transcription = processor.decode(predicted_ids)
print(transcription)