# dInfer Soft Token Experimentation Notebook

In [1]:
# Add dInfer to Python path from the 3rdparty submodule
import sys
import os

# Add the dInfer python directory to sys.path
DINFER_PATH = os.path.abspath('3rdparty/dInfer/python')
if os.path.exists(DINFER_PATH):
    if DINFER_PATH not in sys.path:
        sys.path.insert(0, DINFER_PATH)
    print(f"✓ Added dInfer to Python path: {DINFER_PATH}")
else:
    print(f"⚠ Warning: dInfer path not found: {DINFER_PATH}")
    print("  Make sure you're running this notebook from the project root.")

# Clear any cached imports to ensure we use the latest code
import importlib
for module_name in list(sys.modules.keys()):
    if 'dinfer' in module_name:
        del sys.modules[module_name]
print("✓ Cleared cached dinfer modules")

# Verify the import works
try:
    import dinfer
    print(f"✓ dInfer module found at: {dinfer.__file__}")
except ImportError as e:
    print(f"✗ Failed to import dInfer: {e}")

# Note: You can also use the import utilities from xp/llada_api/llada_generate/dinfer/_imports.py
# which provides the same functionality with additional error handling

✓ Added dInfer to Python path: /lustre/fsw/portfolios/llmservice/users/mfathi/codebases/nemo-rl/3rdparty/dInfer/python
✓ Cleared cached dinfer modules
✓ dInfer module found at: /lustre/fsw/portfolios/llmservice/users/mfathi/codebases/nemo-rl/3rdparty/dInfer/python/dinfer/__init__.py


In [2]:
# Import necessary libraries
import torch
import numpy as np
import torch.nn.functional as F
import os
import time
from transformers import AutoTokenizer, AutoModel

# Disable torch compilation to avoid backend compiler errors
os.environ['TORCH_COMPILE_DISABLE'] = '1'
os.environ['TORCHDYNAMO_DISABLE'] = '1'

# Disable torch.compile globally
torch._dynamo.config.disable = True

# dinfer imports (using local dInfer from 3rdparty/dInfer/python added to path above)
from dinfer.model import LLaDAModelLM
from dinfer.decoding.parallel_strategy import (
    ParallelDecoder,
    ThresholdParallelDecoder,
    CreditThresholdParallelDecoder,
    HierarchyDecoder,
    get_num_transfer_tokens,
    get_transfer_index,
)
from dinfer import (
    BlockWiseDiffusionLLM,
    BlockIteratorFactory,
    KVCacheFactory,
    SlidingWindowDiffusionLLM,
)

from dinfer.decoding.utils import (
    TokenArray,
)

print("Torch compilation disabled to avoid backend issues")

# LLaDA tokenizer constants
MASK_ID = 126336
EOS_ID = 126081

# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


  from .autonotebook import tqdm as notebook_tqdm


INFO 11-20 13:22:44 [__init__.py:216] Automatically detected platform cuda.
Torch compilation disabled to avoid backend issues
Using device: cuda


In [3]:
# Model configuration
MODEL_PATH = "GSAI-ML/LLaDA-8B-Instruct"  # Update this path to your model

# Load model
print(f"Loading model from {MODEL_PATH}...")
model = LLaDAModelLM.from_pretrained(
    MODEL_PATH, 
    trust_remote_code=True,
    torch_dtype=torch.bfloat16, 
    init_device=str(device)  # Convert device to string for JSON serialization
).eval()
model = model.to(device)

# Optional: Compile the model for better performance
# model = torch.compile(model, mode='reduce-overhead', fullgraph=True)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
print("Model and tokenizer loaded successfully!")


Loading model from GSAI-ML/LLaDA-8B-Instruct...


Fetching 6 files: 100%|██████████| 6/6 [00:12<00:00,  2.10s/it]
Loading checkpoint shards: 100%|██████████| 6/6 [00:00<00:00, 195.02it/s]


Model and tokenizer loaded successfully!


In [4]:
# Generation parameters - EXPERIMENT WITH THESE!
generation_config = {
    "gen_length": 256,      # Maximum number of tokens to generate
    "steps": 64,
    "block_length": 64,     # Block size for parallel decoding
    "threshold": 0.9,       # Confidence threshold for token acceptance
    "cache_type": "dual",   # Options: None, "prefix", "dual"
    "early_stop": True,     # Stop at EOS token
    "maximum_unroll": 4,    # Maximum unroll steps
    "expected_tpf": 8,      # Expected tokens per forward pass
}

# Create decoder with threshold strategy
decoder = ThresholdParallelDecoder(0, threshold=generation_config["threshold"], mask_id=MASK_ID, eos_id=EOS_ID)

# Alternative: Use FixedParallelDecoderWithEOS for fixed-step decoding
# decoder = FixedParallelDecoderWithEOS(
#     temperature=0,
#     steps=generation_config["steps"],
#     mask_id=MASK_ID,
#     eos_id=EOS_ID
# )
# print("Using FixedParallelDecoderWithEOS")

# Create iterator factory
iterator_factory = BlockIteratorFactory(True)

# Create KV cache factory if using caching
cache_factory = KVCacheFactory(generation_config["cache_type"]) if generation_config["cache_type"] else None

# Create the Diffusion LLM instance
dllm = BlockWiseDiffusionLLM(
    model=model,
    decoder=decoder,
    iterator_factory=iterator_factory,
    cache_factory=cache_factory,
    early_stop=generation_config["early_stop"],
    maximum_unroll=generation_config["maximum_unroll"],
    expected_tpf=generation_config["expected_tpf"]
)

print("Decoder and generation pipeline configured!")


Decoder and generation pipeline configured!


In [5]:
@torch.no_grad()
def generate_text(prompt, dllm_instance=dllm, tokenizer=tokenizer, config=generation_config):
    """
    Generate text using the diffusion LLM.
    
    Args:
        prompt: Input text prompt
        dllm_instance: Diffusion LLM instance
        tokenizer: Tokenizer instance
        config: Generation configuration dict
    
    Returns:
        Generated text string
    """
    # Tokenize the prompt
    input_ids = tokenizer(prompt, return_tensors="pt")['input_ids'].to(device)
    prompt_length = input_ids.shape[1]
    
    print(f"Prompt length: {prompt_length} tokens")
    print(f"Generating up to {config['gen_length']} tokens...")
    
    # Track statistics
    prev_forwards = dllm_instance.num_forwards
    prev_cache_updates = dllm_instance.cache_updates
    
    # Generate
    start_time = time.time()
    output_ids = dllm_instance.generate(
        input_ids, 
        gen_length=config['gen_length'], 
        block_length=config['block_length']
    )
    end_time = time.time()
    
    # Calculate statistics
    total_forwards = dllm_instance.num_forwards - prev_forwards
    total_cache_updates = dllm_instance.cache_updates - prev_cache_updates
    generated_tokens = output_ids.shape[1] - prompt_length
    generation_time = end_time - start_time
    
    # Decode output
    generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=False)
    
    # Print statistics
    print(f"\nGeneration Statistics:")
    print(f"- Generated tokens: {generated_tokens}")
    print(f"- Forward passes: {total_forwards}")
    print(f"- Cache updates: {total_cache_updates}")
    print(f"- Time: {generation_time:.2f}s")
    print(f"- Tokens/second: {generated_tokens/generation_time:.2f}")
    print(f"- Forwards/second: {total_forwards/generation_time:.2f}")
    
    return generated_text

# Test the generation function
test_prompt = "Once upon a time in a magical forest"
generated = generate_text(test_prompt)
print(f"\nGenerated text:\n{generated}")


Prompt length: 8 tokens
Generating up to 256 tokens...



Generation Statistics:
- Generated tokens: 255
- Forward passes: 236
- Cache updates: 4
- Time: 7.14s
- Tokens/second: 35.74
- Forwards/second: 33.08

Generated text:
Once upon a time in a magical forest, there lived a young girl named Lily. She was known for her kindness and bravery. One day, while exploring the forest, she stumbled upon a mysterious cave. As she entered the cave, she found a hidden room filled with sparkling crystals. The room was filled with a soft light, and Lily felt a strange sensation wash over her. Suddenly, she heard a gentle voice calling her name. She turned to see a wise old owl perched on a nearby branch. "Welcome, Lily," the owl said. "You have been chosen to be the guardian of the forest."

Lily was amazed and scared. She had never been chosen for such a responsibility before. The owl explained that she had been chosen because of her kindness and bravery. She would have to protect the forest from any harm and make sure that the creatures that lived ther

In [6]:
# Interactive prompt examples
prompts = [
    "The future of artificial intelligence is",
    "Explain quantum computing in simple terms:",
    "Write a haiku about programming:",
    "The most important scientific discovery was",
]

# Generate for each prompt
for prompt in prompts[:1]:  # Change to prompts to run all
    print(f"\n{'='*80}")
    print(f"PROMPT: {prompt}")
    print(f"{'='*80}")
    generated = generate_text(prompt)
    print(f"\nGENERATED:\n{generated}")
    print(f"{'='*80}\n")



PROMPT: The future of artificial intelligence is
Prompt length: 6 tokens
Generating up to 256 tokens...



Generation Statistics:
- Generated tokens: 9
- Forward passes: 36
- Cache updates: 1
- Time: 1.04s
- Tokens/second: 8.64
- Forwards/second: 34.57

GENERATED:
The future of artificial intelligence is bright, and the possibilities are endless.<|eot_id|>



In [7]:
# Experiment with different threshold values
thresholds = [0.7, 0.8, 0.9, 0.95]
test_prompt = "The meaning of life is"

for threshold in thresholds:
    print(f"\n{'='*80}")
    print(f"Testing with threshold = {threshold}")
    print(f"{'='*80}")
    
    # Create new decoder with different threshold
    test_decoder = ThresholdParallelDecoder(0, threshold=threshold)
    
    # Create new DLLM instance
    test_dllm = BlockWiseDiffusionLLM(
        model=model,
        decoder=test_decoder,
        iterator_factory=iterator_factory,
        cache_factory=cache_factory,
        early_stop=True
    )
    
    # Generate and compare
    generated = generate_text(test_prompt, dllm_instance=test_dllm)
    print(f"\nGenerated: {generated[len(test_prompt):]}")  # Show only generated part



Testing with threshold = 0.7
Prompt length: 5 tokens
Generating up to 256 tokens...



Generation Statistics:
- Generated tokens: 130
- Forward passes: 100
- Cache updates: 3
- Time: 2.89s
- Tokens/second: 45.05
- Forwards/second: 34.65

Generated:  a deeply personal and subjective question, and it can vary from person to person. For some, the meaning of life may be to find happiness, love, and fulfillment. For others, it may be to make a positive impact on the world or to achieve a sense of purpose and meaning.

It's important to remember that the meaning of life is not a fixed or absolute concept. Rather, it's something that can be explored and evolved over time, as we grow, learn, and have new experiences.

Ultimately, the meaning of life is something that each person must discover for themselves, through exploration, reflection, and experience.<|eot_id|>

Testing with threshold = 0.8
Prompt length: 5 tokens
Generating up to 256 tokens...

Generation Statistics:
- Generated tokens: 130
- Forward passes: 113
- Cache updates: 3
- Time: 3.26s
- Tokens/second: 39.91
- Fo

In [8]:
# Create sliding window DLLM
sliding_dllm = SlidingWindowDiffusionLLM(
    model=model,
    decoder=decoder,
    iterator_factory=iterator_factory,
    cache_factory=KVCacheFactory('dual'),  # Sliding window requires cache
    prefix_look=0,      # How many tokens to look back
    after_look=0,       # How many tokens to look ahead
    warmup_steps=1,     # Warmup iterations
    early_stop=True
)

# Compare block-wise vs sliding window
test_prompt = "Artificial intelligence will revolutionize"

print("Block-wise generation:")
print("="*80)
block_generated = generate_text(test_prompt, dllm_instance=dllm)

print("\n\nSliding window generation:")
print("="*80)
sliding_generated = generate_text(test_prompt, dllm_instance=sliding_dllm)

print("\n\nComparison:")
print(f"Block-wise output: {block_generated[len(test_prompt):][:100]}...")
print(f"Sliding window output: {sliding_generated[len(test_prompt):][:100]}...")


Block-wise generation:
Prompt length: 5 tokens
Generating up to 256 tokens...



Generation Statistics:
- Generated tokens: 70
- Forward passes: 83
- Cache updates: 2
- Time: 2.40s
- Tokens/second: 29.12
- Forwards/second: 34.53


Sliding window generation:
Prompt length: 5 tokens
Generating up to 256 tokens...

Generation Statistics:
- Generated tokens: 67
- Forward passes: 74
- Cache updates: 2
- Time: 2.16s
- Tokens/second: 31.07
- Forwards/second: 34.32


Comparison:
Block-wise output:  the way we live, work, and interact with each other. It will enable us to solve complex problems, i...
Sliding window output:  the way we live, work, and interact with each other. It will enable us to be more efficient, produc...


In [9]:
class _FixedParallelDecoder(ParallelDecoder):
    """ This decoder decodes tokens in a fixed number of steps.
    """
    def __init__(self, temperature, steps, remasking='low_confidence', mask_id=MASK_ID, eos_id=EOS_ID):
        super().__init__(temperature, remasking, mask_id)
        self.steps = steps
        self.iter = 0
        self.eos_id = eos_id

    def block_init(self, block_x, block_id):
        # TODO(zhengda) we need to handle steps correctly here when the distributed version changes the gen length.
        block_mask_index = block_x == self.mask_id
        self.num_transfer_tokens = get_num_transfer_tokens(block_mask_index, self.steps)
        self.iter = 0

    def decode(self, logits, block_start, block_end, x, iter_threshold = None):
        """ Decode the logits in a block.
        """
        mask_index = (x[block_start:block_end] == self.mask_id)
        assert mask_index.shape[1] == logits.shape[1]

        curr_x = x[block_start:block_end]
        x0, transfer_index = get_transfer_index(logits, self.temperature, self.remasking, mask_index, curr_x, self.num_transfer_tokens[:, self.iter], None)
        self.iter += 1
        x[block_start:block_end][transfer_index] = x0[transfer_index]

In [10]:
# Experiment with FixedParallelDecoder if available
print("Testing FixedParallelDecoder with different fixed ratios...")

test_prompt = "The key to innovation is"

# Create FixedParallelDecoder
fixed_decoder = _FixedParallelDecoder(
    0, steps=generation_config["steps"]
)

# Create DLLM with fixed decoder
fixed_dllm = BlockWiseDiffusionLLM(
    model=model,
    decoder=fixed_decoder,
    iterator_factory=iterator_factory,
    cache_factory=cache_factory,
    early_stop=True
)

# Generate
generated = generate_text(test_prompt, dllm_instance=fixed_dllm)
print(f"\nGenerated: {generated[len(test_prompt):][:100]}...")

Testing FixedParallelDecoder with different fixed ratios...
Prompt length: 5 tokens
Generating up to 256 tokens...

Generation Statistics:
- Generated tokens: 18
- Forward passes: 64
- Cache updates: 1
- Time: 1.93s
- Tokens/second: 9.32
- Forwards/second: 33.13

Generated:  to be open to new ideas and to be willing to take risks. It's not about...


# Soft Token Diffusion Sampler

In [11]:
class BlockWiseSoftTokenLLM:
    def __init__(self, model, decoder, iterator_factory, early_stop=True, cache_factory=None, maximum_unroll=4, expected_tpf=8, soft_token_ratio=0.2):
        self.model = model
        self.cache_factory = cache_factory
        self.decoder = decoder
        self.iterator_factory = iterator_factory
        self.num_forwards = 0
        self.cache_updates = 0
        self.early_stop = early_stop
        self.maximum_unroll = maximum_unroll
        self.expected_tpf = expected_tpf
        self.soft_token_ratio = soft_token_ratio
        self.input_embeddings = self.model.get_input_embeddings()
        
    @torch.no_grad()
    def generate(self, prompt, gen_length=128, block_length=128):
        ''' Generate tokens with diffusion iterations block by block using Soft Token Sampling.
        '''
        x = TokenArray(prompt, gen_length, self.decoder.mask_id, self.decoder.eos_id, self.model.device)
        it = self.iterator_factory.create(x, block_length)

        iter_no = 0
        kv_cache = self.cache_factory.create() if self.cache_factory is not None else None
        
        for block_id, (block_loc, block) in enumerate(it):
            self.decoder.block_init(block, block_id)

            while (block == self.decoder.mask_id).sum() > 0:
                
                # Helper to run model with correct context and embeddings
                def run_model(use_input_embeds=None):
                    # Determine input context based on cache type
                    if kv_cache is None:
                        # Full context (no cache)
                        if use_input_embeds is not None:
                            logits = self.model(inputs_embeds=use_input_embeds).logits
                        else:
                            logits = self.model(x.data).logits
                        return logits[:, block_loc.start:block_loc.end]
                        
                    elif kv_cache.cache_type == 'prefix':
                        # Prefix Cache: past_key_values contains context up to block_start
                        past_key_values, replace_position = kv_cache.get_key_values(block_loc.start, block_loc.end)
                        
                        if use_input_embeds is not None:
                            # Input embeddings should correspond to x[block_loc.start:]
                            logits = self.model(inputs_embeds=use_input_embeds, past_key_values=past_key_values, use_cache=True,
                                              replace_position=replace_position).logits
                        else:
                            logits = self.model(x[block_loc.start:], past_key_values=past_key_values, use_cache=True,
                                              replace_position=replace_position).logits
                        
                        curr_len = block_loc.end - block_loc.start
                        return logits[:, :curr_len]

                    else:
                        # Dual/Sliding Cache: typically uses block context
                        past_key_values, replace_position = kv_cache.get_key_values(block_loc.start, block_loc.end)
                        
                        if use_input_embeds is not None:
                             logits = self.model(inputs_embeds=use_input_embeds, past_key_values=past_key_values, use_cache=True,
                                              replace_position=replace_position).logits
                        else:
                             logits = self.model(block, past_key_values=past_key_values, use_cache=True,
                                              replace_position=replace_position).logits
                        return logits

                # 1. Handle KV Cache Update (Initial step for block or periodically)
                if kv_cache is not None and kv_cache.require_update(iter_no, block_loc.start, block_loc.end):
                    output = self.model(x.data, use_cache=True)
                    self.num_forwards += 1
                    
                    # Update cache
                    kv_cache.update(output.past_key_values)
                    self.cache_updates += 1
                    
                    # Decode using these initial logits
                    self.decoder.decode(output.logits[:, block_loc.start:block_loc.end], block_loc.start, block_loc.end, x)
                    iter_no += 1
                    continue

                # 2. Pass 1: Standard Logits (with current masks)
                logits1 = run_model(use_input_embeds=None)
                self.num_forwards += 1
                
                decoding_logits = logits1
                
                # 3. Soft Token Logic
                # Identify masks in the current block
                curr_block_ids = x[block_loc.start:block_loc.end]
                mask_mask = (curr_block_ids == self.decoder.mask_id)
                mask_indices = torch.nonzero(mask_mask).flatten() # Indices relative to block start
                
                if mask_indices.numel() > 0 and self.soft_token_ratio > 0:
                    # Sample masks to turn into soft tokens
                    num_soft = int(mask_indices.numel() * self.soft_token_ratio)
                    if num_soft > 0:
                        perm = torch.randperm(mask_indices.numel(), device=self.model.device)
                        soft_indices = mask_indices[perm[:num_soft]] # Indices relative to block start
                        
                        # Extract logits for these positions
                        # logits1 shape: [1, block_len, vocab]
                        selected_logits = logits1[0, soft_indices]
                        probs = torch.softmax(selected_logits, dim=-1)
                        
                        # Compute Soft Embeddings: Weighted average of token embeddings
                        # [num_soft, vocab] @ [vocab, d_model] -> [num_soft, d_model]
                        soft_embeds = torch.matmul(probs, self.input_embeddings.weight)
                        
                        # Prepare Input Embeddings
                        target_ids = None
                        global_offset = 0
                        
                        if kv_cache is None:
                            target_ids = x.data
                            global_offset = block_loc.start # Offset in target_ids
                        elif kv_cache.cache_type == 'prefix':
                            target_ids = x[block_loc.start:]
                            global_offset = 0 # relative to start of target_ids
                        else:
                            target_ids = curr_block_ids
                            global_offset = 0
                        
                        # Get base embeddings for the input context
                        # We clone to avoid modifying the model's internal cache if any (unlikely here)
                        inputs_embeds = self.input_embeddings(target_ids).clone() # [1, len, d_model]
                        
                        # Replace masks with soft embeddings
                        inputs_embeds[0, global_offset + soft_indices] = soft_embeds
                        
                        # Pass 2: Get logits with Soft Tokens
                        logits2 = run_model(use_input_embeds=inputs_embeds)
                        self.num_forwards += 1
                        decoding_logits = logits2

                # 4. Decode using the latest logits
                self.decoder.decode(decoding_logits, block_loc.start, block_loc.end, x)
                iter_no += 1

            # Early stop at EOS
            if self.early_stop and torch.any(x[block_loc.start:block_loc.end] == self.decoder.eos_id):
                x[block_loc.end:] = self.decoder.eos_id
                break

        return x.get_generated_tokens()

In [12]:
generation_config = {
    "gen_length": 1024,      # Maximum number of tokens to generate
    "steps": 16,
    "block_length": 128,     # Block size for parallel decoding
    "threshold": 0.9,       # Confidence threshold for token acceptance
    "cache_type": "dual",   # Options: None, "prefix", "dual"
    "early_stop": True,     # Stop at EOS token
    "maximum_unroll": 4,    # Maximum unroll steps
    "expected_tpf": 8,      # Expected tokens per forward pass
}

test_prompt = "About the theory of relativity: "

# Create FixedParallelDecoder
fixed_decoder = _FixedParallelDecoder(
    0, steps=generation_config["steps"]
)

soft_llm = BlockWiseSoftTokenLLM(
    model=model,
    decoder=fixed_decoder,
    iterator_factory=iterator_factory,
    cache_factory=cache_factory,
    early_stop=True,
    soft_token_ratio=0.5,
)

block_wise_llm = BlockWiseDiffusionLLM(
    model=model,
    decoder=fixed_decoder,
    iterator_factory=iterator_factory,
    cache_factory=cache_factory,
    early_stop=True,
)

# Generate
print("Generating with Soft Token LLM...")
generated = generate_text(test_prompt, dllm_instance=fixed_dllm, config=generation_config)
print(f"\nGenerated: {generated[len(test_prompt):]}...")

print("--------------------------------")

print("Generating with Block Wise LLM...")
generated = generate_text(test_prompt, dllm_instance=block_wise_llm, config=generation_config)
print(f"\nGenerated: {generated[len(test_prompt):]}...")

Generating with Soft Token LLM...
Prompt length: 7 tokens
Generating up to 1024 tokens...

Generation Statistics:
- Generated tokens: 140
- Forward passes: 128
- Cache updates: 2
- Time: 8.72s
- Tokens/second: 16.05
- Forwards/second: 14.68

Generated: 

The theory of relativity is proposed by Albert Einstein. It is a  two-part theory: Special
Relativity and General Relativity.

Special Relativity:

1. The speed of light is always constant and is independent of the observer's motion.
2. The laws of physics are the same for all observers in uniform motion.
3. Time and relative motion are inter.
4. Length and time time change for an object in motion relative to an observer.


General Relativity:

1. Gravity is not a force but rather a curvature of spacetime caused by mass and energy.
2. Objects move along the curvature of spacetime.<|eot_id|>...
--------------------------------
Generating with Block Wise LLM...
Prompt length: 7 tokens
Generating up to 1024 tokens...

Generation Statistic