<a href="https://colab.research.google.com/github/SunbirdAI/salt/blob/main/notebooks/sample_asr_eval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Sunbird ASR evaluation

In [None]:
# Run ASR eval for SB and other models on our partner datasets
#
# Notebook location:
# 'Shared drives/Sunbird AI/Projects/African Language Technology/ASR Evaluation'
#
# Ideally we should move all the eval data once its stable to
# SB huggingface and call it from there.
#
# Goal is to link the notebook to a leaderboard, where results are
# automatically updated for the different models as a way of tracking
# model improvements.

In [1]:
%%capture
!pip install datasets
!pip install evaluate jiwer
!pip install pyctcdecode
!pip install kenlm

In [21]:
#@title Import stuff

import os
import json
import string
import pandas as pd
import torch
import transformers
from datasets import Dataset, Audio
from evaluate import load
from huggingface_hub import hf_hub_download
from pyctcdecode import build_ctcdecoder
from transformers import (
    Wav2Vec2ForCTC,
    Wav2Vec2CTCTokenizer,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2Processor,
    Wav2Vec2ProcessorWithLM,
    AutomaticSpeechRecognitionPipeline,
    AutoProcessor,
)
from transformers.pipelines.pt_utils import KeyDataset
from transformers import pipeline
from google.colab import drive

In [3]:
drive.mount('/gdrive')
device = "cuda:0" if torch.cuda.is_available() else "cpu"

Mounted at /gdrive


### Fetch eval data
Currently data is fetched from Google Drive. Once the data is stable, data can be moved to SB huggingface and fetched directly.

In [4]:
!unzip /gdrive/'Shared drives'/'Sunbird AI'/Projects/'African Language Technology'/'ASR Evaluation'/eval_ucfd_eng.zip >> /dev/null
!unzip /gdrive/'Shared drives'/'Sunbird AI'/Projects/'African Language Technology'/'ASR Evaluation'/eval_ucfd_lug.zip >> /dev/null
!unzip /gdrive/'Shared drives'/'Sunbird AI'/Projects/'African Language Technology'/'ASR Evaluation'/eval_sema_eng.zip >> /dev/null
!unzip /gdrive/'Shared drives'/'Sunbird AI'/Projects/'African Language Technology'/'ASR Evaluation'/eval_sema_lug.zip >> /dev/null
!unzip /gdrive/'Shared drives'/'Sunbird AI'/Projects/'African Language Technology'/'ASR Evaluation'/eval_trac_fm_lug.zip >> /dev/null

In [6]:
#@title Load eval data

def load_eval_data(folder_path):
  # Load eval dataset
  csv_files = [f for f in os.listdir(folder_path) if f.lower().endswith('.csv')]

  if not csv_files:
      raise FileNotFoundError("No CSV file found in the folder.")

  csv_file_path = os.path.join(folder_path, csv_files[0])

  df = pd.read_csv(csv_file_path)

  # Check if 'filename' column exists in the CSV file
  if 'filename' not in df.columns:
      raise ValueError("'filename' column not found in the CSV file.")

  # Add the folder path to each entry in the 'filename' column
  df['filename'] = df['filename'].apply(lambda x: os.path.join(folder_path, x))

  eval_data = Dataset.from_dict(
  {'audio': df.filename.to_list(), 'transcription': df.transcription.to_list()}
  ).cast_column('audio', Audio())

  return eval_data


In [7]:
ucfd_eng_eval_data = load_eval_data('eval_ucfd_eng')
ucfd_lug_eval_data = load_eval_data('eval_ucfd_lug')
sema_eng_eval_data = load_eval_data('eval_sema_eng')
sema_lug_eval_data = load_eval_data('eval_sema_lug')
trac_fm_lug_eval_data = load_eval_data('eval_trac_fm_lug')

In [8]:
sema_eng_eval_data

Dataset({
    features: ['audio', 'transcription'],
    num_rows: 12
})

In [11]:
#@title Models

def wav2vecpipeline(model_id, lang, lm_file, device=device, use_lm=True):
  # Get Wav2Vec2ForCTC model based transformer pipeline
  model = Wav2Vec2ForCTC.from_pretrained(model_id).to(device)

  # Processor setup
  processor = AutoProcessor.from_pretrained(model_id)
  tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_id)

  tokenizer.set_target_lang(lang)
  model.load_adapter(lang)


  # Feature extractor setup
  feature_extractor = Wav2Vec2FeatureExtractor(
      feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True
  )
  processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
  vocab_dict = processor.tokenizer.get_vocab()
  sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}

  if use_lm:

    # Language model file setup
    lm_file_name = lm_file
    lm_file_subfolder = "language_model"
    lm_file = hf_hub_download(
        repo_id=model_id,
        filename=lm_file_name,
        subfolder=lm_file_subfolder,
    )

    # Decoder setup -> Use KenLM as decoder
    decoder = build_ctcdecoder(
        labels=list(sorted_vocab_dict.keys()),
        kenlm_model_path=lm_file,
    )

    # Use the lm as the Processor
    processor_with_lm = Wav2Vec2ProcessorWithLM(
        feature_extractor=feature_extractor,
        tokenizer=tokenizer,
        decoder=decoder,
    )
    feature_extractor._set_processor_class("Wav2Vec2ProcessorWithLM")

    # ASR Pipeline, with a chunk and stride --> Make it work for even super long audio
    pipe = AutomaticSpeechRecognitionPipeline(
        model=model,
        tokenizer=processor_with_lm.tokenizer,
        feature_extractor=processor_with_lm.feature_extractor,
        decoder=processor_with_lm.decoder,
        device=device,
        chunk_length_s=10,
        stride_length_s=(4, 2),
        return_timestamps="word"
    )

  else:
    pipe = AutomaticSpeechRecognitionPipeline(
        model=model,
        tokenizer=tokenizer,
        feature_extractor=feature_extractor,
        decoder=processor.decode,
        device=device,
        chunk_length_s=10,
        stride_length_s=(4, 2),
        return_timestamps="word"
    )

  return pipe

In [None]:
whisperbase = pipeline(
    "automatic-speech-recognition",
    model = "openai/whisper-base",
    device = device
)

facebookmms = pipeline(
    "automatic-speech-recognition",
    model = "facebook/mms-1b-all",
    device = device
)

whisperSBFinetuned = pipeline(
    "automatic-speech-recognition",
    model = "akera/whisper-base-sb-english",
    device = device
)

sunbirdmms = wav2vecpipeline(
    model_id = "Sunbird/sunbird-mms",
    lang = "eng",
    lm_file = "eng_5gram.bin",
    device = device,
    use_lm = True,
)

facebooklugmms = wav2vecpipeline(
    model_id = "facebook/mms-1b-all",
    lang = "lug",
    lm_file = "lug_eng_5gram.bin",
    device = device,
    use_lm = False
)

In [25]:
#@title Predictions

def get_predictions(pipeline, eval_datasets):
  eval_predictions = {}
  for eval_name, eval_data in eval_datasets.items():
    predictions = []
    for prediction in pipeline(eval_data['audio']):
      predictions.append(prediction['text'])
    eval_predictions[eval_name] = predictions
  return eval_predictions

In [32]:
eng_eval_datasets = {
    "ucfd_eng": ucfd_eng_eval_data,
    "sema_eng": sema_eng_eval_data,
}

lug_eval_datasets = {
    "ucfd_lug": ucfd_lug_eval_data,
    "sema_lug": sema_lug_eval_data,
    "trac_fm_lug": trac_fm_lug_eval_data
}

In [27]:
# Eng eval
whisperbasepredictions = get_predictions(whisperbase, eng_eval_datasets)
facebookmmspredictions = get_predictions(facebookmms, eng_eval_datasets)
sunbirdmmspredictions = get_predictions(sunbirdmms, eng_eval_datasets)
whispersbfinetunedpredictions = get_predictions(whisperSBFinetuned, eng_eval_datasets)

In [28]:
# Lug eval
facebooklugmmspredictions = get_predictions(facebooklugmms, lug_eval_datasets)

In [34]:
#@title Calculate WER

def lower_case_and_strip_punctuation(string_list, allowed_punctuation="'"):
  '''Convert a list of strings by converting to lower case and removing
  punctuation. This helps when calculating WER, as we're interested in which
  words were predicted more than the capitalisation or punctuation.'''
  result = []
  for s in string_list:
    s = s.lower()
    punct = list(string.punctuation)
    if allowed_punctuation:
        for allowed in allowed_punctuation:
            punct.remove(allowed)
    result.append(''.join([c for c in s if c not in punct]))
  return result

def get_wer(predictions, datasets):
  wer_metric = load("wer")
  output_wer = {}
  for eval_name, eval_data in datasets.items():
    wer = 100 * wer_metric.compute(
        references=lower_case_and_strip_punctuation(eval_data["transcription"]),
        predictions=lower_case_and_strip_punctuation(predictions[eval_name]),
    )
    output_wer[eval_name] = round(wer, 2)
  return output_wer


In [39]:
wer_whisperbase = get_wer(whisperbasepredictions, eng_eval_datasets)
wer_facebookmms = get_wer(facebookmmspredictions, eng_eval_datasets)
wer_sunbirdmms = get_wer(sunbirdmmspredictions, eng_eval_datasets)
wer_whisperSBFinetuned = get_wer(whispersbfinetunedpredictions, eng_eval_datasets)
wer_facebooklugmms = get_wer(facebooklugmmspredictions, lug_eval_datasets)

In [40]:
print(f"Whisperbase WER: {json.dumps(wer_whisperbase, indent=4)}")
print(f"FacebookMMS WER: {json.dumps(wer_facebookmms, indent=4)}")
print(f"SunbirdMMS WER: {json.dumps(wer_sunbirdmms, indent=4)}")
print(f"SWhisperSBFinetuned WER: {json.dumps(wer_whisperSBFinetuned, indent=4)}")
print(f"FacebookLugMMS WER: {json.dumps(wer_facebooklugmms, indent=4)}")

Whisperbase WER: {
    "ucfd_eng": 65.1,
    "sema_eng": 85.31
}
FacebookMMS WER: {
    "ucfd_eng": 90.25,
    "sema_eng": 80.81
}
SunbirdMMS WER: {
    "ucfd_eng": 53.32,
    "sema_eng": 50.47
}
SWhisperSBFinetuned WER: {
    "ucfd_eng": 46.6,
    "sema_eng": 47.16
}
FacebookLugMMS WER: {
    "ucfd_lug": 84.8,
    "sema_lug": 87.86,
    "trac_fm_lug": 63.22
}


In [20]:
def compare_predictions(predictions1, predictions2, predictions3, predictions4, ground_truths):
    for idx, (pred1, pred2, pred3, pred4, truth) in enumerate(zip(predictions1, predictions2, predictions3, predictions4, ground_truths)):
        print(f"Example {idx + 1}:")
        print(f"  Ground Truth: {truth}")
        print(f"  Wspr-FineTuned: {pred1}")
        print(f"  SB-MMS: {pred2}")
        print(f"  Wspr-Base: {pred3}")
        print(f"  FacebookMMS: {pred4}")

        print()

ds = 'trac_fm_lug'

if ds != 'trac_fm_lug':
    whispersbfinetunedpredictions = whispersbfinetunedpredictions[ds]
    sunbirdmmspredictions = sunbirdmmspredictions[ds]
    whisperbasepredictions = whisperbasepredictions[ds]
    facebookmmspredictions = facebookmmspredictions[ds]
    ground_truths = eng_eval_datasets[ds]["transcription"]

    compare_predictions(whispersbfinetunedpredictions, sunbirdmmspredictions, whisperbasepredictions, facebookmmspredictions, ground_truths)
else:
    predictions = facebooklugmmspredictions[ds]
    ground_truths = lug_eval_datasets[ds]["transcription"]
    for idx, pred in enumerate(zip(predictions, ground_truths)):
        print(f"Example {idx + 1}:")
        print(f"  Ground Truth: {ground_truths[idx]}")
        print(f"  FacebookLugMMS: {pred[0]}")
        print()


Example 1:
  Ground Truth: Option B,Yee, kukyaalo kyaffe waaliwo omusomo ogw'ebyetaka, nategeera ddi lwenyina okukuuma etaka lyange na ddi lwenyina okulikozesa mubutuufu nga teli muntu yena bwekoseza na ddi lwendi yina ko obwa nanyini.
  FacebookLugMMS: option ba ye ku kyalo kyaffe waliyo omusomo gw'ebyettaka nategeera dd lwenina okuuma ettaka lyange naddi lwennina olikozesa mu butuufu nga teri muntu yena bwenkoseza naddi lwendirinako obwananyini

Example 2:
  Ground Truth: mwasuze mutya ba memba, mba lamusizaako mwena, Nze okusinzira ku kibuuzo ekya leero, nyenda ne kya. Simanyi butya ddembe ly'omuntu welitekeddwa ku beera ku, okusinzira ku enaku zino bye tuyitamu enakuzino, era nze mbadde simanyi nti eddembe ly'obuntu ku bulamu bwa nge nti ndye taaga kubanga bulikimu nkiyitamu buyisi, sibitegeera ko daala ku kye bayita ddembe lyobuntu.
  FacebookLugMMS: mwasuze muti abamemba alamusizaako mwena nze okusinziira okibuuzo ekya leero ngenda nekya simanyi buty eddembe lya muntu weliteekedd