# Run Open Llama speculative sampling on Inf2 & Trn1

In speculative sampling, we use use a smaller "draft model" to speculate future tokens. These are then sent to the larger "target model", which accepts/rejects these tokens.  

For a more detailed understanding, please refer to the original paper by DeepMind titled ["Accelerating Large Language Model Decoding with Speculative Sampling"](https://arxiv.org/abs/2302.01318).

In this example we perform speculative sampling using the Hugging Face "openlm-research/open_llama_13b" model and Hugging Face "openlm-research/open_llama_3b".
Here, the 13b model is considered the target model and the 3b model is considered the draft model.

The example has the following main sections:

1. Set up the Jupyter Notebook
2. Install dependencies
3. Download and construct the model
5. Split the model state_dict into multiple files
6. Perform speculative sampling

This Jupyter Notebook can be run on an Inf2 instance (inf2.48xlarge) or Trn1 instance (trn1.32xlarge)

## Set up the Jupyter Notebook

The following steps set up Jupyter Notebook and launch this tutorial:

1. Clone the AWS Neuron Samples repo to your instance using

    git clone https://github.com/aws-neuron/aws-neuron-samples.git

2. Navigate to the transformers-neuronx inference samples folder

    cd aws-neuron-samples/torch-neuronx/transformers-neuronx/inference

3. Follow the instructions in Jupyter Notebook QuickStart to run Jupyter Notebook on your instance.

4. Locate this tutorial in your Jupyter Notebook session (speculative_sampling.ipynb) and launch it. Follow the rest of the instructions in this tutorial.



## Install Dependencies

This tutorial requires the following pip packages:

- torch-neuronx
- neuronx-cc
- sentencepiece
- transformers
- transformers-neuronx

Most of these packages will be installed when configuring your environment using the torch-neuronx inference setup guide. The additional dependencies must be installed here:


In [None]:
!pip install transformers-neuronx sentencepiece

## Download and construct the model

We download and construct the draft and target models using the Hugging Face from_pretrained method.


In [None]:
from transformers.models.auto import AutoModelForCausalLM

draft_model = AutoModelForCausalLM.from_pretrained("openlm-research/open_llama_3b", low_cpu_mem_usage=True)
target_model = AutoModelForCausalLM.from_pretrained("openlm-research/open_llama_13b", low_cpu_mem_usage=True)

## Split the model state_dict into multiple files

For the sake of reducing host memory usage, it is recommended to save the model state_dict as multiple files, as opposed to one monolithic file given by torch.save. This "split-format" state_dict can be created using the save_pretrained_split function. With this checkpoint format, the Neuron model loader can load parameters to the Neuron device high-bandwidth memory (HBM) directly by keeping at most one layer of model parameters in the CPU main memory.

In [4]:
from transformers_neuronx.module import save_pretrained_split

save_pretrained_split(draft_model, './open-llama-3b-split')
save_pretrained_split(target_model, './open-llama-13b-split')

## Perform speculative sampling

We now load and compile the draft model and the target model.
We use the Neuron `LlamaForSampling` class to load both models. Without extra configuration, autoregressive sampling is used as default.

Since we need to perform regular autoregressive sampling in the draft model, we load and compile it using the default options.
For the target model, we need to explicitly enable speculative decoding by calling the function enable_speculative_decoder(k) and this will let the model compiled for computing a window of k tokens at a time.

Note that when loading the models, we must use the same `tp_degree`. Attempting to use a different value for the draft/target model will result in a load failure.

In [None]:
import time
import torch
from transformers import AutoTokenizer
from transformers_neuronx.llama.model import LlamaForSampling

print("\nStarting to compile Draft Model....")
# Load draft model
draft_neuron_model = LlamaForSampling.from_pretrained('./open-llama-3b-split', n_positions=256, batch_size=1, tp_degree=8, amp='f32')
# compile to neuron 
draft_neuron_model.to_neuron()
print("\nCompleted compilation of Draft Model")

print("\nStarting to compile Target Model....")
# Load target model
target_neuron_model = LlamaForSampling.from_pretrained('./open-llama-13b-split', n_positions=256, batch_size=1, tp_degree=8, amp='f32')
# Enable speculative decoder
target_neuron_model.enable_speculative_decoder(4)
# compile to neuron 
target_neuron_model.to_neuron()
print("\nCompleted compilation of Target Model")

Next, we initialize the tokenizer and the text prompt. 
By default, we use the `DefaultTokenAcceptor` provided by Neuron. This follows the same acceptance logic as the [original DeepMind paper](https://arxiv.org/abs/2302.01318). 
If you choose to have a different acceptance logic, you can always create your own token acceptor class which needs to be a subclass of `TokenAcceptor` class. This needs to then be passed to the `SpeculativeGenerator` class. 

We then initialize the `SpeculativeGenerator` class and pass the draft model, target model and speculation length as arguments. We can use this to call the `sample()` function and get the final sampled tokens after using the tokenizer to decode them. 

Comparing the response generation time between speculative sampling and autoregressive sampling, we see that speculative sampling is faster than autoregressive sampling.

In [None]:
from transformers_neuronx.speculation import SpeculativeGenerator
from transformers import LlamaTokenizer

#Initialize tokenizer and text prompt
tokenizer = LlamaTokenizer.from_pretrained("openlm-research/open_llama_3b")
prompt = "Hello, I'm a generative AI language model."
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

# create SpeculativeGenerator
spec_gen = SpeculativeGenerator(draft_neuron_model, target_neuron_model, 4)

# call speculative sampling on given input
start_spec_timer = time.time()

print("Starting to call Speculative Sampling..")
response = spec_gen.sample(
    input_ids=input_ids,
    sequence_length=50,
)
end_spec_timer = time.time()

generated_text = tokenizer.decode(response[0])
print(f"\nDecoded tokens: {generated_text}")

print(f"\nSpeculative sampling response generation took {end_spec_timer - start_spec_timer} ms")

start_auto_r_timer = time.time()
autor_response = target_neuron_model.sample(input_ids=input_ids, sequence_length=50)
end_auto_r_timer = time.time()

print(f"\nAutoregressive sampling response generation took {end_auto_r_timer - start_auto_r_timer} ms")
