## Sunbird ASR evaluation

Application of the fine-tuned Whisper pipeline to the `ucfd_lug` and `ucfd_eng` test splits in [salt-practical-eval](https://huggingface.co/datasets/Sunbird/salt-practical-eval).

In [None]:
!pip install -q datasets
!pip install -q evaluate jiwer
!pip install -q transformers
!pip install -q librosa
!pip install -q soundfile

In [None]:
import os
import json
import string
import pandas as pd
import torch
import transformers
import datasets
import evaluate
import huggingface_hub
import tqdm.notebook as tqdm
import transformers
import numpy as np
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
import librosa
from IPython import display

from evaluate import load
from pprint import pprint

In [None]:
huggingface_hub.notebook_login()

Helper functions for text normalisation. TODO: move these to salt.metrics

Load the model and set up an ASR pipeline

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

whisper_pipeline = transformers.pipeline(
    "automatic-speech-recognition",
    model = "jq/whisper-large-v2-multilingual-prompts-corrected",
    device = device,
)

wer_metric = evaluate.load("wer", trust_remote_code=True)

### Calculate the WER metric for each subset.

In [None]:
# Prompt 1: set the context of the speech.
#prompt_ids = whisper_pipeline.tokenizer.get_prompt_ids(
#    'Thank you for calling dfcu bank. How can I help you? ',
#    return_tensors='pt',
#).to('cuda')

# Prompt 2: add vocabulary then set context.
prompt_ids = whisper_pipeline.tokenizer.get_prompt_ids(
    'dfcu, Quick Banking app, QuickApp, Quick Online, Quick Banking platform, '
    'dfcu Personal Banking, mobile app, App store, Google Play Store, '
    'dfcu Quick Online, Quick Connect, internet banking, mobile banking, '
    'smartphone, national ID, passport, trust factor, Pinnacle Current Account,'
    ' dfcu SACCO account, savings account, Dembe account, Smart Plan account, '
    'Campus Plus account, Young Savers account, investment club account, '
    'joint account, Secondary Account Ku-Spot, personal loan, mobi loan, save '
    'for loan, home loan, agent banking, banking security, '
    '6th Street, Abayita Ababiri, Bugolobi, Bwaise, Entebbe Road, Impala, '
    'Jinja Road, Kampala Road, Kawempe, Kikuubo, Kireka, Kyadondo, Kyambogo, '
    'Lugogo, Makerere, Market Street, Naalya, Nabugabo, Sun City, Acacia, '
    'Entebbe Town, Kyengera, Luwum Street, Nateete, Ndeeba, Nsambya, Ntinda '
    'Shopping Centre (Capital Shoppers), Ntinda Trading Centre, Owino, '
    'William Street, Abim, Arua, Dokolo, Gulu, Hoima, Ibanda, Iganga, Ishaka, '
    'Isingiro, Jinja, Kabale, Kisoro, Kitgum, Lira, Luweero, Lyantonde, '
    'Masaka, Mbale, Mbarara, Mukono, Ntungamo, Pader, Pallisa, Rushere, '
    'Soroti, Tororo. Thank you for calling dfcu bank. How can I help you? ',
    return_tensors='pt',
).to('cuda')

# Then call the pipeline with prompts specified as follows.
generate_kwargs = {
    "prompt_ids": prompt_ids,
    "prompt_condition_type": "first-segment",
    "condition_on_prev_tokens": True,
    "language": None, 
    "task": "transcribe",
    "num_beams": 1,
}

for subset in ["ucfd_eng", "ucfd_lug"]:
    eval_dataset = datasets.load_dataset("Sunbird/salt-practical-eval", subset, split="test")
    predictions = []
    references = [example["text"] for example in eval_dataset]

    # TODO: Get batching working for ucfd_eng. The pipeline seems to give an error
    # if the batch has some examples > 30s and some < 30s.
    for out in tqdm.tqdm(whisper_pipeline(
        transformers.pipelines.pt_utils.KeyDataset(eval_dataset, "audio"),
        batch_size = 9 if subset == 'ucfd_lug' else 1,
        generate_kwargs=generate_kwargs)
    ):
        predictions.extend([out['text']])

    normalizer = BasicTextNormalizer()
    wer_score = wer_metric.compute(
        predictions=[normalizer(p) for p in predictions],
        references=[normalizer(r) for r in references]
    )
    
    print(f"{subset } WER: {wer_score:.3f}")

In [None]:
predictions

In [None]:
references

In [None]:
# Impact of batch processing on ucfd_lug:
# Batch size 1: 43 seconds
# Batch size 9: 15 seconds

In [None]:
normalizer = BasicTextNormalizer()
normalizer("Hello Arsenal Fans?????!!!!!")

In [None]:
# TODO: Update the evaluation code to match the cell above.

# TODO: Decide which subsets of Sunbird/salt-practical-eval to
# include. SEMA and TRAC FM too? A few records from each if it
# takes too long to process everything.

# TODO: Include silence removal here? (Since that seems important
# for dealing with hallucinations and repetitions)

# TODO: Check effect of different prompts?

normalizer = BasicTextNormalizer()

def normalise(s, allowed_punctuation="'"):
    '''Convert a list of strings by converting to lower case and removing
    punctuation. This helps when calculating WER, as we're interested in which
    words were predicted more than the capitalisation or punctuation.'''
    if isinstance(s, list):
        return [normalizer(item) for item in s]
    else:
        return normalizer(s)

def evaluate_hyperparameters(
    repetition_penalty, no_repeat_ngram_size, whisper_pipeline, prompt_ids
):
    generate_kwargs = {
        "prompt_ids": prompt_ids,
        "prompt_condition_type": "first-segment",
        "condition_on_prev_tokens": True,
        "task": "transcribe",
        "language": None,
        "forced_decoder_ids": None,
        "repetition_penalty": repetition_penalty,
        "no_repeat_ngram_size": no_repeat_ngram_size,
    }

    results = {}
    for subset in ["ucfd_lug", "ucfd_eng"]:
        eval_dataset = datasets.load_dataset(
            "Sunbird/salt-practical-eval", subset, split="test"
        )
        predictions = []
        references = []

        for example in tqdm.tqdm(eval_dataset, desc=f"Evaluating {subset}"):
            audio_input = example["audio"]
            references.append(example["text"])

            transcription = whisper_pipeline(
                audio_input, generate_kwargs=generate_kwargs
            )["text"]
            predictions.append(transcription)

        wer_score = wer_metric.compute(
            predictions=normalise(predictions), references=normalise(references)
        )
        results[subset] = wer_score

    return results


# Main hyperparameter sweep function
def run_hyperparameter_sweep(whisper_pipeline, prompt_ids):
    # Define the range for hyperparameters
    # TODO: figure out if repetition penalty starts at 0 or 1
    repetition_penalties = np.linspace(0.1, 1.5, 5)  # 11 values from 0 to 2.5
    
    no_repeat_ngram_sizes = [None, 5]

    # TODO: sweep beam sizes (1,2,3)

    best_params = {"repetition_penalty": 1.2, "no_repeat_ngram_size": 1}
    best_wer = 0.778  # Current best WER

    # Perform grid search
    for rep_penalty in repetition_penalties:
        for ngram_size in no_repeat_ngram_sizes:
            print(
                f"Evaluating: repetition_penalty={rep_penalty:.2f}, no_repeat_ngram_size={ngram_size}"
            )
            results = evaluate_hyperparameters(
                rep_penalty, ngram_size, whisper_pipeline, prompt_ids
            )

            # Calculate average WER across both datasets
            avg_wer = (results["ucfd_lug"] + results["ucfd_eng"]) / 2

            if avg_wer < best_wer:
                best_wer = avg_wer
                best_params = {
                    "repetition_penalty": rep_penalty,
                    "no_repeat_ngram_size": ngram_size,
                }

            print(f"ucfd_lug WER: {results['ucfd_lug']:.3f}")
            print(f"ucfd_eng WER: {results['ucfd_eng']:.3f}")
            print(f"Average WER: {avg_wer:.3f}")
            print("-" * 40)

    print("Best parameters found:")
    print(f"repetition_penalty: {best_params['repetition_penalty']:.2f}")
    print(f"no_repeat_ngram_size: {best_params['no_repeat_ngram_size']}")
    print(f"Best average WER: {best_wer:.3f}")


In [None]:
run_hyperparameter_sweep(whisper_pipeline, prompt_ids)

In [None]:
whisper_pipeline

In [None]:
predictions = [pred.strip() for pred in predictions]
results = {'predictions': predictions, 'references': references}

for i in range(len(predictions)):
  pprint(predictions[i])
  pprint(references[i])
  pprint("=============")

In [None]:
predictions = [pred.strip() for pred in predictions]
results = {'predictions': predictions, 'references': references}

for i in range(len(predictions)):
  pprint(predictions[i])
  pprint(references[i])
  pprint("=============")

In [None]:
df = pd.DataFrame.from_dict(results)
df.to_csv("results.csv", index=False)

In [None]:
df.head()

## Try test data

In [None]:
from google.colab import drive
drive.mount('/gdrive')

In [None]:
!unzip "/gdrive/Shared drives/Sunbirdsens/dfcu/data/test-data/test_ucfd_eng.zip" >>/dev/null

In [None]:
!mkdir test.hf
!mv data-00000-of-00001.arrow dataset_info.json state.json test.hf

In [None]:
!ls -l

In [None]:
## For ucfd test data
ucfd_test = datasets.Dataset.load_from_disk("test.hf")
ucfd_test

In [None]:
test_predictions = []
test_references = []
filenames = []

for example in tqdm.tqdm(ucfd_test):
    test_audio_input = example["audio"]
    test_references.append(example["text"])
    filenames.append(example["filename"])

    # Use the ASR pipeline to get the transcription
    test_transcription = whisper_pipeline(test_audio_input, generate_kwargs=generate_kwargs)["text"]
    test_predictions.append(test_transcription)


wer_score = wer_metric.compute(
    test_predictions=normalise(test_predictions), test_references=normalise(references))


print(f"ucfd_test WER: {wer_score:.3f}")

In [None]:
test_predictions = [pred.strip() for pred in test_predictions]
test_results = {'filename': filenames, 'test_references': test_references, 'test_predictions': test_predictions}

for i in range(5):
  pprint(test_predictions[i])
  pprint(test_references[i])
  pprint("=============")

In [None]:
test_df = pd.DataFrame.from_dict(test_results)
test_df.to_csv("test_results.csv", index=False)

In [None]:
import shutil

In [None]:
shutil.copy('results.csv', "/gdrive/Shared drives/Sunbirdsens/dfcu/data/test-data/")
shutil.copy('test_results.csv', "/gdrive/Shared drives/Sunbirdsens/dfcu/data/test-data/")

### Example of the effect of resampling

In [None]:
eval_dataset = datasets.load_dataset("Sunbird/salt-practical-eval", "sema_eng", split="test")
dataset_iterator = iter(eval_dataset)

In [None]:
example = next(dataset_iterator)

In [None]:
audio = example['audio']['array']
audio_resampled_16khz = librosa.resample(audio, orig_sr=example['audio']['sampling_rate'], target_sr=16000)
audio_resampled_20khz = librosa.resample(audio, orig_sr=example['audio']['sampling_rate'], target_sr=20000)

In [None]:
display.Audio(audio_resampled_16khz, rate=16000)

In [None]:
display.Audio(audio_resampled_20khz, rate=16000)