# Inference Example with Medusa

In this Jupyter notebook, we're going to demonstrate how to perform inference using the Medusa model on an interesting story prompt. Let's get the ball rolling!

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3" # define GPU id, remove if you want to use all GPUs available
import torch
from tqdm import tqdm
import time
from contextlib import contextmanager
import numpy as np
from medusa.model.modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
from medusa.model.medusa_model import MedusaModel
from medusa.model.kv_cache import *
from medusa.model.utils import *
from medusa.model.medusa_choices import *
import transformers
from huggingface_hub import hf_hub_download

In [None]:
# check checkpoint exists
checkpoint_path = "../austin_medusa/decoder_lm_head.pt"
if not os.path.exists(checkpoint_path):
    state = torch.load('../austin_medusa/medusa_two_decoder_lm_head.pt', map_location=torch.device('cpu'))
    new_state = {}
    new_state['medusa_decoder_layers'] = {}
    new_state['medusa_rms_norm'] = {}
    new_state['medusa_head'] = {}
    for key, value in state.items():
            if 'rms_norm' in key:
                new_state['medusa_rms_norm'][key.replace('medusa_rms_norm.','')] = value
            elif 'head' in key:
                new_state['medusa_head'][key.replace('medusa_head.','')] = value
            else:
                new_state['medusa_decoder_layers'][key.replace('medusa_decoder_layers.','')] = value
    torch.save(new_state, '../austin_medusa/decoder_lm_head.pt')

## Medusa Forward Function

We define the medusa_forward function that will be used for generating stories based on the provided prompts.


In [2]:
@contextmanager
def timed(wall_times, key):
    start = time.time()
    torch.cuda.synchronize()
    yield
    torch.cuda.synchronize()
    end = time.time()
    elapsed_time = end - start
    wall_times[key].append(elapsed_time)

def medusa_forward(input_ids, model, tokenizer, medusa_choices, temperature, posterior_threshold, posterior_alpha, max_steps = 512):
    wall_times = {'medusa': [], 'tree': [], 'posterior': [], 'update': [], 'init': []}
    
    with timed(wall_times, 'init'):
        if hasattr(model, "medusa_choices") and model.medusa_choices == medusa_choices:
            # Load the cached medusa buffer
            medusa_buffers = model.medusa_buffers
        else:
            # Initialize the medusa buffer
            medusa_buffers = generate_medusa_buffers(
                medusa_choices, device=model.base_model.device
            )
        model.medusa_buffers = medusa_buffers
        model.medusa_choices = medusa_choices

        # Initialize the past key and value states
        if hasattr(model, "past_key_values"):
            past_key_values = model.past_key_values
            past_key_values_data = model.past_key_values_data
            current_length_data = model.current_length_data
            # Reset the past key and value states
            current_length_data.zero_()
        else:
            (
                past_key_values,
                past_key_values_data,
                current_length_data,
            ) = initialize_past_key_values(model.base_model)
            model.past_key_values = past_key_values
            model.past_key_values_data = past_key_values_data
            model.current_length_data = current_length_data

        input_len = input_ids.shape[1]
        reset_medusa_mode(model)
        medusa_logits, logits = initialize_medusa(
                input_ids, model, medusa_buffers["medusa_attn_mask"], past_key_values
        )
    new_token = 0

    for idx in range(max_steps): 
        with timed(wall_times, 'medusa'):
            candidates, tree_candidates = generate_candidates(
                    medusa_logits,
                    logits,
                    medusa_buffers["tree_indices"],
                    medusa_buffers["retrieve_indices"],
                )

        with timed(wall_times, 'tree'):
            medusa_logits, logits, outputs = tree_decoding(
                    model,
                    tree_candidates,
                    past_key_values,
                    medusa_buffers["medusa_position_ids"],
                    input_ids,
                    medusa_buffers["retrieve_indices"],
                )

        with timed(wall_times, 'posterior'):
            best_candidate, accept_length = evaluate_posterior(
                    logits, candidates, temperature, posterior_threshold, posterior_alpha
                )
        
        with timed(wall_times, 'update'):
            input_ids, logits, medusa_logits, new_token = update_inference_inputs(
                    input_ids,
                    candidates,
                    best_candidate,
                    accept_length,
                    medusa_buffers["retrieve_indices"],
                    outputs,
                    logits,
                    medusa_logits,
                    new_token,
                    past_key_values_data,
                    current_length_data,
                )

        if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
            break

    return input_ids, new_token, idx, wall_times


## Model Loading

We load the model and tokenizer using the specified paths and configurations.


In [3]:
model_name = '../austin_medusa'
model = MedusaModel.from_pretrained(
    model_name,
    medusa_num_heads = 3,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="auto"
)
tokenizer = model.get_tokenizer()

medusa_choices = mc_sim_7b_63_depth3



Overriding medusa_num_heads as: 3


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. If you see this, DO NOT PANIC! This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


## Setting Inference Parameters

Next, we set some parameters that will be used during the inference process.


In [4]:
temperature = 0.
posterior_threshold = 0.09
posterior_alpha = 0.3

## Setting The Prompt

The following is the story prompt we will use for generating our story in the demo.


In [5]:
prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Hi, could you share a tale about a charming llama that grows Medusa-like hair and starts its own coffee shop? ASSISTANT:"

## Performing Inference

Using the set parameters and the defined function, let's generate our story!


In [8]:
with torch.inference_mode():
    input_ids = tokenizer([prompt]).input_ids
    output_ids, new_token, idx, wall_time = medusa_forward(
                    torch.as_tensor(input_ids).cuda(),
                    model,
                    tokenizer,
                    medusa_choices,
                    temperature,
                    posterior_threshold,
                    posterior_alpha,
                )
    output_ids = output_ids[0][len(input_ids[0]) :]
    print("Output length:", output_ids.size(-1))
    print("Compression ratio:", new_token / idx)

Output length: 403
Compression ratio: tensor(2.5833, device='cuda:0')


## Decoding The Output

Let's decode the generated output to obtain our story.


In [9]:
output = tokenizer.decode(
                    output_ids,
                    spaces_between_special_tokens=False,
                )
print(output)

Once upon a time, in a small village nestled in the Andes mountains, there lived a charming llama named Luna. Luna was known for her kind heart and her love of coffee. She would often spend her afternoons sipping on a steaming cup of joe at the local café, chatting with the villagers and enjoying the warmth of the sun on her back.

One day, as Luna was grazing on some fresh grass, she noticed that her hair was starting to grow longer and thicker. At first, she didn't think much of it, but as the days went on, her hair continued to grow and change. It became thick and wiry, with sharp spikes protruding from it.

Luna was confused and a little scared by her new appearance. She had always been a gentle creature, and now she looked like a monster. She knew that she couldn't stay in the village anymore, so she set off on a journey to find a new home.

As she wandered through the mountains, Luna stumbled upon a beautiful clearing. In the center of the clearing stood a small cottage, with a s

## Analyzing Wall Times

We will now break down and analyze the wall times during the inference process.

You might notice a significant time consumption during the initialization phase. This is primarily due to the GPU cache initialization process on the first run.

For a clearer perspective, you can try rerunning the decoding segment again to observe the differences.

In [10]:
max_length = 50

def format_string(text, value, max_length):
    value_str = "{:.3f}".format(value)
    return f"{text:<{max_length - len(value_str)}}{value_str}"

time_init = np.sum(wall_time['init'] )
time_medusa = np.sum(wall_time['medusa'] )
time_tree = np.sum(wall_time['tree'] )
time_posterior = np.sum(wall_time['posterior'] )
time_update = np.sum(wall_time['update'] )
time_total = time_init + time_medusa + time_tree + time_posterior + time_update

print('='*max_length)
print(format_string("Wall time init: ", time_init, max_length))
print(format_string("Wall time medusa: ", time_medusa, max_length))
print(format_string("Wall time Tree: ", time_tree, max_length))
print(format_string("Wall time Posterior: ", time_posterior, max_length))
print(format_string("Wall time Update: ", time_update, max_length))
print('-'*max_length)
print(format_string("Wall time portion medusa: ", time_medusa / time_total, max_length))
print(format_string("Wall time portion Tree: ", time_tree / time_total, max_length))
print(format_string("Wall time portion Posterior: ", time_posterior / time_total, max_length))
print(format_string("Wall time portion Update: ", time_update / time_total, max_length))
print('-'*max_length)
print(format_string("Tokens/second: ", new_token / time_total, max_length))
print('='*max_length)

Wall time init:                              0.026
Wall time medusa:                            0.031
Wall time Tree:                              3.794
Wall time Posterior:                         0.024
Wall time Update:                            0.049
--------------------------------------------------
Wall time portion medusa:                    0.008
Wall time portion Tree:                      0.967
Wall time portion Posterior:                 0.006
Wall time portion Update:                    0.013
--------------------------------------------------
Tokens/second:                             102.687
