## Sunbird ASR evaluation

Application of the fine-tuned Whisper pipeline to the `ucfd_lug` and `ucfd_eng` test splits in [salt-practical-eval](https://huggingface.co/datasets/Sunbird/salt-practical-eval).

In [None]:
!pip install -q datasets
!pip install -q evaluate jiwer
!pip install -q transformers
!pip install -q librosa
!pip install -q soundfile

In [None]:
import os
import json
import string
import pandas as pd
import torch
import transformers
import datasets
from evaluate import load
import huggingface_hub
import tqdm.notebook as tqdm
import transformers

In [None]:
huggingface_hub.notebook_login()

Helper functions for text normalisation. TODO: move these to salt.metrics

In [None]:
def _normalise_item(s, allowed_punctuation="'"):
  '''Convert a list of strings by converting to lower case and removing
  punctuation. This helps when calculating WER, as we're interested in which
  words were predicted more than the capitalisation or punctuation.'''
  s = s.lower()
  punct = list(string.punctuation)
  if allowed_punctuation:
      for allowed in allowed_punctuation:
          punct.remove(allowed)
  result_chars = []
  for c in s:
    if c not in punct:
      result_chars.append(c)
    else:
      result_chars.append(' ')
  return ''.join(result_chars)

def normalise(s, allowed_punctuation="'"):
    if isinstance(s, list):
        return [_normalise_item(item, allowed_punctuation) for item in s]
    else:
        return _normalise_item(s, allowed_punctuation)

Load the model and set up an ASR pipeline

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

whisper_pipeline = transformers.pipeline(
    "automatic-speech-recognition",
    model = "jq/whisper-medium-lug-eng",
    device = device,
)

wer_metric = datasets.load_metric("wer", trust_remote_code=True)

### Calculate the WER metric for each subset.

(TODO: Decide if CER might make more sense here)

The best word error rates seen so far are:

* English: 0.521
* Luganda: 0.828

In [None]:
generate_kwargs = {
    "task": "transcribe",
    "language": None,
    "forced_decoder_ids": None,
    "max_new_tokens": 200,
}

for subset in ["ucfd_eng", "ucfd_lug"]:
    eval_dataset = datasets.load_dataset("Sunbird/salt-practical-eval", subset, split="test")
    predictions = []
    references = []
    
    for example in tqdm.tqdm(eval_dataset):
        # Get the audio file from the 'audio' feature
        audio_input = example["audio"]
        references.append(example["text"])
    
        # Use the ASR pipeline to get the transcription
        transcription = whisper_pipeline(audio_input, generate_kwargs=generate_kwargs)["text"]
        predictions.append(transcription)
        
    wer_score = wer_metric.compute(
        predictions=normalise(predictions), references=normalise(references))
    
    print(f"{subset } WER: {wer_score:.3f}")

Example of adding prompt IDs to give the model context. In practice, this seems on average to make the predictions worse.

In [None]:
# Variant 1, try adding domain-specific vocabulary.
prompt_ids = whisper_pipeline.tokenizer.get_prompt_ids(
    'dfcu, Quick Banking app, QuickApp, Quick Online, Quick Banking platform, '
    'dfcu Personal Banking, mobile app, App store, Google Play Store, '
    'dfcu Quick Online, Quick Connect, internet banking, mobile banking, '
    'smartphone, national ID, passport, trust factor, Pinnacle Current Account,'
    ' dfcu SACCO account, savings account, Dembe account, Smart Plan account, '
    'Campus Plus account, Young Savers account, investment club account, '
    'joint account, Secondary Account Ku-Spot, personal loan, mobi loan, save '
    'for loan, home loan, agent banking, banking security, '
    '6th Street, Abayita Ababiri, Bugolobi, Bwaise, Entebbe Road, Impala, '
    'Jinja Road, Kampala Road, Kawempe, Kikuubo, Kireka, Kyadondo, Kyambogo, '
    'Lugogo, Makerere, Market Street, Naalya, Nabugabo, Sun City, Acacia, '
    'Entebbe Town, Kyengera, Luwum Street, Nateete, Ndeeba, Nsambya, Ntinda '
    'Shopping Centre (Capital Shoppers), Ntinda Trading Centre, Owino, '
    'William Street, Abim, Arua, Dokolo, Gulu, Hoima, Ibanda, Iganga, Ishaka, '
    'Isingiro, Jinja, Kabale, Kisoro, Kitgum, Lira, Luweero, Lyantonde, '
    'Masaka, Mbale, Mbarara, Mukono, Ntungamo, Pader, Pallisa, Rushere, '
    'Soroti, Tororo',
    return_tensors='pt',
).to('cuda')

# Variant 2, try to set the context.
prompt_ids = whisper_pipeline.tokenizer.get_prompt_ids(
    'Thank you for calling dfcu bank. ',
    return_tensors='pt',
).to('cuda')

# Then call the pipeline with prompts specified as follows.
generate_kwargs = {
    "prompt_ids": prompt_ids,
    "prompt_condition_type": "first-segment",
    "condition_on_prev_tokens": True,
    "task": "transcribe",
    "language": None,
    "forced_decoder_ids": None,
    "max_new_tokens": 200,
}