# Predict on Fleurs with confidence scores

In [1]:
from datasets import load_dataset
from utils import *

## I. Load data

In [2]:
fleurs_en = load_dataset("google/fleurs", "en_us")
fleurs_en = fleurs_en.remove_columns(['id', 'num_samples', 'path', 'gender', 'lang_id', 'language', 'lang_group_id'])
fleurs_en = fleurs_en['train']

Found cached dataset fleurs (/home/antonin/.cache/huggingface/datasets/google___fleurs/en_us/2.0.0/aabb39fb29739c495517ac904e2886819b6e344702f0a5b5283cb178b087c94a)


  0%|          | 0/3 [00:00<?, ?it/s]

## II. Predict with Wav2Vec + 4-grams

In [3]:
processor, model = load_wav2vec_model("patrickvonplaten/wav2vec2-base-960h-4-gram")

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at patrickvonplaten/wav2vec2-base-960h-4-gram and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Please use `allow_patterns` and `ignore_patterns` instead.


Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

In [6]:
fleurs_en_wav2vec = fleurs_en.map(predict_with_confidence_wav2vec, 
    fn_kwargs={"model": model, "processor": processor}, 
    remove_columns=['audio'],
    batch_size = 16)
fleurs_en_wav2vec

  0%|          | 0/2602 [00:00<?, ?ex/s]

Dataset({
    features: ['transcription', 'raw_transcription', 'string_pred', 'tokens_pred', 'probs_tokens_pred', 'ground_truth', 'wer'],
    num_rows: 2602
})

In [8]:
fleurs_en_wav2vec.save_to_disk(os.path.join(predictions_path, 'fleurs_en_wav2vec'))

## III. Predict with Whiser

In [9]:
processor, model = load_whisper_model('openai/whisper-base', 'English')

In [10]:
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language = "en", task = "transcribe")
fleurs_en_whisper = fleurs_en.map(predict_with_confidence_whisper, 
    fn_kwargs={"processor": processor, "model": model, "lang": "en"}, 
    batched=True, \
    remove_columns=['audio'], 
    batch_size = 1)
fleurs_en_whisper

  0%|          | 0/2602 [00:00<?, ?ba/s]



Dataset({
    features: ['transcription', 'raw_transcription', 'string_pred', 'tokens_pred', 'probs_tokens_pred', 'ground_truth', 'wer'],
    num_rows: 2602
})

In [12]:
fleurs_en_whisper.save_to_disk(os.path.join(predictions_path, 'fleurs_en_whisper'))