# LLaDA Inference Profiling on HumanEval

This notebook runs inference on the LLaDA model using the HumanEval dataset and collects wall time statistics for profiling.

## Beneficial Statistics for Profiling
For inference profiling, especially with diffusion models like LLaDA, the following statistics are beneficial:
1.  **Total Wall Time (Latency)**: The total time taken to generate a complete solution.
2.  **Time Per Step**: Since LLaDA is a diffusion model, measuring the time taken per diffusion step is crucial.
3.  **Throughput**: If batching is used, samples per second.
4.  **Memory Usage**: Peak GPU memory consumption.


In [None]:
import os
import torch
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_dataset

os.environ['HF_HOME'] = './hf_models/'
from transformers import AutoTokenizer, AutoModel

# Import local generate function
from generate import generate

# Setup device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")


Using device: cuda


In [4]:
# Load Model and Tokenizer
# Using the Instruct model as per chat.py example
model_id = 'GSAI-ML/LLaDA-8B-Instruct'

print(f"Loading model: {model_id}")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModel.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()

# Ensure padding side is left for generation
if tokenizer.padding_side != 'left':
    tokenizer.padding_side = 'left'


Loading model: GSAI-ML/LLaDA-8B-Instruct


  `use_auth_token` is passed to a function, the `use_auth_token` value is passed
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
A new version of the following files was downloaded from https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct:
- configuration_llada.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct:
- configuration_llada.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct:
- modeling_l

KeyboardInterrupt: 

In [None]:
# Load HumanEval Dataset
print("Loading HumanEval dataset...")
ds = load_dataset("openai_humaneval", split="test")
print(f"Loaded {len(ds)} problems.")

# Display a sample
print("\nSample Problem:")
print(ds[0]['prompt'])


In [None]:
def run_inference(model, tokenizer, prompt_text, steps=128, gen_length=128, block_length=32):
    """
    Runs inference and measures wall time.
    """
    # Prepare input
    messages = [{"role": "user", "content": prompt_text}]
    formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    
    inputs = tokenizer(formatted_prompt, return_tensors="pt", padding=True, add_special_tokens=False)
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)
    
    # Measure time
    start_time = time.perf_counter()
    
    with torch.no_grad():
        out = generate(
            model, 
            input_ids, 
            attention_mask=attention_mask,
            steps=steps, 
            gen_length=gen_length, 
            block_length=block_length, 
            temperature=0., 
            cfg_scale=0., 
            remasking='low_confidence'
        )
    
    end_time = time.perf_counter()
    wall_time = end_time - start_time
    
    # Decode output
    generated_text = tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)[0]
    
    return generated_text, wall_time

print("Inference function defined.")


In [None]:
# Run Inference on a Subset
num_samples = 5  # Adjust as needed
results = []

print(f"Running inference on first {num_samples} samples...")

for i in range(num_samples):
    problem = ds[i]
    prompt = problem['prompt']
    task_id = problem['task_id']
    
    print(f"Processing {task_id}...")
    
    try:
        # Using parameters from chat.py/generate.py examples
        output, duration = run_inference(
            model, 
            tokenizer, 
            prompt, 
            steps=128, 
            gen_length=128, 
            block_length=32
        )
        
        results.append({
            "task_id": task_id,
            "wall_time": duration,
            "output_length": len(output),
            "output": output
        })
        print(f"  Time: {duration:.4f}s")
        
    except Exception as e:
        print(f"  Error: {e}")

df_results = pd.DataFrame(results)


In [None]:
# Calculate Statistics
if not df_results.empty:
    stats = {
        "Mean Latency": df_results['wall_time'].mean(),
        "Median Latency": df_results['wall_time'].median(),
        "Std Dev": df_results['wall_time'].std(),
        "Min": df_results['wall_time'].min(),
        "Max": df_results['wall_time'].max(),
        "P95": df_results['wall_time'].quantile(0.95),
        "P99": df_results['wall_time'].quantile(0.99)
    }

    print("Wall Time Statistics (seconds):")
    for k, v in stats.items():
        print(f"{k}: {v:.4f}")
else:
    print("No results to analyze.")


In [None]:
# Visualize Latency
if not df_results.empty:
    plt.figure(figsize=(10, 6))
    sns.histplot(df_results['wall_time'], kde=True, bins=10)
    plt.title('Inference Wall Time Distribution')
    plt.xlabel('Time (seconds)')
    plt.ylabel('Count')
    plt.axvline(df_results['wall_time'].mean(), color='r', linestyle='--', label=f"Mean: {df_results['wall_time'].mean():.2f}s")
    plt.legend()
    plt.show()
