In [2]:
import time
from datasets import load_dataset

In [7]:
dataset = load_dataset(
    "hf-internal-testing/librispeech_asr_demo", "clean", split="validation")

Found cached dataset librispeech_asr_demo (/Users/ryanselesnik/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_demo/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)


In [17]:
dataset

Dataset({
    features: ['file', 'audio', 'text', 'speaker_id', 'chapter_id', 'id'],
    num_rows: 73
})

In [50]:
from transformers import pipeline
import whisper


def transcribe(model, audio, size='tiny'):
    if 'wav' in model:
        model = pipeline(task='automatic-speech-recognition',
                     model='facebook/wav2vec2-large-960h')
        text = model(audio) 
        return text[0]
    elif 'whisper' in model:
        model = whisper.load_model(size)
        text = model.transcribe(audio)
        return text['text']


In [61]:
import pandas as pd
import statistics

from regex import R
rtf_data = pd.DataFrame()
rtf_data['model'] = ['tiny', 'base']
# For each model size
for i, model_size in enumerate(rtf_data['model']):
    inference_times = []
    sample_durations = []
    rtfs = []
    # For each audio file in the dataset 
    for audio in dataset[:5]['audio']:
        audio_data = audio['array']
        start = time.perf_counter()
        transcribe('whisper', audio_data, size=model_size)
        inf_time = time.perf_counter() - start
        inference_times.append(inf_time)

        sample_duration = len(list(audio_data)) / 16000.0 # sample rate
        sample_durations.append(sample_duration)

        rtfs.append(inf_time / sample_duration)
        


    # append mean infereance time to the correspond model size row
    rtf_data.at[i, 'av_inf_time'] = statistics.fmean(inference_times)
    rtf_data.at[i, 'av_duration'] = statistics.fmean(sample_durations)
    rtf_data.at[i, 'av_RTF'] = statistics.fmean(rtfs)






In [62]:
rtf_data


Unnamed: 0,model,time,sample_length,RTF
0,tiny,4.179262,12.491,0.442954
1,base,8.56648,12.491,0.814013
