Working on RTX 6000Ada 48GB (per-device batch size 2) and H100 80GB (per-device batch size 16)

In [None]:
!pip install -q jiwer==3.1.0
!pip install -q evaluate
!pip install -qU accelerate
!pip install -Uq torch
!pip install -q transformers[torch]
!pip install -q soundfile
!git clone https://github.com/SunbirdAI/salt.git
!pip install -qr salt/requirements.txt
!pip install -q peft
!pip install -q torchaudio torchvision

In [None]:
use_wandb = False
use_mlflow = True

import importlib.metadata
installed = [
    dist.metadata['Name']
    for dist in importlib.metadata.distributions()
]

if use_wandb:
  !pip install -q wandb
  import wandb
  %set_env WANDB_LOG_MODEL=True
  %set_env WANDB_WATCH=all
  %set_env WANDB_NOTEBOOK_NAME=whisper_base_en_sb.ipynb
  wandb.login()

if use_mlflow:
  if 'mlflow' not in installed:
      !pip install -q mlflow
      ## requirements to log system/GPU metrics in mlflow
  !pip install -q psutil
  !pip install -q pynvml
  import os
  from getpass import getpass
  import mlflow
  import mlflow.pytorch
  from mlflow import MlflowClient

  # Set MLflow tracking credentials
  MLFLOW_TRACKING_USERNAME = getpass('Enter the MLFLOW_TRACKING_USERNAME: ')
  os.environ['MLFLOW_TRACKING_USERNAME'] = MLFLOW_TRACKING_USERNAME

  MLFLOW_TRACKING_PASSWORD = getpass('Enter the MLFLOW_TRACKING_PASSWORD: ')
  os.environ['MLFLOW_TRACKING_PASSWORD'] = MLFLOW_TRACKING_PASSWORD
  os.environ["MLFLOW_EXPERIMENT_NAME"] = "kinyarwanda-asr"

  # Set the MLflow tracking URI
  mlflow.set_tracking_uri('https://mlflow-sunbird-ce0ecfc14244.herokuapp.com/')
  mlflow.system_metrics.enable_system_metrics_logging()

In [36]:
import torch
import transformers
from dataclasses import dataclass, field
from typing import Union, List, Dict, Any
import string
import os
import json
import datasets
import numpy as np
import yaml
import evaluate
import salt.dataset
import salt.metrics
import salt.constants
from salt.utils import DataCollatorCTCWithPadding as dcwp
import huggingface_hub
import peft
import pandas as pd
import tqdm.notebook as tqdm

In [5]:
huggingface_hub.notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [6]:
# In case SALT library is modified and has to be reloaded:
# !rm -rf salt
# !git clone https://github.com/jqug/salt.git
#from importlib import reload
#reload(salt.dataset)

In [7]:
yaml_config = f'''
pretrained_model: openai/whisper-large-v3
mlflow_experiment_name : stt-whisper
num_workers: 8
use_peft: False
lora_config:
    r: 32
    lora_alpha: 64
    target_modules: ["q_proj", "v_proj"]
    lora_dropout: 0.05
    bias: "none"

training_args:
    output_dir: whisper-large-v3-multilingual
    per_device_train_batch_size: 16
    per_device_eval_batch_size: 16
    gradient_accumulation_steps: 4  # increase by 2x for every 2x decrease in batch size
    learning_rate: 1.0e-5
    warmup_steps: 100
    max_steps: 20000
    gradient_checkpointing: True
    gradient_checkpointing_kwargs:
      use_reentrant: True
    fp16: True
    eval_strategy: steps
    predict_with_generate: True
    generation_max_length: 200
    save_steps: 1000
    eval_steps: 200 # Was 250
    logging_steps: 200
    load_best_model_at_end: True
    metric_for_best_model: loss
    greater_is_better: False
    push_to_hub: False
    hub_model_id: jq/whisper-large-v3-kin-nyn-lug-xog
    save_total_limit: 2
    
train:
    download_datasets_in_parallel: True
    huggingface_load:
        # Main challenge dataset
        # Keep some for validation while training
        - path: jq/kinyarwanda-speech-hackathon
          split: train[:-100]
        # Yogera open data in related languages
        - path: Sunbird/external-speech-data
          name: common-voice-sample-packed-lug
        # - path: Sunbird/external-speech-data
        #   name: common-voice-sample-packed-swa
        #   split: train[:-25]
        - path: Sunbird/external-speech-data
          name: common-voice-sample-packed-kin
          split: train[:-25]
        - path: Sunbird/external-speech-data
          name: makerere-radio-speech
        # - path: Sunbird/external-speech-data
        #   name: makerere-yogera-ach
        - path: Sunbird/external-speech-data
          name: makerere-yogera-lug
        - path: Sunbird/external-speech-data
          name: makerere-yogera-nyn
        # # Save some myx and xog data for validation
        # - path: Sunbird/external-speech-data
        #   name: makerere-yogera-myx
        #   split: train[:-100]
        - path: Sunbird/external-speech-data
          name: makerere-yogera-xog
          split: train[:-100]
        # Non-open datasets excluded from hackathon
        # # Call centre data
        # - path: Sunbird/salt-ucfd
        #   name: eng
        #   split: train
        # - path: Sunbird/salt-ucfd
        #   name: lug
        #   split: train    
        # - path: Sunbird/salt-ucfd
        #   name: numbers-eng
        #   split: train
        # - path: Sunbird/salt-ucfd
        #   name: numbers-lug
        #   split: train  
        # - path: Sunbird/salt-tracfm
        #   name: lug
        #   split: train
        # Main SALT ASR training data
        - path: Sunbird/salt
          name: multispeaker-lug
          split: train
        - path: Sunbird/salt
          name: multispeaker-eng
          split: train
        # - path: Sunbird/salt
        #   name: multispeaker-ach
        #   split: train
        # - path: Sunbird/salt
        #   name: multispeaker-lgg
        #   split: train
        # - path: Sunbird/salt
        #   name: multispeaker-teo
        #   split: train
        - path: Sunbird/salt
          name: multispeaker-nyn
          split: train
        # Google FLEURS
        - path: google/fleurs
          split: train
          name: lg_ug
          trust_remote_code: True
        # - path: google/fleurs
        #   split: train
        #   name: sw_ke
        #   trust_remote_code: True
    source:
      type: speech
      language: [lug,eng,nyn,kin]
      preprocessing:
        # Downsample some examples to 8KHz (to simulate phone audio) 
        - set_sample_rate:
            rate: 8_000
            p: 0.1
        # Then upsample again
        - set_sample_rate:
            rate: 16_000
        - normalize_audio
        - augment_audio_speed:
            p: 0.2
            low: 0.95
            high: 1.15
        - augment_audio_noise:
            max_relative_amplitude: 0.5
            noise_audio_repo:
                path: Sunbird/urban-noise
                name: small
                split: train       
    target:
      type: text
      preprocessing:
        - ensure_text_ends_with_punctuation
      language: [lug,eng,nyn,kin]
    shuffle: True
validation:
    huggingface_load:
        # Held-out challenge data for validation
        - path: jq/kinyarwanda-speech-hackathon
          split: train[-100:]
        # SALT test data
        # - path: Sunbird/salt
        #   name: multispeaker-eng
        #   split: dev
        - path: Sunbird/salt
          name: multispeaker-lug
          split: dev
        # - path: Sunbird/salt
        #   name: multispeaker-ach
        #   split: dev
        # - path: Sunbird/salt
        #   name: multispeaker-lgg
        #   split: dev
        # - path: Sunbird/salt
        #   name: multispeaker-teo
        #   split: dev
        # - path: Sunbird/salt
        #   name: multispeaker-nyn
        #   split: dev
        # - path: Sunbird/external-speech-data
        #   name: makerere-yogera-myx
        #   split: train[-100:]
        # - path: Sunbird/external-speech-data
        #   name: makerere-yogera-xog
        #   split: train[-100:]
        # - path: Sunbird/external-speech-data
        #   name: common-voice-sample-packed-swa
        #   split: train[-25:]
        # - path: Sunbird/external-speech-data
        #   name: common-voice-sample-packed-kin
        #   split: train[-25:]
    source:
      type: speech
      language: [lug,kin]
      preprocessing:
        - set_sample_rate:
            rate: 16_000
    target:
      type: text
      language: [lug,kin]
'''

config = yaml.safe_load(yaml_config)
train_ds = salt.dataset.create(config['train'], verbose=True)
valid_ds = salt.dataset.create(config['validation'])

In [8]:
# If needed, pre-load the main challenge dataset with multiple download workers
# ds = datasets.load_dataset('jq/kinyarwanda-speech-hackathon', split='train', num_proc=10)

In [9]:
salt.utils.show_dataset(train_ds, audio_features=['source'], N=10)

README.md:   0%|          | 0.00/13.3k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/4.32k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/700 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/13.2k [00:00<?, ?B/s]

fleurs.py:   0%|          | 0.00/12.5k [00:00<?, ?B/s]

train.tar.gz:   0%|          | 0.00/2.19G [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/76 [00:00<?, ?it/s]

train-00000-of-00003.parquet:   0%|          | 0.00/504M [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/475M [00:00<?, ?B/s]

train-00000-of-00003.parquet:   0%|          | 0.00/504M [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/435M [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/211M [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/346M [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/66.1M [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/77.9M [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/76 [00:00<?, ?it/s]

train-00000-of-00001.parquet:   0%|          | 0.00/56.3M [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/1.35M [00:00<?, ?B/s]

Downloading data:   0%|          | 0/76 [00:00<?, ?files/s]

test-00000-of-00001.parquet:   0%|          | 0.00/1.41M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/5002 [00:00<?, ? examples/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/1.62M [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/1.18M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/10867 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/103 [00:00<?, ? examples/s]

train-00000-of-00076.parquet:   0%|          | 0.00/501M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/99 [00:00<?, ? examples/s]

test-00000-of-00001.parquet:   0%|          | 0.00/1.68M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/1.25M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/4884 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/4804 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/103 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/101 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/96 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/99 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/5402 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/6864 [00:00<?, ? examples/s]

train-00001-of-00003.parquet:   0%|          | 0.00/504M [00:00<?, ?B/s]

train-00001-of-00003.parquet:   0%|          | 0.00/504M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7639 [00:00<?, ? examples/s]

train-00001-of-00076.parquet:   0%|          | 0.00/495M [00:00<?, ?B/s]

train-00002-of-00003.parquet:   0%|          | 0.00/506M [00:00<?, ?B/s]

train-00002-of-00003.parquet:   0%|          | 0.00/506M [00:00<?, ?B/s]

train-00002-of-00076.parquet:   0%|          | 0.00/495M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/18489 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/18489 [00:00<?, ? examples/s]

train-00003-of-00076.parquet:   0%|          | 0.00/498M [00:00<?, ?B/s]

dev.tar.gz:   0%|          | 0.00/244M [00:00<?, ?B/s]

train-00004-of-00076.parquet:   0%|          | 0.00/499M [00:00<?, ?B/s]

test.tar.gz:   0%|          | 0.00/596M [00:00<?, ?B/s]

train-00005-of-00076.parquet:   0%|          | 0.00/496M [00:00<?, ?B/s]

train-00006-of-00076.parquet:   0%|          | 0.00/496M [00:00<?, ?B/s]

train-00007-of-00076.parquet:   0%|          | 0.00/496M [00:00<?, ?B/s]

train-00008-of-00076.parquet:   0%|          | 0.00/499M [00:00<?, ?B/s]

train-00009-of-00076.parquet:   0%|          | 0.00/493M [00:00<?, ?B/s]

train-00010-of-00076.parquet:   0%|          | 0.00/496M [00:00<?, ?B/s]

train-00011-of-00076.parquet:   0%|          | 0.00/500M [00:00<?, ?B/s]

train.tsv:   0%|          | 0.00/1.37M [00:00<?, ?B/s]

dev.tsv:   0%|          | 0.00/171k [00:00<?, ?B/s]

test.tsv:   0%|          | 0.00/417k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

train-00012-of-00076.parquet:   0%|          | 0.00/499M [00:00<?, ?B/s]

train-00013-of-00076.parquet:   0%|          | 0.00/494M [00:00<?, ?B/s]

train-00014-of-00076.parquet:   0%|          | 0.00/499M [00:00<?, ?B/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

train-00015-of-00076.parquet:   0%|          | 0.00/498M [00:00<?, ?B/s]

train-00016-of-00076.parquet:   0%|          | 0.00/496M [00:00<?, ?B/s]

train-00017-of-00076.parquet:   0%|          | 0.00/498M [00:00<?, ?B/s]

train-00018-of-00076.parquet:   0%|          | 0.00/495M [00:00<?, ?B/s]

train-00019-of-00076.parquet:   0%|          | 0.00/497M [00:00<?, ?B/s]

train-00020-of-00076.parquet:   0%|          | 0.00/492M [00:00<?, ?B/s]

train-00021-of-00076.parquet:   0%|          | 0.00/497M [00:00<?, ?B/s]

train-00022-of-00076.parquet:   0%|          | 0.00/495M [00:00<?, ?B/s]

train-00023-of-00076.parquet:   0%|          | 0.00/496M [00:00<?, ?B/s]

train-00024-of-00076.parquet:   0%|          | 0.00/492M [00:00<?, ?B/s]

train-00025-of-00076.parquet:   0%|          | 0.00/494M [00:00<?, ?B/s]

train-00026-of-00076.parquet:   0%|          | 0.00/497M [00:00<?, ?B/s]

train-00027-of-00076.parquet:   0%|          | 0.00/498M [00:00<?, ?B/s]

train-00028-of-00076.parquet:   0%|          | 0.00/500M [00:00<?, ?B/s]

train-00029-of-00076.parquet:   0%|          | 0.00/498M [00:00<?, ?B/s]

train-00030-of-00076.parquet:   0%|          | 0.00/500M [00:00<?, ?B/s]

train-00031-of-00076.parquet:   0%|          | 0.00/495M [00:00<?, ?B/s]

train-00032-of-00076.parquet:   0%|          | 0.00/495M [00:00<?, ?B/s]

train-00033-of-00076.parquet:   0%|          | 0.00/499M [00:00<?, ?B/s]

train-00034-of-00076.parquet:   0%|          | 0.00/497M [00:00<?, ?B/s]

train-00035-of-00076.parquet:   0%|          | 0.00/496M [00:00<?, ?B/s]

train-00036-of-00076.parquet:   0%|          | 0.00/498M [00:00<?, ?B/s]

train-00037-of-00076.parquet:   0%|          | 0.00/498M [00:00<?, ?B/s]

train-00038-of-00076.parquet:   0%|          | 0.00/500M [00:00<?, ?B/s]

train-00039-of-00076.parquet:   0%|          | 0.00/497M [00:00<?, ?B/s]

train-00040-of-00076.parquet:   0%|          | 0.00/496M [00:00<?, ?B/s]

train-00041-of-00076.parquet:   0%|          | 0.00/497M [00:00<?, ?B/s]

train-00042-of-00076.parquet:   0%|          | 0.00/497M [00:00<?, ?B/s]

train-00043-of-00076.parquet:   0%|          | 0.00/500M [00:00<?, ?B/s]

train-00044-of-00076.parquet:   0%|          | 0.00/494M [00:00<?, ?B/s]

train-00045-of-00076.parquet:   0%|          | 0.00/496M [00:00<?, ?B/s]

train-00046-of-00076.parquet:   0%|          | 0.00/496M [00:00<?, ?B/s]

train-00047-of-00076.parquet:   0%|          | 0.00/498M [00:00<?, ?B/s]

train-00048-of-00076.parquet:   0%|          | 0.00/495M [00:00<?, ?B/s]

train-00049-of-00076.parquet:   0%|          | 0.00/497M [00:00<?, ?B/s]

train-00050-of-00076.parquet:   0%|          | 0.00/495M [00:00<?, ?B/s]

train-00051-of-00076.parquet:   0%|          | 0.00/499M [00:00<?, ?B/s]

train-00052-of-00076.parquet:   0%|          | 0.00/494M [00:00<?, ?B/s]

train-00053-of-00076.parquet:   0%|          | 0.00/496M [00:00<?, ?B/s]

train-00054-of-00076.parquet:   0%|          | 0.00/494M [00:00<?, ?B/s]

train-00055-of-00076.parquet:   0%|          | 0.00/493M [00:00<?, ?B/s]

train-00056-of-00076.parquet:   0%|          | 0.00/498M [00:00<?, ?B/s]

train-00057-of-00076.parquet:   0%|          | 0.00/496M [00:00<?, ?B/s]

train-00058-of-00076.parquet:   0%|          | 0.00/500M [00:00<?, ?B/s]

train-00059-of-00076.parquet:   0%|          | 0.00/495M [00:00<?, ?B/s]

train-00060-of-00076.parquet:   0%|          | 0.00/497M [00:00<?, ?B/s]

train-00061-of-00076.parquet:   0%|          | 0.00/506M [00:00<?, ?B/s]

train-00062-of-00076.parquet:   0%|          | 0.00/515M [00:00<?, ?B/s]

train-00063-of-00076.parquet:   0%|          | 0.00/514M [00:00<?, ?B/s]

train-00064-of-00076.parquet:   0%|          | 0.00/514M [00:00<?, ?B/s]

train-00065-of-00076.parquet:   0%|          | 0.00/514M [00:00<?, ?B/s]

train-00066-of-00076.parquet:   0%|          | 0.00/513M [00:00<?, ?B/s]

train-00067-of-00076.parquet:   0%|          | 0.00/513M [00:00<?, ?B/s]

train-00068-of-00076.parquet:   0%|          | 0.00/514M [00:00<?, ?B/s]

train-00069-of-00076.parquet:   0%|          | 0.00/512M [00:00<?, ?B/s]

train-00070-of-00076.parquet:   0%|          | 0.00/516M [00:00<?, ?B/s]

train-00071-of-00076.parquet:   0%|          | 0.00/499M [00:00<?, ?B/s]

train-00072-of-00076.parquet:   0%|          | 0.00/492M [00:00<?, ?B/s]

train-00073-of-00076.parquet:   0%|          | 0.00/492M [00:00<?, ?B/s]

train-00074-of-00076.parquet:   0%|          | 0.00/492M [00:00<?, ?B/s]

train-00075-of-00076.parquet:   0%|          | 0.00/491M [00:00<?, ?B/s]

test-00000-of-00003.parquet:   0%|          | 0.00/428M [00:00<?, ?B/s]

test-00001-of-00003.parquet:   0%|          | 0.00/443M [00:00<?, ?B/s]

test-00002-of-00003.parquet:   0%|          | 0.00/454M [00:00<?, ?B/s]

dev_test-00000-of-00003.parquet:   0%|          | 0.00/429M [00:00<?, ?B/s]

dev_test-00001-of-00003.parquet:   0%|          | 0.00/438M [00:00<?, ?B/s]

dev_test-00002-of-00003.parquet:   0%|          | 0.00/452M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/264754 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/9265 [00:00<?, ? examples/s]

Generating dev_test split:   0%|          | 0/9263 [00:00<?, ? examples/s]

Loading dataset shards:   0%|          | 0/75 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/76 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/76 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/75 [00:00<?, ?it/s]

jq/kinyarwanda-speech-hackathon: 264654 rows
Sunbird/external-speech-data_common-voice-sample-packed-lug: 18489 rows
Sunbird/external-speech-data_common-voice-sample-packed-kin: 18464 rows
Sunbird/external-speech-data_makerere-radio-speech: 10867 rows
Sunbird/external-speech-data_makerere-yogera-lug: 5402 rows
Sunbird/external-speech-data_makerere-yogera-nyn: 7639 rows
Sunbird/external-speech-data_makerere-yogera-xog: 6764 rows
Sunbird/salt_multispeaker-lug: 5002 rows
Sunbird/salt_multispeaker-eng: 4804 rows
Sunbird/salt_multispeaker-nyn: 4884 rows
google/fleurs_lg_ug: 2478 rows
Total rows: 349447


README.md:   0%|          | 0.00/2.03k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/48.5M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Unnamed: 0,source,target,source.language,target.language
0,Your browser does not support the audio element.,Abantu batatu ni ukuvuga abana babiri n'umuntuntu mukuru w'umukobwa yaje kuvoma amazi ateze ijerekani.,kin,kin
1,Your browser does not support the audio element.,"Icyumba cy'inama cyagenewe kwakira ibikorwa bitandukanye bihuriza hamwe abantu benshi. Akaba ari mu intebe zifite ibara ry'umukara, ndetse imbere ya buri ntebe hari ameza afite ibara ryikigina, imbere kandi mu nyubako hari ahagenewe gusomerwa, cyangwa gutangirwa imbwirwaruhame hagaragara ameza afite ibara ry'umweru, inyubako kandi irimo indangururamajwi ku buryo abari muri iki cyumba cy'inama bose babasha kumva amajwi arimo kuvugwa, n'uri gutanga imbwirwaruhame.",kin,kin
2,Your browser does not support the audio element.,"Inzu nziza cyane ifite isuku, irimo amarangi y'umweru, parafo z'umweru, hasi hari amakaro y'umukara, ifite inzugi nziza cyane zikomeye zirimo utubara tw'umweru, hagati irimo ababyeyi bagiye bicaye.",kin,kin
3,Your browser does not support the audio element.,"Eby'okurya ebirikwahukana bitebeekanise kurungi kandi byateebwa naha byikaro ebirikubaasa okureteera omuntu yabaasa kugira ekihika ky'okubirya, n'omuntu omwe ayetekanize ari aho haihi kureeba ngu yaabaasa kuba nabiheereza abo abeetekanize kubirya.",nyn,nyn
4,Your browser does not support the audio element.,"Abagabo babiri bahagaze imbere y'igikamyo kinini, umwe akaba yambaye ishati y'umweru n'ipantaro y'ubururu, undi akaba yambaye umupira urimo amabara y'umweru, icyatsi n'umukara, bakaba bafite urufunguzo rwanditseho Akagera mutoro.",kin,kin
5,Your browser does not support the audio element.,"Inyubako ifite parikingi irimo imodoka imwe, iruhande rwayo hariho abantu benshi binjira muri iyo nyubako, utarayigiramo kandi ku gisenge hariho amabati tegura, iruhande rwayo hari umugabo wambaye umupira w'umutuku afite ikayi ari kwandikamo.",kin,kin
6,Your browser does not support the audio element.,"Nindeeba abeegi baboojo na abaishiki bari ahaza karimagyezi, omwojo omwe ajwaire ebirahuri ondiijo omwishiki aha munwa asigireho ebirangi byokutukura akomire ishokye.",nyn,nyn
7,Your browser does not support the audio element.,"Inyubako isize irangi ry'umweru, irimo ibirahuri bibonerana, ndetse iri kugendwamo n'umugabo n'abandi bantu bicaye inyuma y'ibirahuri, harimo n'amatara ari kwaka ndetse n'ameza.",kin,kin
8,Your browser does not support the audio element.,"Ibiganza by'umuganga wambaye uturindantoki, ufite urushinge mu ntoki, uri kurutera undi muntu na we wambaye imyenda y'icyatsi, akaba ari kumutera urushinge kugira ngo amukingire indwara runaka cyangwa se amutera umuti w'indi ndwara runaka.",kin,kin
9,Your browser does not support the audio element.,"Umusore uri mu myitozo ngororamubiri, akaba yambaye agasengeri kari mu ibara ry'umukara ndetse akaba ashyize amaboko inyuma, aho arimo ananura kumwe akoresheje ukundi, bikamufasha kugorora imikaya ye ndetse bikamufasha no kugira ubuzima bwiza.",kin,kin


In [10]:
feature_extractor = transformers.WhisperFeatureExtractor.from_pretrained(
    config['pretrained_model'])
processor = transformers.WhisperProcessor.from_pretrained(
    config['pretrained_model'], language=None, task="transcribe")
model = transformers.WhisperForConditionalGeneration.from_pretrained(
    config['pretrained_model'])

preprocessor_config.json:   0%|          | 0.00/340 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/283k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.48M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/34.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/3.90k [00:00<?, ?B/s]

In [11]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]    
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor, decoder_start_token_id=model.config.decoder_start_token_id)

Read in prompts: preceding text which is used to guide the model.

In [12]:
sentences = datasets.load_dataset(
    'Sunbird/salt', 'text-all', split='train').to_pandas()
prompts = datasets.load_dataset(
    'Sunbird/prompts', split='train').to_pandas()
joined = pd.merge(sentences, prompts, on='id', how='inner')
SALT_PROMPT_LANGUAGES = ['eng', 'ach', 'lgg', 'lug', 'nyn', 'teo']
sentence_to_prompt = {}
for language in SALT_PROMPT_LANGUAGES:
    sentence_key = 'eng_source_text' if language == 'eng' else f'{language}_text'
    sentence_to_prompt[language] = dict(
        zip(joined[sentence_key], joined[f'{language}_prompt']))

train-00000-of-00001.parquet:   0%|          | 0.00/9.53M [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/223k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/233k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/23947 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/496 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/500 [00:00<?, ? examples/s]

README.md:   0%|          | 0.00/509 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/14.8M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/23947 [00:00<?, ? examples/s]

In [13]:
language_id_tokens = salt.constants.SALT_LANGUAGE_TOKENS_WHISPER

def prepare_dataset(example, p_prompt = 0.5):    
    audio = example["source"]
    input_features = feature_extractor(
        audio, sampling_rate=16000, device='cuda',
        do_normalize=True).input_features[0]

    # Encode target text to label ids
    labels = processor.tokenizer(str(example["target"])).input_ids

    # Insert the language ID token into the second position of the sequence.
    labels.insert(1, language_id_tokens[example["target.language"]])

    # If a prompt is known for a particular sentence, add it to the
    # training example with probability `p_prompt`.
    if example["target.language"] in sentence_to_prompt:
        prompt = sentence_to_prompt[example["target.language"]].get(example["target"], None)
        if prompt:
            if np.random.random() < p_prompt:
                prompt_ids = list(processor.get_prompt_ids(prompt))
                labels = prompt_ids + labels  

    # Create a new dictionary with the processed data
    processed_example = {
        "input_features": input_features,
        "labels": np.array(labels),
        "source.language": example["source.language"],
        "target.language": example["target.language"]
    }

    return processed_example

In [14]:
train_data = train_ds.map(prepare_dataset, remove_columns=["source", "target"])
val_data = valid_ds.map(prepare_dataset, remove_columns=["source", "target"])

In [15]:
compute_metrics = salt.metrics.multilingual_eval_fn(
      valid_ds, [evaluate.load('wer'), evaluate.load('cer')],
      processor.tokenizer, log_first_N_predictions=3,
      speech_processor=processor)

Downloading builder script:   0%|          | 0.00/5.13k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/6.61k [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/76 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/76 [00:00<?, ?it/s]

In [16]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

if config['use_peft']:
    model = peft.prepare_model_for_kbit_training(model)
    lora_config = peft.LoraConfig(**config['lora_config'])
    model.enable_input_require_grads()
    model = peft.get_peft_model(model, lora_config)
    model.config.use_cache = False
    model.print_trainable_parameters()

In [17]:
# If there was an interrupted training run, then reset mlflow
#mlflow.end_run()

Launch the training

In [None]:
training_args = transformers.Seq2SeqTrainingArguments(
  **config["training_args"],
  report_to= [
      platform for platform, use in [("wandb", use_wandb), ("mlflow", use_mlflow)] if use]
)

trainer = transformers.Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    processing_class=processor,
)

trainer.train()

In [None]:
%debug

Log the config settings for reference

In [None]:
if use_mlflow:
    mlflow.log_params(config)

In [19]:
config['training_args']['hub_model_id']

'jq/whisper-large-v3-kin-nyn-lug-xog'

Save the full model (not just the adapter weights)

In [None]:
processor.push_to_hub(config['training_args']['hub_model_id'], private=True)
model.push_to_hub(config['training_args']['hub_model_id'], private=True)

# Predictions on the test set

In [50]:
test_ds = datasets.load_dataset('jq/kinyarwanda-speech-hackathon', split='dev_test')
test_ds = test_ds.cast_column("audio", datasets.Audio(sampling_rate=16000))

Resolving data files:   0%|          | 0/76 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/76 [00:00<?, ?it/s]

In [None]:
test_ids = []
test_transcriptions = []

predict_full_test_set = True

if predict_full_test_set:
    N = len(test_ds)
    test_ds = datasets.load_dataset('jq/kinyarwanda-speech-hackathon', split='test')
else:
    test_labels = []
    N = 100
    test_ds = datasets.load_dataset('jq/kinyarwanda-speech-hackathon', split='dev_test')

test_ds = test_ds.cast_column("audio", datasets.Audio(sampling_rate=16000))
for i in tqdm.tqdm(range(N)):   
    example = test_ds[i]
    input_features = processor(
        example["audio"]["array"], sampling_rate=16000, return_tensors="pt").input_features
    input_features = input_features.to('cuda')
    predicted_ids = model.generate(
        input_features,
        num_beams=5,
        language=processor.tokenizer.decode(salt.constants.SALT_LANGUAGE_TOKENS_WHISPER['kin']),
        forced_decoder_ids=None)
    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]

    if quick_verification:
        test_labels.append(example['text'])

    test_transcriptions.append(transcription)
    test_ids.append(example['id'])

Resolving data files:   0%|          | 0/76 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/9263 [00:00<?, ?it/s]

In [85]:
import jiwer
total_wer = jiwer.wer(test_labels, test_transcriptions)
total_cer = jiwer.cer(test_labels, test_transcriptions)
score = 1 - (0.6 * total_cer + 0.4 * total_wer)

print(f"Word Error Rate (WER): {total_wer:.3f}")
print(f"Character Error Rate (CER): {total_cer:.3f}")
print(f"Score: {score:.3f}")

Word Error Rate (WER): 0.167
Character Error Rate (CER): 0.035
Score: 0.913


In [83]:
# No beam search

import jiwer
total_wer = jiwer.wer(test_labels, test_transcriptions)
total_cer = jiwer.cer(test_labels, test_transcriptions)
score = 1 - (0.6 * total_cer + 0.4 * total_wer)

print(f"Word Error Rate (WER): {total_wer:.3f}")
print(f"Character Error Rate (CER): {total_cer:.3f}")
print(f"Score: {score:.3f}")

Word Error Rate (WER): 0.171
Character Error Rate (CER): 0.037
Score: 0.909
