In [1]:
import torch
import time
from transformers import AutoTokenizer
from model.modeling_llada import LLaDAModelLM
from generate import generate_with_dual_cache, generate_with_dual_cache_tokenskip

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 加载模型
device = 'cuda'
model = LLaDAModelLM.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', torch_dtype=torch.bfloat16).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct')

Loading checkpoint shards: 100%|██████████| 6/6 [00:00<00:00,  7.36it/s]


In [3]:
# 准备输入
prompt = "Who is Newton, physics?"
m = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
input_ids = torch.tensor(tokenizer(text)['input_ids']).to(device).unsqueeze(0)

In [4]:
# 测试 baseline
start = time.time()
out1, nfe1 = generate_with_dual_cache(model, input_ids, steps=128, gen_length=128, block_length=32, threshold=0.9)
t1 = time.time() - start
ans1 = tokenizer.decode(out1[0, input_ids.shape[1]:], skip_special_tokens=True)
print(f"Baseline: {t1:.2f}s, NFE={nfe1}")
print(ans1)

Baseline: 6.78s, NFE=71
Isaac Newton was an English physicist and mathematician who made significant contributions to the development of classical mechanics and optics. He is best known for his laws of motion, which describe the motion of objects, and his law of universal gravitation, which explains the force of gravity between objects. Newton's work laid the foundation for modern physics and is considered one of the most influential figures in the history of science.


In [5]:
# 测试 tokenskip（超参可调）
SKIP_LAYER_K = 18       # 判定用的前 K 层
SKIP_THRESHOLD = 1  # 平均 cos sim 阈值
SKIP_OUTLIER = 0.7     # 任意层低于此值则强制计算

start = time.time()
out2, nfe2 = generate_with_dual_cache_tokenskip(
    model, input_ids, steps=128, gen_length=128, block_length=32, threshold=0.9,
    skip_layer_k=SKIP_LAYER_K, skip_threshold=SKIP_THRESHOLD, skip_outlier=SKIP_OUTLIER
)
t2 = time.time() - start
ans2 = tokenizer.decode(out2[0, input_ids.shape[1]:], skip_special_tokens=True)
print(f"TokenSkip: {t2:.2f}s, NFE={nfe2}")
print(ans2)
# 检查输出形状
print(f"out1.shape: {out1.shape}")  # baseline
print(f"out2.shape: {out2.shape}")  # tokenskip
print(f"input_ids.shape: {input_ids.shape}")
print(f"预期 gen_length: 128")
print(f"实际生成长度: {out2.shape[1] - input_ids.shape[1]}")

TokenSkip: 12.07s, NFE=71
Isaac Newton was an English physicist and mathematician who made significant contributions to the development of classical mechanics and optics. He is best known for his laws of motion, which describe the motion of objects, and his law of universal gravitation, which explains the force of gravity between objects. Newton's work laid the foundation for modern physics and is considered one of the most influential figures in the history of science.
out1.shape: torch.Size([1, 147])
out2.shape: torch.Size([1, 147])
input_ids.shape: torch.Size([1, 19])
预期 gen_length: 128
实际生成长度: 128


In [6]:
# 测试 tokenskip（超参可调）
SKIP_LAYER_K = 16       # 判定用的前 K 层
SKIP_THRESHOLD = 0.99  # 平均 cos sim 阈值
SKIP_OUTLIER = 0.9     # 任意层低于此值则强制计算

start = time.time()
out2, nfe2 = generate_with_dual_cache_tokenskip(
    model, input_ids, steps=128, gen_length=128, block_length=32, threshold=0.9,
    skip_layer_k=SKIP_LAYER_K, skip_threshold=SKIP_THRESHOLD, skip_outlier=SKIP_OUTLIER
)
t2 = time.time() - start
ans2 = tokenizer.decode(out2[0, input_ids.shape[1]:], skip_special_tokens=True)
print(f"TokenSkip: {t2:.2f}s, NFE={nfe2}")
print(ans2)
# 检查输出形状
print(f"out1.shape: {out1.shape}")  # baseline
print(f"out2.shape: {out2.shape}")  # tokenskip
print(f"input_ids.shape: {input_ids.shape}")
print(f"预期 gen_length: 128")
print(f"实际生成长度: {out2.shape[1] - input_ids.shape[1]}")

TokenSkip: 9.09s, NFE=68
Isaac Newton was an English physicist, mathematician, and astronomer, and one of the most influential scientists of all time. He is best known for his work on classical mechanics, and his laws of motion and calculus are fundamental to modern physics physics. Newton's work laid the foundation for the development of modern mathematics and physics.
out1.shape: torch.Size([1, 147])
out2.shape: torch.Size([1, 147])
input_ids.shape: torch.Size([1, 19])
预期 gen_length: 128
实际生成长度: 128


In [7]:
ans2

"Isaac Newton was an English physicist, mathematician, and astronomer, and one of the most influential scientists of all time. He is best known for his work on classical mechanics, and his laws of motion and calculus are fundamental to modern physics physics. Newton's work laid the foundation for the development of modern mathematics and physics."

In [8]:
# 对比
print(f"Speedup: {t1/t2:.2f}x")
print(f"NFE: {nfe1} -> {nfe2}")

Speedup: 0.75x
NFE: 71 -> 68


In [None]:
# GSM8K 30题评估（tokenskip模式，复用已加载模型）
import random
import re
import math
from datasets import load_dataset

# TokenSkip 超参（与 cell 4 保持一致）
SKIP_LAYER_K = 16
SKIP_THRESHOLD = 0.99
SKIP_OUTLIER = 0.9

random.seed(42)
gsm8k = load_dataset("openai/gsm8k", "main", split="test")
gsm8k_30 = random.sample(list(gsm8k), 30)

def extract_answer_strict(text):
    """严格提取: lm_eval regex = r'#### (\-?[0-9\.\,]+)'"""
    match = re.search(r'#### (-?[0-9.,]+)', text)
    if match:
        return match.group(1)
    return None

def extract_answer_flexible(text):
    """宽松提取: lm_eval regex = r'(-?[$0-9.,]{2,})|(-?[0-9]+)', 取最后一个"""
    # 匹配所有数字模式
    matches = re.findall(r'(-?\$?[0-9.,]{2,})|(-?[0-9]+)', text)
    # findall 返回 tuple list，展平并过滤空值
    nums = [m[0] or m[1] for m in matches if m[0] or m[1]]
    return nums[-1] if nums else None  # group_select: -1 取最后一个

def get_gold_answer(answer_text):
    """从标准答案提取数字 (用 strict 方式)"""
    match = re.search(r'#### (-?[0-9.,]+)', answer_text)
    return match.group(1) if match else ""

def normalize_answer(ans):
    """归一化答案: 移除 , $ 和末尾的 ."""
    if ans is None:
        return None
    ans = str(ans).strip()
    ans = ans.replace(',', '').replace('$', '')  # 移除逗号和美元符
    ans = ans.rstrip('.')  # 移除末尾的点
    return ans

# 评估循环
strict_correct = 0
flexible_correct = 0
results_list = []

for i, item in enumerate(gsm8k_30):
    question = item['question']
    gold = normalize_answer(get_gold_answer(item['answer']))
    
    # 构造输入
    m = [{"role": "user", "content": question}]
    text = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
    input_ids = torch.tensor(tokenizer(text)['input_ids']).to(device).unsqueeze(0)
    
    # 生成 (tokenskip)
    out, nfe = generate_with_dual_cache_tokenskip(
        model, input_ids, steps=128, gen_length=256, block_length=32, threshold=0.9,
        skip_layer_k=SKIP_LAYER_K, skip_threshold=SKIP_THRESHOLD, skip_outlier=SKIP_OUTLIER
    )
    ans_text = tokenizer.decode(out[0, input_ids.shape[1]:], skip_special_tokens=True)
    
    pred_strict = normalize_answer(extract_answer_strict(ans_text))
    pred_flex = normalize_answer(extract_answer_flexible(ans_text))
    
    is_strict = (pred_strict == gold) if pred_strict else False
    is_flex = (pred_flex == gold) if pred_flex else False
    
    strict_correct += is_strict
    flexible_correct += is_flex
    results_list.append({'pred_strict': pred_strict, 'pred_flex': pred_flex, 'gold': gold, 'strict': is_strict, 'flex': is_flex})
    
    symbol = '✓' if is_flex else '✗'
    print(f"[{i+1}/30] strict={pred_strict}, flex={pred_flex}, gold={gold}, {symbol}")

# 统计 + stderr (二项分布标准误)
n = len(gsm8k_30)
acc_strict = strict_correct / n
acc_flex = flexible_correct / n
stderr_strict = math.sqrt(acc_strict * (1 - acc_strict) / n)
stderr_flex = math.sqrt(acc_flex * (1 - acc_flex) / n)

print(f"\n{'='*60}")
print(f"GSM8K 评估结果 (n={n}, tokenskip k={SKIP_LAYER_K} t={SKIP_THRESHOLD})")
print(f"{'='*60}")
print(f"| Metric           | Value  | Stderr |")
print(f"|------------------|--------|--------|")
print(f"| exact_match,flexible-extract | {acc_flex:.4f} | ±{stderr_flex:.4f} |")
print(f"| exact_match,strict-match     | {acc_strict:.4f} | ±{stderr_strict:.4f} |")

[1/30] strict=None, flex=1880.00, gold=2280, ✗
[2/30] strict=None, flex=1, gold=1, ✓
[3/30] strict=None, flex=5, gold=5, ✓
[4/30] strict=None, flex=12, gold=12, ✓
[5/30] strict=None, flex=364, gold=273, ✗
[6/30] strict=None, flex=30, gold=45, ✗
[7/30] strict=None, flex=3, gold=21, ✗
[8/30] strict=None, flex=20, gold=145, ✗
[9/30] strict=None, flex=60, gold=60, ✓
[10/30] strict=None, flex=122, gold=122, ✓
[11/30] strict=None, flex=None, gold=29, ✗
[12/30] strict=None, flex=80, gold=80, ✓
[13/30] strict=None, flex=6, gold=36, ✗
[14/30] strict=None, flex=1300, gold=1430, ✗
[15/30] strict=None, flex=7, gold=5, ✗
[16/30] strict=None, flex=5, gold=5, ✓
[17/30] strict=None, flex=5, gold=5, ✓
[18/30] strict=None, flex=18, gold=66, ✗
[19/30] strict=None, flex=15, gold=15, ✓
[20/30] strict=None, flex=2, gold=40, ✗
[21/30] strict=None, flex=93, gold=93, ✓
[22/30] strict=None, flex=300, gold=2000, ✗
[23/30] strict=None, flex=1520, gold=1520, ✓
[24/30] strict=None, flex=170, gold=11050, ✗
