## Decode and translate TRAC FM voice poll samples

In [None]:
!pip install -q sentencepiece
!pip install -q datasets
!pip install -q transformers
!pip install -q librosa
!pip install -q soundfile
!git clone https://github.com/sunbirdai/salt.git
!pip install -qr salt/requirements.txt

In [None]:
import torch
import datasets
import transformers
import librosa
from IPython import display
import huggingface_hub
import re
from tqdm.notebook import tqdm
import salt.utils
import string
import pandas as pd
from transformers.models.whisper.english_normalizer import BasicTextNormalizer

# Suppress some non-informative warnings from Transformers
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
huggingface_hub.notebook_login()

In [None]:
model_path = 'jq/whisper-large-v3-salt'
processor = transformers.WhisperProcessor.from_pretrained(
    model_path, language=None, task="transcribe")
asr_model = transformers.WhisperForConditionalGeneration.from_pretrained(model_path)

In [None]:
translation_tokenizer = transformers.NllbTokenizer.from_pretrained(
    "facebook/nllb-200-distilled-1.3B")
translation_model = transformers.M2M100ForConditionalGeneration.from_pretrained(
    'jq/nllb-1.3B-many-to-many-pronouncorrection-charaug')

In [None]:
whisper_pipeline = transformers.pipeline(
    "automatic-speech-recognition",
    model = asr_model,
    tokenizer = processor.tokenizer,
    feature_extractor = processor.feature_extractor,
    device = 'cuda:0',
    torch_dtype=torch.float16,
    model_kwargs=({"attn_implementation": 
        "flash_attention_2" if transformers.utils.is_flash_attn_2_available()
        else "sdpa"}
    ),
    generate_kwargs = {
        "language": None,
        "forced_decoder_ids": None,
        "repetition_penalty": 1.0,
        "no_repeat_ngram_size": 4,
        "num_beams": 3,
    },
    chunk_length_s = 30,
    batch_size = 1, # Higher = faster on long audio but more GPU memory usage
)

In [None]:
def translate_sentence(text, source_language, target_language):
  _language_codes = {
      'eng': 256047,
      'ach': 256111,
      'lgg': 256008,
      'lug': 256110,
      'nyn': 256002,
      'teo': 256006,
  }

  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  inputs = translation_tokenizer(text, return_tensors="pt").to(device)
  inputs['input_ids'][0][0] = _language_codes[source_language]
  translated_tokens = translation_model.to(device).generate(
      **inputs,
      forced_bos_token_id=_language_codes[target_language],
      max_length=100,
      num_beams=5,
      repetition_penalty=1.1,
  )

  result = translation_tokenizer.batch_decode(
      translated_tokens, skip_special_tokens=True)[0]
  return result

def split_into_sentences(text):
    sentences = re.split(r'[.?!]', text)
    sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
    return sentences

def maybe_add_full_stop(s):
    if len(s):
        if s[-1] not in string.punctuation:
            s = s + '.'
        return s
    else:
        return ''

def translate_sentences(text, source_language, target_language):
    sentences = split_into_sentences(text)
    translated_sentences = [
        translate_sentence(s, source_language, target_language) for s in sentences]
    translated_sentences = [
        maybe_add_full_stop(s) for s in translated_sentences]
    return ' '.join(translated_sentences)

In [None]:
test_dataset = datasets.load_dataset(
    "Sunbird/salt-practical-eval", 'trac_fm_lug', split="test")

Do speech recognition and translation for each test example.

In [None]:
all_audio = []
all_transcriptions = []
all_true_transcriptions = []
all_translations = []

for i in tqdm(range(len(test_dataset))):
    example = test_dataset[i]
    all_true_transcriptions.append(example['text'])
    
    audio = librosa.resample(
        example['audio']['array'],
        orig_sr=example['audio']['sampling_rate'],
        target_sr=16000)

    transcription = whisper_pipeline(audio)['text']
    translation = translate_sentences(transcription, 'lug', 'eng')

    all_audio.append(audio)
    all_transcriptions.append(transcription)
    all_translations.append(translation)

Compute WER and take a look at the output.

In [None]:
normalizer = BasicTextNormalizer()
wer_metric = evaluate.load("wer", trust_remote_code=True)
wer_score = wer_metric.compute(
    predictions=[normalizer(p) for p in all_transcriptions],
    references=[normalizer(r) for r in all_true_transcriptions])
print(f'Word error rate: {wer_score:.3f}')

Word error rate: 0.650


In [None]:
processed = pd.DataFrame()
processed['audio'] = all_audio
processed['transcription_truth'] = all_true_transcriptions
processed['transcription_predicted'] = all_transcriptions
processed['translation_predicted'] = all_translations
salt.utils.show_dataset(datasets.Dataset.from_pandas(processed), audio_features=['audio'], N=len(processed))