In [1]:
import torch
from transformers import AutoTokenizer
from model.modeling_llada import LLaDAModelLM
from generate import generate_with_dual_cache, generate_with_dual_cache_tokenskip

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 加载模型
device = 'cuda'
model = LLaDAModelLM.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', torch_dtype=torch.bfloat16).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct')

Loading checkpoint shards: 100%|██████████| 6/6 [00:00<00:00,  8.80it/s]


In [3]:
# 准备输入
prompt = "Who is Newton, physics?"
m = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
input_ids = torch.tensor(tokenizer(text)['input_ids']).to(device).unsqueeze(0)
print(f"Input length: {input_ids.shape[1]}")

Input length: 19


In [4]:
# ===== 单步 debug：只跑一个 block 的第一个 step =====
# 这样可以追踪 model forward 的细节

mask_id = 126336
gen_length = 128
block_length = 32

# 初始化序列
x = torch.full((1, input_ids.shape[1] + gen_length), mask_id, dtype=torch.long, device=device)
x[:, :input_ids.shape[1]] = input_ids
print(f"x shape: {x.shape}")

x shape: torch.Size([1, 147])


In [5]:
# ===== Baseline: 第一次 forward（无 skip）=====
with torch.no_grad():
    out_baseline = model(x, use_cache=True, output_hidden_states=True)

print(f"Baseline logits shape: {out_baseline.logits.shape}")
print(f"Baseline hidden_states: {len(out_baseline.hidden_states)} layers")
print(f"First hidden shape: {out_baseline.hidden_states[0].shape}")

Baseline logits shape: torch.Size([1, 147, 126464])
Baseline hidden_states: 33 layers
First hidden shape: torch.Size([1, 147, 4096])


In [6]:
# ===== TokenSkip: 第一次 forward（有 skip 参数但 prev_hidden=None）=====
with torch.no_grad():
    out_skip = model(
        x, use_cache=True, output_hidden_states=True,
        skip_layer_k=8, skip_threshold=0.95, skip_outlier=0.7,
        prev_hidden=None  # 第一次没有 prev
    )

print(f"Skip logits shape: {out_skip.logits.shape}")
print(f"Skip hidden_states: {len(out_skip.hidden_states)} layers")

Skip logits shape: torch.Size([1, 147, 126464])
Skip hidden_states: 33 layers


In [7]:
# ===== 对比第一次 forward 的 logits =====
diff = (out_baseline.logits - out_skip.logits).abs().max()
print(f"第一次 forward logits 最大差异: {diff.item():.6f}")
print(f"应该接近 0（因为第一次 prev_hidden=None，不会 skip）")

第一次 forward logits 最大差异: 0.000000
应该接近 0（因为第一次 prev_hidden=None，不会 skip）


In [8]:
# ===== 模拟第二次 forward（refinement step）=====
# 用第一次的 hidden_states 作为 prev_hidden

Lp = input_ids.shape[1]
s, e = Lp, Lp + block_length  # 第一个 block 的范围

# 构造 replace_position
replace_position = torch.zeros_like(x, dtype=torch.bool)
replace_position[:, s:e] = True

past_kv = out_baseline.past_key_values
prev_hidden = out_baseline.hidden_states

print(f"Block range: [{s}, {e})")
print(f"prev_hidden: {len(prev_hidden)} layers, shape {prev_hidden[0].shape}")

Block range: [19, 51)
prev_hidden: 33 layers, shape torch.Size([1, 147, 4096])


In [9]:
# ===== Baseline: 第二次 forward =====
with torch.no_grad():
    out2_baseline = model(
        x[:, s:e],
        past_key_values=past_kv,
        use_cache=True,
        replace_position=replace_position,
        output_hidden_states=True
    )

print(f"Baseline 2nd forward logits shape: {out2_baseline.logits.shape}")

Baseline 2nd forward logits shape: torch.Size([1, 32, 126464])


In [10]:
# ===== TokenSkip: 第二次 forward（有 prev_hidden）=====
# 这里会触发判定逻辑

SKIP_LAYER_K = 8
SKIP_THRESHOLD = 1.0  # 设为 1，理论上不会 skip
SKIP_OUTLIER = 0.7

with torch.no_grad():
    out2_skip = model(
        x[:, s:e],
        past_key_values=past_kv,
        use_cache=True,
        replace_position=replace_position,
        output_hidden_states=True,
        skip_layer_k=SKIP_LAYER_K,
        skip_threshold=SKIP_THRESHOLD,
        skip_outlier=SKIP_OUTLIER,
        prev_hidden=prev_hidden
    )

print(f"Skip 2nd forward logits shape: {out2_skip.logits.shape}")

Skip 2nd forward logits shape: torch.Size([1, 32, 126464])


In [11]:
# ===== 对比第二次 forward 的 logits =====
diff2 = (out2_baseline.logits - out2_skip.logits).abs().max()
print(f"第二次 forward logits 最大差异: {diff2.item():.6f}")
print(f"如果 threshold=1 导致无 skip，这里也应该接近 0")

第二次 forward logits 最大差异: 0.000000
如果 threshold=1 导致无 skip，这里也应该接近 0


In [12]:
# ===== 逐层对比 hidden states =====
print("逐层 hidden states 对比:")
for i in range(min(len(out2_baseline.hidden_states), len(out2_skip.hidden_states))):
    h1 = out2_baseline.hidden_states[i]
    h2 = out2_skip.hidden_states[i]
    if h1.shape == h2.shape:
        diff = (h1 - h2).abs().max().item()
        print(f"  Layer {i}: shape {h1.shape}, max diff = {diff:.6f}")
    else:
        print(f"  Layer {i}: shape mismatch! {h1.shape} vs {h2.shape}")

逐层 hidden states 对比:
  Layer 0: shape torch.Size([1, 32, 4096]), max diff = 0.000000
  Layer 1: shape torch.Size([1, 32, 4096]), max diff = 0.000000
  Layer 2: shape torch.Size([1, 32, 4096]), max diff = 0.000000
  Layer 3: shape torch.Size([1, 32, 4096]), max diff = 0.000000
  Layer 4: shape torch.Size([1, 32, 4096]), max diff = 0.000000
  Layer 5: shape torch.Size([1, 32, 4096]), max diff = 0.000000
  Layer 6: shape torch.Size([1, 32, 4096]), max diff = 0.000000
  Layer 7: shape torch.Size([1, 32, 4096]), max diff = 0.000000
  Layer 8: shape torch.Size([1, 32, 4096]), max diff = 0.000000
  Layer 9: shape torch.Size([1, 32, 4096]), max diff = 0.000000
  Layer 10: shape torch.Size([1, 32, 4096]), max diff = 0.000000
  Layer 11: shape torch.Size([1, 32, 4096]), max diff = 0.000000
  Layer 12: shape torch.Size([1, 32, 4096]), max diff = 0.000000
  Layer 13: shape torch.Size([1, 32, 4096]), max diff = 0.000000
  Layer 14: shape torch.Size([1, 32, 4096]), max diff = 0.000000
  Layer 15: sh

In [13]:
# ===== 检查 prev_hidden 和当前 hidden 的 cos sim =====
import torch.nn.functional as F

print(f"检查前 {SKIP_LAYER_K} 层的 cos sim（判定用）:")
print(f"prev_hidden 有 {len(prev_hidden)} 层")
print(f"out2_skip.hidden_states 有 {len(out2_skip.hidden_states)} 层")

# 注意: prev_hidden 是完整序列，out2_skip.hidden_states 是当前 block
# 需要对齐位置
for layer in range(1, min(SKIP_LAYER_K, len(out2_skip.hidden_states))):
    # prev_hidden[layer] 的 [s:e] 对应当前 block
    h_prev = prev_hidden[layer][:, s:e, :]  # 上一次的当前 block 位置
    h_curr = out2_skip.hidden_states[layer]  # 当前的
    
    if h_prev.shape == h_curr.shape:
        # 计算每个 token 的 cos sim
        cos_sims = F.cosine_similarity(h_prev, h_curr, dim=-1)  # (B, L)
        print(f"  Layer {layer}: cos_sim min={cos_sims.min():.4f}, max={cos_sims.max():.4f}, mean={cos_sims.mean():.4f}")
    else:
        print(f"  Layer {layer}: shape mismatch! {h_prev.shape} vs {h_curr.shape}")

检查前 8 层的 cos sim（判定用）:
prev_hidden 有 33 层
out2_skip.hidden_states 有 33 层
  Layer 1: cos_sim min=0.9961, max=1.0078, mean=1.0000
  Layer 2: cos_sim min=0.9922, max=1.0078, mean=1.0000
  Layer 3: cos_sim min=0.9883, max=1.0078, mean=1.0000
  Layer 4: cos_sim min=0.9961, max=1.0078, mean=1.0000
  Layer 5: cos_sim min=0.9922, max=1.0078, mean=1.0000
  Layer 6: cos_sim min=0.9922, max=1.0078, mean=1.0000
  Layer 7: cos_sim min=0.9961, max=1.0078, mean=1.0000


In [14]:
# ===== 模拟多步迭代：更新 x 后继续 forward =====
# 这才是真正的生成过程

import torch.nn.functional as F

# 用 baseline 的 logits 更新 x（简单 argmax）
x_baseline = x.clone()
x_skip = x.clone()

# 从 logits 获取预测的 token
pred_tokens = out_baseline.logits[:, s:e, :].argmax(dim=-1)  # (1, 32)

# 只更新部分位置（模拟 threshold=0.9 的行为）
# 这里简化：更新所有 [MASK] 位置
mask_positions = (x_baseline[:, s:e] == mask_id)
x_baseline[:, s:e] = torch.where(mask_positions, pred_tokens, x_baseline[:, s:e])
x_skip[:, s:e] = torch.where(mask_positions, pred_tokens, x_skip[:, s:e])

print(f"更新后 x 中非 MASK token 数量: {(x_baseline != mask_id).sum().item()}")

更新后 x 中非 MASK token 数量: 51


In [15]:
# ===== Step 2: 用更新后的 x 再跑一次 =====
# Baseline
with torch.no_grad():
    out3_baseline = model(
        x_baseline[:, s:e],
        past_key_values=out2_baseline.past_key_values,
        use_cache=True,
        replace_position=replace_position,
        output_hidden_states=True
    )

# TokenSkip（用 step 1 的 hidden states 作为 prev_hidden）
prev_hidden_step2 = out2_skip.hidden_states  # 注意：这只是当前 block 的 hidden states
with torch.no_grad():
    out3_skip = model(
        x_skip[:, s:e],
        past_key_values=out2_skip.past_key_values,
        use_cache=True,
        replace_position=replace_position,
        output_hidden_states=True,
        skip_layer_k=SKIP_LAYER_K,
        skip_threshold=SKIP_THRESHOLD,
        skip_outlier=SKIP_OUTLIER,
        prev_hidden=prev_hidden_step2
    )

print(f"Step 2 Baseline logits shape: {out3_baseline.logits.shape}")
print(f"Step 2 Skip logits shape: {out3_skip.logits.shape}")

Step 2 Baseline logits shape: torch.Size([1, 32, 126464])
Step 2 Skip logits shape: torch.Size([1, 32, 126464])


In [16]:
# ===== 对比 Step 2 的结果 =====
diff3 = (out3_baseline.logits - out3_skip.logits).abs().max()
print(f"Step 2 logits 最大差异: {diff3.item():.6f}")

# 检查 hidden states 差异
print("\nStep 2 逐层 hidden states 对比:")
for i in range(min(10, len(out3_baseline.hidden_states))):  # 只看前 10 层
    h1 = out3_baseline.hidden_states[i]
    h2 = out3_skip.hidden_states[i]
    if h1.shape == h2.shape:
        diff = (h1 - h2).abs().max().item()
        print(f"  Layer {i}: max diff = {diff:.6f}")
    else:
        print(f"  Layer {i}: shape mismatch! {h1.shape} vs {h2.shape}")

Step 2 logits 最大差异: 0.000000

Step 2 逐层 hidden states 对比:
  Layer 0: max diff = 0.000000
  Layer 1: max diff = 0.000000
  Layer 2: max diff = 0.000000
  Layer 3: max diff = 0.000000
  Layer 4: max diff = 0.000000
  Layer 5: max diff = 0.000000
  Layer 6: max diff = 0.000000
  Layer 7: max diff = 0.000000
  Layer 8: max diff = 0.000000
  Layer 9: max diff = 0.000000


In [17]:
# ===== 检查 Step 2 的 cos sim（输入变化后）=====
print(f"Step 2 cos sim（输入已变化，不再是纯 MASK）:")

for layer in range(1, min(SKIP_LAYER_K, len(out3_skip.hidden_states))):
    h_prev = prev_hidden_step2[layer]  # step 1 的 hidden states
    h_curr = out3_skip.hidden_states[layer]  # step 2 的 hidden states
    
    if h_prev.shape == h_curr.shape:
        cos_sims = F.cosine_similarity(h_prev, h_curr, dim=-1)
        print(f"  Layer {layer}: cos_sim min={cos_sims.min():.4f}, max={cos_sims.max():.4f}, mean={cos_sims.mean():.4f}")
    else:
        print(f"  Layer {layer}: shape mismatch! {h_prev.shape} vs {h_curr.shape}")

Step 2 cos sim（输入已变化，不再是纯 MASK）:
  Layer 1: cos_sim min=0.8086, max=0.9180, mean=0.8828
  Layer 2: cos_sim min=0.8008, max=0.9336, mean=0.8945
  Layer 3: cos_sim min=0.7969, max=0.9375, mean=0.8945
  Layer 4: cos_sim min=0.7930, max=0.9375, mean=0.8945
  Layer 5: cos_sim min=0.8203, max=0.9336, mean=0.8984
  Layer 6: cos_sim min=0.8242, max=0.9414, mean=0.8906
  Layer 7: cos_sim min=0.8047, max=0.9453, mean=0.8867


In [18]:
# ===== 测试真正的 skip（threshold=0.95）=====
print("=" * 60)
print("测试 threshold=0.95（会真正 skip 一些 token）")
print("=" * 60)

SKIP_THRESHOLD_REAL = 0.95

with torch.no_grad():
    out3_skip_real = model(
        x_skip[:, s:e],
        past_key_values=out2_skip.past_key_values,
        use_cache=True,
        replace_position=replace_position,
        output_hidden_states=True,
        skip_layer_k=SKIP_LAYER_K,
        skip_threshold=SKIP_THRESHOLD_REAL,
        skip_outlier=SKIP_OUTLIER,
        prev_hidden=prev_hidden_step2
    )

diff_real = (out3_baseline.logits - out3_skip_real.logits).abs().max()
print(f"与 Baseline 的 logits 最大差异: {diff_real.item():.6f}")

# 检查 hidden states shape（如果有 skip，可能会不同）
print(f"\nHidden states 数量: baseline={len(out3_baseline.hidden_states)}, skip={len(out3_skip_real.hidden_states)}")
for i in range(min(5, len(out3_skip_real.hidden_states))):
    print(f"  Layer {i}: baseline={out3_baseline.hidden_states[i].shape}, skip={out3_skip_real.hidden_states[i].shape}")

测试 threshold=0.95（会真正 skip 一些 token）
与 Baseline 的 logits 最大差异: 0.000000

Hidden states 数量: baseline=33, skip=33
  Layer 0: baseline=torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096])
  Layer 1: baseline=torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096])
  Layer 2: baseline=torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096])
  Layer 3: baseline=torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096])
  Layer 4: baseline=torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096])


In [19]:
# ===== 汇总 =====
print("=" * 60)
print("调试汇总")
print("=" * 60)
print(f"Step 0 (初始 forward):  diff = 0 (prev_hidden=None)")
print(f"Step 1 (threshold=1):   diff = {diff2.item():.6f}")
print(f"Step 2 (threshold=1):   diff = {diff3.item():.6f}")
print(f"Step 2 (threshold=0.95): diff = {diff_real.item():.6f}")
print()
print("如果 threshold=1 时 diff > 0，说明双 loop 逻辑有 bug")
print("如果 threshold=0.95 时 diff > 0 但输出合理，说明 skip 在工作")

调试汇总
Step 0 (初始 forward):  diff = 0 (prev_hidden=None)
Step 1 (threshold=1):   diff = 0.000000
Step 2 (threshold=1):   diff = 0.000000
Step 2 (threshold=0.95): diff = 0.000000

如果 threshold=1 时 diff > 0，说明双 loop 逻辑有 bug
如果 threshold=0.95 时 diff > 0 但输出合理，说明 skip 在工作


In [20]:
# ===== 完备 Trace：追踪双 loop 逻辑的实际执行 =====
# 用 hook 追踪每一层的 hidden state，找出问题

print("=" * 60)
print("完备 Trace：追踪模型内部执行")
print("=" * 60)

class LayerTracer:
    """用 hook 追踪每一层的 hidden state"""
    def __init__(self, model):
        self.traces = {}  # layer_idx -> hidden_state
        self.hooks = []
        # 给每个 block 注册 forward hook
        for i, block in enumerate(model.model.transformer.blocks):
            hook = block.register_forward_hook(self._make_hook(i))
            self.hooks.append(hook)
    
    def _make_hook(self, layer_idx):
        def hook(module, input, output):
            # output[0] 是 hidden state, output[1] 是 cache
            self.traces[layer_idx] = {
                'input_shape': input[0].shape,
                'output_shape': output[0].shape,
                'output': output[0].clone().detach()
            }
        return hook
    
    def clear(self):
        self.traces = {}
    
    def remove_hooks(self):
        for hook in self.hooks:
            hook.remove()
        self.hooks = []

# 创建 tracer
tracer_baseline = LayerTracer(model)
tracer_skip = LayerTracer(model)

print("Tracer 已创建，开始追踪...")

完备 Trace：追踪模型内部执行
Tracer 已创建，开始追踪...


In [21]:
# ===== 用 Tracer 跑一次 forward，对比每层输入/输出形状 =====
# 关键：如果 skip 生效，Loop 2 的层应该看到不同的 input_shape

# 重新初始化
x_test = x.clone()
prev_hidden_test = out_baseline.hidden_states  # 用第一次的 hidden states

# Baseline
tracer_baseline.clear()
with torch.no_grad():
    out_trace_baseline = model(
        x_test[:, s:e],
        past_key_values=out_baseline.past_key_values,
        use_cache=True,
        replace_position=replace_position,
        output_hidden_states=True
    )

baseline_traces = {k: v.copy() for k, v in tracer_baseline.traces.items()}
print(f"Baseline 追踪完成，收集了 {len(baseline_traces)} 层")

# TokenSkip (threshold=0.95，应该触发 skip)
tracer_skip.clear()
# 注意：需要重新创建 tracer，因为同一个 model
tracer_baseline.remove_hooks()
tracer_skip = LayerTracer(model)

with torch.no_grad():
    out_trace_skip = model(
        x_test[:, s:e],
        past_key_values=out_baseline.past_key_values,
        use_cache=True,
        replace_position=replace_position,
        output_hidden_states=True,
        skip_layer_k=8,
        skip_threshold=0.95,  # 设为 0.95，应该触发 skip
        skip_outlier=0.7,
        prev_hidden=prev_hidden_test
    )

skip_traces = {k: v.copy() for k, v in tracer_skip.traces.items()}
print(f"TokenSkip 追踪完成，收集了 {len(skip_traces)} 层")
tracer_skip.remove_hooks()

Baseline 追踪完成，收集了 32 层
TokenSkip 追踪完成，收集了 32 层


In [22]:
# ===== 逐层对比：找出 skip 是否生效 =====
print("逐层 input/output shape 对比:")
print("如果 skip 生效，Layer 8+ 的 input_shape 应该 < 32（部分 token 被踢出）")
print()

SKIP_LAYER_K_TEST = 8
for layer in range(32):
    b = baseline_traces.get(layer, {})
    sk = skip_traces.get(layer, {})
    
    b_in = b.get('input_shape', 'N/A')
    b_out = b.get('output_shape', 'N/A')
    sk_in = sk.get('input_shape', 'N/A')
    sk_out = sk.get('output_shape', 'N/A')
    
    marker = ""
    if layer == SKIP_LAYER_K_TEST:
        marker = " <-- split_layer (判定点)"
    if sk_in != b_in:
        marker += " *** DIFF ***"
    
    print(f"Layer {layer:2d}: baseline={b_in} -> {b_out}, skip={sk_in} -> {sk_out}{marker}")

逐层 input/output shape 对比:
如果 skip 生效，Layer 8+ 的 input_shape 应该 < 32（部分 token 被踢出）

Layer  0: baseline=torch.Size([1, 32, 4096]) -> torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096]) -> torch.Size([1, 32, 4096])
Layer  1: baseline=torch.Size([1, 32, 4096]) -> torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096]) -> torch.Size([1, 32, 4096])
Layer  2: baseline=torch.Size([1, 32, 4096]) -> torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096]) -> torch.Size([1, 32, 4096])
Layer  3: baseline=torch.Size([1, 32, 4096]) -> torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096]) -> torch.Size([1, 32, 4096])
Layer  4: baseline=torch.Size([1, 32, 4096]) -> torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096]) -> torch.Size([1, 32, 4096])
Layer  5: baseline=torch.Size([1, 32, 4096]) -> torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096]) -> torch.Size([1, 32, 4096])
Layer  6: baseline=torch.Size([1, 32, 4096]) -> torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096]) -

In [23]:
# ===== 手动复现判定逻辑，检查为什么 skip 没生效 =====
print("=" * 60)
print("手动复现判定逻辑")
print("=" * 60)

# 关键：prev_hidden 的形状和当前 block 的对应关系
print(f"prev_hidden[0].shape = {prev_hidden_test[0].shape}")  # 完整序列 [1, 147, 4096]
print(f"当前 block 范围: [{s}, {e})")
print(f"block_length = {e - s}")

# 模型内部判定时，x 的形状是 [1, 32, 4096]（当前 block）
# 但 prev_hidden 的形状是 [1, 147, 4096]（完整序列）
# 问题：模型内部是否正确切片了 prev_hidden？

print()
print("检查 modeling_llada.py 中的判定逻辑:")
print("代码中 prev_hidden[layer] 是完整序列，而 all_hidden_states[layer] 是当前 block")
print("这可能是 bug！需要对 prev_hidden 做切片")

手动复现判定逻辑
prev_hidden[0].shape = torch.Size([1, 147, 4096])
当前 block 范围: [19, 51)
block_length = 32

检查 modeling_llada.py 中的判定逻辑:
代码中 prev_hidden[layer] 是完整序列，而 all_hidden_states[layer] 是当前 block
这可能是 bug！需要对 prev_hidden 做切片


In [24]:
# ===== 正确模拟 generate 流程 =====
print("=" * 60)
print("正确模拟 generate 流程")
print("=" * 60)

# Step 0: 初始 forward（完整序列）- 用于建立 KV cache
x_gen = x.clone()
with torch.no_grad():
    out_init = model(x_gen, use_cache=True, output_hidden_states=True)
past_kv_gen = out_init.past_key_values
print(f"Step 0: 初始 forward, x shape={x_gen.shape}, hidden_states[0] shape={out_init.hidden_states[0].shape}")

# 更新 x（模拟第一次采样）
logits_full = out_init.logits
x0 = logits_full.argmax(dim=-1)
mask_idx = (x_gen == mask_id)
x_gen = torch.where(mask_idx, x0, x_gen)
print(f"更新后 x_gen 中非 MASK token 数量: {(x_gen != mask_id).sum().item()}")

正确模拟 generate 流程
Step 0: 初始 forward, x shape=torch.Size([1, 147]), hidden_states[0] shape=torch.Size([1, 147, 4096])
更新后 x_gen 中非 MASK token 数量: 147


In [25]:
# ===== Refinement Step 1（Baseline vs TokenSkip）=====
# 关键：prev_hidden = None，这一步不会 skip

print("=" * 60)
print("Refinement Step 1")
print("=" * 60)

# 构造 replace_position
replace_position_gen = torch.zeros_like(x_gen, dtype=torch.bool)
replace_position_gen[:, s:e] = True

# Baseline
with torch.no_grad():
    out1_base = model(
        x_gen[:, s:e],
        past_key_values=past_kv_gen,
        use_cache=True,
        replace_position=replace_position_gen,
        output_hidden_states=True
    )

# TokenSkip（prev_hidden=None，不会 skip）
with torch.no_grad():
    out1_skip = model(
        x_gen[:, s:e],
        past_key_values=past_kv_gen,
        use_cache=True,
        replace_position=replace_position_gen,
        output_hidden_states=True,
        skip_layer_k=8,
        skip_threshold=0.95,
        skip_outlier=0.7,
        prev_hidden=None  # 关键：第一次是 None
    )

diff1 = (out1_base.logits - out1_skip.logits).abs().max()
print(f"Step 1: prev_hidden=None, diff={diff1.item():.6f} (应该=0)")
print(f"out1_skip.hidden_states[0].shape = {out1_skip.hidden_states[0].shape}")

# 这个 hidden_states 会作为下一步的 prev_hidden
prev_hidden_step1 = out1_skip.hidden_states
print(f"prev_hidden_step1 有 {len(prev_hidden_step1)} 层, shape={prev_hidden_step1[0].shape}")

Refinement Step 1
Step 1: prev_hidden=None, diff=0.000000 (应该=0)
out1_skip.hidden_states[0].shape = torch.Size([1, 32, 4096])
prev_hidden_step1 有 33 层, shape=torch.Size([1, 32, 4096])


In [26]:
# ===== Refinement Step 2（应该触发 skip）=====
# 关键：prev_hidden 形状是 [1, 32, 4096]（当前 block）

print("=" * 60)
print("Refinement Step 2（skip 应该生效）")
print("=" * 60)

# 更新 x_gen（模拟采样）
pred1 = out1_base.logits.argmax(dim=-1)
mask_blk = (x_gen[:, s:e] == mask_id)
x_gen[:, s:e] = torch.where(mask_blk, pred1, x_gen[:, s:e])
print(f"更新后 x_gen 中非 MASK token 数量: {(x_gen != mask_id).sum().item()}")

# Baseline
with torch.no_grad():
    out2_base = model(
        x_gen[:, s:e],
        past_key_values=out1_base.past_key_values,
        use_cache=True,
        replace_position=replace_position_gen,
        output_hidden_states=True
    )

# TokenSkip（threshold=1，不应该 skip）
with torch.no_grad():
    out2_skip_no = model(
        x_gen[:, s:e],
        past_key_values=out1_skip.past_key_values,
        use_cache=True,
        replace_position=replace_position_gen,
        output_hidden_states=True,
        skip_layer_k=8,
        skip_threshold=1.0,  # threshold=1，不 skip
        skip_outlier=0.7,
        prev_hidden=prev_hidden_step1
    )

# TokenSkip（threshold=0.95，应该 skip）
with torch.no_grad():
    out2_skip_yes = model(
        x_gen[:, s:e],
        past_key_values=out1_skip.past_key_values,
        use_cache=True,
        replace_position=replace_position_gen,
        output_hidden_states=True,
        skip_layer_k=8,
        skip_threshold=0.95,  # threshold=0.95，可能 skip
        skip_outlier=0.7,
        prev_hidden=prev_hidden_step1
    )

diff2_no = (out2_base.logits - out2_skip_no.logits).abs().max()
diff2_yes = (out2_base.logits - out2_skip_yes.logits).abs().max()

print(f"Step 2 (threshold=1):    diff={diff2_no.item():.6f} (应该=0)")
print(f"Step 2 (threshold=0.95): diff={diff2_yes.item():.6f} (如果有 skip，应该>0)")

Refinement Step 2（skip 应该生效）
更新后 x_gen 中非 MASK token 数量: 147
Step 2 (threshold=1):    diff=0.000000 (应该=0)
Step 2 (threshold=0.95): diff=26.000000 (如果有 skip，应该>0)


In [27]:
# ===== 手动计算判定逻辑应该看到的 cos_sim =====
print("=" * 60)
print("手动复现模型内部判定逻辑")
print("=" * 60)

# 模型内部：
# - prev_hidden[layer] 形状 [1, 32, 4096]（上一步的当前 block hidden states）
# - all_hidden_states[layer] 形状 [1, 32, 4096]（当前 block 的 hidden states）

# 我们需要比较 prev_hidden_step1 和 out2_skip_no 的前 K 层 hidden states

print(f"prev_hidden_step1 形状: {prev_hidden_step1[0].shape}")
print(f"out2_skip_no.hidden_states 形状: {out2_skip_no.hidden_states[0].shape}")
print()

SKIP_LAYER_K_SIM = 8
SKIP_THRESHOLD_SIM = 0.95
SKIP_OUTLIER_SIM = 0.7

# 模拟判定
L = 32  # block_length
active_mask = torch.ones(L, dtype=torch.bool)
all_cos_sims = []

for j in range(L):
    cos_sims_j = []
    for layer in range(1, SKIP_LAYER_K_SIM):
        h1 = prev_hidden_step1[layer][0, j, :]  # 上一步的 token j
        h2 = out2_skip_no.hidden_states[layer][0, j, :]  # 当前的 token j
        cos = F.cosine_similarity(h1.unsqueeze(0), h2.unsqueeze(0), dim=-1).item()
        cos_sims_j.append(cos)
    
    all_cos_sims.append(cos_sims_j)
    
    min_cos = min(cos_sims_j)
    mean_cos = sum(cos_sims_j) / len(cos_sims_j)
    stable = min_cos >= SKIP_OUTLIER_SIM and mean_cos >= SKIP_THRESHOLD_SIM
    active_mask[j] = not stable  # 稳定的标记为 False（踢出）

print(f"判定结果:")
print(f"  稳定（会 skip）的 token 数量: {(~active_mask).sum().item()}")
print(f"  不稳定（继续计算）的 token 数量: {active_mask.sum().item()}")
print()

# 打印几个 token 的 cos_sim
print("前 5 个 token 的 cos_sim:")
for j in range(min(5, L)):
    print(f"  Token {j}: {[f'{c:.4f}' for c in all_cos_sims[j]]}, min={min(all_cos_sims[j]):.4f}, mean={sum(all_cos_sims[j])/len(all_cos_sims[j]):.4f}")

手动复现模型内部判定逻辑
prev_hidden_step1 形状: torch.Size([1, 32, 4096])
out2_skip_no.hidden_states 形状: torch.Size([1, 32, 4096])

判定结果:
  稳定（会 skip）的 token 数量: 32
  不稳定（继续计算）的 token 数量: 0

前 5 个 token 的 cos_sim:
  Token 0: ['0.9961', '1.0000', '1.0078', '1.0000', '1.0000', '1.0078', '1.0000'], min=0.9961, mean=1.0017
  Token 1: ['1.0000', '1.0000', '1.0078', '1.0000', '1.0000', '1.0000', '0.9922'], min=0.9922, mean=1.0000
  Token 2: ['1.0078', '1.0000', '0.9961', '0.9961', '1.0000', '0.9922', '1.0000'], min=0.9922, mean=0.9989
  Token 3: ['1.0000', '1.0000', '0.9961', '0.9961', '0.9961', '0.9961', '1.0000'], min=0.9961, mean=0.9978
  Token 4: ['1.0000', '1.0078', '1.0000', '0.9961', '1.0000', '1.0000', '1.0000'], min=0.9961, mean=1.0006


In [28]:
# ===== 统计所有 token 的 cos_sim 分布 =====
print("=" * 60)
print("所有 token 的 cos_sim 统计")
print("=" * 60)

import numpy as np

all_min_cos = [min(c) for c in all_cos_sims]
all_mean_cos = [sum(c)/len(c) for c in all_cos_sims]

print(f"min(cos_sim) 分布:")
print(f"  min={min(all_min_cos):.4f}, max={max(all_min_cos):.4f}, mean={np.mean(all_min_cos):.4f}")
print(f"mean(cos_sim) 分布:")
print(f"  min={min(all_mean_cos):.4f}, max={max(all_mean_cos):.4f}, mean={np.mean(all_mean_cos):.4f}")
print()
print(f"判定条件: min >= {SKIP_OUTLIER_SIM} AND mean >= {SKIP_THRESHOLD_SIM}")
print(f"满足条件的 token 数量: {sum(1 for m, a in zip(all_min_cos, all_mean_cos) if m >= SKIP_OUTLIER_SIM and a >= SKIP_THRESHOLD_SIM)}")
print()

# 如果没有 token 满足条件，说明输入变化太大，cos_sim 下降
if all(a < SKIP_THRESHOLD_SIM for a in all_mean_cos):
    print("结论: 输入变化后，所有 token 的 cos_sim 都低于 threshold，没有 skip 发生")
    print("这是预期行为！因为输入从 [MASK] 变成了预测的 token")
else:
    print(f"有 {sum(1 for a in all_mean_cos if a >= SKIP_THRESHOLD_SIM)} 个 token 的 mean >= threshold")

所有 token 的 cos_sim 统计
min(cos_sim) 分布:
  min=0.9883, max=1.0000, mean=0.9948
mean(cos_sim) 分布:
  min=0.9967, max=1.0045, mean=1.0000

判定条件: min >= 0.7 AND mean >= 0.95
满足条件的 token 数量: 32

有 32 个 token 的 mean >= threshold


In [29]:
# ===== 最终诊断 =====
print("=" * 60)
print("最终诊断")
print("=" * 60)

print("1. 形状对齐检查:")
print(f"   prev_hidden shape: {prev_hidden_step1[0].shape}")
print(f"   current hidden shape: {out2_skip_no.hidden_states[0].shape}")
print(f"   ✓ 形状一致（都是当前 block 的 hidden states）")
print()

print("2. 判定逻辑检查:")
print(f"   threshold=1.0 时 diff: {diff2_no.item():.6f}")
if diff2_no.item() == 0:
    print("   ✓ threshold=1 时完全退化到 baseline（正确）")
else:
    print("   ✗ BUG: threshold=1 时应该与 baseline 一致！")
print()

print(f"3. Skip 触发检查:")
print(f"   threshold=0.95 时 diff: {diff2_yes.item():.6f}")
if diff2_yes.item() == 0:
    print("   cos_sim 统计显示所有 token 都不满足 skip 条件")
    print("   这是因为输入从 [MASK] 变成了预测 token，hidden states 变化较大")
    print("   这是预期行为，不是 bug")
else:
    print("   ✓ 有 token 被 skip，输出与 baseline 不同")
print()

print("4. 建议:")
print("   - 在真实生成场景中，只有当输入变化很小时 skip 才会触发")
print("   - 可以考虑降低 threshold（如 0.8）来增加 skip 率")
print("   - 或者在连续多个 step 输入不变的情况下测试 skip")

最终诊断
1. 形状对齐检查:
   prev_hidden shape: torch.Size([1, 32, 4096])
   current hidden shape: torch.Size([1, 32, 4096])
   ✓ 形状一致（都是当前 block 的 hidden states）

2. 判定逻辑检查:
   threshold=1.0 时 diff: 0.000000
   ✓ threshold=1 时完全退化到 baseline（正确）

3. Skip 触发检查:
   threshold=0.95 时 diff: 26.000000
   ✓ 有 token 被 skip，输出与 baseline 不同

4. 建议:
   - 在真实生成场景中，只有当输入变化很小时 skip 才会触发
   - 可以考虑降低 threshold（如 0.8）来增加 skip 率
   - 或者在连续多个 step 输入不变的情况下测试 skip
