Working on RTX 6000Ada 48GB (per-device batch size 2) and H100 80GB (per-device batch size 16)

In [None]:
!pip install -q jiwer==3.1.0
!pip install -q evaluate
!pip install -qU accelerate
!pip install -Uq torch
!pip install -q transformers[torch]
!pip install -q soundfile
!git clone https://github.com/SunbirdAI/salt.git
!pip install -qr salt/requirements.txt
!pip install -q peft
!pip install -q torchaudio torchvision

In [1]:
use_wandb = False
use_mlflow = True

import importlib.metadata
installed = [
    dist.metadata['Name']
    for dist in importlib.metadata.distributions()
]

if use_wandb:
  !pip install -q wandb
  import wandb
  %set_env WANDB_LOG_MODEL=True
  %set_env WANDB_WATCH=all
  %set_env WANDB_NOTEBOOK_NAME=whisper_base_en_sb.ipynb
  wandb.login()

if use_mlflow:
  if 'mlflow' not in installed:
      !pip install -q mlflow
      ## requirements to log system/GPU metrics in mlflow
  !pip install -q psutil
  !pip install -q pynvml
  import os
  from getpass import getpass
  import mlflow
  import mlflow.pytorch
  from mlflow import MlflowClient

  # Set MLflow tracking credentials
  MLFLOW_TRACKING_USERNAME = getpass('Enter the MLFLOW_TRACKING_USERNAME: ')
  os.environ['MLFLOW_TRACKING_USERNAME'] = MLFLOW_TRACKING_USERNAME

  MLFLOW_TRACKING_PASSWORD = getpass('Enter the MLFLOW_TRACKING_PASSWORD: ')
  os.environ['MLFLOW_TRACKING_PASSWORD'] = MLFLOW_TRACKING_PASSWORD
  os.environ["MLFLOW_EXPERIMENT_NAME"] = "kinyarwanda-asr"

  # Set the MLflow tracking URI
  mlflow.set_tracking_uri('https://mlflow-sunbird-ce0ecfc14244.herokuapp.com/')
  mlflow.system_metrics.enable_system_metrics_logging()

[0m

Enter the MLFLOW_TRACKING_USERNAME:  ········
Enter the MLFLOW_TRACKING_PASSWORD:  ········


In [3]:
import torch
import transformers
from dataclasses import dataclass, field
from typing import Union, List, Dict, Any
import string
import os
import json
import datasets
import numpy as np
import yaml
import evaluate
import salt.dataset
import salt.metrics
import salt.constants
from salt.utils import DataCollatorCTCWithPadding as dcwp
import huggingface_hub
import peft
import pandas as pd
import tqdm.notebook as tqdm

In [5]:
huggingface_hub.notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [6]:
# In case SALT library is modified and has to be reloaded:
# !rm -rf salt
# !git clone https://github.com/jqug/salt.git
#from importlib import reload
#reload(salt.dataset)

In [4]:
yaml_config = f'''
pretrained_model: jq/whisper-large-v3-kin-nyn-lug-xog # openai/whisper-large-v3

num_workers: 16
use_peft: False
lora_config:
    r: 32
    lora_alpha: 64
    target_modules: ["q_proj", "v_proj"]
    lora_dropout: 0.05
    bias: "none"

training_args:
    output_dir: whisper-large-v3-kin
    per_device_train_batch_size: 16
    per_device_eval_batch_size: 16
    gradient_accumulation_steps: 4  # increase by 2x for every 2x decrease in batch size
    learning_rate: 1.0e-5
    warmup_steps: 100
    max_steps: 20000
    gradient_checkpointing: True
    gradient_checkpointing_kwargs:
      use_reentrant: True
    fp16: True
    eval_strategy: steps
    predict_with_generate: True
    generation_max_length: 200
    save_steps: 1000
    eval_steps: 200 # Was 250
    logging_steps: 200
    load_best_model_at_end: True
    metric_for_best_model: loss
    greater_is_better: False
    push_to_hub: False
    hub_model_id: jq/whisper-large-v3-kin
    save_total_limit: 2
    
train:
    download_datasets_in_parallel: True
    huggingface_load:
        # Main challenge dataset
        - path: jq/kinyarwanda-speech-hackathon
          split: train
        - path: jq/kinyarwanda-speech-hackathon
          split: dev_test[1000:]
    source:
      type: speech
      language: [kin]
      preprocessing:
        # Downsample some examples to 8KHz (to simulate phone audio) 
        - set_sample_rate:
            rate: 8_000
            p: 0.05
        # Then upsample again
        - set_sample_rate:
            rate: 16_000
        - normalize_audio
        - augment_audio_speed:
            p: 0.2
            low: 0.95
            high: 1.15
        - augment_audio_noise:
            max_relative_amplitude: 0.5
            noise_audio_repo:
                path: Sunbird/urban-noise
                name: small
                split: train       
    target:
      type: text
      preprocessing:
        - ensure_text_ends_with_punctuation
      language: [kin]
    shuffle: True
validation:
    huggingface_load:
        # Held-out challenge data for validation
        - path: jq/kinyarwanda-speech-hackathon
          split: dev_test[:100]
    source:
      type: speech
      language: [kin]
      preprocessing:
        - set_sample_rate:
            rate: 16_000
    target:
      type: text
      language: [kin]
'''

config = yaml.safe_load(yaml_config)
train_ds = salt.dataset.create(config['train'], verbose=True)
valid_ds = salt.dataset.create(config['validation'])

In [5]:
# If needed, pre-load the main challenge dataset with multiple download workers
# ds = datasets.load_dataset('jq/kinyarwanda-speech-hackathon', split='train', num_proc=10)

In [5]:
salt.utils.show_dataset(train_ds, audio_features=['source'], N=10)

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

Using the latest cached version of the dataset since jq/kinyarwanda-speech-hackathon couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /root/.cache/huggingface/datasets/jq___kinyarwanda-speech-hackathon/default/0.0.0/90607206a75fcf60a683663f8826dd5013e0ef39 (last modified on Mon Jun 23 18:33:39 2025).


Loading dataset shards:   0%|          | 0/74 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/74 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

jq/kinyarwanda-speech-hackathon: 261657 rows
jq/kinyarwanda-speech-hackathon: 8263 rows
Total rows: 269920


Unnamed: 0,source,target,source.language,target.language
0,Your browser does not support the audio element.,"Icyapa kiyobora imodoka, icyo cyapa kikaba kiri mu ibara ry'umweru, ibimenyetso byereka imodoka bikaba biri mu ibara ry'umukara ibyo byose bikaba biri mu kinyampande cy'impande eshatu.",kin,kin
1,Your browser does not support the audio element.,Uburyo bwo kuhira imyaka nibura uburyo bwifashishwa mu gihe ibihingwa bitameze neza bitewe n'izuba cyangwa se aho hantu hava izuba ryinshi hagaterwa n'ibyo bindi kandi bigaterwa n'uko ikirere kimeze cyangwa se gihindagurika buri munsi imvura ikabura bigatuma abantu buhira cyangwa se abahinzi buhira kugira imyaka yabo ikomezaze kumera neza.,kin,kin
2,Your browser does not support the audio element.,"Igiceri kiriho ibirango bitandukanye, biriho insina ifite igitoki ndetse n'mwano, bigaragara ko ari icy'icumi cy'amafaranga y'u Rwanda akoreshwa ubungubu.",kin,kin
3,Your browser does not support the audio element.,"Amashu ubwoko bw'imboga abantu baryaga bazihingaga ahantu habaga amazi menshi, ubundi zamara gukura bakazisarura.Abantu bari kumenera izi mboga kugira ngo zikure neza, izi mboga iyo zeze abantu barazikekaga bakaziteka hari abazitogosaga abandi bakazikaranga.",kin,kin
4,Your browser does not support the audio element.,"Ikigo gishinzwe imisoro n'amahoro cyagaragaje ishusho yashyizweho telefone ifite ekara y'umweru, ku ruhande rw'iburyo hari amagambo ari mu rurimi rw'icyongereza, ndetse na nimero za telefoni zishobora kwandikirwaho mesaje ku rubuga rwabugenewe.",kin,kin
5,Your browser does not support the audio element.,"Akayira karimo umugabo wambaye inkweto y'umweru ndetse n'ipantaro y'umukara, akaba ari gusohoka yerekeza ku muryango imbere ye hari umukobwa ugeze hanze, hepfo yabo hari inzu zihubatse ndetse n'ibiti bihateye.",kin,kin
6,Your browser does not support the audio element.,"Ku mugezi amajerekani atatu, ijerekani imwe iri mu mugezi, iri kujyamo amazi, umwana ayifasheho, yambaye umupira w'umweru n'ipantaro, afashe no ku mugezi, ari kureba ahandi hantu, afite agasatsi gakeya.",kin,kin
7,Your browser does not support the audio element.,"Aha baragaragaza kamera zishoba kuba ziri mu mihanda runaka nko muri Kigali se, cyangwa mu tundi duce two mu Rwanda. Zifashishwa mu gucunga amakosa abera mu muhanda aterwa n'batwaye ibinyabiziga.",kin,kin
8,Your browser does not support the audio element.,"Umubyeyi ufite umwana mutoya mu ntoki, ari kumwe na muganga wambaye itaburiya ndetse kandi uyu muganga akaba yambaye uturindantoki, uyu muganga arimo aratanga serivise z'ubuvuzi kuri uyu umwana mutoya.",kin,kin
9,Your browser does not support the audio element.,"Imbere yanjye ndahabona ambiranse ziparitse, hari nindi ifunguye bari gukuramo umurwayi, hari abantu bicaye ku ntebe bategereje, uko bigaragara hano ari kwa muganga baje kwivuza.",kin,kin


In [6]:
feature_extractor = transformers.WhisperFeatureExtractor.from_pretrained(
    config['pretrained_model'])
processor = transformers.WhisperProcessor.from_pretrained(
    config['pretrained_model'], language=None, task="transcribe")
model = transformers.WhisperForConditionalGeneration.from_pretrained(
    config['pretrained_model'])

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]    
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor, decoder_start_token_id=model.config.decoder_start_token_id)

Read in prompts: preceding text which is used to guide the model.

In [8]:
ds = datasets.load_dataset('jq/kinyarwanda-speech-hackathon', split='train', num_proc=10)
text = list(ds['text'])
prompts = list(ds['prompt'])
sentence_to_prompt = {}
for t, p in zip(text, prompts):
    sentence_to_prompt[t] = p

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/74 [00:00<?, ?it/s]

In [9]:
language_id_tokens = salt.constants.SALT_LANGUAGE_TOKENS_WHISPER

def prepare_dataset(example, p_prompt = 0.5):    
    audio = example["source"]
    input_features = feature_extractor(
        audio, sampling_rate=16000, device='cuda',
        do_normalize=True).input_features[0]

    # Encode target text to label ids
    labels = processor.tokenizer(str(example["target"])).input_ids

    # Insert the language ID token into the second position of the sequence.
    labels.insert(1, language_id_tokens[example["target.language"]])

    # If a prompt is known for a particular sentence, add it to the
    # training example with probability `p_prompt`.
    prompt = sentence_to_prompt.get(example["target"], None)
    if prompt:
        if np.random.random() < p_prompt:
            prompt_ids = list(processor.get_prompt_ids(prompt))
            labels = prompt_ids + labels  

    # Create a new dictionary with the processed data
    processed_example = {
        "input_features": input_features,
        "labels": np.array(labels),
        "source.language": example["source.language"],
        "target.language": example["target.language"]
    }

    return processed_example

In [10]:
train_data = train_ds.map(prepare_dataset, remove_columns=["source", "target"])
val_data = valid_ds.map(prepare_dataset, remove_columns=["source", "target"])

In [11]:
compute_metrics = salt.metrics.multilingual_eval_fn(
      valid_ds, [evaluate.load('wer'), evaluate.load('cer')],
      processor.tokenizer, log_first_N_predictions=3,
      speech_processor=processor)

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

In [12]:
model.config.suppress_tokens = []
model.config.forced_decoder_ids = None
model.generation_config.forced_decoder_ids = None
model.generation_config = transformers.GenerationConfig(
    pad_token_id=processor.tokenizer.pad_token_id,
    eos_token_id=processor.tokenizer.eos_token_id,
    decoder_start_token_id=model.config.decoder_start_token_id,
    use_cache=False
)

if config['use_peft']:
    model = peft.prepare_model_for_kbit_training(model)
    lora_config = peft.LoraConfig(**config['lora_config'])
    model.enable_input_require_grads()
    model = peft.get_peft_model(model, lora_config)
    model.config.use_cache = False
    model.print_trainable_parameters()

In [13]:
# If there was an interrupted training run, then reset mlflow
#mlflow.end_run()

Launch the training

In [None]:
training_args = transformers.Seq2SeqTrainingArguments(
  **config["training_args"],
  report_to= [
      platform for platform, use in [("wandb", use_wandb), ("mlflow", use_mlflow)] if use]
)

trainer = transformers.Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    processing_class=processor,
)

trainer.train()

2025/06/23 19:25:45 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.


Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/74 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/74 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

jq/kinyarwanda-speech-hackathon: 261657 rows
jq/kinyarwanda-speech-hackathon: 8263 rows
Total rows: 269920


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


Step,Training Loss,Validation Loss,Wer Kin,Wer Mean,Cer Kin,Cer Mean
200,0.2299,0.246787,0.101,0.101,0.026,0.026
400,0.1774,0.275631,0.103,0.103,0.027,0.027
600,0.175,0.255629,0.109,0.109,0.027,0.027


Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


First N predictions in eval set:
Prediction (kin to kin): "Umugore wambaye umupira w'akazi, impuzankano iri mu ibara ry'umuhondo, handitseho amagambo yandikishije ibara ry'ubururu, afite igikoresho cy'itumanaho, gikoreshwa mu guhamagara no kwandika ubutumwa bugufi.", True label: "Umugore wambaye umupira w'akazi mpuzankano iri mu ibara ry'umuhondo handitseho amagambo yandikishije ibara ry'ubururu. Afite igikoresho cy'itumanaho gikoreshwa mu guhamagara no kwandika ubutumwa bugufi. "
Prediction (kin to kin): "Uburyo emutiyene yatangije wishyura amafaranga kuri telefone ukoresheje cyangwa mudasobwa batagize amafaranga na make bagukata wohereza.", True label: "Uburyo emutiyene yatangije wishyura amafaranga kuri terefone, ukoresheje cyangwa mudasobwa batagize amafaranga na make bagukata wohereza."
Prediction (kin to kin): "Umudamu uhagaze mu iduka r'inzobe, uri guseka, ufite imisatsi migufiya, inyuma ye hakaba hari etajeri iriho ibicuruzwa bigiye bitandukanye, amavuta, amasabune, ibiribwa nd

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

First N predictions in eval set:
Prediction (kin to kin): "Umugore wambaye umupira w'akazi mpuzankano iri mu ibara ry'umuhondo, handitseho amagambo yandikishije ibara ry'ubururu, afite igikoresho cy'itumanaho gikoreshwa mu guhamagara no kwandika ubutumwa bugufi.", True label: "Umugore wambaye umupira w'akazi mpuzankano iri mu ibara ry'umuhondo handitseho amagambo yandikishije ibara ry'ubururu. Afite igikoresho cy'itumanaho gikoreshwa mu guhamagara no kwandika ubutumwa bugufi. "
Prediction (kin to kin): "Uburyo emutiyeni yatangije wishyura amafaranga kuri terefone ukoresheje cyangwa mudasobwa batagize amafaranga na make bagukata wohereza.", True label: "Uburyo emutiyene yatangije wishyura amafaranga kuri terefone, ukoresheje cyangwa mudasobwa batagize amafaranga na make bagukata wohereza."
Prediction (kin to kin): "Umudamu uhagaze mu iduka rw'inzobe uri guseka ufite imisatsi migufiya, inyuma ye hakaba hari etageri iriho ibicuruzwa bigiye bitandukanye amavuta, amasabune, ibiribwa ndetse 

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

First N predictions in eval set:
Prediction (kin to kin): "Umugore wambaye umupira w'akazi mpuzankano iri mu ibara ry'umuhondo handitseho amagambo yandikishije ibara ry'ubururu, afite igikoresho cy'itumanaho gikoreshwa mu guhamagara no kwandika ubutumwa bugufi.", True label: "Umugore wambaye umupira w'akazi mpuzankano iri mu ibara ry'umuhondo handitseho amagambo yandikishije ibara ry'ubururu. Afite igikoresho cy'itumanaho gikoreshwa mu guhamagara no kwandika ubutumwa bugufi. "
Prediction (kin to kin): "Uburyo emutiyeni yatangije wishyura amafaranga kuri telefone ukoresheje cyangwa mudasobwa batagize amafaranga na make bagukata wohereza.", True label: "Uburyo emutiyene yatangije wishyura amafaranga kuri terefone, ukoresheje cyangwa mudasobwa batagize amafaranga na make bagukata wohereza."
Prediction (kin to kin): "Umudamu uhagaze mu iduka w'inzobe uri guseka ufite imisatsi migufiya, inyuma ye hakaba hari etajeri iriho ibicuruzwa bigiye bitandukanye amavuta, amasabune, ibiribwa ndetse n'

In [None]:
%debug

Log the config settings for reference

In [None]:
if use_mlflow:
    mlflow.log_params(config)

In [19]:
config['training_args']['hub_model_id']

'jq/whisper-large-v3-kin-nyn-lug-xog'

Save the full model (not just the adapter weights)

In [None]:
processor.push_to_hub(config['training_args']['hub_model_id'], private=True)
model.push_to_hub(config['training_args']['hub_model_id'], private=True)

# Predictions on the test set

In [50]:
test_ds = datasets.load_dataset('jq/kinyarwanda-speech-hackathon', split='dev_test')
test_ds = test_ds.cast_column("audio", datasets.Audio(sampling_rate=16000))

Resolving data files:   0%|          | 0/76 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/76 [00:00<?, ?it/s]

In [86]:
test_ids = []
test_transcriptions = []

predict_full_test_set = True

if predict_full_test_set:
    N = len(test_ds)
    test_ds = datasets.load_dataset('jq/kinyarwanda-speech-hackathon', split='test')
else:
    test_labels = []
    N = 100
    test_ds = datasets.load_dataset('jq/kinyarwanda-speech-hackathon', split='dev_test')

test_ds = test_ds.cast_column("audio", datasets.Audio(sampling_rate=16000))
for i in tqdm.tqdm(range(N)):   
    example = test_ds[i]
    input_features = processor(
        example["audio"]["array"], sampling_rate=16000, return_tensors="pt").input_features
    input_features = input_features.to('cuda')
    predicted_ids = model.generate(
        input_features,
        num_beams=5,
        language=processor.tokenizer.decode(salt.constants.SALT_LANGUAGE_TOKENS_WHISPER['kin']),
        forced_decoder_ids=None)
    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]

    if quick_verification:
        test_labels.append(example['text'])

    test_transcriptions.append(transcription)
    test_ids.append(example['id'])

Resolving data files:   0%|          | 0/76 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/9263 [00:00<?, ?it/s]

In [89]:
with open('test.json') as f:
    test_metadata = json.load(f)

In [91]:
test_keys = list(test_metadata.keys())

In [92]:
test_keys[0]

'4ibA9OLWZTajRbwnWjjY'

In [87]:
len(test_transcriptions)

9263

In [93]:
predictions = {}
for i, pred in zip(test_ids, test_transcriptions):
    predictions[i] = pred

In [101]:
import string

def strip_punctuation(text):
    # Create a translation table to remove all punctuation
    translator = str.maketrans('', '', string.punctuation)
    return text.translate(translator)
    
with open('submission.csv', "w", encoding="utf-8") as f:
    f.write('id,transcription\n')
    for k in test_keys:
        pred = predictions.get(k)
        if not pred:
            print('No prediction for key ', k)
            f.write(f"{k},a\n")
        else:
            normalised_pred = strip_punctuation(pred.lower())
            f.write(f"{k},{normalised_pred}\n")

No prediction for key  lHwrxgnDqh3OV5yMjC3z
No prediction for key  3IRGkO7JucevGZghaSBn


In [102]:
!wc -l submission.csv

9266 submission.csv


In [100]:
!head submission.csv

id,transcription
4ibA9OLWZTajRbwnWjjY,ndabona umugabo uri kuri uhagaze wambaye kasike na jire akaba ashobora kuba atwara ibintu handitseho ngo kashi
ZarC9zz753YnLnE98mpK,pisine ku ruhande rwayo hari udutebe dutwikiriye nimitaka tubiri ku rundi ruhande naho hakaba hari akandi gatebe konyine hirya hakaba udutebe tundi hari nimitaka itwikiriye hirya yaho hakaba hari inzu iri kubakwa itariyuzura
1ai3w0iU2yUOeUtLoTSX,ubwishingizi ni ingenzi cyane kubera ko budufasha kandi ntaho batageze bahagize amashami yabo hano ni muri rusizi nkiki kirango kuko kibigaragaza mu ibara ryubururu amagambo yandikishije umweru ndetse numuhondo ubona ko rero ushobora kuza nawe ugatanga ikibazo cyawe bakakugoboka
IQFsYcsFTsGlnqftc8jg,imodoka ihagaze iri mu ibara ritukura iriho ibirango byamamaza isoko rikorera kuri murandasi hariho nimero zabo za telefone ngendanwa ndetse nahandi hose ushobora kubabona
Sd3umUI1wjqp5z5poHe6,ahantu bacururiza amata hari ameza ya purasitike imwe iteretseho ishage numufuniko indi ir

In [85]:
import jiwer
total_wer = jiwer.wer(test_labels, test_transcriptions)
total_cer = jiwer.cer(test_labels, test_transcriptions)
score = 1 - (0.6 * total_cer + 0.4 * total_wer)

print(f"Word Error Rate (WER): {total_wer:.3f}")
print(f"Character Error Rate (CER): {total_cer:.3f}")
print(f"Score: {score:.3f}")

Word Error Rate (WER): 0.167
Character Error Rate (CER): 0.035
Score: 0.913


In [83]:
# No beam search

import jiwer
total_wer = jiwer.wer(test_labels, test_transcriptions)
total_cer = jiwer.cer(test_labels, test_transcriptions)
score = 1 - (0.6 * total_cer + 0.4 * total_wer)

print(f"Word Error Rate (WER): {total_wer:.3f}")
print(f"Character Error Rate (CER): {total_cer:.3f}")
print(f"Score: {score:.3f}")

Word Error Rate (WER): 0.171
Character Error Rate (CER): 0.037
Score: 0.909
