## Sunbird ASR evaluation

Application of the fine-tuned Whisper pipeline to the `ucfd_lug` and `ucfd_eng` test splits in [salt-practical-eval](https://huggingface.co/datasets/Sunbird/salt-practical-eval).

In [None]:
!pip install -q datasets
!pip install -q evaluate jiwer
!pip install -q transformers
!pip install -q librosa
!pip install -q soundfile

In [98]:
import os
import json
import string
import pandas as pd
import torch
import transformers
import datasets
from evaluate import load
import evaluate
import huggingface_hub
from tqdm.notebook import tqdm
import transformers
import peft
from transformers.models.whisper.english_normalizer import BasicTextNormalizer

In [None]:
huggingface_hub.notebook_login()

Load the model and set up an ASR pipeline

In [99]:
device = "cuda" if torch.cuda.is_available() else "cpu"

whisper_pipeline = transformers.pipeline(
    "automatic-speech-recognition",
    model = "jq/whisper-large-v2-multilingual-prompts-corrected",
    device = device,
    torch_dtype=torch.float16,
    model_kwargs=({"attn_implementation": "sdpa"}),  # Maybe a speedup?
)

wer_metric = evaluate.load("wer", trust_remote_code=True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### Calculate the WER metric for each subset.

In [None]:
# # Prompt 1: set the context of the speech.
# prompt_ids = whisper_pipeline.tokenizer.get_prompt_ids(
#     'Thank you for calling dfcu bank. How can I help you? ',
#     return_tensors='pt',
# ).to('cuda')

# #Prompt 2: Set the context in Luganda.
# prompt_ids = whisper_pipeline.tokenizer.get_prompt_ids(
#     'Webale kuwagira dfcu bank. Nkuyambe ntya leero? ',
#     return_tensors='pt',
# ).to('cuda')

# Prompt 3: add vocabulary then set context.
prompt_ids = whisper_pipeline.tokenizer.get_prompt_ids(
    'dfcu, Quick Banking app, QuickApp, Quick Online, Quick Banking platform, '
    'dfcu Personal Banking, mobile app, App store, Google Play Store, '
    'dfcu Quick Online, Quick Connect, internet banking, mobile banking, '
    'smartphone, national ID, passport, trust factor, Pinnacle Current Account,'
    ' dfcu SACCO account, savings account, Dembe account, Smart Plan account, '
    'Campus Plus account, Young Savers account, investment club account, '
    'joint account, Secondary Account Ku-Spot, personal loan, mobi loan, save '
    'for loan, home loan, agent banking, banking security, '
    '6th Street, Abayita Ababiri, Bugolobi, Bwaise, Entebbe Road, Impala, '
    'Jinja Road, Kampala Road, Kawempe, Kikuubo, Kireka, Kyadondo, Kyambogo, '
    'Lugogo, Makerere, Market Street, Naalya, Nabugabo, Sun City, Acacia, '
    'Entebbe Town, Kyengera, Luwum Street, Nateete, Ndeeba, Nsambya, Ntinda '
    'Shopping Centre (Capital Shoppers), Ntinda Trading Centre, Owino, '
    'William Street, Abim, Arua, Dokolo, Gulu, Hoima, Ibanda, Iganga, Ishaka, '
    'Isingiro, Jinja, Kabale, Kisoro, Kitgum, Lira, Luweero, Lyantonde, '
    'Masaka, Mbale, Mbarara, Mukono, Ntungamo, Pader, Pallisa, Rushere, '
    'Soroti, Tororo. '
    'Thank you for calling dfcu bank. How can I help you? ',
    return_tensors='pt',
).to('cuda')


# Then call the pipeline with prompts specified as follows.
generate_kwargs = {
    "prompt_ids": prompt_ids,
    "prompt_condition_type": "first-segment",
    "condition_on_prev_tokens": True,
    "language": None, 
    "task": "transcribe",
    "num_beams": 1,
}

for subset in ["ucfd_eng", "ucfd_lug"]:
    eval_dataset = datasets.load_dataset("Sunbird/salt-practical-eval", subset, split="test")
    predictions = []
    references = [example["text"] for example in eval_dataset]

    # TODO: Get batching working for ucfd_eng
    for out in tqdm(whisper_pipeline(
        transformers.pipelines.pt_utils.KeyDataset(eval_dataset, "audio"), batch_size=1,
        generate_kwargs=generate_kwargs)
    ):
        predictions.extend([out['text']])

    normalizer = BasicTextNormalizer()
    wer_score = wer_metric.compute(
        predictions=[normalise(p) for p in predictions],
        references=[normalise(r) for r in references]
    )
    
    print(f"{subset } WER: {wer_score:.3f}")

  0%|          | 0/16 [00:00<?, ?it/s]

ucfd_eng WER: 0.254


  0%|          | 0/9 [00:00<?, ?it/s]