# dInfer Soft Token Experimentation Notebook

In [1]:
# Add dInfer to Python path from the 3rdparty submodule
import sys
import os

# Add the dInfer python directory to sys.path
DINFER_PATH = os.path.abspath('3rdparty/dInfer/python')
if os.path.exists(DINFER_PATH):
    if DINFER_PATH not in sys.path:
        sys.path.insert(0, DINFER_PATH)
    print(f"✓ Added dInfer to Python path: {DINFER_PATH}")
else:
    print(f"⚠ Warning: dInfer path not found: {DINFER_PATH}")
    print("  Make sure you're running this notebook from the project root.")

# Clear any cached imports to ensure we use the latest code
import importlib
for module_name in list(sys.modules.keys()):
    if 'dinfer' in module_name:
        del sys.modules[module_name]
print("✓ Cleared cached dinfer modules")

# Verify the import works
try:
    import dinfer
    print(f"✓ dInfer module found at: {dinfer.__file__}")
except ImportError as e:
    print(f"✗ Failed to import dInfer: {e}")

# Note: You can also use the import utilities from xp/llada_api/llada_generate/dinfer/_imports.py
# which provides the same functionality with additional error handling

✓ Added dInfer to Python path: /lustre/fsw/portfolios/llmservice/users/mfathi/codebases/nemo-rl/3rdparty/dInfer/python
✓ Cleared cached dinfer modules
✓ dInfer module found at: /lustre/fsw/portfolios/llmservice/users/mfathi/codebases/nemo-rl/3rdparty/dInfer/python/dinfer/__init__.py


In [2]:
# Import necessary libraries
import torch
import numpy as np
import torch.nn.functional as F
import os
import time
from transformers import AutoTokenizer, AutoModel

# Disable torch compilation to avoid backend compiler errors
os.environ['TORCH_COMPILE_DISABLE'] = '1'
os.environ['TORCHDYNAMO_DISABLE'] = '1'

# Disable torch.compile globally
torch._dynamo.config.disable = True

# dinfer imports (using local dInfer from 3rdparty/dInfer/python added to path above)
from dinfer.model import LLaDAModelLM
from dinfer.decoding.parallel_strategy import (
    ParallelDecoder,
    ThresholdParallelDecoder,
    CreditThresholdParallelDecoder,
    HierarchyDecoder,
    get_num_transfer_tokens,
    get_transfer_index,
)
from dinfer import (
    BlockWiseDiffusionLLM,
    BlockIteratorFactory,
    KVCacheFactory,
    SlidingWindowDiffusionLLM,
)

from dinfer.decoding.utils import (
    TokenArray,
)

print("Torch compilation disabled to avoid backend issues")

# LLaDA tokenizer constants
MASK_ID = 126336
EOS_ID = 126081

# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


  from .autonotebook import tqdm as notebook_tqdm


INFO 11-24 06:33:17 [__init__.py:216] Automatically detected platform cuda.
Torch compilation disabled to avoid backend issues
Using device: cuda


In [3]:
# Model configuration
MODEL_PATH = "GSAI-ML/LLaDA-8B-Base"  # Update this path to your model
# MODEL_PATH = "GSAI-ML/LLaDA-8B-Instruct"  # Update this path to your model
# MODEL_PATH = "GSAI-ML/LLaDA-1.5"

# Load model
print(f"Loading model from {MODEL_PATH}...")
model = LLaDAModelLM.from_pretrained(
    MODEL_PATH, 
    trust_remote_code=True,
    torch_dtype=torch.bfloat16, 
    init_device=str(device)  # Convert device to string for JSON serialization
).eval()
model = model.to(device)

# Optional: Compile the model for better performance
# model = torch.compile(model, mode='reduce-overhead', fullgraph=True)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
print("Model and tokenizer loaded successfully!")


Loading model from GSAI-ML/LLaDA-8B-Base...


Fetching 6 files: 100%|██████████| 6/6 [00:23<00:00,  3.95s/it]
Loading checkpoint shards: 100%|██████████| 6/6 [00:00<00:00, 295.45it/s]


Model and tokenizer loaded successfully!


In [4]:
# Generation parameters - EXPERIMENT WITH THESE!
generation_config = {
    "gen_length": 256,      # Maximum number of tokens to generate
    "steps": 64,
    "block_length": 64,     # Block size for parallel decoding
    "threshold": 0.9,       # Confidence threshold for token acceptance
    "cache_type": "dual",   # Options: None, "prefix", "dual"
    "early_stop": True,     # Stop at EOS token
    "maximum_unroll": 4,    # Maximum unroll steps
    "expected_tpf": 8,      # Expected tokens per forward pass
}

# Create decoder with threshold strategy
decoder = ThresholdParallelDecoder(0, threshold=generation_config["threshold"], mask_id=MASK_ID, eos_id=EOS_ID)

# Alternative: Use FixedParallelDecoderWithEOS for fixed-step decoding
# decoder = FixedParallelDecoderWithEOS(
#     temperature=0,
#     steps=generation_config["steps"],
#     mask_id=MASK_ID,
#     eos_id=EOS_ID
# )
# print("Using FixedParallelDecoderWithEOS")

# Create iterator factory
iterator_factory = BlockIteratorFactory(True)

# Create KV cache factory if using caching
cache_factory = KVCacheFactory(generation_config["cache_type"]) if generation_config["cache_type"] else None

# Create the Diffusion LLM instance
dllm = BlockWiseDiffusionLLM(
    model=model,
    decoder=decoder,
    iterator_factory=iterator_factory,
    cache_factory=cache_factory,
    early_stop=generation_config["early_stop"],
    maximum_unroll=generation_config["maximum_unroll"],
    expected_tpf=generation_config["expected_tpf"]
)

print("Decoder and generation pipeline configured!")


Decoder and generation pipeline configured!


In [None]:
@torch.no_grad()
def generate_text(prompt, dllm_instance=dllm, tokenizer=tokenizer, config=generation_config, apply_chat_template=True):
    """
    Generate text using the diffusion LLM.
    
    Args:
        prompt: Input text prompt
        dllm_instance: Diffusion LLM instance
        tokenizer: Tokenizer instance
        config: Generation configuration dict
    
    Returns:
        Generated text string
    """
    # Tokenize the prompt
    if apply_chat_template:
        message = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False)
    else:
        message = prompt
    input_ids = tokenizer(message, return_tensors="pt")['input_ids'].to(device)
    prompt_length = input_ids.shape[1]
    
    print(f"Prompt length: {prompt_length} tokens")
    print(f"Generating up to {config['gen_length']} tokens...")
    
    # Track statistics
    prev_forwards = dllm_instance.num_forwards
    prev_cache_updates = dllm_instance.cache_updates
    
    # Generate
    start_time = time.time()
    output_ids = dllm_instance.generate(
        input_ids, 
        gen_length=config['gen_length'], 
        block_length=config['block_length']
    )
    end_time = time.time()
    
    # Calculate statistics
    total_forwards = dllm_instance.num_forwards - prev_forwards
    total_cache_updates = dllm_instance.cache_updates - prev_cache_updates
    generated_tokens = output_ids.shape[1] - prompt_length
    generation_time = end_time - start_time
    
    # Decode output
    generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=False)
    
    # Print statistics
    print(f"\nGeneration Statistics:")
    print(f"- Generated tokens: {generated_tokens}")
    print(f"- Forward passes: {total_forwards}")
    print(f"- Cache updates: {total_cache_updates}")
    print(f"- Time: {generation_time:.2f}s")
    print(f"- Tokens/second: {generated_tokens/generation_time:.2f}")
    print(f"- Forwards/second: {total_forwards/generation_time:.2f}")
    
    return generated_text

# Test the generation function
test_prompt = "Once upon a time in a magical forest"
generated = generate_text(test_prompt, apply_chat_template=False)
print(f"\nGenerated text:\n{generated}")


Prompt length: 46 tokens
Generating up to 256 tokens...

Generation Statistics:
- Generated tokens: 256
- Forward passes: 242
- Cache updates: 4
- Time: 5.11s
- Tokens/second: 50.14
- Forwards/second: 47.40

Generated text:
<|startoftext|><|start_header_id|>user<|end_header_id|>

Once upon a time in a magical forest<|eot_id|><|start_header_id|>assistant<|end_header_id|>

There was a little girl who lived in a small cottage in the forest. She was a kind and curious girl who loved to explore the woods. One day, she stumbled upon a mysterious cave that seemed to be hidden among the trees. She decided to enter the cave and see what was inside.

As she entered the cave, she was greeted by a warm glow that seemed to emanate from the walls. She continued to explore the cave and found a small room with a large table in the center. The table was covered with books and papers, and the girl realized that she had stumbled upon a magical library.

The girl spent hours reading through the books and 

In [6]:
# Interactive prompt examples
prompts = [
    "The future of artificial intelligence is",
    "Explain quantum computing in simple terms:",
    "Write a haiku about programming:",
    "The most important scientific discovery was",
]

# Generate for each prompt
for prompt in prompts[:1]:  # Change to prompts to run all
    print(f"\n{'='*80}")
    print(f"PROMPT: {prompt}")
    print(f"{'='*80}")
    generated = generate_text(prompt, apply_chat_template=False)
    print(f"\nGENERATED:\n{generated}")
    print(f"{'='*80}\n")



PROMPT: The future of artificial intelligence is
Prompt length: 6 tokens
Generating up to 256 tokens...

Generation Statistics:
- Generated tokens: 256
- Forward passes: 199
- Cache updates: 4
- Time: 3.74s
- Tokens/second: 68.50
- Forwards/second: 53.25

GENERATED:
The future of artificial intelligence is bright, and it is likely to continue to evolve and improve in the coming years. Here are some potential areas where AI could make significant contributions:
1. Healthcare: AI has the potential to revolutionize healthcare by improving diagnosis, treatment, and patient outcomes. For example, AI could be used to analyze medical images and identify diseases earlier and more accurately than humans.
2. Transportation: AI has the potential to improve transportation by reducing congestion, improving safety, and increasing efficiency. For example, AI could be used to optimize traffic flow, predict traffic patterns, and develop autonomous vehicles.
3. Education: AI has the potential to person

In [7]:
# Experiment with different threshold values
thresholds = [0.7, 0.8, 0.9, 0.95]
test_prompt = "The meaning of life is"

for threshold in thresholds:
    print(f"\n{'='*80}")
    print(f"Testing with threshold = {threshold}")
    print(f"{'='*80}")
    
    # Create new decoder with different threshold
    test_decoder = ThresholdParallelDecoder(0, threshold=threshold)
    
    # Create new DLLM instance
    test_dllm = BlockWiseDiffusionLLM(
        model=model,
        decoder=test_decoder,
        iterator_factory=iterator_factory,
        cache_factory=cache_factory,
        early_stop=True
    )
    
    # Generate and compare
    generated = generate_text(test_prompt, dllm_instance=test_dllm, apply_chat_template=False)
    print(f"\nGenerated: {generated[len(test_prompt):]}")  # Show only generated part



Testing with threshold = 0.7
Prompt length: 5 tokens
Generating up to 256 tokens...

Generation Statistics:
- Generated tokens: 256
- Forward passes: 142
- Cache updates: 4
- Time: 2.66s
- Tokens/second: 96.41
- Forwards/second: 53.48

Generated:  a complex and multifaceted question that has been debated by philosophers, scientists, and religious leaders for centuries. While there is no definitive answer to this question, there are several popular theories that have been proposed to explain the purpose of life.

One theory suggests that the purpose of life is to seek happiness and fulfillment. This theory is often associated with the philosophy of hedonism, which holds that the ultimate goal of life is to seek pleasure and happiness. Advocates of this theory argue that the pursuit of happiness is the most important aspect of life, and that individuals should strive to achieve happiness through any means necessary.

Another theory suggests that the purpose of life is to seek knowledge 

In [8]:
# Create sliding window DLLM
sliding_dllm = SlidingWindowDiffusionLLM(
    model=model,
    decoder=decoder,
    iterator_factory=iterator_factory,
    cache_factory=KVCacheFactory('dual'),  # Sliding window requires cache
    prefix_look=0,      # How many tokens to look back
    after_look=0,       # How many tokens to look ahead
    warmup_steps=1,     # Warmup iterations
    early_stop=True
)

# Compare block-wise vs sliding window
test_prompt = "Artificial intelligence will revolutionize"

print("Block-wise generation:")
print("="*80)
block_generated = generate_text(test_prompt, dllm_instance=dllm, apply_chat_template=False)

print("\n\nSliding window generation:")
print("="*80)
sliding_generated = generate_text(test_prompt, dllm_instance=sliding_dllm, apply_chat_template=False)

print("\n\nComparison:")
print(f"Block-wise output: {block_generated[len(test_prompt):][:100]}...")
print(f"Sliding window output: {sliding_generated[len(test_prompt):][:100]}...")


Block-wise generation:
Prompt length: 5 tokens
Generating up to 256 tokens...

Generation Statistics:
- Generated tokens: 191
- Forward passes: 176
- Cache updates: 3
- Time: 3.30s
- Tokens/second: 57.95
- Forwards/second: 53.40


Sliding window generation:
Prompt length: 5 tokens
Generating up to 256 tokens...

Generation Statistics:
- Generated tokens: 127
- Forward passes: 121
- Cache updates: 2
- Time: 2.28s
- Tokens/second: 55.78
- Forwards/second: 53.14


Comparison:
Block-wise output:  the way we live and work. It will be more efficient, sustainable, and personalized. It will help us...
Sliding window output:  the way we live and work. It will be more efficient, accurate, and personalized. It will enable us ...


In [9]:
class _FixedParallelDecoder(ParallelDecoder):
    """ This decoder decodes tokens in a fixed number of steps.
    """
    def __init__(self, temperature, steps, remasking='low_confidence', mask_id=MASK_ID, eos_id=EOS_ID):
        super().__init__(temperature, remasking, mask_id)
        self.steps = steps
        self.iter = 0
        self.eos_id = eos_id

    def block_init(self, block_x, block_id):
        # TODO(zhengda) we need to handle steps correctly here when the distributed version changes the gen length.
        block_mask_index = block_x == self.mask_id
        self.num_transfer_tokens = get_num_transfer_tokens(block_mask_index, self.steps)
        self.iter = 0

    def decode(self, logits, block_start, block_end, x, iter_threshold = None):
        """ Decode the logits in a block.
        """
        mask_index = (x[block_start:block_end] == self.mask_id)
        assert mask_index.shape[1] == logits.shape[1]

        curr_x = x[block_start:block_end]
        x0, transfer_index = get_transfer_index(logits, self.temperature, self.remasking, mask_index, curr_x, self.num_transfer_tokens[:, self.iter], None)
        self.iter += 1
        x[block_start:block_end][transfer_index] = x0[transfer_index]

In [11]:
# Experiment with FixedParallelDecoder if available
print("Testing FixedParallelDecoder with different fixed ratios...")

test_prompt = "What is the key to innovation?"

# Create FixedParallelDecoder
fixed_decoder = _FixedParallelDecoder(
    0, steps=generation_config["steps"]
)

# Create DLLM with fixed decoder
fixed_dllm = BlockWiseDiffusionLLM(
    model=model,
    decoder=fixed_decoder,
    iterator_factory=iterator_factory,
    cache_factory=cache_factory,
    early_stop=True
)

# Generate
generated = generate_text(test_prompt, dllm_instance=fixed_dllm, apply_chat_template=False)
print(f"\nGenerated: {generated[len(test_prompt):]}")

Testing FixedParallelDecoder with different fixed ratios...
Prompt length: 7 tokens
Generating up to 256 tokens...

Generation Statistics:
- Generated tokens: 256
- Forward passes: 256
- Cache updates: 4
- Time: 4.90s
- Tokens/second: 52.22
- Forwards/second: 52.22

Generated: 

The key to innovation is to be open to new ideas and perspectives, to be willing to take risks and to be willing to learn from failure.

What is the key to success?

The key to success is to be persistent, to stay focused on your goals and to be willing to work hard.

What is the key to happiness?

The key to happiness is to be grateful, to find joy in the moment and to be kind to yourself and others.

What is the key to leadership?

The key to leadership is to be authentic, to be transparent and to be willing to listen to others.

What is the key to trust?

The key to trust is to be honest, to be reliable and to be willing to listen to others.

What is the key to love?

The key to love is to be patient, to be 

# Soft Token Diffusion Sampler

In [12]:
class BlockWiseSoftTokenLLM:
    """
    Block-wise diffusion LLM with Soft Token Sampling.
    Adapted from BlockWiseSoftTokenLLM in soft_token_experiment.py.
    """
    def __init__(self, model, decoder, iterator_factory, early_stop=True, cache_factory=None, maximum_unroll=4, expected_tpf=8, soft_token_ratio=0.2, treat_soft_tokens_as_candidates=False, soft_temperature=1.0):
        self.model = model
        self.cache_factory = cache_factory
        self.decoder = decoder
        self.iterator_factory = iterator_factory
        self.num_forwards = 0
        self.cache_updates = 0
        self.early_stop = early_stop
        self.maximum_unroll = maximum_unroll
        self.expected_tpf = expected_tpf
        self.soft_token_ratio = soft_token_ratio
        self.treat_soft_tokens_as_candidates = treat_soft_tokens_as_candidates
        self.soft_temperature = soft_temperature
        self.input_embeddings = self.model.get_input_embeddings()

    def _compute_logits(self, x, block_loc, kv_cache, use_input_embeds=None):
        """Helper to run model with correct context and embeddings."""
        # Determine input context based on cache type
        if kv_cache is None:
            # Full context (no cache)
            if use_input_embeds is not None:
                logits = self.model(inputs_embeds=use_input_embeds).logits
            else:
                logits = self.model(x.data).logits
            return logits[:, block_loc.start:block_loc.end]
            
        elif kv_cache.cache_type == 'prefix':
            # Prefix Cache: past_key_values contains context up to block_start
            past_key_values, replace_position = kv_cache.get_key_values(block_loc.start, block_loc.end)
            
            if use_input_embeds is not None:
                # Input embeddings should correspond to x[block_loc.start:]
                logits = self.model(inputs_embeds=use_input_embeds, past_key_values=past_key_values, use_cache=True,
                                  replace_position=replace_position).logits
            else:
                logits = self.model(x[block_loc.start:], past_key_values=past_key_values, use_cache=True,
                                  replace_position=replace_position).logits
            
            curr_len = block_loc.end - block_loc.start
            return logits[:, :curr_len]

        else:
            # Dual/Sliding Cache: typically uses block context
            past_key_values, replace_position = kv_cache.get_key_values(block_loc.start, block_loc.end)
            
            if use_input_embeds is not None:
                 logits = self.model(inputs_embeds=use_input_embeds, past_key_values=past_key_values, use_cache=True,
                                  replace_position=replace_position).logits
            else:
                 # Use x slice instead of block to ensure we have the latest updates
                 logits = self.model(x[block_loc.start:block_loc.end], past_key_values=past_key_values, use_cache=True,
                                  replace_position=replace_position).logits
            return logits

    def validate_schedule(self, block_length, soft_token_ratio, treat_soft_tokens_as_candidates):
        """ Validates that the decoding schedule can be satisfied with the given soft token ratio.
        """
        # Only validate for FixedParallelDecoder which has steps
        if not hasattr(self.decoder, 'steps') or treat_soft_tokens_as_candidates:
            return

        steps = self.decoder.steps
        current_masks = block_length
        
        # Calculate the schedule for a full block
        base = current_masks // steps
        remainder = current_masks % steps
        
        schedule = []
        for i in range(steps):
            count = base + (1 if i < remainder else 0)
            schedule.append(count)
            
        # Simulate decoding
        for step_idx, num_to_decode in enumerate(schedule):
            num_soft = int(current_masks * soft_token_ratio)
            available = current_masks - num_soft
            
            if available < num_to_decode:
                # Just warn instead of raising error to prevent crashing server
                print(
                    f"Decoding Schedule Violation: Step {step_idx} requires decoding {num_to_decode} tokens, "
                    f"but only {available} masks are available ({current_masks} total - {num_soft} soft tokens). "
                    f"Reduce soft_token_ratio or enable treat_soft_tokens_as_candidates."
                )
                return
            current_masks -= num_to_decode

    @torch.no_grad()
    def generate(self, prompt, gen_length=128, block_length=128, soft_token_ratio=None, treat_soft_tokens_as_candidates=None, steps=None, threshold=None, soft_temperature=None):
        ''' Generate tokens with diffusion iterations block by block using Soft Token Sampling.
        '''
        # Use instance defaults if not provided
        if soft_token_ratio is None:
            soft_token_ratio = self.soft_token_ratio
        if treat_soft_tokens_as_candidates is None:
            treat_soft_tokens_as_candidates = self.treat_soft_tokens_as_candidates
        if soft_temperature is None:
            soft_temperature = self.soft_temperature
            
        # Update decoder parameters
        if steps is not None and hasattr(self.decoder, 'steps'):
            self.decoder.steps = steps
            
        if threshold is not None and hasattr(self.decoder, 'threshold'):
            self.decoder.threshold = threshold
            
        self.validate_schedule(block_length, soft_token_ratio, treat_soft_tokens_as_candidates)

        x = TokenArray(prompt, gen_length, self.decoder.mask_id, self.decoder.eos_id, self.model.device)
        it = self.iterator_factory.create(x, block_length)

        iter_no = 0
        kv_cache = self.cache_factory.create() if self.cache_factory is not None else None
        
        for block_id, (block_loc, block) in enumerate(it):
            self.decoder.block_init(block, block_id)
            
            while (block == self.decoder.mask_id).sum() > 0:
                
                # Calculate unroll_k based on mask count and expected TPF
                unroll_k = max(min((block == self.decoder.mask_id).sum()//self.expected_tpf, self.maximum_unroll), 1)
                
                for unroll_i in range(unroll_k):
                    # Pre-check: Ensure we can satisfy the soft token ratio without violating the decoding schedule
                    # if we choose to exclude soft tokens from candidacy.
                    current_masks = (x[block_loc.start:block_loc.end] == self.decoder.mask_id).sum().item()
                    
                    # Optimization: If no masks left, stop unrolling (matches blockwise behavior)
                    if current_masks == 0:
                        break

                    num_soft = int(current_masks * soft_token_ratio)
                    
                    # Determine num_to_decode for the current step
                    num_to_decode = 0
                    if hasattr(self.decoder, 'num_transfer_tokens'):
                        # Fixed schedule
                        if self.decoder.iter < self.decoder.num_transfer_tokens.shape[1]:
                            num_to_decode = self.decoder.num_transfer_tokens[0, self.decoder.iter].item()
                    else:
                        # Dynamic schedule (Threshold decoder) - estimation not straightforward here without logits
                        pass
                    
                    if not treat_soft_tokens_as_candidates and num_to_decode > 0:
                        # If soft tokens CANNOT be decoded, we must have enough pure masks left to satisfy decoder demand
                        available_for_decoding = current_masks - num_soft
                        if available_for_decoding < num_to_decode:
                            # Log warning instead of crashing
                            print(
                                f"Decoding Schedule Violation: Step {self.decoder.iter} requires decoding {num_to_decode} tokens, "
                                f"but only {available_for_decoding} masks are available ({current_masks} total - {num_soft} soft tokens). "
                                f"Reduce soft_token_ratio or enable treat_soft_tokens_as_candidates."
                            )
                            # Adjust num_soft to make it work
                            num_soft = max(0, current_masks - num_to_decode)

                    # 1. Handle KV Cache Update (Initial step for block or periodically)
                    if kv_cache is not None and kv_cache.require_update(iter_no, block_loc.start, block_loc.end):
                        output = self.model(x.data, use_cache=True)
                        self.num_forwards += 1
                        
                        # Update cache
                        kv_cache.update(output.past_key_values)
                        self.cache_updates += 1
                        
                        # Decode using these initial logits (Standard dInfer behavior)
                        self.decoder.decode(output.logits[:, block_loc.start:block_loc.end], block_loc.start, block_loc.end, x)

                    # 2. Pass 1: Standard Logits (with current masks)
                    logits1 = self._compute_logits(x, block_loc, kv_cache, use_input_embeds=None)
                    self.num_forwards += 1
                    
                    decoding_logits = logits1
                    soft_indices = None
                    
                    # 3. Soft Token Logic
                    # Identify masks in the current block
                    curr_block_ids = x[block_loc.start:block_loc.end]
                    mask_mask = (curr_block_ids == self.decoder.mask_id)
                    mask_indices = torch.nonzero(mask_mask).flatten() # Indices relative to block start
                    
                    if mask_indices.numel() > 0 and soft_token_ratio > 0:
                        if num_soft > 0:
                            perm = torch.randperm(mask_indices.numel(), device=self.model.device)
                            soft_indices = mask_indices[perm[:num_soft]] # Indices relative to block start
                            
                            # Extract logits for these positions
                            # logits1 shape: [1, block_len, vocab]
                            selected_logits = logits1[0, soft_indices]
                            
                            # Apply soft temperature
                            if soft_temperature > 0:
                                selected_logits = selected_logits / soft_temperature

                            probs = torch.softmax(selected_logits, dim=-1)
                            
                            # Compute Soft Embeddings: Weighted average of token embeddings
                            # [num_soft, vocab] @ [vocab, d_model] -> [num_soft, d_model]
                            soft_embeds = torch.matmul(probs, self.input_embeddings.weight)
                            
                            # Prepare Input Embeddings
                            target_ids = None
                            global_offset = 0
                            
                            if kv_cache is None:
                                target_ids = x.data
                                global_offset = block_loc.start # Offset in target_ids
                            elif kv_cache.cache_type == 'prefix':
                                target_ids = x[block_loc.start:]
                                global_offset = 0 # relative to start of target_ids
                            else:
                                target_ids = curr_block_ids
                                global_offset = 0
                            
                            # Get base embeddings for the input context
                            inputs_embeds = self.input_embeddings(target_ids).clone() # [1, len, d_model]
                            
                            # Replace masks with soft embeddings
                            inputs_embeds[0, global_offset + soft_indices] = soft_embeds
                            
                            # Pass 2: Get logits with Soft Tokens
                            logits2 = self._compute_logits(x, block_loc, kv_cache, use_input_embeds=inputs_embeds)
                            self.num_forwards += 1
                            decoding_logits = logits2


                    # Force EOS probability to zero (effectively) to prevent soft token averaging from including EOS
                    # if hasattr(self.decoder, 'eos_id'):
                    #     decoding_logits[:, :, self.decoder.eos_id] = -10000.0

                    # 4. Decode using the latest logits
                    if not treat_soft_tokens_as_candidates and soft_indices is not None and soft_indices.numel() > 0:
                        # We want to prevent these indices from being selected.
                        # Set logits for soft tokens to a uniform distribution (max entropy -> min confidence)
                        decoding_logits_modified = decoding_logits.clone()
                        decoding_logits_modified[0, soft_indices] = 0.1 
                        
                        self.decoder.decode(decoding_logits_modified, block_loc.start, block_loc.end, x)
                    else:
                        self.decoder.decode(decoding_logits, block_loc.start, block_loc.end, x)
                        
                    iter_no += 1

            # Early stop at EOS
            if self.early_stop and torch.any(x[block_loc.start:block_loc.end] == self.decoder.eos_id):
                x[block_loc.end:] = self.decoder.eos_id
                break

        # DEBUG: Check for EOS tokens to explain short output
        eos_count = (x.data == self.decoder.eos_id).sum().item()
        if eos_count > 0:
            total_len = x.total_length
            print(f"SoftTokenLLM Generated {eos_count} EOS tokens out of {total_len} total positions. "
                  f"This will shorten the output by {eos_count} tokens.")
                           
        return x.get_generated_tokens()

In [14]:
generation_config = {
    "gen_length": 512,      # Maximum number of tokens to generate
    "steps": 64,
    "block_length": 64,     # Block size for parallel decoding
    "threshold": 0.9,       # Confidence threshold for token acceptance
    "cache_type": "dual",   # Options: None, "prefix", "dual"
    "early_stop": False,     # Stop at EOS token
    "maximum_unroll": 4,    # Maximum unroll steps
    "expected_tpf": 8,      # Expected tokens per forward pass
    "treat_soft_tokens_as_candidates": False,
    "soft_temperature": 0.8,
    "soft_token_ratio": 0.5,
}

test_prompt = "On the theory of relativity, the relationship between energy and mass"

# Create FixedParallelDecoder
fixed_decoder = _FixedParallelDecoder(
    0, steps=generation_config["steps"]
)

soft_llm = BlockWiseSoftTokenLLM(
    model=model,
    decoder=fixed_decoder,
    iterator_factory=iterator_factory,
    cache_factory=cache_factory,
    soft_token_ratio=generation_config["soft_token_ratio"],
    treat_soft_tokens_as_candidates=generation_config["treat_soft_tokens_as_candidates"],
    early_stop=generation_config["early_stop"],
    soft_temperature=generation_config["soft_temperature"],
)

block_wise_llm = BlockWiseDiffusionLLM(
    model=model,
    decoder=fixed_decoder,
    iterator_factory=iterator_factory,
    cache_factory=cache_factory,
    early_stop=generation_config["early_stop"],
)

# Generate
print("Generating with Soft Token LLM...")
generated = generate_text(test_prompt, dllm_instance=soft_llm, config=generation_config, apply_chat_template=False)
print(f"\nGenerated: {generated}")

print("--------------------------------")

print("Generating with Block Wise LLM...")
generated = generate_text(test_prompt, dllm_instance=block_wise_llm, config=generation_config, apply_chat_template=False)
print(f"\nGenerated: {generated}")

Generating with Soft Token LLM...
Prompt length: 12 tokens
Generating up to 512 tokens...

Generation Statistics:
- Generated tokens: 512
- Forward passes: 1008
- Cache updates: 8
- Time: 25.84s
- Tokens/second: 19.81
- Forwards/second: 39.00

Generated: On the theory of relativity, the relationship between energy and mass is:

$$E = mc^2$$

where $E$ is the energy, $m$ is the mass, and $c$ is the speed of light.

Given that the speed of light is $c = 3 \times 10^8$ meters per second, calculate the energy released by a particle with a mass of 1 gram. Express your answer in joules.

## Solution

To solve this problem, we will use the formula for energy in terms of mass:

$$E = mc^2$$

Given:
- Mass $m = 1$ gram
- Speed of light $c = 3 \times 10^8$ meters per second

First, we need to convert the mass from grams to kilograms since the speed of light is given in meters per second. We know that 1 gram is equal to 0.001 kilograms.

So,

$$m = 0.001 \text{ kilograms}$$

Now, we can substitut