In [None]:
import torch
import numpy as np
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

model_name = "GSAI-ML/LLaDA-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16)

# Move model to GPU if available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device).eval()
print(f"Model loaded on {device}")

In [None]:
def add_gumbel_noise(logits, temperature):
    '''
    The Gumbel max is a method for sampling categorical distributions.
    According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
    Thus, we use float64.
    '''
    if temperature == 0:
        return logits
    logits = logits.to(torch.float64)
    noise = torch.rand_like(logits, dtype=torch.float64)
    gumbel_noise = (- torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise


def get_num_transfer_tokens(mask_index, steps):
    '''
    In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
    Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
    the expected number of tokens transitioned at each step should be consistent.

    This function is designed to precompute the number of tokens that need to be transitioned at each step.
    '''
    mask_num = mask_index.sum(dim=1, keepdim=True)

    base = mask_num // steps
    remainder = mask_num % steps

    num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base

    for i in range(mask_num.size(0)):
        num_transfer_tokens[i, :remainder[i]] += 1

    return num_transfer_tokens

print("Core LLaDA functions loaded: add_gumbel_noise, get_num_transfer_tokens")


In [None]:
import torch.nn.functional as F
import numpy as np
from typing import List, Optional, Dict, Tuple
import random

@torch.no_grad()
def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0.,
             cfg_scale=0., remasking='low_confidence', mask_id=126336):
    '''
    Args:
        model: Mask predictor.
        prompt: A tensor of shape (1, L).
        steps: Sampling steps, less than or equal to gen_length.
        gen_length: Generated answer length.
        block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
        temperature: Categorical distribution sampling temperature.
        cfg_scale: Unsupervised classifier-free guidance scale.
        remasking: Remasking strategy. 'low_confidence' or 'random'.
        mask_id: The token id of [MASK] is 126336.
    '''
    x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
    x[:, :prompt.shape[1]] = prompt.clone()

    prompt_index = (x != mask_id)

    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length

    assert steps % num_blocks == 0
    steps = steps // num_blocks

    for num_block in range(num_blocks):
        block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
        for i in range(steps):
            mask_index = (x == mask_id)
            if cfg_scale > 0.:
                un_x = x.clone()
                un_x[prompt_index] = mask_id
                x_ = torch.cat([x, un_x], dim=0)
                logits = model(x_).logits
                logits, un_logits = torch.chunk(logits, 2, dim=0)
                logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
            else:
                logits = model(x).logits

            logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
            x0 = torch.argmax(logits_with_noise, dim=-1) # b, l

            if remasking == 'low_confidence':
                p = F.softmax(logits, dim=-1)
                x0_p = torch.squeeze(
                    torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
            elif remasking == 'random':
                x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
            else:
                raise NotImplementedError(remasking)

            x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf

            x0 = torch.where(mask_index, x0, x)
            confidence = torch.where(mask_index, x0_p, -np.inf)

            transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
            for j in range(confidence.shape[0]):
                _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
                transfer_index[j, select_index] = True
            x[transfer_index] = x0[transfer_index]

    return x

# Constants for LLaDA model
MASK_ID = 126336  # The token id of [MASK] in LLaDA tokenizer

print("LLaDA generate function loaded!")


In [None]:
# Sample prompts for testing different LLaDA strategies
sample_prompts = [
    "What do you think the future of artificial intelligence will look like?",
    "Can you write me a short science fiction story about space exploration?",
    "I'm struggling to understand quantum computing. Can you explain it in simple terms?",
    "Could you help me write a creative story about a robot learning to feel emotions?",
    "What are some practical solutions we could implement to address climate change?",
    "How does machine learning actually work under the hood?",
    "I need help planning a healthy meal prep routine for the week",
    "What's the best way to learn a new programming language as a beginner?"
]

print("Sample prompts loaded:")
for i, prompt in enumerate(sample_prompts, 1):
    print(f"{i}. {prompt}")


In [None]:
def llada_generate_with_chat_template(prompt_text: str, **kwargs):
    """
    Helper function to generate text using LLaDA with proper chat formatting
    """
    # Format prompt for instruct model
    messages = [{"role": "user", "content": prompt_text}]
    formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    
    # Tokenize input
    input_ids = tokenizer(formatted_prompt, return_tensors="pt")['input_ids'].to(device)
    
    # Generate using LLaDA
    output = generate(model, input_ids, **kwargs)
    
    # Decode only the generated part (excluding input)
    generated_text = tokenizer.batch_decode(output[:, input_ids.shape[1]:], skip_special_tokens=True)[0]
    
    return {
        'input': formatted_prompt,
        'output': generated_text,
        'full_sequence': tokenizer.batch_decode(output, skip_special_tokens=True)[0]
    }

# Test basic generation
test_prompt = sample_prompts[0]
print(f"Testing prompt: '{test_prompt}'")

result = llada_generate_with_chat_template(
    test_prompt, 
    steps=64, 
    gen_length=64, 
    block_length=32, 
    temperature=0.0,
    remasking='low_confidence'
)

print(f"\nGenerated: {result['output']}")


In [None]:
# Explore different LLaDA sampling strategies
def explore_llada_strategies(prompt_text: str):
    """
    Compare different parameter settings for LLaDA generation
    """
    print(f"\n{'='*80}")
    print(f"EXPLORING LLaDA STRATEGIES FOR: '{prompt_text}'")
    print(f"{'='*80}")
    
    strategies = [
        {
            'name': 'Deterministic (temp=0.0)',
            'params': {'steps': 128, 'gen_length': 64, 'block_length': 32, 'temperature': 0.0, 'remasking': 'low_confidence'}
        },
        {
            'name': 'Low Temperature (temp=0.5)',
            'params': {'steps': 64, 'gen_length': 64, 'block_length': 16, 'temperature': 0.5, 'remasking': 'low_confidence'}
        },
        {
            'name': 'Random Remasking',
            'params': {'steps': 128, 'gen_length': 64, 'block_length': 32, 'temperature': 0.0, 'remasking': 'random'}
        },
        {
            'name': 'Semi-Autoregressive (small blocks)',
            'params': {'steps': 32, 'gen_length': 64, 'block_length': 8, 'temperature': 0.0, 'remasking': 'low_confidence'}
        },
        {
            'name': 'With CFG Guidance',
            'params': {'steps': 128, 'gen_length': 64, 'block_length': 32, 'temperature': 0.0, 'cfg_scale': 1.5, 'remasking': 'low_confidence'}
        }
    ]
    
    results = {}
    
    for strategy in strategies:
        print(f"\n🔸 {strategy['name']}")
        try:
            result = llada_generate_with_chat_template(prompt_text, **strategy['params'])
            results[strategy['name']] = result['output']
            print(f"✅ Output: {result['output'][:120]}{'...' if len(result['output']) > 120 else ''}")
        except Exception as e:
            print(f"❌ Error: {str(e)}")
            results[strategy['name']] = f"Error: {str(e)}"
    
    return results

# Test with different prompts
for prompt in sample_prompts[:2]:  # Test first 2 prompts
    explore_llada_strategies(prompt)


In [None]:
# Parameter Analysis: Understanding LLaDA's key parameters
def analyze_block_length_effect():
    """
    Analyze how block_length affects generation patterns
    """
    print("\n" + "="*60)
    print("BLOCK LENGTH ANALYSIS")
    print("="*60)
    
    prompt = "Explain how machine learning works"
    block_lengths = [8, 16, 32, 64]
    
    for block_length in block_lengths:
        print(f"\n🔹 Block Length: {block_length}")
        try:
            result = llada_generate_with_chat_template(
                prompt,
                steps=64,
                gen_length=64,
                block_length=block_length,
                temperature=0.0,
                remasking='low_confidence'
            )
            print(f"Output: {result['output'][:100]}...")
        except Exception as e:
            print(f"Error: {str(e)}")

def analyze_temperature_effect():
    """
    Analyze how temperature affects generation diversity
    """
    print("\n" + "="*60)
    print("TEMPERATURE ANALYSIS")
    print("="*60)
    
    prompt = "Write a creative story about"
    temperatures = [0.0, 0.3, 0.7, 1.0]
    
    for temp in temperatures:
        print(f"\n🔹 Temperature: {temp}")
        try:
            result = llada_generate_with_chat_template(
                prompt,
                steps=64,
                gen_length=64,
                block_length=32,
                temperature=temp,
                remasking='low_confidence'
            )
            print(f"Output: {result['output'][:100]}...")
        except Exception as e:
            print(f"Error: {str(e)}")

# Run analyses
analyze_block_length_effect()
analyze_temperature_effect()


In [None]:
# Advanced LLaDA Features: CFG and Remasking Strategies
def compare_remasking_strategies():
    """
    Compare low_confidence vs random remasking strategies
    """
    print("\n" + "="*60)
    print("REMASKING STRATEGY COMPARISON")
    print("="*60)
    
    prompt = "The benefits of renewable energy include"
    
    strategies = ['low_confidence', 'random']
    
    for strategy in strategies:
        print(f"\n🔹 Remasking: {strategy}")
        try:
            result = llada_generate_with_chat_template(
                prompt,
                steps=128,
                gen_length=80,
                block_length=40,
                temperature=0.0,
                remasking=strategy
            )
            print(f"Output: {result['output']}")
        except Exception as e:
            print(f"Error: {str(e)}")

def explore_cfg_guidance():
    """
    Explore Classifier-Free Guidance effects
    """
    print("\n" + "="*60)
    print("CLASSIFIER-FREE GUIDANCE ANALYSIS")
    print("="*60)
    
    prompt = "In the future, artificial intelligence will"
    cfg_scales = [0.0, 1.0, 2.0, 3.0]
    
    for cfg_scale in cfg_scales:
        print(f"\n🔹 CFG Scale: {cfg_scale}")
        try:
            result = llada_generate_with_chat_template(
                prompt,
                steps=64,
                gen_length=64,
                block_length=32,
                temperature=0.0,
                cfg_scale=cfg_scale,
                remasking='low_confidence'
            )
            print(f"Output: {result['output'][:120]}...")
        except Exception as e:
            print(f"Error: {str(e)}")

# Run advanced feature exploration
compare_remasking_strategies()
explore_cfg_guidance()


In [None]:
# Practical Examples: Using LLaDA for Different Tasks
def demonstrate_practical_applications():
    """
    Show LLaDA applied to different types of tasks
    """
    print("\n" + "="*80)
    print("PRACTICAL LLaDA APPLICATIONS")
    print("="*80)
    
    applications = [
        {
            'task': 'Question Answering',
            'prompt': 'What are the main causes of climate change?',
            'params': {'steps': 128, 'gen_length': 100, 'block_length': 25, 'temperature': 0.0}
        },
        {
            'task': 'Creative Writing',
            'prompt': 'Write a short poem about the ocean',
            'params': {'steps': 64, 'gen_length': 80, 'block_length': 20, 'temperature': 0.7}
        },
        {
            'task': 'Code Explanation',
            'prompt': 'Explain what this Python function does: def fibonacci(n): return n if n <= 1 else fibonacci(n-1) + fibonacci(n-2)',
            'params': {'steps': 96, 'gen_length': 120, 'block_length': 30, 'temperature': 0.2}
        },
        {
            'task': 'Mathematical Problem',
            'prompt': 'Solve this step by step: If a train travels 120 km in 2 hours, what is its average speed?',
            'params': {'steps': 128, 'gen_length': 100, 'block_length': 25, 'temperature': 0.0}
        }
    ]
    
    for app in applications:
        print(f"\n🎯 {app['task']}")
        print(f"Prompt: {app['prompt'][:60]}...")
        try:
            result = llada_generate_with_chat_template(
                app['prompt'], 
                remasking='low_confidence',
                **app['params']
            )
            print(f"Output: {result['output']}")
        except Exception as e:
            print(f"Error: {str(e)}")
        print("-" * 40)

demonstrate_practical_applications()


In [None]:
# Summary and Tips for LLaDA Usage
print("\n" + "="*80)
print("🚀 LLaDA DIFFUSION LANGUAGE MODEL - COMPLETE IMPLEMENTATION")
print("="*80)

print("""
✅ IMPLEMENTED FEATURES:

1. 🎯 Core LLaDA Functions:
   • add_gumbel_noise() - Gumbel sampling for categorical distributions
   • get_num_transfer_tokens() - Linear noise schedule implementation
   • generate() - Full LLaDA generation with all features

2. 🔧 Key Parameters:
   • steps: Number of diffusion steps (32-128)
   • gen_length: Output sequence length
   • block_length: Semi-autoregressive block size
   • temperature: Sampling randomness (0.0-1.0+)
   • cfg_scale: Classifier-free guidance strength
   • remasking: 'low_confidence' or 'random'

3. 🎨 Sampling Strategies:
   • Deterministic (temp=0.0) - Consistent outputs
   • Stochastic (temp>0.0) - Creative/diverse outputs  
   • CFG Guidance - Enhanced instruction following
   • Semi-autoregressive - Faster generation with blocks

4. 📊 Analysis Tools:
   • Parameter comparison functions
   • Practical application examples
   • Performance analysis across tasks

💡 USAGE TIPS:

• For factual Q&A: Use temp=0.0, low block_length
• For creative tasks: Use temp=0.5-1.0, larger blocks  
• For instruction following: Add cfg_scale=1.5-3.0
• For speed: Reduce steps, increase block_length
• For quality: Increase steps, use 'low_confidence' remasking

🔗 Based on LLaDA paper: arXiv:2409.02908
""")

print("="*80)
print("🎉 Ready to explore diffusion language generation!")
print("="*80)
