# Speculative Decoding using pre-trained distill-whisper model for Multi language

## Multilingual Speech Transcription

In [1]:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    attn_implementation="sdpa",
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from datasets import load_dataset

dataset = load_dataset("mpanda27/voxpopuli_fi_pseudo_labelled", "fi", split="test")

In [3]:
import time

def generate_with_time(model, inputs, **kwargs):
    start_time = time.time()
    outputs = model.generate(**inputs, **kwargs)
    generation_time = time.time() - start_time
    return outputs, generation_time

In [4]:
from tqdm import tqdm

all_time = 0
predictions = []
references = []

for sample in tqdm(dataset):
    audio = sample["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    inputs = inputs.to(device=device, dtype=torch.float16)

    output, gen_time = generate_with_time(model, inputs)
    all_time += gen_time
    predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
    references.append(processor.tokenizer._normalize(sample["normalized_text"]))

print(all_time)

  0%|          | 0/199 [00:00<?, ?it/s]Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
100%|██████████| 199/199 [13:10<00:00,  3.97s/it] 

704.5053243637085





In [5]:
from evaluate import load

wer = load("wer")

print(wer.compute(predictions=predictions, references=references))

0.1504890895410083


### Distilled Model

In [6]:
from transformers import AutoModelForCausalLM

assistant_model_id = "mpanda27/distil-whisper-large-v3"

assistant_model = AutoModelForCausalLM.from_pretrained(
    assistant_model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    attn_implementation="sdpa",
)

assistant_model.to(device);

In [7]:
def assisted_generate_with_time(model, inputs, **kwargs):
    start_time = time.time()
    outputs = model.generate(**inputs, assistant_model=assistant_model, **kwargs)
    generation_time = time.time() - start_time
    return outputs, generation_time

In [8]:
all_time = 0
predictions_distilled = []
references = []

for sample in tqdm(dataset):
    audio = sample["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    inputs = inputs.to(device=device, dtype=torch.float16)

    output, gen_time = assisted_generate_with_time(model, inputs)
    all_time += gen_time
    predictions_distilled.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
    references.append(processor.tokenizer._normalize(sample["normalized_text"]))

print(all_time)

  0%|          | 0/199 [00:00<?, ?it/s]From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.
100%|██████████| 199/199 [04:41<00:00,  1.41s/it]

266.5648021697998





In [9]:
print(wer.compute(predictions=predictions_distilled, references=references))

0.1504890895410083


In [10]:
# assistant_model_id = "openai/whisper-tiny"

# assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
#     assistant_model_id,
#     torch_dtype=torch_dtype,
#     low_cpu_mem_usage=True,
#     use_safetensors=True,
#     attn_implementation="sdpa",
# )

# assistant_model.to(device);

In [11]:
# dataset = load_dataset("sanchit-gandhi/voxpopuli_dummy", "nl", split="validation")

In [12]:
# all_time = 0
# predictions = []
# references = []

# for sample in tqdm(dataset):
#     audio = sample["audio"]
#     inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
#     inputs = inputs.to(device=device, dtype=torch.float16)

#     output, gen_time = generate_with_time(model, inputs, language="nl", task="transcribe")
#     all_time += gen_time
#     predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
#     references.append(processor.tokenizer._normalize(sample["normalized_text"]))

# wer_result = wer.compute(predictions=predictions, references=references)

# print("Time:", all_time)
# print("WER:", wer_result)

Right! We have our baseline time of 117 seconds and a WER of 12.8%. Let's re-run the generation process using speculative decoding:

In [13]:
# all_time = 0
# predictions = []
# references = []

# for sample in tqdm(dataset):
#     audio = sample["audio"]
#     inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
#     inputs = inputs.to(device=device, dtype=torch.float16)

#     output, gen_time = assisted_generate_with_time(model, inputs, language="nl", task="transcribe")
#     all_time += gen_time
#     predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
#     references.append(processor.tokenizer._normalize(sample["normalized_text"]))

# wer_result = wer.compute(predictions=predictions, references=references)

# print("Time:", all_time)
# print("WER:", wer_result)