# Multi-Token Prediction and MEDUSA: Accelerating Large Language Models

This notebook demonstrates the implementation of two advanced techniques for accelerating language model inference:

1. **Multi-Token Prediction (MTP)** based on the paper "Better & Faster Large Language Models via Multi-token Prediction"
2. **MEDUSA** framework based on the paper "MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads"

Both techniques modify the standard next-token prediction task to predict multiple future tokens at once, using different architectural approaches.

## 1. Setup and Imports

In [None]:
!pip install matplotlib tqdm datasets transformers[torch] accelerate>=0.26.0

import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from copy import deepcopy

import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

import transformers
print(f"PyTorch version: {torch.__version__}")
print(f"Transformers version: {transformers.__version__}")

## 2. Data Preprocessing

We'll use the MetamathQA dataset, which includes mathematical and computational reasoning tasks.

In [None]:
from data_preprocessing import preprocess_data, explore_dataset, validate_tokenization
explore_dataset("meta-math/MetaMathQA")

In [None]:
max_seq_len = 512
train_batch_size = 4
val_batch_size = 4
split_ratio = 0.1
seed = 42
max_examples = 500

train_dataloader, val_dataloader, tokenizer = preprocess_data(
    model_name="gpt2",
    max_seq_len=max_seq_len,
    train_batch_size=train_batch_size,
    val_batch_size=val_batch_size,
    split_ratio=split_ratio,
    seed=seed,
    max_examples=max_examples,
    num_tokens_to_predict=4
)

## 3. Model Architectures

We'll implement and compare three different model architectures:
1. Standard Next-Token Prediction (NTP)
2. Multi-Token Prediction (MTP)
3. MEDUSA

### 3.1 Standard Next-Token Prediction Model

The standard model predicts the next token given the previous tokens.

In [None]:
def analyze_model_logits(models, prompt, tokenizer, top_k=5):
    """Analyze top logits for each model to understand prediction quality."""
    results = {}
    
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    
    for model_name, model in models.items():
        model.eval()
        with torch.no_grad():
            if hasattr(model, 'backbone'):  # MEDUSA model
                outputs = model.backbone(input_ids)
                logits = outputs.logits[:, -1, :]
            elif hasattr(model, 'num_tokens_to_predict'):  # Multi-token model
                outputs = model(input_ids)
                # Use the first head (standard next-token prediction)
                if isinstance(outputs, dict) and 'logits' in outputs:
                    logits = outputs['logits']
                    if len(logits.shape) == 4:  # [batch, heads, seq, vocab]
                        logits = logits[:, 0, -1, :]  # First head, last position
                    else:
                        logits = logits[:, -1, :]
                else:
                    logits = outputs.logits[:, -1, :]
            else:  # Standard model
                outputs = model(input_ids)
                logits = outputs.logits[:, -1, :]
            
            # Get top-k tokens
            top_logits, top_indices = torch.topk(logits, k=top_k, dim=-1)
            
            # Convert to probabilities
            top_probs = torch.softmax(top_logits, dim=-1)
            
            # Get token strings
            top_tokens = [tokenizer.decode([idx.item()]) for idx in top_indices[0]]
            
            # Calculate entropy (measure of uncertainty)
            all_probs = torch.softmax(logits, dim=-1)
            entropy = -torch.sum(all_probs * torch.log(all_probs + 1e-10), dim=-1)
            
            results[model_name] = {
                'tokens': top_tokens,
                'probs': top_probs[0].tolist(),
                'entropy': entropy.item(),
                'confidence': top_probs[0][0].item()  # Confidence in top prediction
            }
    
    return results

def compare_generation_quality(models, prompts, tokenizer, max_length=50):
    """Compare generation quality across models."""
    results = {}
    
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    
    for prompt in prompts:
        print(f"\n{'='*60}")
        print(f"Prompt: '{prompt}'")
        print('='*60)
        
        results[prompt] = {}
        
        for model_name, model in models.items():
            model.eval()
            try:
                inputs = tokenizer(prompt, return_tensors="pt", padding=True)
                input_ids = inputs["input_ids"].to(device)
                attention_mask = inputs["attention_mask"].to(device)
                
                with torch.no_grad():
                    if hasattr(model, 'backbone'):
                        generated = model.backbone.generate(
                            input_ids,
                            attention_mask=attention_mask,
                            max_new_tokens=max_length,
                            do_sample=False,
                            pad_token_id=tokenizer.pad_token_id,
                            eos_token_id=tokenizer.eos_token_id
                        )
                        generated_text = tokenizer.decode(generated[0], skip_special_tokens=True)
                        
                    elif hasattr(model, 'num_tokens_to_predict'):  # Mtp
                        generated = model.generate(
                            input_ids,
                            attention_mask=attention_mask,
                            max_new_tokens=max_length,
                            do_sample=False,
                            eos_token_id=tokenizer.eos_token_id
                        )
                        generated_text = tokenizer.decode(generated[0], skip_special_tokens=True)
                        
                    else:  #ntp
                        generated = model.generate(
                            input_ids,
                            attention_mask=attention_mask,
                            max_new_tokens=max_length,
                            do_sample=False,
                            pad_token_id=tokenizer.pad_token_id,
                            eos_token_id=tokenizer.eos_token_id,
                            repetition_penalty=1.1  # Reduce repetition
                        )
                        generated_text = tokenizer.decode(generated[0], skip_special_tokens=True)
                
                results[prompt][model_name] = generated_text
                print(f"\n{model_name}:")
                print(f"'{generated_text}'")
                
            except Exception as e:
                error_msg = f"Error: {str(e)[:100]}..."
                results[prompt][model_name] = error_msg
                print(f"\n{model_name}: {error_msg}")
    
    return results


In [None]:
from train_standard_model import train_standard_gpt2

standard_output_dir = "standard_gpt2_outputs"
num_train_epochs = 1
learning_rate = 5e-5
gradient_accumulation_steps = 8

print("Training standard next-token prediction model...")
standard_model = train_standard_gpt2(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    tokenizer=tokenizer,
    output_dir=standard_output_dir,
    num_train_epochs=num_train_epochs,
    learning_rate=learning_rate,
    gradient_accumulation_steps=gradient_accumulation_steps
)

### 3.2 Multi-Token Prediction Model

In [None]:
from train_modified_model import train_multi_token_gpt2, MultiTokenGPT2

multi_token_output_dir = "multi_gpt2_outputs"
num_tokens_to_predict = 4
#trunk_layers = 12 - (num_tokens_to_predict - 1) = 9

multi_token_model = train_multi_token_gpt2(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    tokenizer=tokenizer,
    output_dir=multi_token_output_dir,
    num_tokens_to_predict=num_tokens_to_predict,
    num_train_epochs=num_train_epochs,
    learning_rate=learning_rate,
    gradient_accumulation_steps=gradient_accumulation_steps
)

# Explain the architecture
trunk_layers = 12 - (num_tokens_to_predict - 1)  # = 9 for num_tokens_to_predict = 4
print(f"- Base model: GPT-2")
print(f"- Trunk layers: {trunk_layers} (shared processing)")
print(f"- Prediction heads: {num_tokens_to_predict} (first_head + {num_tokens_to_predict-1} extra_heads)")

### 3.3 MEDUSA Model

MEDUSA uses the standard model as its backbone

In [None]:
from medusa import train_medusa, MedusaModel, generate_text_with_medusa

medusa_output_dir = "medusa_outputs"
num_medusa_heads = 5  # MEDUSA typically uses 5 heads

# Use either a pretrained model or our standard model as the base
base_model = standard_output_dir  # Path to our trained standard model

print("Training MEDUSA model...")
medusa_model = train_medusa(
    base_model=base_model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    output_dir=medusa_output_dir,
    num_medusa_heads=num_medusa_heads,
    num_train_epochs=num_train_epochs,
    learning_rate=learning_rate,
    gradient_accumulation_steps=gradient_accumulation_steps,
    freeze_backbone=True  # MEDUSA-1: freeze backbone for efficient training
)

# Explain the architecture
print(f"MEDUSA Model Architecture:")
print(f"- Base model: Pretrained GPT-2")
print(f"- Number of MEDUSA heads: {num_medusa_heads}")
print(f"- Each head is a feed-forward network with SiLU activation and residual connection")
print(f"- Tree verification allows multiple candidate predictions to be verified in parallel")

### 4. Model Analysis and Diagnostics

In [None]:
def analyze_model_logits(models, prompt, tokenizer, top_k=5):
    """Analyze top logits for each model to understand prediction quality."""
    results = {}
    
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    
    for model_name, model in models.items():
        model.eval()
        with torch.no_grad():
            if hasattr(model, 'backbone'):  #medusa
                outputs = model.backbone(input_ids)
                logits = outputs.logits[:, -1, :]
            elif hasattr(model, 'num_tokens_to_predict'):  #mtp
                outputs = model(input_ids)
                if isinstance(outputs, dict) and 'logits' in outputs:
                    logits = outputs['logits']
                    if len(logits.shape) == 4:
                        logits = logits[:, 0, -1, :]
                    else:
                        logits = logits[:, -1, :]
                else:
                    logits = outputs.logits[:, -1, :]
            else:
                outputs = model(input_ids)
                logits = outputs.logits[:, -1, :]
            
            #top-k tokens
            top_logits, top_indices = torch.topk(logits, k=top_k, dim=-1)
            
            #probabilities
            top_probs = torch.softmax(top_logits, dim=-1)
            
            #token strings
            top_tokens = [tokenizer.decode([idx.item()]) for idx in top_indices[0]]
            
            #entropy
            all_probs = torch.softmax(logits, dim=-1)
            entropy = -torch.sum(all_probs * torch.log(all_probs + 1e-10), dim=-1)
            
            results[model_name] = {
                'tokens': top_tokens,
                'probs': top_probs[0].tolist(),
                'entropy': entropy.item(),
                'confidence': top_probs[0][0].item()
            }
    
    return results

def compare_generation_quality(models, prompts, tokenizer, max_length=50):
    """Compare generation quality across models - FIXED INDENTATION."""
    results = {}
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    
    for prompt in prompts:
        print(f"\n{'='*60}")
        print(f"Prompt: '{prompt}'")
        print('='*60)
        results[prompt] = {}
        
        for model_name, model in models.items():
            model.eval()
            try:
                inputs = tokenizer(prompt, return_tensors="pt", padding=True)
                input_ids = inputs["input_ids"].to(device)
                attention_mask = inputs["attention_mask"].to(device)
                
                with torch.no_grad():
                    if hasattr(model, 'backbone'):  # MEDUSA
                        generated = model.backbone.generate(
                            input_ids, attention_mask=attention_mask, max_new_tokens=max_length,
                            do_sample=False, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
                    elif hasattr(model, 'num_tokens_to_predict'):  # Multi-token
                        generated = model.generate(
                            input_ids, attention_mask=attention_mask, max_new_tokens=max_length,
                            do_sample=False, use_speculative=True)
                    else:  # Standard
                        generated = model.generate(
                            input_ids, attention_mask=attention_mask, max_new_tokens=max_length,
                            do_sample=False, pad_token_id=tokenizer.pad_token_id, 
                            eos_token_id=tokenizer.eos_token_id, repetition_penalty=1.1)
                
                generated_text = tokenizer.decode(generated[0], skip_special_tokens=True)
                results[prompt][model_name] = generated_text
                print(f"\n{model_name}:")
                print(f"'{generated_text}'")
            except Exception as e:
                error_msg = f"Error: {str(e)[:100]}..."
                results[prompt][model_name] = error_msg
                print(f"\n{model_name}: {error_msg}")
    
    return results


In [None]:
def compare_generation_quality(models, prompts, tokenizer, max_length=50):
    """Compare generation quality across models."""
    results = {}
    
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    
    for prompt in prompts:
        print(f"\n{'='*60}")
        print(f"Prompt: '{prompt}'")
        print('='*60)
        
        results[prompt] = {}
        
        for model_name, model in models.items():
            model.eval()
            try:
                #tokenization with attention mask
                inputs = tokenizer(prompt, return_tensors="pt", padding=True)
                input_ids = inputs["input_ids"].to(device)
                attention_mask = inputs["attention_mask"].to(device)
                
                with torch.no_grad():
                    if hasattr(model, 'backbone'):
                        generated = model.backbone.generate(
                            input_ids,
                            attention_mask=attention_mask,
                            max_new_tokens=max_length,
                            do_sample=False,
                            pad_token_id=tokenizer.pad_token_id,
                            eos_token_id=tokenizer.eos_token_id
                        )
                        generated_text = tokenizer.decode(generated[0], skip_special_tokens=True)
                        
                    elif hasattr(model, 'num_tokens_to_predict'):  # Multi-token model
                        # MultiTokenGPT2.generate() only accepts these parameters:
                        # input_ids, attention_mask, max_new_tokens, temperature, top_k, top_p, do_sample, use_speculative
                        generated = model.generate(
                            input_ids,
                            attention_mask=attention_mask,
                            max_new_tokens=max_length,
                            do_sample=False,
                            use_speculative=True  #speculative decoding
                        )
                        generated_text = tokenizer.decode(generated[0], skip_special_tokens=True)
                        
                    else:  # Standard model
                        # Standard generation with all parameters
                        generated = model.generate(
                            input_ids,
                            attention_mask=attention_mask,
                            max_new_tokens=max_length,
                            do_sample=False,
                            pad_token_id=tokenizer.pad_token_id,
                            eos_token_id=tokenizer.eos_token_id,
                            repetition_penalty=1.1
                        )
                        generated_text = tokenizer.decode(generated[0], skip_special_tokens=True)
                
                results[prompt][model_name] = generated_text
                print(f"\n{model_name}:")
                print(f"'{generated_text}'")
                
            except Exception as e:
                error_msg = f"Error: {str(e)[:100]}..."
                results[prompt][model_name] = error_msg
                print(f"\n{model_name}: {error_msg}")
    
    return results

print("done!")


In [None]:
def compare_generation_quality(models, prompts, tokenizer, max_length=50):
    results = {}
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    
    for prompt in prompts:
        print(f"\n{'='*60}")
        print(f"Prompt: '{prompt}'")
        print('='*60)
        results[prompt] = {}
        
        for model_name, model in models.items():
            model.eval()
            try:
                inputs = tokenizer(prompt, return_tensors="pt", padding=True)
                input_ids = inputs["input_ids"].to(device)
                attention_mask = inputs["attention_mask"].to(device)
                
                with torch.no_grad():
                    if hasattr(model, 'backbone'):  # MEDUSA
                        generated = model.backbone.generate(
                            input_ids, attention_mask=attention_mask, max_new_tokens=max_length,
                            do_sample=False, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
                    elif hasattr(model, 'num_tokens_to_predict'):  # Multi-token
                        generated = model.generate(
                            input_ids, attention_mask=attention_mask, max_new_tokens=max_length,
                            do_sample=False, use_speculative=True)
                    else:  # Standard
                        generated = model.generate(
                            input_ids, attention_mask=attention_mask, max_new_tokens=max_length,
                            do_sample=False, pad_token_id=tokenizer.pad_token_id, 
                            eos_token_id=tokenizer.eos_token_id, repetition_penalty=1.1)
                
                generated_text = tokenizer.decode(generated[0], skip_special_tokens=True)
                results[prompt][model_name] = generated_text
                print(f"\n{model_name}:")
                print(f"'{generated_text}'")
            except Exception as e:
                error_msg = f"Error: {str(e)[:100]}..."
                results[prompt][model_name] = error_msg
                print(f"\n{model_name}: {error_msg}")
    return results

print("done!")


In [None]:
models = {
    "Standard NTP": standard_model,
    "Multi-Token Prediction": multi_token_model, 
    "MEDUSA": medusa_model
}

print("=== LOGITS ANALYSIS ===")
test_prompt = "The Pythagorean theorem states that"
logits_analysis = analyze_model_logits(models, test_prompt, tokenizer)

for model_name, result in logits_analysis.items():
    print(f"\n{model_name} Top 5 Predictions:")
    print(f"Entropy: {result['entropy']:.3f} | Confidence: {result['confidence']:.3f}")
    for i, (token, prob) in enumerate(zip(result['tokens'], result['probs'])):
        print(f"  {i+1}. '{token}' (Probability: {prob:.4f})")


In [None]:
print("\n\n=== GENERATION QUALITY COMPARISON ===")

test_prompts = [
    "The area of a triangle with sides 3, 4, and 5 is",
    "To solve x^2 + 5x + 6 = 0, we",
    "The derivative of f(x) = x^3 is"
]

generation_results = compare_generation_quality(models, test_prompts, tokenizer, max_length=30)


## 4. Evaluation and Comparison

Now, we'll evaluate and compare the three models in terms of:
1. Generation quality
2. Inference speed
3. Acceptance rate for multi-token models

### 4.1 Generation Quality Evaluation

We'll use perplexity as a metric to evaluate the language modeling quality of the models.

In [None]:
def calculate_perplexity(model, dataloader):
    """Calculate perplexity of a model on a dataset."""
    model.eval()
    total_loss = 0
    total_tokens = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Calculating perplexity"):
            #inputs to only include what the model expects
            inputs = {}
            for k in ["input_ids", "attention_mask", "labels"]:
                if k in batch:
                    inputs[k] = batch[k].to(device)
            
            #copy of the labels for loss calculation
            labels = inputs["labels"]
            
            #handling for different model types
            if isinstance(model, MultiTokenGPT2):
                #MultiTokenGPT2, manually compute loss using the first head
                outputs = model(**inputs)
                
                #logits and compute loss manually
                logits = outputs["logits"]
                
                if logits is not None:
                    #the first head for perplexity calculation
                    if len(logits.shape) == 4: 
                        logits = logits[:, 0] 
                    
                    # Shift logits and labels for loss calculation
                    shift_logits = logits[..., :-1, :].contiguous()
                    shift_labels = labels[..., 1:].contiguous()
                    
                    # Calculate loss
                    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="sum")
                    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), 
                                   shift_labels.view(-1))
                    
                    # Count non-padding tokens
                    non_pad_tokens = (shift_labels != -100).sum().item()
                    
                    # Only add to totals if we got valid results
                    if non_pad_tokens > 0:
                        total_loss += loss.item()
                        total_tokens += non_pad_tokens
            elif isinstance(model, MedusaModel):
                #use only the backbone for perplexity calculation
                outputs = model.backbone(**inputs)
                
                if hasattr(outputs, "loss") and outputs.loss is not None:
                    loss = outputs.loss
                    
                    #non-padding tokens
                    non_pad_tokens = (labels != -100).sum().item()
                    
                    #add to totals if we got valid results
                    if non_pad_tokens > 0:
                        total_loss += loss.item() * non_pad_tokens
                        total_tokens += non_pad_tokens
            else:  # Standard model
                outputs = model(**inputs)
                
                if hasattr(outputs, "loss") and outputs.loss is not None:
                    loss = outputs.loss
                    
                    non_pad_tokens = (labels != -100).sum().item()
                    
                    if non_pad_tokens > 0:
                        total_loss += loss.item() * non_pad_tokens
                        total_tokens += non_pad_tokens
    
    if total_tokens == 0:
        return float("inf")  #if no tokens were processed
        
    avg_loss = total_loss / total_tokens
    perplexity = torch.exp(torch.tensor(avg_loss))
    
    return perplexity.item()

print("Calculating perplexity on validation set...")
standard_ppl = calculate_perplexity(standard_model, val_dataloader)
mtp_ppl = calculate_perplexity(multi_token_model, val_dataloader)
medusa_ppl = calculate_perplexity(medusa_model, val_dataloader)

print(f"Standard model perplexity: {standard_ppl:.2f}")
print(f"Multi-token model perplexity: {mtp_ppl:.2f}")
print(f"MEDUSA model perplexity: {medusa_ppl:.2f}")

### 4.2 Inference Speed Comparison

Now, let's compare the inference speed of the three approaches.

In [None]:
from medusa import generate_text_with_medusa
import time

def measure_inference_speed(model_name, model, tokenizer, prompt, max_new_tokens=100, num_runs=5, 
                           use_multi_token=False, use_medusa=False, tree_branching=(5,5,3,3,2)):
    """Measure inference speed of a model."""
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    # Set pad_token_id to eos_token_id if not set
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    
    print(f"Warming up {model_name}...")
    if use_medusa:
        output_ids = generate_text_with_medusa(model, tokenizer, prompt, max_new_tokens=20, tree_branching=tree_branching)
        _ = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    else:
        with torch.no_grad():
            if hasattr(model, "num_tokens_to_predict"):  # MultiTokenGPT2 model
                generate_kwargs = {
                    "max_new_tokens": 20,
                    "do_sample": False,
                    "attention_mask": inputs["attention_mask"]
                }
                
                if use_multi_token:
                    generate_kwargs["use_speculative"] = True
            else:  # Standard model
                generate_kwargs = {
                    "max_new_tokens": 20,
                    "do_sample": False,
                    "pad_token_id": tokenizer.pad_token_id,
                    "attention_mask": inputs["attention_mask"]
                }
                
            _ = model.generate(
                inputs["input_ids"],
                **generate_kwargs
            )
    
    # Measure time
    times = []
    tokens_generated = []
    
    print(f"Running inference on {model_name}...")
    for i in range(num_runs):
        start_time = time.time()
        
        if use_medusa:
            # Generate tokens using Medusa
            output_ids = generate_text_with_medusa(model, tokenizer, prompt, max_new_tokens=max_new_tokens, tree_branching=tree_branching)
            num_new_tokens = len(output_ids[0]) - len(inputs["input_ids"][0])
            output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        else:
            with torch.no_grad():
                #model-specific generate kwargs
                if hasattr(model, "num_tokens_to_predict"):
                    generate_kwargs = {
                        "max_new_tokens": max_new_tokens,
                        "do_sample": False,
                        "attention_mask": inputs["attention_mask"]
                    }
                    
                    if use_multi_token:
                        generate_kwargs["use_speculative"] = True
                else:
                    generate_kwargs = {
                        "max_new_tokens": max_new_tokens,
                        "do_sample": False,
                        "pad_token_id": tokenizer.pad_token_id,
                        "attention_mask": inputs["attention_mask"]
                    }
                    
                output_ids = model.generate(
                    inputs["input_ids"],
                    **generate_kwargs
                )
            num_new_tokens = output_ids.shape[1] - inputs["input_ids"].shape[1]
            # Decode for display (optional)
            output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        
        end_time = time.time()
        elapsed_time = end_time - start_time
        
        times.append(elapsed_time)
        tokens_generated.append(num_new_tokens)
        
        print(f"Run {i+1}: Generated {num_new_tokens} tokens in {elapsed_time:.2f}s ({num_new_tokens/elapsed_time:.2f} tokens/s)")
    
    # Calculate average
    avg_time = sum(times) / len(times)
    avg_tokens = sum(tokens_generated) / len(tokens_generated)
    tokens_per_second = avg_tokens / avg_time
    
    return {
        "model": model_name,
        "avg_time": avg_time,
        "avg_tokens": avg_tokens,
        "tokens_per_second": tokens_per_second
    }

# Test prompts
test_prompts = [
    "What is the derivative of f(x) = x^3 + 2x^2 - 5x + 7?",
    "Solve the equation: 3x^2 - 12 = 0",
    "If a triangle has sides of length 3, 4, and 5, what is its area?"
]

#inference speed tests
results = []

for prompt in test_prompts:
    print(f"\nTesting prompt: {prompt}")
    
    standard_result = measure_inference_speed("Standard NTP", standard_model, tokenizer, prompt)
    results.append(standard_result)
    
    mtp_result = measure_inference_speed("Multi-Token Prediction", multi_token_model, tokenizer, prompt, use_multi_token=True)
    results.append(mtp_result)
    
    medusa_result = measure_inference_speed("MEDUSA", medusa_model, tokenizer, prompt, use_medusa=True)
    results.append(medusa_result)

import pandas as pd

speed_df = pd.DataFrame(results)
print("\nInference Speed Summary:")
print(speed_df.groupby('model').mean())

plt.figure(figsize=(10, 6))
avg_speeds = speed_df.groupby('model')['tokens_per_second'].mean()
avg_speeds.plot(kind='bar')
plt.title('Average Inference Speed Comparison')
plt.ylabel('Tokens per Second')
plt.xticks(rotation=0)
plt.grid(axis='y', alpha=0.3)
for i, v in enumerate(avg_speeds):
    plt.text(i, v + 0.5, f"{v:.1f}", ha='center')
plt.tight_layout()
plt.show()