In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig, GenerationConfig
from FastDLLM_inferencing.Fast_dLLM_v2_7B.modeling import Fast_dLLM_QwenForCausalLM


# load LLaDa
device = 'cuda'
verifier = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, dtype=torch.bfloat16)
verifier_tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)


# load fast dLLM
model_name = "Efficient-Large-Model/Fast_dLLM_7B"

drafter_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

# remote config (no remote code execution)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)

# using local class to load remote weights
drafter = Fast_dLLM_QwenForCausalLM.from_pretrained(
    model_name, 
    config=config, 
    trust_remote_code=True,
    dtype="auto",
    device_map="auto",)  # downloads weights from Hub

# (optional) generation parameters from the repo
gen_config = GenerationConfig.from_pretrained(model_name)
drafter.generation_config = gen_config

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 6/6 [00:00<00:00, 148.37it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.22s/it]


In [3]:
# ensure everything is on device
verifier.to(device)
drafter.to(device)

Fast_dLLM_QwenForCausalLM(
  (model): Fast_dLLM_QwenModel(
    (embed_tokens): Embedding(152064, 3584, padding_idx=151645)
    (layers): ModuleList(
      (0-27): 28 x Fast_dLLM_QwenDecoderLayer(
        (self_attn): Fast_dLLM_QwenAttention(
          (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
          (k_proj): Linear(in_features=3584, out_features=512, bias=True)
          (v_proj): Linear(in_features=3584, out_features=512, bias=True)
          (o_proj): Linear(in_features=3584, out_features=3584, bias=False)
        )
        (mlp): Fast_dLLM_QwenMLP(
          (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (up_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (down_proj): Linear(in_features=18944, out_features=3584, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Fast_dLLM_QwenRMSNorm((3584,), eps=1e-06)
        (post_attention_layernorm): Fast_dLLM_QwenRMSNorm((35

In [4]:
def fdllm_inf(prompt, tokenizer, model):
    
    prompt = "Give me a short introduction to large language model."
    messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt},
    {"role": "assistant", "content": "A large language model (LLM) is a type of artificial intelligence model designed to process and generate human"}
    ]

    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=False
    )
    inputs = tokenizer([text], return_tensors="pt").to(model.device)

    gen_ids, past_key_values, past_block_key_values = model.generate(
    inputs["input_ids"],
    tokenizer=tokenizer,
    max_new_tokens=256,
    small_block_size=8,
    threshold=0.95,
    steps=16,
    )

    # response = tokenizer.decode(
    #     gen_ids[0][inputs["input_ids"].shape[1]:], 
    #     skip_special_tokens=False
    # )
    # print(response)

    # response = tokenizer.decode(
    #     gen_ids[0][inputs["input_ids"].shape[1]:], 
    #     skip_special_tokens=False
    # )

    return gen_ids

In [None]:
from LLaDA.generate import generate, generate_per_step

# llada inference
def llada_inf(model, tokenizer, context_tensor):
    gen_length = 256

    # m = [{"role": "user", "content": query}]
    # user_input = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
    # input_ids = tokenizer(user_input)['input_ids']
    input_ids = torch.tensor(context_tensor).to(device)
    # query = input_ids

    out = generate_per_step(model, 
        input_ids, 
        n = 1, 
        k = 1,
        gen_length=gen_length, 
        block_length=gen_length, 
        temperature=0.0, 
        remasking='low_confidence')

    answer = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
    print(answer)
    return out


In [12]:
from inference import convert

drafted = fdllm_inf("Hi what can you do?", drafter_tokenizer, drafter)
print(drafted, len(drafted))

drafted_retok = convert(151665,126336, drafted, drafter_tokenizer, verifier_tokenizer)
print(drafted_retok, len(drafted_retok))

verified = llada_inf(verifier, verifier_tokenizer, drafted_retok)
print(verified)

tensor([[151644,   8948,    198,   2610,    525,    264,  10950,  17847,     13,
         151645,    198, 151644,    872,    198,  35127,    752,    264,   2805,
          16800,    311,   3460,   4128,   1614,     13, 151645,    198, 151644,
          77091,    198,     32,   3460,   4128,   1614,    320,   4086,     44,
              8,    374,    264,    943,    315,  20443,  11229,   1614,   6188,
            311,   1882,    323,   6923,   3738, 151645,    198,  11528,     13,
           1084,    525,  11136,  16176,    389,  10951,  14713,    315,   1467,
            821,     11, 151665,   1105,    311,   3535, 151665, 151665, 151665,
         151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665,
         151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665,
         151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665,
         151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665,
         151665, 151665, 151

  input_ids = torch.tensor(context_tensor).to(device)



Top-10 tokens per position at step 0:

Position 133:
  Rank   Token                          Probability 
  ------------------------------------------------
  1      <|endoftext|>                  0.234375     ◄ SELECTED
  2      ,                              0.057373     
  3       and                           0.027954     
  4      .                              0.027100     
  5      \n                             0.024658     
  6       to                            0.016479     
  7      Ms                             0.016479     
  8       of                            0.014526     
  9      <                              0.012817     
  10      language                      0.012024     

Position 134:
  Rank   Token                          Probability 
  ------------------------------------------------
  1      <|endoftext|>                  0.235352     ◄ SELECTED
  2      ,                              0.059570     
  3       and                           0.029907     
 