# Inference Example with Medusa

In this Jupyter notebook, we're going to demonstrate how to perform inference using the Medusa model on an interesting story prompt. Let's get the ball rolling!

In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# os.environ["CUDA_VISIBLE_DEVICES"] = "3" # define GPU id, remove if you want to use all GPUs available
import torch
from tqdm import tqdm
import time
from contextlib import contextmanager
import numpy as np
import pandas as pd
from medusa.model.modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
from medusa.model.medusa_model import MedusaModel
from medusa.model.kv_cache import *
from medusa.model.utils import *
from medusa.model.medusa_choices import *
import transformers
from huggingface_hub import hf_hub_download
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


## Medusa Forward Function

We define the medusa_forward function that will be used for generating stories based on the provided prompts.


In [2]:
activations = {}

def capture_activation(layer_name):
    def hook(module, input, output):
        activations[layer_name] = output.detach().cpu()
    return hook

In [3]:
@contextmanager
def timed(wall_times, key):
    start = time.time()
    torch.cuda.synchronize()
    yield
    torch.cuda.synchronize()
    end = time.time()
    elapsed_time = end - start
    wall_times[key].append(elapsed_time)

def register_hooks(model):
    for idx, layer in enumerate(model.base_model.model.layers):
        mlp_module = layer.mlp
        layer_name = f"layer_{idx+1}_mlp"
        mlp_module.register_forward_hook(capture_activation(layer_name))

def medusa_forward(input_ids, model, tokenizer, medusa_choices, temperature, posterior_threshold, posterior_alpha, max_steps = 512):
    wall_times = {'medusa': [], 'tree': [], 'posterior': [], 'update': [], 'init': []}
    
    with timed(wall_times, 'init'):
        if hasattr(model, "medusa_choices") and model.medusa_choices == medusa_choices:
            # Load the cached medusa buffer
            medusa_buffers = model.medusa_buffers
        else:
            # Initialize the medusa buffer
            medusa_buffers = generate_medusa_buffers(
                medusa_choices, device=model.base_model.device
            )
        model.medusa_buffers = medusa_buffers
        model.medusa_choices = medusa_choices

        # Initialize the past key and value states
        if hasattr(model, "past_key_values"):
            past_key_values = model.past_key_values
            past_key_values_data = model.past_key_values_data
            current_length_data = model.current_length_data
            # Reset the past key and value states
            current_length_data.zero_()
        else:
            (
                past_key_values,
                past_key_values_data,
                current_length_data,
            ) = initialize_past_key_values(model.base_model)
            model.past_key_values = past_key_values
            model.past_key_values_data = past_key_values_data
            model.current_length_data = current_length_data

        input_len = input_ids.shape[1]
        reset_medusa_mode(model)
        medusa_logits, logits = initialize_medusa(
                input_ids, model, medusa_buffers["medusa_attn_mask"], past_key_values
        )
    new_token = 0

    for idx in range(max_steps): 
        with timed(wall_times, 'medusa'):
            candidates, tree_candidates = generate_candidates(
                    medusa_logits,
                    logits,
                    medusa_buffers["tree_indices"],
                    medusa_buffers["retrieve_indices"],
                )

        with timed(wall_times, 'tree'):
            medusa_logits, logits, outputs = tree_decoding(
                    model,
                    tree_candidates,
                    past_key_values,
                    medusa_buffers["medusa_position_ids"],
                    input_ids,
                    medusa_buffers["retrieve_indices"],
                )

        with timed(wall_times, 'posterior'):
            best_candidate, accept_length = evaluate_posterior(
                    logits, candidates, temperature, posterior_threshold, posterior_alpha
                )
        
        with timed(wall_times, 'update'):
            input_ids, logits, medusa_logits, new_token = update_inference_inputs(
                    input_ids,
                    candidates,
                    best_candidate,
                    accept_length,
                    medusa_buffers["retrieve_indices"],
                    outputs,
                    logits,
                    medusa_logits,
                    new_token,
                    past_key_values_data,
                    current_length_data,
                )

        if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
            break

    return input_ids, new_token, idx, wall_times


## Model Loading

We load the model and tokenizer using the specified paths and configurations.


In [4]:
model_name = '/workspace/laurel/models/medusa-vicuna-7b-v1.3'
model = MedusaModel.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="cuda"
)
tokenizer = model.get_tokenizer()

medusa_choices = mc_sim_7b_63

You are using a model of type llama to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
  return torch.load(checkpoint_file, map_location=map_location)
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:09<00:00,  4.85s/it]
Some weights of MedusaModelLlama were not initialized from the model checkpoint at /workspace/laurel/models/vicuna-7b-v1.3 and are newly initialized: ['medusa_head.3.0.lin

In [5]:
register_hooks(model)

## Setting Inference Parameters

Next, we set some parameters that will be used during the inference process.


In [6]:
temperature = 0.
posterior_threshold = 0.09
posterior_alpha = 0.3

## Setting The Prompt

The following is the story prompt we will use for generating our story in the demo.


In [7]:
path="/workspace/laurel/project/Medusa/layerwise/data/gsm8k/main/test-00000-of-00001.parquet"
df = pd.read_parquet(path)
test = df["question"].str.replace("\n", " ")

## Performing Inference

Using the set parameters and the defined function, let's generate our story!


In [11]:
!python -c "import torch; print(torch.backends.cuda.is_built()); print(torch.cuda.get_device_name(0)); print(torch.cuda.is_available())"

True
NVIDIA RTX A5000
True


In [None]:
all_cross_entropy_losses = []
all_kl_div_losses = []

for prompt in test:    
    with torch.inference_mode():
        input_ids = tokenizer([prompt]).input_ids
        output_ids, new_token, idx, wall_time = medusa_forward(
                        torch.as_tensor(input_ids).cuda(),
                        model,
                        tokenizer,
                        medusa_choices,
                        temperature,
                        posterior_threshold,
                        posterior_alpha,
                    )
        output_ids = output_ids[0][len(input_ids[0]) :]
        # print("Output length:", output_ids.size(-1))
        # print("Compression ratio:", new_token / idx)

        output_text = tokenizer.decode(output_ids, skip_special_tokens=True)
        
        cross_entropy_losses = []
        kl_div_losses = []
        for i in range(1, 33):
            cross_entropy_loss = []
            kl_div_loss = []
            for j in range(1, 33):
                layer1_name = f'layer_{i}_mlp'
                layer2_name = f'layer_{j}_mlp'

                # Extract logits from consecutive layers
                layer1_logits = activations[layer1_name][0]
                layer2_logits = activations[layer2_name][0]

                # Cross-Entropy
                layer2_targets = layer2_logits.argmax(dim=-1)
                ce_loss = F.cross_entropy(layer1_logits, layer2_targets, reduction='mean')
                
                # KL Divergence
                layer1_probs = F.log_softmax(layer1_logits, dim=-1)
                layer2_probs = F.log_softmax(layer2_logits, dim=-1)
                kl_loss = F.kl_div(layer1_probs.log(), layer2_probs, reduction='batchmean', log_target=True)
                # if torch.isnan(kl_div_loss) or torch.isinf(kl_div_loss):
                #     print(f"Layer {i} logits: Min: {layer1_logits.min()}, Max: {layer1_logits.max()}, Mean: {layer1_logits.mean()}")
                #     print(f"Layer {i+1} logits: Min: {layer2_logits.min()}, Max: {layer2_logits.max()}, Mean: {layer2_logits.mean()}")

                cross_entropy_loss.append(ce_loss.item())
                kl_div_loss.append(kl_loss.item())
        
            cross_entropy_losses.append(cross_entropy_loss)
            kl_div_losses.append(kl_div_loss)

        all_cross_entropy_losses.append(cross_entropy_losses)
        all_kl_div_losses.append(kl_div_losses)

In [None]:
all_cross_entropy_losses = np.array(all_cross_entropy_losses)
all_kl_div_losses = np.array(all_kl_div_losses)
mean_cross_entropy_loss = np.mean(all_cross_entropy_losses, axis=0)
mean_kl_div_loss = np.mean(all_kl_div_losses, axis=0)

In [None]:
import matplotlib.pyplot as plt

# Cross-Entropy Loss
plt.imshow(mean_cross_entropy_loss, cmap='viridis', interpolation='none')
plt.xlabel('Layer i')
plt.ylabel('Layer i')
plt.title('avg Cross-Entropy Loss between Layers')
plt.show()

# KL Divergence Loss
plt.imshow(mean_kl_div_loss, cmap='viridis', interpolation='none')
plt.xlabel('Layer i')
plt.ylabel('Layer j')
plt.title('avg KL Divergence Loss between Layers')
plt.show()