In [1]:
import torch
import numpy as np
import pandas as pd
import time
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
import gc # For garbage collection
from vllm import LLM, SamplingParams

# --- Configuration ---
LLADA_MODEL_NAME = "GSAI-ML/LLaDA-8B-Instruct"
LLAMA3_MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# --- Load Tokenizers ---
print("Loading tokenizers...")
llada_tokenizer = AutoTokenizer.from_pretrained(LLADA_MODEL_NAME, trust_remote_code=True)
llama3_tokenizer = AutoTokenizer.from_pretrained(LLAMA3_MODEL_NAME)

# Add a padding token to Llama3 tokenizer if it doesn't exist
if llama3_tokenizer.pad_token is None:
    llama3_tokenizer.pad_token = llama3_tokenizer.eos_token
# Note: The model config for pad_token_id will be set when the model is loaded.
    
print(f"Tokenizers loaded. Models will be loaded on-demand to save memory.")
print(f"Using device: {DEVICE}")


INFO 08-12 11:31:22 [__init__.py:256] Automatically detected platform cuda.
Loading tokenizers...
Tokenizers loaded. Models will be loaded on-demand to save memory.
Using device: cuda


In [2]:
@torch.no_grad()
def get_num_transfer_tokens(mask_index, steps):
    """
    Computes the number of tokens to transition at each step of the diffusion process.
    This function is copied from the generation.ipynb notebook.
    """
    mask_num = mask_index.sum(dim=1, keepdim=True)
    base = mask_num // steps
    remainder = mask_num % steps
    num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
    for i in range(mask_num.size(0)):
        num_transfer_tokens[i, :remainder[i]] += 1
    return num_transfer_tokens

@torch.no_grad()
def llada_generate_simple(model, tokenizer, prompt_text, gen_length=512, block_length=128):
    """
    A simplified LLaDA generation function using a basic block generation approach.
    It does not include advanced sampling like Gumbel noise, CFG, or temperature.
    """
    mask_id = llada_tokenizer.vocab.get("[MASK]")  # Find MASK token id
    if mask_id is None:
        # Fallback if specific [MASK] isn't in vocab, which is unlikely for LLaDA
        mask_id = llada_tokenizer.mask_token_id or 0

    # Format prompt
    messages = [{"role": "user", "content": prompt_text}]
    formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    input_ids = tokenizer(formatted_prompt, return_tensors="pt")['input_ids'].to(DEVICE)
    prompt_len = input_ids.shape[1]

    # Initialize sequence with masks
    x = torch.full((1, prompt_len + gen_length), mask_id, dtype=torch.long, device=DEVICE)
    x[:, :prompt_len] = input_ids.clone()

    mask_index_original = (x == mask_id)
    
    # Simple block generation
    num_blocks = (gen_length + block_length - 1) // block_length  # Ceiling division
    steps = 4 # Fewer steps for speed benchmark

    for i in range(steps):
        mask_index = (x == mask_id)
        if not mask_index.any():
            break

        logits = model(x).logits
        
        # Simple greedy decoding
        x0 = torch.argmax(logits, dim=-1)
        
        # Determine which tokens to unmask based on confidence (softmax probability)
        p = torch.softmax(logits, dim=-1)
        x0_p = torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)).squeeze(-1)
        
        confidence = torch.where(mask_index, x0_p, -torch.inf)
        
        # Determine number of tokens to reveal in this step
        num_to_reveal = (mask_index_original.sum() // steps) + 1
        
        _, top_indices = torch.topk(confidence.view(-1), k=int(num_to_reveal))
        
        x.view(-1)[top_indices] = x0.view(-1)[top_indices]

    # Final greedy fill for any remaining masks
    mask_index = (x == mask_id)
    if mask_index.any():
        logits = model(x).logits
        x0 = torch.argmax(logits, dim=-1)
        x = torch.where(mask_index, x0, x)

    return tokenizer.batch_decode(x[:, prompt_len:], skip_special_tokens=True)[0]

print("LLaDA simple generation function loaded.")


LLaDA simple generation function loaded.


In [3]:
@torch.no_grad()
def llama3_generate(model, tokenizer, prompt_text, gen_length=512):
    """
    Standard autoregressive generation for Llama3.
    """
    messages = [{"role": "user", "content": prompt_text}]
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(DEVICE)
    
    # Generate text
    outputs = model.generate(
        input_ids,
        max_new_tokens=gen_length,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=False  # Use greedy decoding for speed comparison
    )
    
    # Decode the generated tokens
    response = outputs[0][input_ids.shape[-1]:]
    return tokenizer.decode(response, skip_special_tokens=True)

print("Llama3 generation function loaded.")


Llama3 generation function loaded.


In [4]:
def vllm_generate(model, tokenizer, prompt_text, gen_length=512):
    """
    Generates text using the vLLM engine.
    """
    # The tokenizer is used for creating the prompt, but vLLM handles tokenization internally
    messages = [{"role": "user", "content": prompt_text}]
    prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    
    sampling_params = SamplingParams(
        n=1,
        temperature=0.0,
        max_tokens=gen_length,
    )
    
    # Generate text
    outputs = model.generate(prompt, sampling_params)
    
    # Return the generated text from the first output
    return outputs[0].outputs[0].text

print("vLLM generation function loaded.")


vLLM generation function loaded.


In [5]:
# --- Benchmarking Framework ---
PROMPT_TEXT = "Explain the theory of general relativity in a few paragraphs."
GENERATION_LENGTHS = [512, 768, 1024]
NUM_TRIALS = 3  # Number of times to run each test to get an average

def run_benchmark(model_name, gen_function, tokenizer, gen_length, **kwargs):
    """
    Runs the generation benchmark for a given model and returns timing statistics.
    """
    total_time = 0
    total_tokens = 0
    
    # Warm-up run
    print(f"  Warm-up run for {gen_length} tokens...")
    gen_function(
        model=kwargs.get('model'), 
        tokenizer=tokenizer, 
        prompt_text=PROMPT_TEXT, 
        gen_length=gen_length,
        **kwargs.get('func_args', {})
    )
    torch.cuda.synchronize()

    # Timed trials
    for i in range(NUM_TRIALS):
        print(f"  Trial {i+1}/{NUM_TRIALS} for {gen_length} tokens...")
        start_time = time.time()
        
        generated_text = gen_function(
            model=kwargs.get('model'), 
            tokenizer=tokenizer, 
            prompt_text=PROMPT_TEXT, 
            gen_length=gen_length,
            **kwargs.get('func_args', {})
        )
        
        torch.cuda.synchronize()
        end_time = time.time()
        
        elapsed_time = end_time - start_time
        num_tokens = len(tokenizer.encode(generated_text))
        
        total_time += elapsed_time
        total_tokens += num_tokens

    avg_time = total_time / NUM_TRIALS
    avg_tokens = total_tokens / NUM_TRIALS
    tokens_per_sec = avg_tokens / avg_time if avg_time > 0 else 0
    
    return {
        "Model": model_name,
        "Gen Length": gen_length,
        "Avg Time (s)": avg_time,
        "Avg Tokens": avg_tokens,
        "Tokens/Sec": tokens_per_sec
    }

print("Benchmarking framework loaded.")


Benchmarking framework loaded.


In [None]:
# --- Run Benchmarks ---
results_list = []

# --- LLaDA Benchmark ---
print(f"Loading LLaDA model: {LLADA_MODEL_NAME}...")
llada_model = AutoModel.from_pretrained(
    LLADA_MODEL_NAME, 
    trust_remote_code=True, 
    torch_dtype=torch.bfloat16
).to(DEVICE).eval()
print("LLaDA model loaded.")

llada_params = {'block_length': 128}
for length in GENERATION_LENGTHS:
    print(f"\\n{'='*20} Benchmarking LLaDA for {length} tokens {'='*20}")
    llada_results = run_benchmark(
        "LLaDA-8B",
        llada_generate_simple,
        llada_tokenizer,
        length,
        model=llada_model,
        func_args=llada_params
    )
    results_list.append(llada_results)

# Clear LLaDA model from memory
print("\\nClearing LLaDA model from memory...")
del llada_model
gc.collect()
torch.cuda.empty_cache()
print("LLaDA model cleared.")

Loading LLaDA model: GSAI-ML/LLaDA-8B-Instruct...


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

LLaDA model loaded.
  Warm-up run for 512 tokens...
  Trial 1/3 for 512 tokens...
  Trial 2/3 for 512 tokens...
  Trial 3/3 for 512 tokens...
  Warm-up run for 768 tokens...
  Trial 1/3 for 768 tokens...
  Trial 2/3 for 768 tokens...
  Trial 3/3 for 768 tokens...
  Warm-up run for 1024 tokens...
  Trial 1/3 for 1024 tokens...
  Trial 2/3 for 1024 tokens...
  Trial 3/3 for 1024 tokens...
\nClearing LLaDA model from memory...
LLaDA model cleared.


: 

In [None]:
# --- Llama3 vLLM Benchmark ---
print(f"\\nLoading Llama3 model with vLLM: {LLAMA3_MODEL_NAME}...")
# vLLM handles device placement automatically
llama3_model = LLM(model=LLAMA3_MODEL_NAME, trust_remote_code=True)
print("Llama3 vLLM model loaded.")

for length in GENERATION_LENGTHS:
    print(f"\\n{'='*20} Benchmarking Llama3 (vLLM) for {length} tokens {'='*20}")
    llama3_results = run_benchmark(
        "Llama-3-8B (vLLM)",
        vllm_generate,
        llama3_tokenizer,
        length,
        model=llama3_model
    )
    results_list.append(llama3_results)

# Clear Llama3 model from memory
print("\\nClearing Llama3 model from memory...")
del llama3_model
gc.collect()
torch.cuda.empty_cache()
print("Llama3 model cleared.")

# Convert results to a DataFrame for easy analysis
results_df = pd.DataFrame(results_list)

print("\\nBenchmark Complete!")
results_df

\nLoading Llama3 model with vLLM: meta-llama/Meta-Llama-3-8B-Instruct...
INFO 08-12 11:32:13 [config.py:583] This model supports multiple tasks: {'classify', 'score', 'reward', 'embed', 'generate'}. Defaulting to 'generate'.
INFO 08-12 11:32:13 [config.py:1693] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 08-12 11:32:25 [__init__.py:256] Automatically detected platform cuda.
INFO 08-12 11:32:31 [core.py:53] Initializing a V1 LLM engine (v0.8.0) with config: model='meta-llama/Meta-Llama-3-8B-Instruct', speculative_config=None, tokenizer='meta-llama/Meta-Llama-3-8B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, deco

Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:02,  1.23it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.28it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:01<00:00,  1.85it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.68it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.60it/s]



INFO 08-12 11:32:50 [loader.py:429] Loading weights took 2.61 seconds
INFO 08-12 11:32:50 [gpu_model_runner.py:1140] Model loading took 14.9595 GB and 3.535962 seconds
ERROR 08-12 11:32:59 [core.py:340] EngineCore hit an exception: Traceback (most recent call last):
ERROR 08-12 11:32:59 [core.py:340]   File "/home/mahan/.venvs/default/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 332, in run_engine_core
ERROR 08-12 11:32:59 [core.py:340]     engine_core = EngineCoreProc(*args, **kwargs)
ERROR 08-12 11:32:59 [core.py:340]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-12 11:32:59 [core.py:340]   File "/home/mahan/.venvs/default/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 287, in __init__
ERROR 08-12 11:32:59 [core.py:340]     super().__init__(vllm_config, executor_class, log_stats)
ERROR 08-12 11:32:59 [core.py:340]   File "/home/mahan/.venvs/default/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 62, in __init__
ERROR 08-12 11:32:59 [c

In [None]:
# --- Visualize Results ---
fig, ax = plt.subplots(figsize=(12, 7))

# Plotting the data
pivot_df = results_df.pivot(index='Gen Length', columns='Model', values='Tokens/Sec')
pivot_df.plot(kind='bar', ax=ax, width=0.4)

# Formatting the plot
ax.set_title('LLaDA-8B vs. Llama-3-8B Generation Speed', fontsize=16)
ax.set_xlabel('Generation Length (Number of Tokens)', fontsize=12)
ax.set_ylabel('Tokens per Second', fontsize=12)
ax.tick_params(axis='x', rotation=0)
ax.grid(axis='y', linestyle='--', alpha=0.7)

ax.legend(title='Model')

# Adding labels on top of the bars
for container in ax.containers:
    ax.bar_label(container, fmt='%.2f', label_type='edge', fontsize=10, padding=3)

plt.tight_layout()
plt.show()
