# TokenSkip 全量中间变量追踪 Notebook

用 Hook 抓取所有可能的中间变量，深度对比 baseline 和 tokenskip。

In [1]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from model.modeling_llada import LLaDAModelLM
import numpy as np

device = 'cuda'
model = LLaDAModelLM.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', torch_dtype=torch.bfloat16).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct')
print(f"模型加载完成，{len(model.model.transformer.blocks)} 层")

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 6/6 [00:00<00:00,  7.69it/s]


模型加载完成，32 层


## 0. Run All 基础设置
说明：本 notebook 建议直接 Run All。此处集中放全局 import / logging / debug 参数，避免后续单元报错或顺序依赖。

In [2]:
import logging
import warnings
from generate import generate_with_dual_cache, generate_with_dual_cache_tokenskip, get_transfer_index, get_num_transfer_tokens

# 统一屏蔽一个无害警告（避免 Run All 被噪音打断）
warnings.filterwarnings("ignore", message="To copy construct from a tensor")

# logging 仅初始化一次
if not logging.getLogger().handlers:
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s %(levelname)s %(message)s",
    )

# 轻量调试参数（可按需调整）
DEBUG_BLOCKS = 1  # 只追踪前几个 block
DEBUG_STEPS = 8   # 每个 block 只追踪前几个 step


In [3]:
# 准备输入
prompt = "Who is Newton in physics?"
messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
input_ids = torch.tensor(tokenizer(text)['input_ids']).to(device).unsqueeze(0)

MASK_ID = 126336
GEN_LENGTH = 128
BLOCK_LENGTH = 32

prompt_len = input_ids.shape[1]
total_len = prompt_len + GEN_LENGTH

x_init = torch.full((1, total_len), MASK_ID, dtype=torch.long, device=device)
x_init[:, :prompt_len] = input_ids

print(f"prompt_len: {prompt_len}, total_len: {total_len}")

prompt_len: 19, total_len: 147


## 全量 Tracer 类

In [4]:
class FullTracer:
    """全量中间变量追踪器"""
    
    def __init__(self, model):
        self.model = model
        self.hooks = []
        self.traces = {}
        self._register_hooks()
    
    def _register_hooks(self):
        # 1. Embedding 层
        self.hooks.append(
            self.model.model.transformer.wte.register_forward_hook(
                self._make_hook('embedding')
            )
        )
        
        # 2. 每个 Block 的详细追踪
        for i, block in enumerate(self.model.model.transformer.blocks):
            # Block 整体输入输出
            self.hooks.append(
                block.register_forward_hook(self._make_hook(f'block_{i}'))
            )
            
            # Attention 前的 LayerNorm
            self.hooks.append(
                block.attn_norm.register_forward_hook(self._make_hook(f'block_{i}_attn_norm'))
            )
            
            # FFN 前的 LayerNorm
            self.hooks.append(
                block.ff_norm.register_forward_hook(self._make_hook(f'block_{i}_ff_norm'))
            )
            
            # Q/K/V 投影
            self.hooks.append(
                block.q_proj.register_forward_hook(self._make_hook(f'block_{i}_q_proj'))
            )
            self.hooks.append(
                block.k_proj.register_forward_hook(self._make_hook(f'block_{i}_k_proj'))
            )
            self.hooks.append(
                block.v_proj.register_forward_hook(self._make_hook(f'block_{i}_v_proj'))
            )
            
            # Attention 输出投影
            self.hooks.append(
                block.attn_out.register_forward_hook(self._make_hook(f'block_{i}_attn_out'))
            )
            
            # FFN 中间层
            self.hooks.append(
                block.ff_proj.register_forward_hook(self._make_hook(f'block_{i}_ff_proj'))
            )
            self.hooks.append(
                block.up_proj.register_forward_hook(self._make_hook(f'block_{i}_up_proj'))
            )
            
            # FFN 输出投影
            self.hooks.append(
                block.ff_out.register_forward_hook(self._make_hook(f'block_{i}_ffn_out'))
            )
            
            # RoPE
            self.hooks.append(
                block.rotary_emb.register_forward_hook(self._make_hook(f'block_{i}_rope'))
            )
        
        # 3. 最终 LayerNorm
        self.hooks.append(
            self.model.model.transformer.ln_f.register_forward_hook(
                self._make_hook('final_norm')
            )
        )
    
    def _make_hook(self, name):
        def hook(module, input, output):
            inp = input[0] if isinstance(input, tuple) and len(input) > 0 else input
            out = output[0] if isinstance(output, tuple) and len(output) > 0 else output
            
            self.traces[name] = {
                'input_shape': inp.shape if hasattr(inp, 'shape') else None,
                'output_shape': out.shape if hasattr(out, 'shape') else None,
                'input': inp.clone().detach() if hasattr(inp, 'clone') else inp,
                'output': out.clone().detach() if hasattr(out, 'clone') else out,
            }
            
            if hasattr(out, 'float'):
                out_f = out.float()
                self.traces[name]['stats'] = {
                    'mean': out_f.mean().item(),
                    'std': out_f.std().item(),
                    'min': out_f.min().item(),
                    'max': out_f.max().item(),
                    'abs_mean': out_f.abs().mean().item(),
                }
        return hook
    
    def clear(self):
        self.traces = {}
    
    def remove(self):
        for h in self.hooks:
            h.remove()
        self.hooks = []

print("FullTracer 定义完成")

FullTracer 定义完成


## 初始化

In [5]:
# 初始 forward
x = x_init.clone()
with torch.no_grad():
    out_init = model(x, use_cache=True, output_hidden_states=True)

past_kv = out_init.past_key_values

# 初始采样
x0 = out_init.logits.argmax(dim=-1)
mask_pos = (x == MASK_ID)
x = torch.where(mask_pos, x0, x)

print(f"初始 forward 完成，KV cache: {len(past_kv)} 层")
print(f"KV shape: {past_kv[0][0].shape}")

初始 forward 完成，KV cache: 32 层
KV shape: torch.Size([1, 32, 147, 128])


In [6]:
# 定义 block 和超参
s = prompt_len
e = s + BLOCK_LENGTH

replace_position = torch.zeros_like(x, dtype=torch.bool)
replace_position[:, s:e] = True

SKIP_LAYER_K = 8
SKIP_THRESHOLD = 0.95
SKIP_OUTLIER = 0.7

print(f"Block [{s}, {e})")
print(f"超参: K={SKIP_LAYER_K}, threshold={SKIP_THRESHOLD}, outlier={SKIP_OUTLIER}")

Block [19, 51)
超参: K=8, threshold=0.95, outlier=0.7


## Step 0: Baseline vs TokenSkip（prev_hidden=None）

In [7]:
tracer = FullTracer(model)

# Step 0: Baseline
x_baseline = x.clone()
past_kv_baseline = tuple((k.clone(), v.clone()) for k, v in past_kv)

tracer.clear()
with torch.no_grad():
    out0_baseline = model(
        x_baseline[:, s:e],
        past_key_values=past_kv_baseline,
        use_cache=True,
        replace_position=replace_position.clone(),
        output_hidden_states=True
    )
traces_baseline_0 = dict(tracer.traces)
print(f"Baseline Step 0: logits shape {out0_baseline.logits.shape}")

Baseline Step 0: logits shape torch.Size([1, 32, 126464])


In [8]:
# Step 0: TokenSkip（prev_hidden=None）
x_skip = x.clone()
past_kv_skip = tuple((k.clone(), v.clone()) for k, v in past_kv)

tracer.clear()
with torch.no_grad():
    out0_skip = model(
        x_skip[:, s:e],
        past_key_values=past_kv_skip,
        use_cache=True,
        replace_position=replace_position.clone(),
        output_hidden_states=True,
        skip_layer_k=SKIP_LAYER_K,
        skip_threshold=SKIP_THRESHOLD,
        skip_outlier=SKIP_OUTLIER,
        prev_hidden=None
    )
traces_skip_0 = dict(tracer.traces)
print(f"TokenSkip Step 0: logits shape {out0_skip.logits.shape}")

TokenSkip Step 0: logits shape torch.Size([1, 32, 126464])


In [9]:
# 对比 Step 0
diff0_logits = (out0_baseline.logits - out0_skip.logits).abs().max().item()
print(f"Step 0 logits diff: {diff0_logits:.6f}")
print(f"预期: 0（prev_hidden=None 不 skip）")
print(f"结果: {'✓ PASS' if diff0_logits == 0 else '✗ FAIL'}")

Step 0 logits diff: 0.000000
预期: 0（prev_hidden=None 不 skip）
结果: ✓ PASS


In [10]:
# 对比每层 block 输出
print("Step 0 逐层 block 输出对比:")
for i in range(32):
    key = f'block_{i}'
    if key in traces_baseline_0 and key in traces_skip_0:
        b_out = traces_baseline_0[key]['output']
        s_out = traces_skip_0[key]['output']
        if b_out.shape == s_out.shape:
            diff = (b_out - s_out).abs().max().item()
            if diff > 0:
                print(f"  Layer {i}: diff={diff:.6f} ***")
        else:
            print(f"  Layer {i}: SHAPE MISMATCH {b_out.shape} vs {s_out.shape}")

Step 0 逐层 block 输出对比:


In [11]:
# 采样更新 x
pred0 = out0_baseline.logits.argmax(dim=-1)
mask_blk = (x_baseline[:, s:e] == MASK_ID)
x_baseline[:, s:e] = torch.where(mask_blk, pred0, x_baseline[:, s:e])
x_skip[:, s:e] = torch.where(mask_blk, pred0, x_skip[:, s:e])

prev_hidden = out0_skip.hidden_states
print(f"prev_hidden: {len(prev_hidden)} 层")

prev_hidden: 33 层


## Step 1: 关键对比（有 prev_hidden）

In [12]:
# Step 1: Baseline
tracer.clear()
with torch.no_grad():
    out1_baseline = model(
        x_baseline[:, s:e],
        past_key_values=out0_baseline.past_key_values,
        use_cache=True,
        replace_position=replace_position.clone(),
        output_hidden_states=True
    )
traces_baseline_1 = dict(tracer.traces)
print(f"Baseline Step 1 完成")

Baseline Step 1 完成


In [13]:
# Step 1: TokenSkip threshold=1.0（不应该 skip）
tracer.clear()
with torch.no_grad():
    out1_skip_no = model(
        x_skip[:, s:e],
        past_key_values=out0_skip.past_key_values,
        use_cache=True,
        replace_position=replace_position.clone(),
        output_hidden_states=True,
        skip_layer_k=SKIP_LAYER_K,
        skip_threshold=1.0,
        skip_outlier=SKIP_OUTLIER,
        prev_hidden=prev_hidden
    )
traces_skip_no_1 = dict(tracer.traces)

diff1_no = (out1_baseline.logits - out1_skip_no.logits).abs().max().item()
print(f"threshold=1.0 vs Baseline logits diff: {diff1_no:.6f}")
print(f"结果: {'✓ PASS' if diff1_no == 0 else '✗ FAIL'}")

threshold=1.0 vs Baseline logits diff: 0.000000
结果: ✓ PASS


In [14]:
# 如果 threshold=1.0 有 diff，逐层检查
if diff1_no > 0:
    print("\nthreshold=1.0 逐层对比:")
    for i in range(32):
        key = f'block_{i}'
        if key in traces_baseline_1 and key in traces_skip_no_1:
            b_out = traces_baseline_1[key]['output']
            s_out = traces_skip_no_1[key]['output']
            if b_out.shape == s_out.shape:
                diff = (b_out - s_out).abs().max().item()
                if diff > 0:
                    print(f"  Layer {i}: diff={diff:.6f}")
            else:
                print(f"  Layer {i}: SHAPE MISMATCH {b_out.shape} vs {s_out.shape}")

In [15]:
# Step 1: TokenSkip threshold=0.95（可能 skip）
tracer.clear()
with torch.no_grad():
    out1_skip_yes = model(
        x_skip[:, s:e],
        past_key_values=out0_skip.past_key_values,
        use_cache=True,
        replace_position=replace_position.clone(),
        output_hidden_states=True,
        skip_layer_k=SKIP_LAYER_K,
        skip_threshold=SKIP_THRESHOLD,
        skip_outlier=SKIP_OUTLIER,
        prev_hidden=prev_hidden
    )
traces_skip_yes_1 = dict(tracer.traces)

diff1_yes = (out1_baseline.logits - out1_skip_yes.logits).abs().max().item()
print(f"threshold={SKIP_THRESHOLD} vs Baseline logits diff: {diff1_yes:.6f}")

threshold=0.95 vs Baseline logits diff: 24.875000


## 逐层形状对比

In [16]:
print("逐层输入形状对比（Baseline vs TokenSkip threshold=0.95）:")
print(f"Layer 0-{SKIP_LAYER_K-1}: Loop 1, Layer {SKIP_LAYER_K}-31: Loop 2")
print("-" * 80)

for i in range(32):
    key = f'block_{i}'
    b = traces_baseline_1.get(key, {})
    sk = traces_skip_yes_1.get(key, {})
    
    b_in = b.get('input_shape', 'N/A')
    sk_in = sk.get('input_shape', 'N/A')
    
    marker = " <-- split" if i == SKIP_LAYER_K else ""
    if b_in != sk_in:
        marker += " *** DIFF ***"
    
    print(f"Layer {i:2d}: baseline={b_in}, skip={sk_in}{marker}")

逐层输入形状对比（Baseline vs TokenSkip threshold=0.95）:
Layer 0-7: Loop 1, Layer 8-31: Loop 2
--------------------------------------------------------------------------------
Layer  0: baseline=torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096])
Layer  1: baseline=torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096])
Layer  2: baseline=torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096])
Layer  3: baseline=torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096])
Layer  4: baseline=torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096])
Layer  5: baseline=torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096])
Layer  6: baseline=torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096])
Layer  7: baseline=torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096])
Layer  8: baseline=torch.Size([1, 32, 4096]), skip=N/A <-- split *** DIFF ***
Layer  9: baseline=torch.Size([1, 32, 4096]), skip=N/A *** DIFF ***
Layer 10: baseline=torch.Size([1, 32, 4096]), skip=N/A *** DIFF ***
Lay

## Q/K/V 投影对比

In [17]:
print("Q/K/V 投影对比（关键层）:")
print("-" * 80)

for i in [0, SKIP_LAYER_K-1, SKIP_LAYER_K, SKIP_LAYER_K+1, 31]:
    print(f"\nLayer {i}:")
    for proj in ['q_proj', 'k_proj', 'v_proj']:
        key = f'block_{i}_{proj}'
        b = traces_baseline_1.get(key, {})
        sk = traces_skip_yes_1.get(key, {})
        
        b_out = b.get('output_shape', 'N/A')
        sk_out = sk.get('output_shape', 'N/A')
        
        diff_str = ""
        if 'output' in b and 'output' in sk:
            bo, so = b['output'], sk['output']
            if bo.shape == so.shape:
                diff_str = f", diff={( bo - so).abs().max().item():.6f}"
        
        print(f"  {proj}: baseline={b_out}, skip={sk_out}{diff_str}")

Q/K/V 投影对比（关键层）:
--------------------------------------------------------------------------------

Layer 0:
  q_proj: baseline=torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096]), diff=0.000000
  k_proj: baseline=torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096]), diff=0.000000
  v_proj: baseline=torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096]), diff=0.000000

Layer 7:
  q_proj: baseline=torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096]), diff=0.000000
  k_proj: baseline=torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096]), diff=0.000000
  v_proj: baseline=torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096]), diff=0.000000

Layer 8:
  q_proj: baseline=torch.Size([1, 32, 4096]), skip=N/A
  k_proj: baseline=torch.Size([1, 32, 4096]), skip=N/A
  v_proj: baseline=torch.Size([1, 32, 4096]), skip=N/A

Layer 9:
  q_proj: baseline=torch.Size([1, 32, 4096]), skip=N/A
  k_proj: baseline=torch.Size([1, 32, 4096]), skip=N/A
  v_proj: baseline=torch.Size([1, 3

## RoPE 对比

In [18]:
print("RoPE 对比（关键层）:")
print("-" * 80)

for i in [0, SKIP_LAYER_K-1, SKIP_LAYER_K, SKIP_LAYER_K+1, 31]:
    key = f'block_{i}_rope'
    b = traces_baseline_1.get(key, {})
    sk = traces_skip_yes_1.get(key, {})
    
    # RoPE 输出是 (q, k) tuple
    b_out = b.get('output_shape', 'N/A')
    sk_out = sk.get('output_shape', 'N/A')
    
    print(f"Layer {i}: baseline_out={b_out}, skip_out={sk_out}")
    
    # 详细看 q, k 形状
    if 'output' in b:
        bo = b['output']
        if isinstance(bo, tuple):
            print(f"  baseline q={bo[0].shape if hasattr(bo[0], 'shape') else 'N/A'}, k={bo[1].shape if len(bo)>1 and hasattr(bo[1], 'shape') else 'N/A'}")
    if 'output' in sk:
        so = sk['output']
        if isinstance(so, tuple):
            print(f"  skip     q={so[0].shape if hasattr(so[0], 'shape') else 'N/A'}, k={so[1].shape if len(so)>1 and hasattr(so[1], 'shape') else 'N/A'}")

RoPE 对比（关键层）:
--------------------------------------------------------------------------------
Layer 0: baseline_out=torch.Size([1, 32, 32, 128]), skip_out=torch.Size([1, 32, 32, 128])
Layer 7: baseline_out=torch.Size([1, 32, 32, 128]), skip_out=torch.Size([1, 32, 32, 128])
Layer 8: baseline_out=torch.Size([1, 32, 32, 128]), skip_out=N/A
Layer 9: baseline_out=torch.Size([1, 32, 32, 128]), skip_out=N/A
Layer 31: baseline_out=torch.Size([1, 32, 32, 128]), skip_out=N/A


## Attention 输出对比

In [19]:
print("Attention 输出对比（关键层）:")
print("-" * 80)

for i in [0, SKIP_LAYER_K-1, SKIP_LAYER_K, SKIP_LAYER_K+1, 31]:
    key = f'block_{i}_attn_out'
    b = traces_baseline_1.get(key, {})
    sk = traces_skip_yes_1.get(key, {})
    
    print(f"\nLayer {i}:")
    print(f"  baseline: out={b.get('output_shape')}")
    print(f"  skip:     out={sk.get('output_shape')}")
    
    if 'output' in b and 'output' in sk:
        bo, so = b['output'], sk['output']
        if bo.shape == so.shape:
            print(f"  diff: {(bo - so).abs().max().item():.6f}")

Attention 输出对比（关键层）:
--------------------------------------------------------------------------------

Layer 0:
  baseline: out=torch.Size([1, 32, 4096])
  skip:     out=torch.Size([1, 32, 4096])
  diff: 0.000000

Layer 7:
  baseline: out=torch.Size([1, 32, 4096])
  skip:     out=torch.Size([1, 32, 4096])
  diff: 0.000000

Layer 8:
  baseline: out=torch.Size([1, 32, 4096])
  skip:     out=None

Layer 9:
  baseline: out=torch.Size([1, 32, 4096])
  skip:     out=None

Layer 31:
  baseline: out=torch.Size([1, 32, 4096])
  skip:     out=None


## FFN 对比

In [20]:
print("FFN 对比（关键层）:")
print("-" * 80)

for i in [0, SKIP_LAYER_K-1, SKIP_LAYER_K, SKIP_LAYER_K+1, 31]:
    print(f"\nLayer {i}:")
    for proj in ['ff_proj', 'up_proj', 'ffn_out']:
        key = f'block_{i}_{proj}'
        b = traces_baseline_1.get(key, {})
        sk = traces_skip_yes_1.get(key, {})
        
        b_out = b.get('output_shape', 'N/A')
        sk_out = sk.get('output_shape', 'N/A')
        
        diff_str = ""
        if 'output' in b and 'output' in sk:
            bo, so = b['output'], sk['output']
            if bo.shape == so.shape:
                diff_str = f", diff={( bo - so).abs().max().item():.6f}"
        
        print(f"  {proj:8s}: baseline={b_out}, skip={sk_out}{diff_str}")

FFN 对比（关键层）:
--------------------------------------------------------------------------------

Layer 0:
  ff_proj : baseline=torch.Size([1, 32, 12288]), skip=torch.Size([1, 32, 12288]), diff=0.000000
  up_proj : baseline=torch.Size([1, 32, 12288]), skip=torch.Size([1, 32, 12288]), diff=0.000000
  ffn_out : baseline=torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096]), diff=0.000000

Layer 7:
  ff_proj : baseline=torch.Size([1, 32, 12288]), skip=torch.Size([1, 32, 12288]), diff=0.000000
  up_proj : baseline=torch.Size([1, 32, 12288]), skip=torch.Size([1, 32, 12288]), diff=0.000000
  ffn_out : baseline=torch.Size([1, 32, 4096]), skip=torch.Size([1, 32, 4096]), diff=0.000000

Layer 8:
  ff_proj : baseline=torch.Size([1, 32, 12288]), skip=N/A
  up_proj : baseline=torch.Size([1, 32, 12288]), skip=N/A
  ffn_out : baseline=torch.Size([1, 32, 4096]), skip=N/A

Layer 9:
  ff_proj : baseline=torch.Size([1, 32, 12288]), skip=N/A
  up_proj : baseline=torch.Size([1, 32, 12288]), skip=N/A
  ffn

## 手动计算 cos_sim 判定

In [21]:
L = BLOCK_LENGTH
all_cos_sims = []

for j in range(L):
    cos_sims_j = []
    for layer in range(1, SKIP_LAYER_K):
        h1 = prev_hidden[layer][0, j, :]
        h2 = out1_skip_no.hidden_states[layer][0, j, :]
        cos = F.cosine_similarity(h1.unsqueeze(0), h2.unsqueeze(0), dim=-1).item()
        cos = min(1.0, cos)
        cos_sims_j.append(cos)
    all_cos_sims.append(cos_sims_j)

all_min = [min(c) for c in all_cos_sims]
all_mean = [sum(c)/len(c) for c in all_cos_sims]

print(f"cos_sim 统计（每个 token 的 layer 1-{SKIP_LAYER_K-1}）:")
print(f"  token min cos: {min(all_min):.6f} ~ {max(all_min):.6f}")
print(f"  token mean cos: {min(all_mean):.6f} ~ {max(all_mean):.6f}")

cos_sim 统计（每个 token 的 layer 1-7）:
  token min cos: 0.992188 ~ 1.000000
  token mean cos: 0.997210 ~ 1.000000


In [22]:
# 判定
active_mask = []
for j in range(L):
    min_cos = min(all_cos_sims[j])
    mean_cos = sum(all_cos_sims[j]) / len(all_cos_sims[j])
    stable = min_cos >= SKIP_OUTLIER and mean_cos > SKIP_THRESHOLD
    active_mask.append(not stable)

num_active = sum(active_mask)
num_stable = L - num_active
active_indices = [i for i, a in enumerate(active_mask) if a]
stable_indices = [i for i, a in enumerate(active_mask) if not a]

print(f"判定结果:")
print(f"  Active: {num_active}, Stable: {num_stable}")
print(f"  Active: {active_indices[:15]}{'...' if len(active_indices) > 15 else ''}")
print(f"  Stable: {stable_indices[:15]}{'...' if len(stable_indices) > 15 else ''}")

判定结果:
  Active: 0, Stable: 32
  Active: []
  Stable: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]...


In [23]:
# 每个 token 详细
print("每个 token 的 cos_sim:")
for j in range(L):
    min_cos = min(all_cos_sims[j])
    mean_cos = sum(all_cos_sims[j]) / len(all_cos_sims[j])
    status = "stable" if not active_mask[j] else "active"
    print(f"  Token {j:2d}: min={min_cos:.6f}, mean={mean_cos:.6f} -> {status}")

每个 token 的 cos_sim:
  Token  0: min=0.996094, mean=0.998326 -> stable
  Token  1: min=0.992188, mean=0.998326 -> stable
  Token  2: min=0.992188, mean=0.998326 -> stable
  Token  3: min=0.996094, mean=0.999442 -> stable
  Token  4: min=0.992188, mean=0.998326 -> stable
  Token  5: min=0.996094, mean=0.998884 -> stable
  Token  6: min=0.992188, mean=0.997210 -> stable
  Token  7: min=1.000000, mean=1.000000 -> stable
  Token  8: min=0.992188, mean=0.997210 -> stable
  Token  9: min=0.992188, mean=0.997210 -> stable
  Token 10: min=0.992188, mean=0.998326 -> stable
  Token 11: min=1.000000, mean=1.000000 -> stable
  Token 12: min=1.000000, mean=1.000000 -> stable
  Token 13: min=0.996094, mean=0.998884 -> stable
  Token 14: min=0.992188, mean=0.997768 -> stable
  Token 15: min=0.996094, mean=0.998884 -> stable
  Token 16: min=0.992188, mean=0.997210 -> stable
  Token 17: min=0.996094, mean=0.998326 -> stable
  Token 18: min=0.996094, mean=0.999442 -> stable
  Token 19: min=0.992188, mean

## KV Cache 对比

In [24]:
print("KV Cache 对比（关键层）:")
print("-" * 80)

kv_base = out1_baseline.past_key_values
kv_skip = out1_skip_yes.past_key_values

for i in [0, SKIP_LAYER_K-1, SKIP_LAYER_K, SKIP_LAYER_K+1, 31]:
    k_base, v_base = kv_base[i]
    k_skip, v_skip = kv_skip[i]
    
    k_diff = (k_base - k_skip).abs().max().item()
    v_diff = (v_base - v_skip).abs().max().item()
    
    print(f"Layer {i:2d}: K_diff={k_diff:.6f}, V_diff={v_diff:.6f}")
    
    if k_diff > 0:
        diff_per_pos = (k_base - k_skip).abs().max(dim=1)[0].max(dim=-1)[0][0]
        diff_positions = (diff_per_pos > 0).nonzero(as_tuple=True)[0].tolist()
        print(f"         K diff positions: {diff_positions[:10]}{'...' if len(diff_positions) > 10 else ''}")

KV Cache 对比（关键层）:
--------------------------------------------------------------------------------
Layer  0: K_diff=0.000000, V_diff=0.000000
Layer  7: K_diff=0.000000, V_diff=0.000000
Layer  8: K_diff=0.000000, V_diff=0.000000
Layer  9: K_diff=0.000000, V_diff=0.000000
Layer 31: K_diff=0.000000, V_diff=0.000000


## Hidden States 逐 token 对比

In [25]:
h_base = out1_baseline.hidden_states[-1][0]
h_skip = out1_skip_yes.hidden_states[-1][0]

print(f"最后一层 hidden states:")
print(f"  Baseline: {h_base.shape}")
print(f"  Skip: {h_skip.shape}")

if h_base.shape == h_skip.shape:
    print("\n逐 token diff:")
    for j in range(min(L, h_base.shape[0])):
        diff = (h_base[j] - h_skip[j]).abs().max().item()
        status = "stable" if not active_mask[j] else "active"
        marker = " ***" if diff > 0.01 else ""
        print(f"  Token {j:2d} ({status:6s}): diff={diff:.6f}{marker}")

最后一层 hidden states:
  Baseline: torch.Size([32, 4096])
  Skip: torch.Size([32, 4096])

逐 token diff:
  Token  0 (stable): diff=179.000000 ***
  Token  1 (stable): diff=156.000000 ***
  Token  2 (stable): diff=164.000000 ***
  Token  3 (stable): diff=185.000000 ***
  Token  4 (stable): diff=190.000000 ***
  Token  5 (stable): diff=189.000000 ***
  Token  6 (stable): diff=191.000000 ***
  Token  7 (stable): diff=181.000000 ***
  Token  8 (stable): diff=184.000000 ***
  Token  9 (stable): diff=182.000000 ***
  Token 10 (stable): diff=184.000000 ***
  Token 11 (stable): diff=180.000000 ***
  Token 12 (stable): diff=166.000000 ***
  Token 13 (stable): diff=183.000000 ***
  Token 14 (stable): diff=179.000000 ***
  Token 15 (stable): diff=180.000000 ***
  Token 16 (stable): diff=172.000000 ***
  Token 17 (stable): diff=172.000000 ***
  Token 18 (stable): diff=178.000000 ***
  Token 19 (stable): diff=171.000000 ***
  Token 20 (stable): diff=172.000000 ***
  Token 21 (stable): diff=173.000000 *

## Logits 逐 token 对比

In [26]:
logits_base = out1_baseline.logits[0]
logits_skip = out1_skip_yes.logits[0]

print("Logits 逐 token 对比:")
for j in range(min(L, logits_base.shape[0])):
    diff = (logits_base[j] - logits_skip[j]).abs().max().item()
    status = "stable" if not active_mask[j] else "active"
    
    pred_base = logits_base[j].argmax().item()
    pred_skip = logits_skip[j].argmax().item()
    pred_match = "✓" if pred_base == pred_skip else "✗"
    
    print(f"  Token {j:2d} ({status:6s}): diff={diff:.6f}, pred={pred_match}")

Logits 逐 token 对比:
  Token  0 (stable): diff=19.250000, pred=✗
  Token  1 (stable): diff=17.375000, pred=✗
  Token  2 (stable): diff=19.000000, pred=✗
  Token  3 (stable): diff=20.750000, pred=✗
  Token  4 (stable): diff=20.250000, pred=✗
  Token  5 (stable): diff=22.875000, pred=✗
  Token  6 (stable): diff=22.500000, pred=✗
  Token  7 (stable): diff=22.125000, pred=✗
  Token  8 (stable): diff=23.000000, pred=✗
  Token  9 (stable): diff=20.625000, pred=✗
  Token 10 (stable): diff=18.750000, pred=✗
  Token 11 (stable): diff=19.750000, pred=✗
  Token 12 (stable): diff=20.875000, pred=✗
  Token 13 (stable): diff=20.000000, pred=✗
  Token 14 (stable): diff=21.000000, pred=✗
  Token 15 (stable): diff=23.500000, pred=✗
  Token 16 (stable): diff=19.875000, pred=✗
  Token 17 (stable): diff=20.750000, pred=✗
  Token 18 (stable): diff=22.375000, pred=✗
  Token 19 (stable): diff=20.375000, pred=✗
  Token 20 (stable): diff=20.750000, pred=✗
  Token 21 (stable): diff=20.875000, pred=✗
  Token 22 (s

## 清理 & 总结

In [None]:
tracer.remove()

print("="*60)
print("总结")
print("="*60)
print(f"\nStep 0 (prev_hidden=None):")
print(f"  logits diff: {diff0_logits:.6f} {'✓' if diff0_logits == 0 else '✗'}")
print(f"\nStep 1 (threshold=1.0):")
print(f"  logits diff: {diff1_no:.6f} {'✓' if diff1_no == 0 else '✗'}")
print(f"\nStep 1 (threshold={SKIP_THRESHOLD}):")
print(f"  Active: {num_active}, Stable: {num_stable}")
print(f"  logits diff: {diff1_yes:.6f}")
print("\n关键检查:")
print("  1. Step 0 diff = 0 ✓")
print("  2. threshold=1.0 diff = 0 -> 退化正确")
print("  3. threshold<1 有 skip -> 正常有差异")

总结

Step 0 (prev_hidden=None):
  logits diff: 0.000000 ✓

Step 1 (threshold=1.0):
  logits diff: 0.000000 ✓

Step 1 (threshold=0.95):
  Active: 0, Stable: 32
  logits diff: 24.875000

关键检查:
  1. Step 0 diff = 0 ✓
  2. threshold=1.0 diff = 0 -> 退化正确
  3. threshold<1 有 skip -> 正常有差异


## 修复诊断单元（替换上面报错/警告）
如果上面的单元报 NameError 或 warning，请改运行下面这两格。

In [None]:
from generate import generate_with_dual_cache, generate_with_dual_cache_tokenskip

# 重新跑：Baseline vs TokenSkip 的 mask 残留与长度
MASK_ID = safe_get("MASK_ID", 126336)
GEN_LENGTH = safe_get("GEN_LENGTH", 128)
BLOCK_LENGTH = safe_get("BLOCK_LENGTH", 32)
STEPS = safe_get("STEPS", 128)

if "input_ids" not in globals():
    raise RuntimeError("input_ids 未定义，请先运行上面的输入准备单元")

prompt_len = int(input_ids.shape[1])

out_base, nfe_base = generate_with_dual_cache(
    model,
    input_ids,
    steps=STEPS,
    gen_length=GEN_LENGTH,
    block_length=BLOCK_LENGTH,
    threshold=0.9,
)

out_skip, nfe_skip = generate_with_dual_cache_tokenskip(
    model,
    input_ids,
    steps=STEPS,
    gen_length=GEN_LENGTH,
    block_length=BLOCK_LENGTH,
    threshold=0.9,
    skip_layer_k=safe_get("SKIP_LAYER_K", 8),
    skip_threshold=safe_get("SKIP_THRESHOLD", 0.95),
    skip_outlier=safe_get("SKIP_OUTLIER", 0.7),
)

base_gen = out_base[0, prompt_len:]
skip_gen = out_skip[0, prompt_len:]

base_mask_res = count_mask_residual(base_gen, MASK_ID)
skip_mask_res = count_mask_residual(skip_gen, MASK_ID)

base_tokens_len = int((base_gen != MASK_ID).sum().item())
skip_tokens_len = int((skip_gen != MASK_ID).sum().item())

base_blocks = block_mask_stats(out_base[0], prompt_len, GEN_LENGTH, BLOCK_LENGTH, MASK_ID)
skip_blocks = block_mask_stats(out_skip[0], prompt_len, GEN_LENGTH, BLOCK_LENGTH, MASK_ID)

logging.info("Baseline: nfe=%d, mask_residual=%d, gen_tokens=%d", nfe_base, base_mask_res, base_tokens_len)
logging.info("TokenSkip: nfe=%d, mask_residual=%d, gen_tokens=%d", nfe_skip, skip_mask_res, skip_tokens_len)
logging.info("Baseline block mask residuals: %s", base_blocks)
logging.info("TokenSkip block mask residuals: %s", skip_blocks)

if skip_mask_res > 0:
    logging.warning("TokenSkip 存在 mask 残留，可能导致 decode 变短")

# 修复 warning：active_mask 兼容 tensor/list
if "x_skip" in globals() and "s" in globals() and "e" in globals() and "active_mask" in globals():
    mask_blk = (x_skip[:, s:e] == MASK_ID)[0]
    if isinstance(active_mask, torch.Tensor):
        stable_mask = ~active_mask
    else:
        stable_mask = ~torch.tensor(active_mask, device=x_skip.device)

    if stable_mask.all() and mask_blk.any():
        logging.error("检测到：全部稳定但仍有 mask，后续层可能被完全跳过")
    else:
        logging.info("稳定/Mask 检查完成，未触发危险状态")

: 

## 诊断：Mask 残留与稳定判定
用于定位 tokenskip 导致生成长度变短的问题。

In [29]:
import logging

def setup_logger(level=logging.INFO):
    logging.basicConfig(
        level=level,
        format="%(asctime)s %(levelname)s %(message)s",
    )


def count_mask_residual(tokens: torch.Tensor, mask_id: int) -> int:
    return int((tokens == mask_id).sum().item())


def block_mask_stats(tokens: torch.Tensor, prompt_len: int, gen_length: int, block_length: int, mask_id: int):
    stats = []
    num_blocks = gen_length // block_length
    for nb in range(num_blocks):
        s = prompt_len + nb * block_length
        e = s + block_length
        block = tokens[s:e]
        stats.append(count_mask_residual(block, mask_id))
    return stats


def safe_get(name, default=None):
    return globals().get(name, default)


setup_logger()

In [32]:
# 对比 Baseline vs TokenSkip 的 mask 残留与长度
MASK_ID = safe_get("MASK_ID", 126336)
GEN_LENGTH = safe_get("GEN_LENGTH", 128)
BLOCK_LENGTH = safe_get("BLOCK_LENGTH", 32)
STEPS = safe_get("STEPS", 128)

if "input_ids" not in globals():
    raise RuntimeError("input_ids 未定义，请先运行上面的输入准备单元")

prompt_len = int(input_ids.shape[1])

out_base, nfe_base = generate_with_dual_cache(
    model,
    input_ids,
    steps=STEPS,
    gen_length=GEN_LENGTH,
    block_length=BLOCK_LENGTH,
    threshold=0.9,
)

out_skip, nfe_skip = generate_with_dual_cache_tokenskip(
    model,
    input_ids,
    steps=STEPS,
    gen_length=GEN_LENGTH,
    block_length=BLOCK_LENGTH,
    threshold=0.9,
    skip_layer_k=safe_get("SKIP_LAYER_K", 8),
    skip_threshold=safe_get("SKIP_THRESHOLD", 0.95),
    skip_outlier=safe_get("SKIP_OUTLIER", 0.7),
)

base_gen = out_base[0, prompt_len:]
skip_gen = out_skip[0, prompt_len:]

base_mask_res = count_mask_residual(base_gen, MASK_ID)
skip_mask_res = count_mask_residual(skip_gen, MASK_ID)

base_tokens_len = int((base_gen != MASK_ID).sum().item())
skip_tokens_len = int((skip_gen != MASK_ID).sum().item())

base_blocks = block_mask_stats(out_base[0], prompt_len, GEN_LENGTH, BLOCK_LENGTH, MASK_ID)
skip_blocks = block_mask_stats(out_skip[0], prompt_len, GEN_LENGTH, BLOCK_LENGTH, MASK_ID)

logging.info("Baseline: nfe=%d, mask_residual=%d, gen_tokens=%d", nfe_base, base_mask_res, base_tokens_len)
logging.info("TokenSkip: nfe=%d, mask_residual=%d, gen_tokens=%d", nfe_skip, skip_mask_res, skip_tokens_len)
logging.info("Baseline block mask residuals: %s", base_blocks)
logging.info("TokenSkip block mask residuals: %s", skip_blocks)

if skip_mask_res > 0:
    logging.warning("TokenSkip 存在 mask 残留，可能导致 decode 变短")

: 

In [33]:
# 检查：稳定判定是否包含 mask 位置
if "prev_hidden" not in globals() or "out1_skip_no" not in globals():
    logging.warning("prev_hidden 或 out1_skip_no 未定义，请先运行 Step 0/1 的单元")
else:
    skip_layer_k = safe_get("SKIP_LAYER_K", 8)
    skip_threshold = safe_get("SKIP_THRESHOLD", 0.95)
    skip_outlier = safe_get("SKIP_OUTLIER", 0.7)

    cur_hidden = out1_skip_no.hidden_states  # 当前 step 的 hidden
    L = cur_hidden[0].shape[1]

    # 复现模型内的稳定判定
    active_mask = []
    for j in range(L):
        cos_sims = []
        for layer in range(1, min(skip_layer_k, len(cur_hidden))):
            h1 = prev_hidden[layer][0, j, :]
            h2 = cur_hidden[layer][0, j, :]
            cos = F.cosine_similarity(h1.unsqueeze(0), h2.unsqueeze(0), dim=-1).item()
            cos = min(1.0, cos)
            cos_sims.append(cos)
        stable = len(cos_sims) > 0 and min(cos_sims) >= skip_outlier and sum(cos_sims) / len(cos_sims) > skip_threshold
        active_mask.append(not stable)

    active_mask = torch.tensor(active_mask, device=cur_hidden[0].device)
    stable_mask = ~active_mask

    if "x_skip" not in globals() or "s" not in globals() or "e" not in globals():
        logging.warning("x_skip 或 s/e 未定义，无法检查 mask 交集")
    else:
        mask_blk = (x_skip[:, s:e] == MASK_ID)[0]
        stable_and_mask = stable_mask & mask_blk
        logging.info("mask positions in block: %d", int(mask_blk.sum().item()))
        logging.info("stable positions in block: %d", int(stable_mask.sum().item()))
        logging.info("stable AND mask positions: %d", int(stable_and_mask.sum().item()))
        if stable_and_mask.any():
            idx = stable_and_mask.nonzero(as_tuple=True)[0].tolist()
            logging.warning("存在 mask 被判定为稳定: indices=%s", idx[:20])

2026-01-23 13:41:59,730 INFO mask positions in block: 0
2026-01-23 13:41:59,731 INFO stable positions in block: 32
2026-01-23 13:41:59,732 INFO stable AND mask positions: 0


In [34]:
# 检查：是否出现“全部稳定但仍有 mask”的危险状态
if "x_skip" in globals() and "s" in globals() and "e" in globals():
    mask_blk = (x_skip[:, s:e] == MASK_ID)[0]
    if "active_mask" in globals():
        try:
            stable_mask = ~torch.tensor(active_mask, device=x_skip.device)
        except Exception:
            stable_mask = None
    else:
        stable_mask = None

    if stable_mask is not None:
        if stable_mask.all() and mask_blk.any():
            logging.error("检测到：全部稳定但仍有 mask，后续层可能被完全跳过")
    else:
        logging.warning("active_mask 未定义，跳过该检查")

## 诊断补充：逐 step 追踪 mask / stable 交集
定位“mask 被判稳定导致后半层跳过”的具体 step。

In [35]:
# 逐 step 追踪：mask 数量 / stable 数量 / stable∩mask
# 只跑前 DEBUG_BLOCKS 个 block、每个 block 前 DEBUG_STEPS 个 step，避免太重

if "x_init" not in globals():
    raise RuntimeError("x_init 未定义，请先运行上面的输入准备单元")

MASK_ID = safe_get("MASK_ID", 126336)
GEN_LENGTH = safe_get("GEN_LENGTH", 128)
BLOCK_LENGTH = safe_get("BLOCK_LENGTH", 32)
STEPS = safe_get("STEPS", 128)
SKIP_LAYER_K = safe_get("SKIP_LAYER_K", 8)
SKIP_THRESHOLD = safe_get("SKIP_THRESHOLD", 0.95)
SKIP_OUTLIER = safe_get("SKIP_OUTLIER", 0.7)

x = x_init.clone()
num_blocks = GEN_LENGTH // BLOCK_LENGTH
steps_per_block = STEPS // num_blocks

max_stable_mask = 0
first_hit = None

for nb in range(min(num_blocks, DEBUG_BLOCKS)):
    s = prompt_len + nb * BLOCK_LENGTH
    e = s + BLOCK_LENGTH

    # 当前 block 的 mask 数
    block_mask_index = (x[:, s:e] == MASK_ID)
    num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block)

    # Step 0: 全序列 forward
    out_full = model(x, use_cache=True)
    past_key_values = out_full.past_key_values

    replace_position = torch.zeros_like(x, dtype=torch.bool)
    replace_position[:, s:e] = True

    global_mask_index = (x == MASK_ID)
    global_mask_index[:, e:] = False

    quota0 = None  # 使用 threshold 模式
    x0, transfer_index = get_transfer_index(
        out_full.logits,
        temperature=0.0,
        remasking="low_confidence",
        mask_index=global_mask_index,
        x=x,
        num_transfer_tokens=quota0,
        threshold=0.9,
    )
    x = torch.where(transfer_index, x0, x)

    prev_hidden = None

    for i in range(1, min(steps_per_block, DEBUG_STEPS)):
        mask_blk = (x[:, s:e] == MASK_ID)
        mask_cnt = int(mask_blk.sum().item())
        if mask_cnt == 0:
            logging.info("Block %d Step %d: mask=0, break", nb, i)
            break

        # 为了避免触发模型内部 skip（可能导致 position_ids 形状不匹配），
        # 这里强制关闭 skip，仅用于生成隐藏状态并在外部判定稳定。
        out_blk = model(
            x[:, s:e],
            past_key_values=past_key_values,
            use_cache=True,
            replace_position=replace_position,
            output_hidden_states=True,
            skip_layer_k=None,
            prev_hidden=None,
        )

        # 复现稳定判定（和模型逻辑一致）
        if prev_hidden is not None:
            cur_hidden = out_blk.hidden_states
            L = cur_hidden[0].shape[1]
            stable_mask = torch.ones(L, dtype=torch.bool, device=cur_hidden[0].device)
            for j in range(L):
                cos_sims = []
                for layer in range(1, min(SKIP_LAYER_K, len(cur_hidden))):
                    h1 = prev_hidden[layer][0, j, :]
                    h2 = cur_hidden[layer][0, j, :]
                    cos = F.cosine_similarity(h1.unsqueeze(0), h2.unsqueeze(0), dim=-1).item()
                    cos = min(1.0, cos)
                    cos_sims.append(cos)
                if len(cos_sims) > 0 and min(cos_sims) >= SKIP_OUTLIER and sum(cos_sims) / len(cos_sims) > SKIP_THRESHOLD:
                    # 稳定
                    continue
                stable_mask[j] = False

            stable_cnt = int(stable_mask.sum().item())
            stable_and_mask = stable_mask & mask_blk[0]
            stable_and_mask_cnt = int(stable_and_mask.sum().item())
            logging.info(
                "Block %d Step %d: mask=%d, stable=%d, stable&mask=%d",
                nb, i, mask_cnt, stable_cnt, stable_and_mask_cnt
            )
            if stable_and_mask_cnt > 0:
                idx = stable_and_mask.nonzero(as_tuple=True)[0].tolist()
                logging.warning("Block %d Step %d: mask 被判稳定，indices=%s", nb, i, idx[:20])
                max_stable_mask = max(max_stable_mask, stable_and_mask_cnt)
                if first_hit is None:
                    first_hit = (nb, i, stable_and_mask_cnt)
        else:
            logging.info("Block %d Step %d: prev_hidden=None (不判定稳定)", nb, i)

        logits_blk = out_blk.logits
        x0_blk, transfer_idx_blk = get_transfer_index(
            logits_blk,
            temperature=0.0,
            remasking="low_confidence",
            mask_index=mask_blk,
            x=x[:, s:e],
            num_transfer_tokens=None,
            threshold=0.9,
        )

        blk_old = x[:, s:e]
        blk_new = torch.where(transfer_idx_blk, x0_blk, blk_old)
        x = torch.cat([x[:, :s], blk_new, x[:, e:]], dim=1)
        prev_hidden = out_blk.hidden_states

    logging.info("Block %d: step-wise 追踪结束", nb)

if first_hit is not None:
    nb, i, cnt = first_hit
    logging.error("结论: 发现 mask 被判稳定（首个命中: block=%d, step=%d, stable&mask=%d）", nb, i, cnt)
    logging.error("推断: skip 判定未排除 mask，导致后半层跳过 -> 生成长度变短")
else:
    logging.info("结论: 未发现 mask 被判稳定（本次 DEBUG 范围内）")


2026-01-23 13:42:05,153 INFO Block 0 Step 1: prev_hidden=None (不判定稳定)
2026-01-23 13:42:05,364 INFO Block 0 Step 2: mask=30, stable=30, stable&mask=29
2026-01-23 13:42:05,499 INFO Block 0 Step 3: mask=29, stable=30, stable&mask=29
2026-01-23 13:42:05,659 INFO Block 0 Step 4: mask=28, stable=31, stable&mask=28
2026-01-23 13:42:05,811 INFO Block 0 Step 5: mask=27, stable=30, stable&mask=26
2026-01-23 13:42:05,954 INFO Block 0 Step 6: mask=26, stable=31, stable&mask=26
2026-01-23 13:42:06,110 INFO Block 0 Step 7: mask=25, stable=31, stable&mask=25
2026-01-23 13:42:06,115 INFO Block 0: step-wise 追踪结束
