# ‚ö° WeDLM Inference

**Key Formula - Adjusted Entropy:**
$$\tilde{H}_j = H(P_j) + \lambda(j - j_{\min}), \quad \text{Fill if } \tilde{H}_j < \tau$$

## üìã Quick Start
Run cells 1-3 for a quick demo of WeDLM parallel decoding.


In [None]:
# 1Ô∏è‚É£ Install Dependencies & Clone Repo
!pip install -q torch transformers accelerate

# Clone the repository (for wedlm package)
import os
if not os.path.exists('05_WeDLM_Reconciling_Diffusion_with_Causal_Attention'):
    !git clone https://github.com/Gaurav14cs17/05_WeDLM_Reconciling_Diffusion_with_Causal_Attention.git

# Add to Python path
import sys
sys.path.insert(0, '05_WeDLM_Reconciling_Diffusion_with_Causal_Attention')

# Try to import wedlm package
try:
    from wedlm import LLM, SamplingParams
    WEDLM_AVAILABLE = True
    print("‚úÖ wedlm package imported successfully!")
except ImportError as e:
    WEDLM_AVAILABLE = False
    print(f"‚ö†Ô∏è wedlm import failed: {e}")
    print("Using standalone implementation...")

# 2Ô∏è‚É£ Load Model (standalone fallback)
import torch
import torch.nn.functional as F
import time
from transformers import AutoTokenizer, AutoModelForCausalLM

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
MASK_TOKEN = "<|mask|>"
W, TAU, LAMBDA = 16, 0.4, 0.02  # Window size, threshold, position penalty

tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL, trust_remote_code=True,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)
if MASK_TOKEN not in tokenizer.get_vocab():
    tokenizer.add_special_tokens({"additional_special_tokens": [MASK_TOKEN]})
    model.resize_token_embeddings(len(tokenizer))
MASK_ID = tokenizer.convert_tokens_to_ids(MASK_TOKEN)
print(f"‚úÖ Model loaded on {device}")


In [None]:
@torch.no_grad()
def generate(prompt, max_tokens=100):
    model.eval()
    gen = tokenizer.encode(prompt, return_tensors="pt").to(device)[0].tolist()
    win, flags = [MASK_ID]*W, [True]*W
    steps, toks = 0, 0
    
    while toks < max_tokens:
        steps += 1
        logits = model(torch.tensor([gen + win], device=device)).logits[0]
        idx = [i for i,f in enumerate(flags) if f]
        if not idx: break
        
        mlogits = torch.stack([logits[len(gen)+i-1] for i in idx])
        probs = F.softmax(mlogits, dim=-1)
        H = -(probs * torch.log(probs + 1e-10)).sum(-1)
        pos = torch.tensor(idx, device=device, dtype=torch.float)
        Hadj = H + LAMBDA * (pos - pos[0])
        
        fill = (Hadj < TAU).nonzero(as_tuple=True)[0]
        if len(fill) == 0: fill = Hadj.argmin().unsqueeze(0)
        
        for k in fill.tolist():
            win[idx[k]] = mlogits[k].argmax().item()
            flags[idx[k]] = False
        
        commit = next((i for i,f in enumerate(flags) if f), len(win))
        if commit == 0: commit = 1
        gen.extend(win[:commit])
        toks += commit
        if tokenizer.eos_token_id in win[:commit]: break
        win = win[commit:] + [MASK_ID]*commit
        flags = flags[commit:] + [True]*commit
    
    return tokenizer.decode(gen, skip_special_tokens=True), {"steps": steps, "tokens": toks, "tok/step": toks/steps}

print("‚úÖ Generator ready")


In [None]:
# Demo
prompts = ["Solve: 15 √ó 7 + 23 = ", "The capital of France is", "def fibonacci(n):"]
for p in prompts:
    t0 = time.time()
    out, stats = generate(p, 50)
    print(f"\n{'='*50}\nPrompt: {p}\nOutput: {out[len(p):]}\nStats: {stats}, Time: {time.time()-t0:.2f}s")


---

## üìö Detailed Version (Alternative Implementation)

The cells above provide a compact demo. Below is a more detailed version with extensive comments.

## üìê The Algorithm

```
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ                    WeDLM Inference Flow                         ‚îÇ
‚îÇ                                                                 ‚îÇ
‚îÇ  Prefix (KV Cache)          Window (Processing)                ‚îÇ
‚îÇ  [The][quick][brown]        [M][M][M][M][M][M][M][M]           ‚îÇ
‚îÇ                              ‚îÇ                                  ‚îÇ
‚îÇ                              ‚ñº                                  ‚îÇ
‚îÇ                    1. Forward Pass (causal attention)          ‚îÇ
‚îÇ                              ‚îÇ                                  ‚îÇ
‚îÇ                              ‚ñº                                  ‚îÇ
‚îÇ                    2. Compute Entropy H(P) for each mask       ‚îÇ
‚îÇ                              ‚îÇ                                  ‚îÇ
‚îÇ                              ‚ñº                                  ‚îÇ
‚îÇ                    3. Fill positions where HÃÉ < threshold       ‚îÇ
‚îÇ                              ‚îÇ                                  ‚îÇ
‚îÇ                              ‚ñº                                  ‚îÇ
‚îÇ                    4. Commit prefix, slide window              ‚îÇ
‚îÇ                                                                 ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
```


In [None]:
# 1Ô∏è‚É£ Install Dependencies & Setup
!pip install -q torch transformers accelerate
!pip install -q flash-attn --no-build-isolation 2>/dev/null || echo "FlashAttn not available (optional)"

# Clone repo and setup wedlm import
import os, sys
if not os.path.exists('05_WeDLM_Reconciling_Diffusion_with_Causal_Attention'):
    !git clone https://github.com/Gaurav14cs17/05_WeDLM_Reconciling_Diffusion_with_Causal_Attention.git
sys.path.insert(0, '05_WeDLM_Reconciling_Diffusion_with_Causal_Attention')

# Import wedlm
try:
    from wedlm import LLM, SamplingParams
    print("‚úÖ wedlm package imported!")
except ImportError:
    print("‚ö†Ô∏è Using standalone implementation")

!nvidia-smi


In [None]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 2Ô∏è‚É£ Load Model (Using Qwen as base for demo, replace with WeDLM when available)
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"  # Replace with "tencent/WeDLM-8B-Instruct" for full model

print(f"Loading {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, trust_remote_code=True,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(device)

# Add mask token
MASK_TOKEN = "<|mask|>"
if MASK_TOKEN not in tokenizer.get_vocab():
    tokenizer.add_special_tokens({"additional_special_tokens": [MASK_TOKEN]})
    model.resize_token_embeddings(len(tokenizer))
MASK_ID = tokenizer.convert_tokens_to_ids(MASK_TOKEN)

print(f"‚úÖ Model loaded! Mask token ID: {MASK_ID}")


In [None]:
# 3Ô∏è‚É£ WeDLM Core Functions

def compute_entropy(logits):
    """
    Compute entropy: H(P) = -Œ£ p_i log(p_i)
    
    Low entropy = high confidence = safe to commit
    """
    probs = F.softmax(logits, dim=-1)
    return -(probs * torch.log(probs + 1e-10)).sum(dim=-1)

def select_positions_to_fill(entropy, remaining_indices, threshold=0.4, pos_penalty=0.02):
    """
    Select which mask positions to fill based on adjusted entropy.
    
    Adjusted entropy: HÃÉ_j = H(P_j) + Œª(j - j_min)
    
    Args:
        entropy: Raw entropy values
        remaining_indices: Position indices in window
        threshold: Fill if HÃÉ < threshold
        pos_penalty: Œª - penalty for later positions
    """
    positions = torch.tensor(remaining_indices, device=entropy.device, dtype=torch.float)
    base_pos = positions[0]
    
    # Position penalty encourages left-to-right decoding
    adjusted = entropy + pos_penalty * (positions - base_pos)
    
    # Select low-entropy positions
    selected = (adjusted < threshold).nonzero(as_tuple=True)[0]
    
    if len(selected) == 0:
        # Fallback: select minimum entropy position
        selected = adjusted.argmin().unsqueeze(0)
    
    return selected.tolist()

print("‚úÖ Core functions defined")


In [None]:
# 4Ô∏è‚É£ WeDLM Generation Function

@torch.no_grad()
def wedlm_generate(
    model, tokenizer, prompt,
    max_new_tokens=100,
    window_size=16,
    entropy_threshold=0.4,
    pos_penalty=0.02,
    temperature=0.0,
    verbose=False
):
    """
    WeDLM Streaming Parallel Decoding
    
    Args:
        model: Language model
        tokenizer: Tokenizer
        prompt: Input text
        max_new_tokens: Max tokens to generate
        window_size: Sliding window size (W)
        entropy_threshold: œÑ - threshold for parallel filling
        pos_penalty: Œª - position penalty factor
        temperature: Sampling temperature (0 = greedy)
        verbose: Print step-by-step progress
    
    Returns:
        Generated text, stats dictionary
    """
    model.eval()
    mask_id = tokenizer.convert_tokens_to_ids(MASK_TOKEN)
    eos_id = tokenizer.eos_token_id
    
    # Encode prompt
    prompt_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)[0].tolist()
    generated = prompt_ids.copy()
    
    # Initialize window with masks
    window = [mask_id] * window_size
    window_mask_flags = [True] * window_size  # True = mask
    
    tokens_generated = 0
    steps = 0
    total_filled = 0
    
    while tokens_generated < max_new_tokens:
        steps += 1
        
        # Build input: prefix + window
        input_ids = torch.tensor([generated + window], device=device)
        
        # Forward pass
        outputs = model(input_ids)
        logits = outputs.logits[0]  # [seq_len, vocab]
        
        # Get logits for mask positions in window
        prefix_len = len(generated)
        
        # Find remaining mask positions
        mask_indices = [i for i, is_mask in enumerate(window_mask_flags) if is_mask]
        
        if not mask_indices:
            break
        
        # Get logits for mask positions (offset by -1 for next-token prediction)
        mask_logits = torch.stack([logits[prefix_len + i - 1] for i in mask_indices])
        
        # Compute entropy
        entropy = compute_entropy(mask_logits)
        
        # Select positions to fill
        fill_indices = select_positions_to_fill(
            entropy, mask_indices, entropy_threshold, pos_penalty
        )
        
        # Sample tokens
        for idx in fill_indices:
            pos = mask_indices[idx]
            pos_logits = mask_logits[idx]
            
            if temperature > 0:
                probs = F.softmax(pos_logits / temperature, dim=-1)
                token = torch.multinomial(probs, 1).item()
            else:
                token = pos_logits.argmax().item()
            
            window[pos] = token
            window_mask_flags[pos] = False
            total_filled += 1
        
        # Find committed prefix (consecutive non-masks from start)
        commit_count = 0
        for i in range(len(window)):
            if not window_mask_flags[i]:
                commit_count += 1
            else:
                break
        
        if commit_count == 0:
            commit_count = 1  # Force progress
        
        # Commit to output
        committed = window[:commit_count]
        generated.extend(committed)
        tokens_generated += commit_count
        
        if verbose:
            print(f"Step {steps}: Filled {len(fill_indices)}, Committed {commit_count}")
        
        # Check for EOS
        if eos_id in committed:
            break
        
        # Slide window
        window = window[commit_count:] + [mask_id] * commit_count
        window_mask_flags = window_mask_flags[commit_count:] + [True] * commit_count
    
    # Decode
    output_text = tokenizer.decode(generated, skip_special_tokens=True)
    
    stats = {
        "steps": steps,
        "tokens_generated": tokens_generated,
        "avg_tokens_per_step": tokens_generated / steps if steps > 0 else 0,
        "total_filled": total_filled
    }
    
    return output_text, stats

print("‚úÖ Generation function defined")


In [None]:
# 5Ô∏è‚É£ Demo: Math Problem
print("=" * 70)
print("üìù DEMO: Math Problem Solving")
print("=" * 70)

prompt = "Solve step by step: What is 15 √ó 7 + 23?"

start = time.time()
output, stats = wedlm_generate(
    model, tokenizer, prompt,
    max_new_tokens=100,
    entropy_threshold=0.5,
    verbose=True
)
elapsed = time.time() - start

print(f"\n{'='*70}")
print(f"Prompt: {prompt}")
print(f"{'='*70}")
print(f"Output: {output[len(prompt):]}")
print(f"{'='*70}")
print(f"üìä Statistics:")
print(f"   ‚Ä¢ Tokens generated: {stats['tokens_generated']}")
print(f"   ‚Ä¢ Forward passes: {stats['steps']}")
print(f"   ‚Ä¢ Avg tokens/step: {stats['avg_tokens_per_step']:.2f}")
print(f"   ‚Ä¢ Time: {elapsed:.2f}s")
print(f"   ‚Ä¢ Speed: {stats['tokens_generated']/elapsed:.1f} tok/s")


In [None]:
# 6Ô∏è‚É£ Compare: Different Entropy Thresholds
print("\n" + "=" * 70)
print("üî¨ Experiment: Effect of Entropy Threshold")
print("=" * 70)

test_prompt = "The capital of France is"

for threshold in [0.2, 0.4, 0.6, 0.8]:
    output, stats = wedlm_generate(
        model, tokenizer, test_prompt,
        max_new_tokens=30,
        entropy_threshold=threshold
    )
    print(f"\nœÑ = {threshold}:")
    print(f"   Steps: {stats['steps']}, Avg tok/step: {stats['avg_tokens_per_step']:.2f}")
    print(f"   Output: {output[len(test_prompt):50]}...")


---

## üìö Summary

### Key Parameters

| Parameter | Default | Description |
|-----------|---------|-------------|
| `window_size` | 16 | Number of mask tokens to process in parallel |
| `entropy_threshold` (œÑ) | 0.4 | Fill positions with HÃÉ < œÑ |
| `pos_penalty` (Œª) | 0.02 | Penalty for later positions |

### Speed-Quality Tradeoff

- **Lower œÑ** ‚Üí More conservative, higher quality, fewer tokens/step
- **Higher œÑ** ‚Üí More aggressive, faster, potential quality drop
- **Higher Œª** ‚Üí Stronger left-to-right bias

### Recommended Settings

| Use Case | œÑ | Œª | Expected Speedup |
|----------|---|---|------------------|
| Math/Code | 0.4-0.6 | 0.02 | 3-6√ó |
| General QA | 0.3-0.4 | 0.02 | 1.5-2√ó |
| Creative | 0.2-0.3 | 0.01 | 1-1.5√ó |

---

## üîó Resources

- [Paper](https://arxiv.org/abs/2512.22737)
- [Official Code](https://github.com/tencent/WeDLM)
- [HuggingFace Models](https://huggingface.co/collections/tencent/wedlm)
