# Speedup Llama 3 70B with Speculative Decoding

In this guide, we'll show you how to implement speculative decoding with Llama 3.1 70B model (base) and Llama 3.2 3B model (assistant). Transformers has a `generate` API where we pass an `assistant_model` to enable speculative decoding. By the end, you'll see how this technique can significantly speed up text generation (upto 2x), making your workflows faster and more efficient.

## Imports and Setup

In [None]:
!pip install -Uq transformers accelerate

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import torch
from tqdm import tqdm

In [None]:
# supress the warning in the notebook
import logging
import warnings
logging.getLogger('transformers').setLevel(logging.ERROR)
warnings.filterwarnings('ignore')

In [None]:
# base models
checkpoint = "meta-llama/Meta-Llama-3.1-70B"      # <-- Larger Model
assistant_checkpoint = "meta-llama/Llama-3.2-3B"  # <-- Smaller Model

## Prepare Models

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    checkpoint,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
assistant_model = AutoModelForCausalLM.from_pretrained(
    assistant_checkpoint,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

### Prepare Inputs

In [None]:
prompt = "Alice and Bob"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

### Benchmark the text generation speed


In [None]:
max_new_tokens = 256
num_iterations = 10

### Greedy Decoding

In [None]:
# warmup
for _ in range(2):
    model.generate(
        **model_inputs,
        do_sample=False,
        assistant_model=assistant_model,
        max_new_tokens=max_new_tokens,
        eos_token_id=-1,
    )

In [None]:
# without assistance
print("🐘 Big Model Generation")
duration = 0
for _ in tqdm(range(num_iterations)):
    start = time.time()
    outputs = model.generate(
        **model_inputs,
        do_sample=False,
        max_new_tokens=256,
        eos_token_id=-1,
    )
    end = time.time()
    duration = duration + (end-start)

print(f"\nThroughput: {(num_iterations * max_new_tokens) / (duration):.4f} (tokens/sec)")


# with assistance
print("\n\n🤝 Big Model Generation with Assistance")
duration = 0
for _ in tqdm(range(10)):
    start = time.time()
    outputs = model.generate(
        **model_inputs,
        do_sample=False,
        max_new_tokens=256,
        eos_token_id=-1,
        assistant_model=assistant_model,
    )
    end = time.time()
    duration = duration + (end-start)

print(f"\nThroughput: {(num_iterations * max_new_tokens) / (duration):.4f} (tokens/sec)")

### Multinomial Decoding

In [None]:
# warmup
for _ in range(2):
    model.generate(
        **model_inputs,
        do_sample=True,
        temperature=0.2,
        assistant_model=assistant_model,
        max_new_tokens=max_new_tokens,
        eos_token_id=-1,
    )

In [None]:
# without assistance
print("🐘 Big Model Generation (multinomial)")
duration = 0
for _ in tqdm(range(num_iterations)):
    start = time.time()
    outputs = model.generate(
        **model_inputs,
        do_sample=True,
        temperature=0.2,
        max_new_tokens=256,
        eos_token_id=-1,
    )
    end = time.time()
    duration = duration + (end-start)

print(f"\nThroughput: {(num_iterations * max_new_tokens) / (duration):.4f} (tokens/sec)")


# with assistance
print("\n\n🤝 Big Model Generation with Assistance (multinomial)")
duration = 0
for _ in tqdm(range(10)):
    start = time.time()
    outputs = model.generate(
        **model_inputs,
        do_sample=True,
        temperature=0.2,
        max_new_tokens=256,
        eos_token_id=-1,
        assistant_model=assistant_model,
    )
    end = time.time()
    duration = duration + (end-start)

print(f"\nThroughput: {(num_iterations * max_new_tokens) / (duration):.4f} (tokens/sec)")

## Conclusion


| | Base Model Throughput | |
| :-- | :-- | --: |
| | simple | assisted |
| greedy | 4.9464 | **9.7564** |
| multinomial | 4.9309 | **6.2531** |


The throughput increases with assisted generation! 🎉

While this process gains speed, it often comes at the cost of increased memory usage, so it's important to balance both metrics.

## Next Steps

To know more about speculative decoding we suggest reading:

- [Assisted Generation](https://huggingface.co/blog/assisted-generation): Learn more about assisted generation.
- [Speculative Decoding Docs](https://huggingface.co/docs/transformers/main/en/generation_strategies#speculative-decoding): See how transformers does decoding with an assistant.

## Acknowledgements

1. [Vaibhav Srivastav](https://huggingface.co/reach-vb) for the thorough review and suggestions to make the tutorial better.
2. [Joao Gante](https://huggingface.co/joaogante) for clarifying my doubts on speculative decoding.