# LLaDA Inference Profiling on HumanEval

This notebook runs inference on the LLaDA model using the HumanEval dataset and collects wall time statistics for profiling.

## Beneficial Statistics for Profiling
For inference profiling, especially with diffusion models like LLaDA, the following statistics are beneficial:
1.  **Total Wall Time (Latency)**: The total time taken to generate a complete solution.
2.  **Time Per Step**: Since LLaDA is a diffusion model, measuring the time taken per diffusion step is crucial.
3.  **Throughput**: If batching is used, samples per second.
4.  **Memory Usage**: Peak GPU memory consumption.


In [None]:
import os
import subprocess
import torch
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_dataset

os.environ['HF_HOME'] = '/root/LLaDA/hf_models/'
from transformers import AutoTokenizer, AutoModel

from generate import generate

# Setup device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Load Model and Tokenizer
model_id = 'GSAI-ML/LLaDA-8B-Instruct'
cache_path = '/root/LLaDA/hf_models/hub'

print(f"Loading model: {model_id}")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, cache_dir=cache_path, local_files_only=True)
model = AutoModel.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.bfloat16, cache_dir=cache_path, local_files_only=True).to(device).eval()

# Ensure padding side is left for generation
if tokenizer.padding_side != 'left':
    tokenizer.padding_side = 'left'

In [None]:
# Load HumanEval Dataset
print("Loading HumanEval dataset...")
dataset_path = '/root/LLaDA/hf_models/datasets/openai_humaneval'
ds = load_dataset(path=dataset_path, split="test")
print(f"Loaded {len(ds)} problems.")

# Display a sample
print("\nSample Problem:")
print(ds[0]['prompt'])


In [None]:
def profiled_generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0., cfg_scale=0., remasking='low_confidence', mask_id=126336):
    """
    A version of the generate function with internal profiling for steps and token masks.
    """
    # Internal helper for getting blocks
    def get_blocks(start, end, step):
        if step == 0:
            return []
        return [i for i in range(start, end, step)] + [end]

    x = prompt
    
    # Profiling data storage
    step_timings = []
    step_details = []

    total_start = time.perf_counter()

    for i in get_blocks(0, gen_length, block_length):
        block_start_time = time.perf_counter()
        
        # Extend sequence with masks for this block
        prefix_len = x.shape[1]
        x = torch.cat((x, torch.full((1, block_length), mask_id, dtype=torch.long).to(x.device)), dim=1)
        
        # Determine mask indices for the new block
        mask_indices = torch.zeros(x.shape, dtype=torch.bool, device=x.device)
        mask_indices[:, prefix_len:] = True

        # Denoising loop for this block
        for step in range(steps):
            step_start = time.perf_counter()
            
            # 1. Model Forward Pass
            if cfg_scale > 0:
                # CFG logic omitted for brevity/speed in profiling unless requested, assuming cfg=0 based on notebook usage
                logits = model(x).logits
            else:
                logits = model(x).logits

            # 2. Sampling
            # Only look at the masked positions
            current_masks = (x == mask_id)
            
            # Simple greedy/temperature sampling logic (simplified from standard generate for clarity)
            if temperature > 0:
                probs = torch.softmax(logits / temperature, dim=-1)
                x0 = torch.multinomial(probs.view(-1, probs.shape[-1]), 1).view(x.shape)
            else:
                x0 = torch.argmax(logits, dim=-1)

            # 3. Remasking Strategy
            # Calculate confidence
            probs = torch.softmax(logits, dim=-1)
            confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
            
            # Apply remasking schedule
            ratio = (steps - step - 1) / steps
            
            # Identify which tokens to keep vs remask
            # We only care about the currently masked/generated part, but the tensor shape must match x
            
            # Create a mask for where we should re-mask based on confidence
            # This logic mimics 'low_confidence' remasking
            
            # Get confidence only for the tokens we are currently generating (the block)
            # Note: In LLaDA, we might be re-evaluating the whole sequence or just the block. 
            # Usually, the mask is applied to the whole sequence, but we only change tokens that were masks.
            
            # Find the N tokens with lowest confidence among those that were masks
            # To do this correctly per sample in batch (batch=1 here):
            
            is_mask_position = mask_indices # The positions we are generating in this block
            
            # We temporarily update x with the predicted tokens x0
            x_curr = torch.where(is_mask_position, x0, x)
            
            if step < steps - 1:
                # Determine how many to remask
                num_to_remask = int(block_length * ratio)
                
                # Filter confidence to only relevant positions
                # Set confidence of non-mask positions to infinity so they aren't picked
                conf_for_ranking = confidence.clone()
                conf_for_ranking[~is_mask_position] = float('inf')
                
                # Find indices of lowest confidence
                # We flatten to find topk (smallest)
                
                # Get the indices of the lowest confidence tokens in the block
                # We need to re-mask 'num_to_remask' tokens.
                
                if num_to_remask > 0:
                    # Get values and indices of lowest confidence
                    # We only care about the last 'block_length' positions usually
                    # But let's do it globally for the mask_indices
                    
                    # Mask out non-generated tokens
                    masked_conf = torch.where(is_mask_position, confidence, torch.tensor(float('inf'), device=x.device))
                    
                    # Get indices of the k lowest confidence tokens
                    # We flatten, find indices, then unflatten
                    flat_conf = masked_conf.view(-1)
                    _, remask_flat_indices = torch.topk(flat_conf, k=num_to_remask, largest=False)
                    
                    # Create a remask tensor
                    remask_mask = torch.zeros_like(flat_conf, dtype=torch.bool)
                    remask_mask[remask_flat_indices] = True
                    remask_mask = remask_mask.view(x.shape)
                    
                    # Apply remasking
                    x = torch.where(remask_mask, mask_id, x_curr)
                    
                    # Record details
                    remasked_count = num_to_remask
                    sampled_tokens = (~remask_mask & is_mask_position).sum().item()
                else:
                    x = x_curr
                    remasked_count = 0
                    sampled_tokens = is_mask_position.sum().item()
            else:
                # Last step, keep everything
                x = x_curr
                remasked_count = 0
                sampled_tokens = is_mask_position.sum().item()

            step_end = time.perf_counter()
            step_duration = step_end - step_start
            
            step_timings.append(step_duration)
            step_details.append({
                "block_idx": i,
                "step_idx": step,
                "duration": step_duration,
                "remasked_count": remasked_count,
                "sampled_count": sampled_tokens,
                "total_generated_so_far": (x != mask_id).sum().item() - prefix_len
            })

    total_end = time.perf_counter()
    
    return x, step_details

def run_inference(model, tokenizer, prompt_text, steps=64, gen_length=64, block_length=32):
    """
    Runs inference and measures wall time with detailed profiling.
    """
    # Prepare input
    messages = [{"role": "user", "content": prompt_text}]
    formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    
    inputs = tokenizer(formatted_prompt, return_tensors="pt", padding=True, add_special_tokens=False)
    input_ids = inputs['input_ids'].to(device)
    
    # Measure time
    start_time = time.perf_counter()
    
    with torch.no_grad():
        out, step_details = profiled_generate(
            model, 
            input_ids, 
            steps=steps, 
            gen_length=gen_length, 
            block_length=block_length, 
            temperature=0., 
            cfg_scale=0., 
            remasking='low_confidence'
        )
    
    end_time = time.perf_counter()
    wall_time = end_time - start_time
    
    # Decode output
    generated_text = tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)[0]
    
    return generated_text, wall_time, step_details

In [None]:
# Run Inference on a Subset
num_samples = 5
results = []
all_step_details = []

print(f"Running inference on first {num_samples} samples...")

for i in range(num_samples):
    problem = ds[i]
    prompt = problem['prompt']
    task_id = problem['task_id']
    
    print(f"Processing {task_id}...")
    
    try:
        # Using parameters from chat.py/generate.py examples
        output, duration = run_inference(
            model=model, 
            tokenizer=tokenizer, 
            prompt=prompt, 
            steps=128, 
            gen_length=128, 
            block_length=32
        )
        
        results.append({
            "task_id": task_id,
            "wall_time": duration,
            "output_length": len(output),
            "output": output
        })
        
		# Add task_id to each step detail
		for d in step_details:
			d['task_id'] = task_id
		all_step_details.extend(step_details)

		print(f"  Time: {duration:.4f}s")
        
    except Exception as e:
        print(f"  Error: {e}")
		import traceback
		traceback.print_exc()

df_results = pd.DataFrame(results)
df_steps = pd.DataFrame(all_step_details)

In [None]:
# Calculate Statistics
if not df_results.empty:
    stats = {
        "Mean Latency": df_results['wall_time'].mean(),
        "Median Latency": df_results['wall_time'].median(),
        "Std Dev": df_results['wall_time'].std(),
        "Min": df_results['wall_time'].min(),
        "Max": df_results['wall_time'].max(),
        "P95": df_results['wall_time'].quantile(0.95),
        "P99": df_results['wall_time'].quantile(0.99)
    }

    print("Wall Time Statistics (seconds):")
    for k, v in stats.items():
        print(f"{k}: {v:.4f}")
else:
    print("No results to analyze.")


In [None]:
# Visualize Latency
if not df_results.empty:
    plt.figure(figsize=(10, 6))
    sns.histplot(df_results['wall_time'], kde=True, bins=10)
    plt.title('Inference Wall Time Distribution')
    plt.xlabel('Time (seconds)')
    plt.ylabel('Count')
    plt.axvline(df_results['wall_time'].mean(), color='r', linestyle='--', label=f"Mean: {df_results['wall_time'].mean():.2f}s")
    plt.legend()
    plt.show()


# Display Step Statistics
if not df_steps.empty:
    print("\nStep-level Statistics (First 5 steps):")
    print(df_steps.head())
    
    # Plot Step Times
    plt.figure(figsize=(12, 5))
    sns.lineplot(data=df_steps, x='step_idx', y='duration', hue='block_idx', marker='o')
    plt.title('Inference Time per Step across Blocks')
    plt.xlabel('Step Index')
    plt.ylabel('Time (seconds)')
    plt.grid(True)
    plt.show()

    # Plot Remasking
    plt.figure(figsize=(12, 5))
    sns.lineplot(data=df_steps, x='step_idx', y='remasked_count', hue='block_idx', marker='x')
    plt.title('Tokens Remasked per Step')
    plt.xlabel('Step Index')
    plt.ylabel('Count of Remasked Tokens')
    plt.grid(True)
    plt.show()