# Speculative Decoding

Speculative Decoding was proposed in Fast Inference from Transformers via Speculative Decoding by Yaniv Leviathan et. al. from Google. It works on the premise that a faster, assistant model very often generates the same tokens as a larger main model.

This project aims to test this using Open AI's Whisper speech transcription model.

# Benchmarking Whisper large-v2


In [None]:
!pip install torch
!pip install transformers
!pip install accelerate

In [None]:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-large-v2"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id,
    torch_dtype=torch_dtype,
    # low_cpu_mem_usage=True, # fast loading
    use_safetensors=True, # secure (over pickle)
    attn_implementation="sdpa", # Flash Attention speed-up
)

model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

In [None]:
!pip install datasets
from datasets import load_dataset

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

In [None]:
# time taken test 

import time

def time_gen(model, inputs, **kwargs):
    start = time.time()
    outputs = model.generate(**inputs, **kwargs)
    gen_time = time.time() - start
    return outputs, gen_time

In [None]:
!pip install librosa
!pip install soundfile
!pip install langchain
!pip install sentence-transformers
!pip uninstall numpy -y
!pip install numpy==1.26.4
from tqdm import tqdm

all_time = 0
pred = []
ref = []

def dummy_npwarn_decorator_factory():
  def npwarn_decorator(x):
    return x
  return npwarn_decorator
np._no_nep50_warning = getattr(np, '_no_nep50_warning', dummy_npwarn_decorator_factory)


for sample in tqdm(dataset):
    audio = sample["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    inputs = inputs.to(device=device, dtype=torch.float16)

    output, gen_time = time_gen(model, inputs)
    all_time += gen_time
    pred.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
    ref.append(processor.tokenizer._normalize(sample["text"]))

print(f"Total time taken: {all_time:.2f}s")

In [None]:
!pip install evaluate
!pip install jiwer

from evaluate import load

wer = load("wer")
print(wer.compute(predictions=pred, references=ref))


# Using Speculative Decoding

In [None]:
from transformers import AutoModelForCausalLM

assistant_model_id = "distil-whisper/distil-large-v2"

assistant_model = AutoModelForCausalLM.from_pretrained(
    assistant_model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True, # fast loading
    use_safetensors=True, # secure (over pickle)
    attn_implementation="sdpa", # Flash Attention speed-up
)

assistant_model.to(device)

In [None]:
def assisted_generate_with_time(model, inputs, **kwargs):
    start_time = time.time()
    outputs = model.generate(**inputs, assistant_model=assistant_model, **kwargs)
    generation_time = time.time() - start_time
    return outputs, generation_time

In [None]:
all_time = 0
pred = []
ref = []

for sample in tqdm(dataset):
    audio = sample["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    inputs = inputs.to(device=device, dtype=torch.float16)

    out, gen_time = assisted_generate_with_time(model, inputs)
    all_time += gen_time
    pred.append(processor.batch_decode(out, skip_special_tokens=True, normalize=True)[0])
    ref.append(processor.tokenizer._normalize(sample["text"]))

print(f"Total time taken: {all_time:.2f}s")

In [None]:
print(wer.compute(predictions=pred, references=ref))

# Conclusion

We end with a slight speed-up in inference time, while maintaining the same WER score. 