In [12]:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, AutoTokenizer

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    attn_implementation="sdpa",
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id) 

In [2]:
from datasets import load_dataset

dataset = load_dataset("mpanda27/voxpopuli_fi_pseudo_labelled", "fi", split="test")

In [3]:
import time

def generate_with_time(model, inputs, **kwargs):
    start_time = time.time()
    outputs = model.generate(**inputs, **kwargs)
    generation_time = time.time() - start_time
    return outputs, generation_time

In [4]:
from tqdm import tqdm

all_time = 0
predictions = []
references = []

for sample in tqdm(dataset):
    audio = sample["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    inputs = inputs.to(device=device, dtype=torch.float16)

    output, gen_time = generate_with_time(model, inputs)
    all_time += gen_time
    predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
    references.append(processor.tokenizer._normalize(sample["normalized_text"]))

print(all_time)

  0%|          | 0/199 [00:00<?, ?it/s]Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
100%|██████████| 199/199 [12:59<00:00,  3.91s/it] 

704.0587303638458





In [5]:
from evaluate import load

wer = load("wer")

print(wer.compute(predictions=predictions, references=references))

0.1504890895410083


### Whisper Model tiny

In [13]:
assistant_model_id = "openai/whisper-tiny"

assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
    assistant_model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    attn_implementation="sdpa",
)

assistant_model.to(device);
assistant_tokenizer = AutoTokenizer.from_pretrained(assistant_model_id)  # Load the tokenizer for the assistant model

In [14]:
def assisted_generate_with_time(model, inputs, **kwargs):
    start_time = time.time()
    outputs = model.generate(
        inputs.input_features,
        assistant_model=assistant_model,
        tokenizer=tokenizer,
        assistant_tokenizer=assistant_tokenizer,
        **kwargs,
    )
    generation_time = time.time() - start_time
    return outputs, generation_time

In [15]:
all_time = 0
predictions_distilled = []
references = []

for sample in tqdm(dataset):
    audio = sample["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    inputs = inputs.to(device=device, dtype=torch.float16)

    output, gen_time = assisted_generate_with_time(model, inputs)
    all_time += gen_time
    predictions_distilled.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
    references.append(processor.tokenizer._normalize(sample["normalized_text"]))

print(all_time)

  0%|          | 0/199 [00:18<?, ?it/s]


RuntimeError: Given groups=1, weight of size [384, 80, 3], expected input[1, 128, 3000] to have 80 channels, but got 128 channels instead

In [None]:
print(wer.compute(predictions=predictions_distilled, references=references))