In [None]:
!pip install -q datasets
!pip install -q transformers
!pip install -q librosa

In [None]:
import torch
import datasets
import transformers
import librosa
from IPython import display

Get some data to test on

In [None]:
eval_dataset = datasets.load_dataset("Sunbird/salt-practical-eval", "sema_eng", split="test")
dataset_iterator = iter(eval_dataset)

Get references to the model and processor

In [None]:
processor = transformers.WhisperProcessor.from_pretrained(
    "jq/whisper-large-v2-multilingual-prompts-corrected", language=None, task="transcribe")

model = transformers.WhisperForConditionalGeneration.from_pretrained(
    "jq/whisper-large-v2-multilingual-prompts-corrected").to('cuda')
# Note: If a pipeline is already loaded in memory, then we can just use:
# model = whisper_pipeline.model

Get a mapping from token IDs to language codes

In [None]:
salt_whisper_language_id_tokens = {
    'eng': 50259,
    'ach': 50357,
    'lgg': 50356,
    'lug': 50355,
    'nyn': 50354,
    'teo': 50353,
}
token_to_language = {}
for lang, token in salt_whisper_language_id_tokens.items(): 
    token_to_language[token] = lang

Get test audio and resample it.

In [None]:
example = next(dataset_iterator)
audio = example['audio']['array']
audio = librosa.resample(
    audio, orig_sr=example['audio']['sampling_rate'], target_sr=16000)

display.Audio(audio, rate=16000)

Run the model and pull out the language token

In [None]:
# TODO: The results are probably better if silences are removed.
input_features = processor(
    audio,
    sampling_rate=16000,
    return_tensors="pt",
    do_normalize=True,
    device="cuda").input_features
with torch.no_grad():
    predicted_ids = model.generate(
        input_features.to("cuda"),
        max_new_tokens=5
    )[0]
# Note here that we don't need all the tokens corresonding to the full
# text: we're just interested in the language detection here. So save
# time by quitting after detecting just a few tokens.

language_token = int(predicted_ids[1])

detected_language = token_to_language.get(
    int(predicted_ids[1]), None)

In [None]:
print(f'Language detected: {detected_language}')
print(processor.decode(predicted_ids))