In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig, GenerationConfig
from FastDLLM_inferencing.Fast_dLLM_v2_7B.modeling import Fast_dLLM_QwenForCausalLM


# load LLaDa
device = 'cuda'
verifier = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, dtype=torch.bfloat16)
verifier_tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)


# load fast dLLM
model_name = "Efficient-Large-Model/Fast_dLLM_7B"

drafter_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

# remote config (no remote code execution)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)

# using local class to load remote weights
drafter = Fast_dLLM_QwenForCausalLM.from_pretrained(
    model_name, 
    config=config, 
    trust_remote_code=True,
    dtype="auto",
    device_map="auto",)  # downloads weights from Hub

# (optional) generation parameters from the repo
gen_config = GenerationConfig.from_pretrained(model_name)
drafter.generation_config = gen_config

2025-11-30 18:51:57.185904: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-30 18:51:57.247339: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-11-30 18:51:59.240738: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
# ensure everything is on device
verifier.to(device)
drafter.to(device)

Fast_dLLM_QwenForCausalLM(
  (model): Fast_dLLM_QwenModel(
    (embed_tokens): Embedding(152064, 3584, padding_idx=151645)
    (layers): ModuleList(
      (0-27): 28 x Fast_dLLM_QwenDecoderLayer(
        (self_attn): Fast_dLLM_QwenAttention(
          (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
          (k_proj): Linear(in_features=3584, out_features=512, bias=True)
          (v_proj): Linear(in_features=3584, out_features=512, bias=True)
          (o_proj): Linear(in_features=3584, out_features=3584, bias=False)
        )
        (mlp): Fast_dLLM_QwenMLP(
          (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (up_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (down_proj): Linear(in_features=18944, out_features=3584, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Fast_dLLM_QwenRMSNorm((3584,), eps=1e-06)
        (post_attention_layernorm): Fast_dLLM_QwenRMSNorm((35

In [16]:
# Model wrapper functions for Fast_dLLM and LLaDA
# Add this to your notebook or create a new file: model_wrappers.py

import torch
from LLaDA.generate import generate_per_step


def fastdllm_generate_fn(model, tokenizer, input_ids, num_steps, **kwargs):
    """
    Wrapper for Fast_dLLM generation to match pipeline interface.
    
    Args:
        model: Fast_dLLM model
        tokenizer: Fast_dLLM tokenizer
        input_ids: Input tensor
        num_steps: Number of steps
        **kwargs: Additional arguments (small_block_size, threshold, etc.)
    
    Returns:
        Generated tensor
    """
    # Extract Fast_dLLM specific parameters or use defaults
    small_block_size = kwargs.get('small_block_size', 8)
    threshold = kwargs.get('threshold', 0.95)
    max_new_tokens = kwargs.get('max_new_tokens', 256)
    mask_id = kwargs.get('mask_id', 151665)
    mask_positions = (input_ids == mask_id)
    if mask_positions.any():
        # Find first masked position
        first_mask_idx = torch.where(mask_positions[0])[0][0].item()
        prompt = input_ids[:, :first_mask_idx]
    else:
        # No masks, use entire input as prompt
        prompt = input_ids

    prompt_len = len(prompt)
    
    # Fast_dLLM returns (gen_ids, past_key_values, past_block_key_values)
    output = model.generate(
        input_ids,
        tokenizer=tokenizer,
        max_new_tokens=max_new_tokens,
        block_size=32,
        small_block_size=small_block_size,
        threshold=threshold,
        steps=num_steps,
    )
    
    return output


def llada_generate_fn(model, tokenizer, input_ids, num_steps, **kwargs):
    """
    Wrapper for LLaDA generation to match pipeline interface.
    
    Args:
        model: LLaDA model
        tokenizer: LLaDA tokenizer
        input_ids: Input tensor (should already have prompt + masked tokens)
        num_steps: Number of steps (n)
        **kwargs: Additional arguments (k, gen_length, block_length, etc.)
    
    Returns:
        Generated tensor
    """
    # Extract LLaDA specific parameters
    k = kwargs.get('k', 1)  # tokens per step
    gen_length = kwargs.get('gen_length', 256)
    block_length = kwargs.get('block_length', 256)
    temperature = kwargs.get('temperature', 0.0)
    remasking = kwargs.get('remasking', 'low_confidence')
    mask_id = kwargs.get('mask_id', 126336)
    
    # LLaDA's generate_per_step expects just the prompt, not prompt+masked
    # So we need to extract the prompt part
    # Assuming the masked tokens are at the end
    mask_positions = (input_ids == mask_id)
    if mask_positions.any():
        # Find first masked position
        first_mask_idx = torch.where(mask_positions[0])[0][0].item()
        prompt = input_ids[:, :first_mask_idx]
    else:
        # No masks, use entire input as prompt
        prompt = input_ids
    
    output = generate_per_step(
        model,
        prompt,
        n=num_steps,
        k=k,
        gen_length=gen_length,
        block_length=block_length,
        temperature=temperature,
        remasking=remasking,
        mask_id=mask_id
    )
    
    return output

In [17]:
from dual_pipeline import dual_diffusion_generate
from inference import convert
from verification_algos import confidence_threshold_verification

result = dual_diffusion_generate(
    # Models
    drafter_model=drafter,
    drafter_tokenizer=drafter_tokenizer,
    verifier_model=verifier,
    verifier_tokenizer=verifier_tokenizer,
    
    # Input
    query="Give me a short introduction to large language models.",
    max_new_tokens=256,
    
    # Steps
    num_drafter_steps=16,
    num_verifier_steps=1,
    
    # Mask IDs
    drafter_mask_id=151665,  # Fast_dLLM mask ID
    verifier_mask_id=126336,  # LLaDA mask ID
    
    # Custom generate functions
    drafter_generate_fn=fastdllm_generate_fn,
    verifier_generate_fn=llada_generate_fn,
    
    # Verification (None = use default trust_verifier)
    verification_fn=confidence_threshold_verification,
    
    # Iteration control
    max_iterations=4,  # Single draft-verify pass
    
    # Model-specific kwargs
    small_block_size=8,
    threshold=0.95,
    k=1,
    gen_length=256,
    block_length=256,
    temperature=0.0,
    remasking='low_confidence'
)

print("Generated text:")
print(result['output_text'])
print("\nStats:")
print(result['stats'])

Full decoded base prompt (with special tokens):
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Give me a short introduction to large language models.<|im_end|>
<|im_start|>assistant
|<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK>||<MASK

IndexError: index 304 is out of bounds for dimension 1 with size 256