In [None]:
%%capture
!pip install datasets
!pip install evaluate jiwer
!pip install pyctcdecode
!pip install kenlm

In [None]:
import json

import pandas as pd
import torch
import transformers
from datasets import Dataset, Audio
from evaluate import load
from huggingface_hub import hf_hub_download
from pyctcdecode import build_ctcdecoder
from transformers import (
    Wav2Vec2ForCTC,
    Wav2Vec2CTCTokenizer,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2Processor,
    Wav2Vec2ProcessorWithLM,
    AutomaticSpeechRecognitionPipeline,
    AutoProcessor,
)
from transformers.pipelines.pt_utils import KeyDataset
from transformers import pipeline

In [None]:
from google.colab import drive
drive.mount('/gdrive')

device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
!unzip /gdrive/'Shared drives'/'Sunbird AI'/Projects/'African Language Technology'/'ASR Evaluation'/eval_ucfd.zip >> /dev/null

In [None]:
cd eval_ucfd/

In [None]:
df = pd.read_csv('eval_df.csv')

In [None]:
ucfd_eval_data = Dataset.from_dict(
    {'audio': df.filename.to_list(), 'transcription': df.transcript.to_list()}
    ).cast_column('audio', Audio())

ucfd_eval_data

## Models

In [None]:
whisperbase = pipeline(
    "automatic-speech-recognition",
    model = "openai/whisper-base",
    device = device
)

# facebookmms = pipeline(
#     "automatic-speech-recognition",
#     model = "facebook/mms-1b-all",
#     device = device
# )

whisperSBFinetuned = pipeline(
    "automatic-speech-recognition",
    model = "akera/whisper-base-sb-english",
    device = device
)

In [None]:
model_id = "facebook/mms-1b-all"
model = Wav2Vec2ForCTC.from_pretrained(model_id).to(device)

# Processor setup
processor = AutoProcessor.from_pretrained(model_id)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_id)


tokenizer.set_target_lang("lug")
model.load_adapter("lug")


# Feature extractor setup
feature_extractor = Wav2Vec2FeatureExtractor(
    feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True
)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
vocab_dict = processor.tokenizer.get_vocab()
sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}



use_lm = False

if use_lm:

  # Language model file setup
  lm_file_name = "eng_5gram.bin"
  lm_file_subfolder = "language_model"
  lm_file = hf_hub_download(
      repo_id=lang_config["eng"],
      filename=lm_file_name,
      subfolder=lm_file_subfolder,
  )

  # Decoder setup -> Use KenLM as decoder
  decoder = build_ctcdecoder(
      labels=list(sorted_vocab_dict.keys()),
      kenlm_model_path=lm_file,
  )

  # Use the lm as the Processor
  processor_with_lm = Wav2Vec2ProcessorWithLM(
      feature_extractor=feature_extractor,
      tokenizer=tokenizer,
      decoder=decoder,
  )
  feature_extractor._set_processor_class("Wav2Vec2ProcessorWithLM")


  # ASR Pipeline, with a chunk and stride --> Make it work for even super long audio
  facebookmms = AutomaticSpeechRecognitionPipeline(
      model=model,
      tokenizer=processor_with_lm.tokenizer,
      feature_extractor=processor_with_lm.feature_extractor,
      decoder=processor_with_lm.decoder,
      device=device,
      chunk_length_s=10,
      stride_length_s=(4, 2),
      return_timestamps="word"
  )

else:
  facebookmms = AutomaticSpeechRecognitionPipeline(
      model=model,
      tokenizer=tokenizer,
      feature_extractor=feature_extractor,
      decoder=processor.decode,
      device=device,
      chunk_length_s=10,
      stride_length_s=(4, 2),
      return_timestamps="word"
  )


In [None]:
lang_config = {
    "eng": "Sunbird/sunbird-mms",
}

model_id = "Sunbird/sunbird-mms"
model = Wav2Vec2ForCTC.from_pretrained(model_id).to(device)

# Processor setup
processor = AutoProcessor.from_pretrained(model_id)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_id)


tokenizer.set_target_lang("eng")
model.load_adapter("eng")


# Feature extractor setup
feature_extractor = Wav2Vec2FeatureExtractor(
    feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True
)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
vocab_dict = processor.tokenizer.get_vocab()
sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}




# Language model file setup
lm_file_name = "eng_5gram.bin"
lm_file_subfolder = "language_model"
lm_file = hf_hub_download(
    repo_id=lang_config["eng"],
    filename=lm_file_name,
    subfolder=lm_file_subfolder,
)

# Decoder setup -> Use KenLM as decoder
decoder = build_ctcdecoder(
    labels=list(sorted_vocab_dict.keys()),
    kenlm_model_path=lm_file,
)

# Use the lm as the Processor
processor_with_lm = Wav2Vec2ProcessorWithLM(
    feature_extractor=feature_extractor,
    tokenizer=tokenizer,
    decoder=decoder,
)
feature_extractor._set_processor_class("Wav2Vec2ProcessorWithLM")


# ASR Pipeline, with a chunk and stride --> Make it work for even super long audio
sunbirdmms = AutomaticSpeechRecognitionPipeline(
    model=model,
    tokenizer=processor_with_lm.tokenizer,
    feature_extractor=processor_with_lm.feature_extractor,
    decoder=processor_with_lm.decoder,
    device=device,
    chunk_length_s=10,
    stride_length_s=(4, 2),
    return_timestamps="word"
)


In [None]:
facebookmmspredictions = []
for prediction in facebookmms(KeyDataset(ucfd_eval_data, 'audio')):
  facebookmmspredictions.append(prediction['text'])


In [None]:
facebookmmspredictions

## Predictions

In [None]:
whisperbasepredictions, facebookmmspredictions, sunbirdmmspredictions, whispersbfinetunedpredictions = [], [], [],  []

for prediction in sunbirdmms(KeyDataset(ucfd_eval_data, 'audio')):
  sunbirdmmspredictions.append(prediction['text'])


for prediction in whisperbase(KeyDataset(ucfd_eval_data, 'audio')):
  whisperbasepredictions.append(prediction['text'])

for prediction in facebookmms(KeyDataset(ucfd_eval_data, 'audio')):
  facebookmmspredictions.append(prediction['text'])


for prediction in whisperSBFinetuned(KeyDataset(ucfd_eval_data, 'audio')):
  whispersbfinetunedpredictions.append(prediction['text'])

## Calculate WER

In [None]:
wer_metric = load("wer")

In [None]:
wer_whisperbase = 100 * wer_metric.compute(
    references=ucfd_eval_data["transcription"], predictions=whisperbasepredictions
)
wer_facebookmms = 100 * wer_metric.compute(
    references=ucfd_eval_data["transcription"], predictions=facebookmmspredictions
)
wer_sunbirdmms = 100 * wer_metric.compute(
    references=ucfd_eval_data["transcription"], predictions=sunbirdmmspredictions
)

wer_whisperSBFinetuned = 100 * wer_metric.compute(
    references=ucfd_eval_data["transcription"], predictions=whispersbfinetunedpredictions
)

In [None]:
print(f"WhisperBase WER: {wer_whisperbase:.2f}%")
print(f"FacebookMMS WER: {wer_facebookmms:.2f}%")
print(f"SunbirdMMS WER: {wer_sunbirdmms:.2f}%")
print(f"WhisperSBFinetuned WER: {wer_whisperSBFinetuned:.2f}%")

In [None]:
def compare_predictions(predictions1, predictions2, predictions3, predictions4, ground_truths):
    for idx, (pred1, pred2, pred3, pred4, truth) in enumerate(zip(predictions1, predictions2, predictions3, predictions4, ground_truths)):
        print(f"Example {idx + 1}:")
        print(f"  Ground Truth: {truth}")
        print(f"  Wspr-FineTuned: {pred1}")
        print(f"  SB-MMS: {pred2}")
        print(f"  Wspr-Base: {pred3}")
        print(f"  FacebookMMS: {pred4}")

        print()

whispersbfinetunedpredictions = whispersbfinetunedpredictions
sunbirdmmspredictions = sunbirdmmspredictions
whisperbasepredictions = whisperbasepredictions
facebookmmspredictions = facebookmmspredictions
ground_truths = ucfd_eval_data["transcription"]

compare_predictions(whispersbfinetunedpredictions, sunbirdmmspredictions, whisperbasepredictions, facebookmmspredictions, ground_truths)

In [None]:
whispersbfinetunedpredictions