In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig, GenerationConfig
from FastDLLM_inferencing.Fast_dLLM_v2_7B.modeling import Fast_dLLM_QwenForCausalLM


# load LLaDa
device = 'cuda'
verifier = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, dtype=torch.bfloat16)
verifier_tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)


# load fast dLLM
model_name = "Efficient-Large-Model/Fast_dLLM_7B"

drafter_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

# remote config (no remote code execution)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)

# using local class to load remote weights
drafter = Fast_dLLM_QwenForCausalLM.from_pretrained(
    model_name, 
    config=config, 
    trust_remote_code=True,
    dtype="auto",
    device_map="auto",)  # downloads weights from Hub

# (optional) generation parameters from the repo
gen_config = GenerationConfig.from_pretrained(model_name)
drafter.generation_config = gen_config

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 6/6 [00:00<00:00, 150.15it/s]
Fetching 4 files: 100%|██████████| 4/4 [01:55<00:00, 28.96s/it] 
Loading checkpoint shards: 100%|██████████| 4/4 [00:16<00:00,  4.00s/it]


In [15]:
!nvidia-smi

Thu Nov  6 18:00:30 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.124.06             Driver Version: 570.124.06     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A40                     On  |   00000000:23:00.0 Off |                    0 |
|  0%   44C    P0             80W /  300W |   30847MiB /  46068MiB |      0%   E. Process |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
from LLaDA.generate import generate

# llada inference
def llada_inf(query):
    gen_length = 256

    m = [{"role": "user", "content": query}]
    user_input = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
    input_ids = tokenizer(user_input)['input_ids']
    input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
    query = input_ids

    out = generate(model, 
        query, 
        steps=1, 
        gen_length=gen_length, 
        block_length=gen_length, 
        temperature=0.0, 
        remasking='low_confidence')

    answer = tokenizer.batch_decode(out[:, prompt.shape[1]:], skip_special_tokens=True)[0]


In [9]:
def fdllm_inf(prompt, tokenizer, model):
    
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": prompt}
    ]

    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    inputs = tokenizer([text], return_tensors="pt").to(model.device)

    # Fast-dLLM v2 parallel decoding
    gen_ids = model.generate(
        inputs["input_ids"],
        tokenizer=tokenizer,
        max_new_tokens=256,
        small_block_size=8,
        threshold=0.95,
        steps=5,
    )

    # response = tokenizer.decode(
    #     gen_ids[0][inputs["input_ids"].shape[1]:], 
    #     skip_special_tokens=False
    # )

    return gen_ids

In [10]:
fdllm_inf("Hi what can you do?", drafter_tokenizer, drafter)

tensor([[151644,   8948,    198,   2610,    525,    264,  10950,  17847,     13,
         151645,    198, 151644,    872,    198,  13048,   1128,    646,    498,
            653,     30, 151645,    198, 151644,  77091,    198,   2121,    264,
         151665,   1614,     11,    358, 151665, 151665, 151665, 151665, 151665,
         151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665,
         151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665,
         151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665,
         151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665,
         151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665,
         151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665,
         151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665,
         151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665,
         151665, 151665, 151

In [None]:
def(draft_tensor):
    