# TokenSkip 分歧追踪 Notebook

**目标**: 对比 `skip_threshold=1` 和 `skip_threshold=0.9999` 的差异

**方法**: 从输出结果往回追溯，逐 step 对比，找出第一个分歧点

**关键问题**: 为什么 threshold=0.9999 时输出有重复/质量下降？

In [1]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from model.modeling_llada import LLaDAModelLM
from generate import get_transfer_index, get_num_transfer_tokens
import numpy as np
from collections import defaultdict

device = 'cuda'
model = LLaDAModelLM.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', torch_dtype=torch.bfloat16).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct')
print(f"模型加载完成，{len(model.model.transformer.blocks)} 层")

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 6/6 [00:00<00:00,  7.83it/s]


模型加载完成，32 层


## 证明：prev_hidden 和 current_hidden 维度不匹配

直接模拟 `generate_with_dual_cache_tokenskip` 的行为，打印每一步的 hidden states 形状。

In [2]:
# 证明 prev_hidden 和 current_hidden 维度不匹配
# 这里直接模拟 generate_with_dual_cache_tokenskip 的行为

MASK_ID = 126336
prompt = "Who is Newton, physics?"
messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
input_ids = torch.tensor(tokenizer(text)['input_ids']).to(device).unsqueeze(0)

GEN_LENGTH = 128
BLOCK_LENGTH = 32
prompt_len = input_ids.shape[1]
total_len = prompt_len + GEN_LENGTH

# 初始化 x
x = torch.full((1, total_len), MASK_ID, dtype=torch.long, device=device)
x[:, :prompt_len] = input_ids

print(f"prompt_len: {prompt_len}")
print(f"total_len: {total_len}")
print(f"x.shape: {x.shape}")
print()

# Block 0 的范围
s = prompt_len  # 19
e = s + BLOCK_LENGTH  # 51
print(f"Block 0 范围: [{s}, {e})")
print()

# ============ Step 0: 完整前向（和 generate_with_dual_cache_tokenskip 一样）============
print("=" * 60)
print("Step 0: 完整前向（模拟 generate_with_dual_cache_tokenskip）")
print("=" * 60)

with torch.no_grad():
    out_full = model(x, use_cache=True, output_hidden_states=True)

past_kv = out_full.past_key_values
prev_hidden = out_full.hidden_states

print(f"模型输入 x.shape: {x.shape}")
print(f"prev_hidden 层数: {len(prev_hidden)}")
print(f"prev_hidden[0].shape: {prev_hidden[0].shape}  <-- 完整序列长度 {total_len}")
print()

# 做初始采样（简化，只更新当前 block）
replace_position = torch.zeros_like(x, dtype=torch.bool)
replace_position[:, s:e] = True

global_mask = (x == MASK_ID)
global_mask[:, e:] = False
x0 = out_full.logits.argmax(dim=-1)
# 简单处理：更新所有 mask 位置
x = torch.where(global_mask, x0, x)

print(f"采样后 x 中 block 0 的 mask 数: {(x[:, s:e] == MASK_ID).sum().item()}")
print()

# ============ Step 1: 只输入 block（和 generate_with_dual_cache_tokenskip 一样）============
print("=" * 60)
print("Step 1: 只输入 block 部分")
print("=" * 60)

with torch.no_grad():
    out_blk = model(
        x[:, s:e],  # 只输入 block 部分
        past_key_values=past_kv,
        use_cache=True,
        replace_position=replace_position,
        output_hidden_states=True,
        skip_layer_k=18,
        skip_threshold=1.0,  # 先用 1.0 不触发 skip
        skip_outlier=0.7,
        prev_hidden=prev_hidden,  # 传入完整序列的 hidden states
    )

current_hidden = out_blk.hidden_states

print(f"模型输入 x[:, s:e].shape: {x[:, s:e].shape}  <-- 只有 block 长度 {BLOCK_LENGTH}")
print(f"current_hidden 层数: {len(current_hidden)}")
print(f"current_hidden[0].shape: {current_hidden[0].shape}  <-- 只有 block 长度 {BLOCK_LENGTH}")
print()

# ============ 关键对比 ============
print("=" * 60)
print("关键对比：prev_hidden vs current_hidden")
print("=" * 60)
print()
print(f"prev_hidden[0].shape:    {prev_hidden[0].shape}")
print(f"current_hidden[0].shape: {current_hidden[0].shape}")
print()
print(f">>> 维度不匹配！{prev_hidden[0].shape[1]} ≠ {current_hidden[0].shape[1]} <<<")
print()

# ============ 这会导致什么问题？============
print("=" * 60)
print("在 Token Skip 判定中会发生什么？")
print("=" * 60)
print()
print("Token Skip 判定代码（modeling_llada.py line 1517-1524）：")
print("  for j in range(L):  # L = x.shape[1] = 32")
print("      h1 = prev_hidden[layer][0, j, :]   # j=0 → 序列位置 0")
print("      h2 = all_hidden_states[layer][0, j, :]  # j=0 → 序列位置 19")
print()
print("具体例子：")
print(f"  当 j=0 时:")
print(f"    h1 = prev_hidden[1][0, 0, :]  → 序列位置 0 (prompt 第一个 token)")
print(f"    h2 = current_hidden[1][0, 0, :] → 序列位置 {s} (block 第一个 token)")
print()
print(f"  这两个 token 完全不同！比较它们的 cos_sim 没有意义！")
print()

# 实际计算一下 cos_sim
h1 = prev_hidden[1][0, 0, :]  # 位置 0
h2 = current_hidden[1][0, 0, :]  # 位置 s (19)
cos_wrong = F.cosine_similarity(h1.unsqueeze(0).float(), h2.unsqueeze(0).float(), dim=-1).item()

# 如果我们用正确的位置呢？
h1_correct = prev_hidden[1][0, s, :]  # 位置 s (19)
h2_correct = current_hidden[1][0, 0, :]  # 也是位置 s (19)
cos_correct = F.cosine_similarity(h1_correct.unsqueeze(0).float(), h2_correct.unsqueeze(0).float(), dim=-1).item()

print(f"错误的 cos_sim（位置 0 vs 位置 {s}）: {cos_wrong:.6f}")
print(f"正确的 cos_sim（位置 {s} vs 位置 {s}）: {cos_correct:.6f}")
print()
print(">>> 结论：当前代码比较的是完全不同的 token！<<<")

prompt_len: 19
total_len: 147
x.shape: torch.Size([1, 147])

Block 0 范围: [19, 51)

Step 0: 完整前向（模拟 generate_with_dual_cache_tokenskip）
模型输入 x.shape: torch.Size([1, 147])
prev_hidden 层数: 33
prev_hidden[0].shape: torch.Size([1, 147, 4096])  <-- 完整序列长度 147

采样后 x 中 block 0 的 mask 数: 0

Step 1: 只输入 block 部分
模型输入 x[:, s:e].shape: torch.Size([1, 32])  <-- 只有 block 长度 32
current_hidden 层数: 33
current_hidden[0].shape: torch.Size([1, 32, 4096])  <-- 只有 block 长度 32

关键对比：prev_hidden vs current_hidden

prev_hidden[0].shape:    torch.Size([1, 147, 4096])
current_hidden[0].shape: torch.Size([1, 32, 4096])

>>> 维度不匹配！147 ≠ 32 <<<

在 Token Skip 判定中会发生什么？

Token Skip 判定代码（modeling_llada.py line 1517-1524）：
  for j in range(L):  # L = x.shape[1] = 32
      h1 = prev_hidden[layer][0, j, :]   # j=0 → 序列位置 0
      h2 = all_hidden_states[layer][0, j, :]  # j=0 → 序列位置 19

具体例子：
  当 j=0 时:
    h1 = prev_hidden[1][0, 0, :]  → 序列位置 0 (prompt 第一个 token)
    h2 = current_hidden[1][0, 0, :] → 序列位置 19 (block 第一个 t

In [6]:
# 准备输入
prompt = "Who is Newton, physics?"
messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
input_ids = torch.tensor(tokenizer(text)['input_ids']).to(device).unsqueeze(0)

MASK_ID = 126336
GEN_LENGTH = 128
BLOCK_LENGTH = 32
STEPS = 32
THRESHOLD = 0.9  # 置信度阈值

# Token Skip 超参
SKIP_LAYER_K = 18
SKIP_OUTLIER = 0.7

prompt_len = input_ids.shape[1]
total_len = prompt_len + GEN_LENGTH
num_blocks = GEN_LENGTH // BLOCK_LENGTH
steps_per_block = STEPS // num_blocks

print(f"prompt_len: {prompt_len}, total_len: {total_len}")
print(f"num_blocks: {num_blocks}, steps_per_block: {steps_per_block}")

prompt_len: 19, total_len: 147
num_blocks: 4, steps_per_block: 8


## 1. 完整运行两个版本，记录每一步的状态

In [3]:
def run_with_trace(model, prompt, skip_threshold, skip_layer_k=18, skip_outlier=0.7):
    """
    运行 generate 并记录每一步的详细状态
    返回: (final_x, trace_log)
    trace_log 包含每个 block 的每个 step 的详细信息
    """
    B = prompt.shape[0]
    Lp = int(prompt.shape[1])
    
    x = torch.full((B, Lp + GEN_LENGTH), MASK_ID, dtype=torch.long, device=model.device)
    x[:, :Lp] = prompt
    
    trace_log = defaultdict(list)  # {block_idx: [step_info, ...]}
    
    for nb in range(num_blocks):
        s = Lp + nb * BLOCK_LENGTH
        e = s + BLOCK_LENGTH
        
        block_mask_index = (x[:, s:e] == MASK_ID)
        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block)
        
        # Step 0: 完整前向
        out_full = model(x, use_cache=True, output_hidden_states=True)
        past_key_values = out_full.past_key_values
        
        replace_position = torch.zeros_like(x, dtype=torch.bool)
        replace_position[:, s:e] = True
        
        global_mask_index = (x == MASK_ID)
        global_mask_index[:, e:] = False
        
        x0, transfer_index = get_transfer_index(
            out_full.logits, 0., "low_confidence", global_mask_index, x, None, THRESHOLD
        )
        
        # 记录 Step 0
        step_info = {
            'step': 0,
            'x_before': x[:, s:e].clone(),
            'x_after': torch.where(transfer_index, x0, x)[:, s:e].clone(),
            'transfer_count': transfer_index[:, s:e].sum().item(),
            'mask_count_before': (x[:, s:e] == MASK_ID).sum().item(),
            'logits_sample': out_full.logits[:, s:e, :100].clone(),  # 只记录前 100 个 vocab
            'prev_hidden': None,
            'active_mask': None,
            'cos_sims': None,
        }
        trace_log[nb].append(step_info)
        
        x = torch.where(transfer_index, x0, x)
        prev_hidden = out_full.hidden_states
        
        # Step 1 ~ N
        for i in range(1, steps_per_block):
            mask_count = (x[:, s:e] == MASK_ID).sum().item()
            if mask_count == 0:
                break
            
            # 调用模型（带 Token Skip）
            out_blk = model(
                x[:, s:e],
                past_key_values=past_key_values,
                use_cache=True,
                replace_position=replace_position,
                output_hidden_states=True,
                skip_layer_k=skip_layer_k,
                skip_threshold=skip_threshold,
                skip_outlier=skip_outlier,
                prev_hidden=prev_hidden,
            )
            logits_blk = out_blk.logits
            
            # === 手动计算 cos_sim 以记录（和模型内部一样的逻辑）===
            cos_sims_all = []
            active_mask_manual = torch.ones(BLOCK_LENGTH, dtype=torch.bool, device=x.device)
            
            if prev_hidden is not None:
                current_hidden = out_blk.hidden_states
                for j in range(BLOCK_LENGTH):
                    cos_sims_j = []
                    for layer in range(1, min(skip_layer_k, len(current_hidden))):
                        # 注意: prev_hidden 和 current_hidden 的形状可能不同！
                        if prev_hidden[layer].shape[1] > j and current_hidden[layer].shape[1] > j:
                            h1 = prev_hidden[layer][0, j, :]
                            h2 = current_hidden[layer][0, j, :]
                            cos = F.cosine_similarity(h1.unsqueeze(0), h2.unsqueeze(0), dim=-1).item()
                            cos = min(1.0, cos)
                            cos_sims_j.append(cos)
                    
                    cos_sims_all.append(cos_sims_j)
                    
                    if len(cos_sims_j) > 0:
                        avg_cos = sum(cos_sims_j) / len(cos_sims_j)
                        min_cos = min(cos_sims_j)
                        if min_cos >= skip_outlier and avg_cos > skip_threshold:
                            active_mask_manual[j] = False  # 稳定，可跳过
            
            # 获取 transfer index
            mask_blk = (x[:, s:e] == MASK_ID)
            x0_blk, transfer_idx_blk = get_transfer_index(
                logits_blk, 0., "low_confidence", mask_blk, x[:, s:e], None, THRESHOLD
            )
            
            # 记录这一步
            step_info = {
                'step': i,
                'x_before': x[:, s:e].clone(),
                'x_after': None,  # 稍后填充
                'transfer_count': transfer_idx_blk.sum().item(),
                'mask_count_before': mask_count,
                'logits_sample': logits_blk[:, :, :100].clone(),
                'prev_hidden_shape': [h.shape for h in prev_hidden] if prev_hidden else None,
                'current_hidden_shape': [h.shape for h in out_blk.hidden_states],
                'active_mask': active_mask_manual.clone(),
                'skip_count': (~active_mask_manual).sum().item(),
                'cos_sims': cos_sims_all,  # 每个 token 的 cos_sim 列表
            }
            
            # 更新 x
            blk_old = x[:, s:e]
            blk_new = torch.where(transfer_idx_blk, x0_blk, blk_old)
            x = torch.cat([x[:, :s], blk_new, x[:, e:]], dim=1)
            
            step_info['x_after'] = x[:, s:e].clone()
            trace_log[nb].append(step_info)
            
            prev_hidden = out_blk.hidden_states
    
    return x, trace_log

In [7]:
# 运行 threshold=1 版本
print("运行 threshold=1 版本...")
with torch.no_grad():
    x_t1, trace_t1 = run_with_trace(model, input_ids, skip_threshold=1.0)

ans_t1 = tokenizer.decode(x_t1[0, prompt_len:], skip_special_tokens=True)
print(f"threshold=1 输出:")
print(ans_t1)
print(f"\n输出长度: {len(x_t1[0, prompt_len:])} tokens")

运行 threshold=1 版本...
threshold=1 输出:
Isaac Newton was an English physicist and mathematician who

输出长度: 128 tokens


In [11]:
# 运行 threshold=0.9999 版本
print("运行 threshold=0.9999 版本...")
with torch.no_grad():
    x_t2, trace_t2 = run_with_trace(model, input_ids, skip_threshold=0.99)

ans_t2 = tokenizer.decode(x_t2[0, prompt_len:], skip_special_tokens=True)
print(f"threshold=0.9999 输出:")
print(ans_t2)
print(f"\n输出长度: {len(x_t2[0, prompt_len:])} tokens")

运行 threshold=0.9999 版本...


: 

## 2. 对比两个版本的最终输出

In [None]:
# Token 级别对比
print("=" * 60)
print("Token 级别对比")
print("=" * 60)

tokens_t1 = x_t1[0, prompt_len:].tolist()
tokens_t2 = x_t2[0, prompt_len:].tolist()

diff_positions = []
for i, (t1, t2) in enumerate(zip(tokens_t1, tokens_t2)):
    if t1 != t2:
        diff_positions.append(i)
        w1 = tokenizer.decode([t1])
        w2 = tokenizer.decode([t2])
        print(f"位置 {i}: threshold=1 -> '{w1}' (id={t1}), threshold=0.9999 -> '{w2}' (id={t2})")

print(f"\n总共 {len(diff_positions)} 个位置不同")
if diff_positions:
    print(f"第一个不同的位置: {diff_positions[0]}")

Token 级别对比

总共 0 个位置不同


## 3. 找出第一个分歧点（从 trace 往回查）

In [None]:
def find_first_divergence(trace_t1, trace_t2):
    """找出第一个分歧的 (block, step)"""
    for nb in range(num_blocks):
        steps_t1 = trace_t1[nb]
        steps_t2 = trace_t2[nb]
        
        max_steps = max(len(steps_t1), len(steps_t2))
        
        for i in range(max_steps):
            if i >= len(steps_t1) or i >= len(steps_t2):
                return nb, i, "step count mismatch"
            
            s1 = steps_t1[i]
            s2 = steps_t2[i]
            
            # 比较 x_after
            if s1['x_after'] is not None and s2['x_after'] is not None:
                if not torch.equal(s1['x_after'], s2['x_after']):
                    return nb, i, "x_after diverged"
            
            # 比较 logits
            if not torch.allclose(s1['logits_sample'], s2['logits_sample'], atol=1e-3):
                return nb, i, "logits diverged"
    
    return None, None, "no divergence found"

div_block, div_step, div_reason = find_first_divergence(trace_t1, trace_t2)
print(f"第一个分歧点: block={div_block}, step={div_step}, 原因={div_reason}")

第一个分歧点: block=None, step=None, 原因=no divergence found


In [None]:
# 详细查看分歧点
if div_block is not None and div_step is not None:
    print("=" * 60)
    print(f"分歧点详情: Block {div_block}, Step {div_step}")
    print("=" * 60)
    
    s1 = trace_t1[div_block][div_step]
    s2 = trace_t2[div_block][div_step]
    
    print(f"\n--- threshold=1 ---")
    print(f"mask_count_before: {s1['mask_count_before']}")
    print(f"transfer_count: {s1['transfer_count']}")
    print(f"skip_count: {s1.get('skip_count', 'N/A')}")
    print(f"active_mask sum: {s1['active_mask'].sum().item() if s1['active_mask'] is not None else 'N/A'}")
    
    print(f"\n--- threshold=0.9999 ---")
    print(f"mask_count_before: {s2['mask_count_before']}")
    print(f"transfer_count: {s2['transfer_count']}")
    print(f"skip_count: {s2.get('skip_count', 'N/A')}")
    print(f"active_mask sum: {s2['active_mask'].sum().item() if s2['active_mask'] is not None else 'N/A'}")
    
    # 对比 x_after
    if s1['x_after'] is not None and s2['x_after'] is not None:
        diff_mask = (s1['x_after'] != s2['x_after'])
        diff_pos = diff_mask.nonzero(as_tuple=True)[1].tolist()
        print(f"\nx_after 不同的位置: {diff_pos}")
        for pos in diff_pos[:5]:  # 只显示前 5 个
            t1 = s1['x_after'][0, pos].item()
            t2 = s2['x_after'][0, pos].item()
            print(f"  位置 {pos}: t1={t1} ('{tokenizer.decode([t1])}'), t2={t2} ('{tokenizer.decode([t2])}')") 

## 4. 分析 cos_sim 分布

In [None]:
# 查看 threshold=0.9999 时的 cos_sim 分布
if div_block is not None and div_step is not None:
    s2 = trace_t2[div_block][div_step]
    
    if s2['cos_sims'] is not None:
        print("=" * 60)
        print(f"Block {div_block}, Step {div_step} 的 cos_sim 分布 (threshold=0.9999)")
        print("=" * 60)
        
        for j, cos_sims_j in enumerate(s2['cos_sims']):
            if len(cos_sims_j) > 0:
                avg = sum(cos_sims_j) / len(cos_sims_j)
                min_cos = min(cos_sims_j)
                max_cos = max(cos_sims_j)
                
                # 判定结果
                skip = (min_cos >= SKIP_OUTLIER and avg > 0.9999)
                status = "SKIP" if skip else "KEEP"
                
                # 只打印被 skip 的或者 cos_sim 接近阈值的
                if skip or avg > 0.99:
                    print(f"Token {j:2d}: avg={avg:.6f}, min={min_cos:.6f}, max={max_cos:.6f} -> {status}")

In [None]:
# 对比两个版本的 cos_sim（如果有）
if div_block is not None and div_step is not None:
    s1 = trace_t1[div_block][div_step]
    s2 = trace_t2[div_block][div_step]
    
    print("=" * 60)
    print("cos_sim 对比 (应该相同，因为输入相同)")
    print("=" * 60)
    
    if s1['cos_sims'] is not None and s2['cos_sims'] is not None:
        for j in range(min(len(s1['cos_sims']), len(s2['cos_sims']))):
            c1 = s1['cos_sims'][j]
            c2 = s2['cos_sims'][j]
            if len(c1) > 0 and len(c2) > 0:
                avg1 = sum(c1) / len(c1)
                avg2 = sum(c2) / len(c2)
                if abs(avg1 - avg2) > 1e-6:
                    print(f"Token {j}: t1_avg={avg1:.6f}, t2_avg={avg2:.6f} <- 不同!")

## 5. 深入检查：prev_hidden 形状问题

In [None]:
# 检查每个 step 的 hidden_states 形状变化
print("=" * 60)
print("threshold=0.9999: hidden_states 形状变化")
print("=" * 60)

for nb in range(num_blocks):
    print(f"\n--- Block {nb} ---")
    for step_info in trace_t2[nb]:
        step = step_info['step']
        prev_shape = step_info.get('prev_hidden_shape')
        curr_shape = step_info.get('current_hidden_shape')
        skip_count = step_info.get('skip_count', 0)
        
        if prev_shape and curr_shape:
            # 只显示 Layer 0 的形状（代表性）
            prev_l0 = prev_shape[0] if len(prev_shape) > 0 else None
            curr_l0 = curr_shape[0] if len(curr_shape) > 0 else None
            
            shape_match = prev_l0 == curr_l0 if prev_l0 and curr_l0 else True
            flag = "" if shape_match else " <- SHAPE MISMATCH!"
            
            print(f"  Step {step}: prev={prev_l0}, curr={curr_l0}, skip_count={skip_count}{flag}")

threshold=0.9999: hidden_states 形状变化

--- Block 0 ---
  Step 1: prev=torch.Size([1, 147, 4096]), curr=torch.Size([1, 32, 4096]), skip_count=0 <- SHAPE MISMATCH!
  Step 2: prev=torch.Size([1, 32, 4096]), curr=torch.Size([1, 32, 4096]), skip_count=0
  Step 3: prev=torch.Size([1, 32, 4096]), curr=torch.Size([1, 32, 4096]), skip_count=0
  Step 4: prev=torch.Size([1, 32, 4096]), curr=torch.Size([1, 32, 4096]), skip_count=0
  Step 5: prev=torch.Size([1, 32, 4096]), curr=torch.Size([1, 32, 4096]), skip_count=0
  Step 6: prev=torch.Size([1, 32, 4096]), curr=torch.Size([1, 32, 4096]), skip_count=0
  Step 7: prev=torch.Size([1, 32, 4096]), curr=torch.Size([1, 32, 4096]), skip_count=0

--- Block 1 ---
  Step 1: prev=torch.Size([1, 147, 4096]), curr=torch.Size([1, 32, 4096]), skip_count=0 <- SHAPE MISMATCH!
  Step 2: prev=torch.Size([1, 32, 4096]), curr=torch.Size([1, 32, 4096]), skip_count=0
  Step 3: prev=torch.Size([1, 32, 4096]), curr=torch.Size([1, 32, 4096]), skip_count=0
  Step 4: prev=torc

## 6. 关键发现总结

In [None]:
print("=" * 60)
print("关键发现总结")
print("=" * 60)

print(f"\n1. 最终输出对比:")
print(f"   - threshold=1 输出长度: {len(tokens_t1)} tokens")
print(f"   - threshold=0.9999 输出长度: {len(tokens_t2)} tokens")
print(f"   - 不同位置数: {len(diff_positions)}")

print(f"\n2. 分歧点:")
print(f"   - Block: {div_block}")
print(f"   - Step: {div_step}")
print(f"   - 原因: {div_reason}")

# 统计 skip 次数
total_skip_t1 = sum(
    s.get('skip_count', 0) 
    for nb in range(num_blocks) 
    for s in trace_t1[nb]
)
total_skip_t2 = sum(
    s.get('skip_count', 0) 
    for nb in range(num_blocks) 
    for s in trace_t2[nb]
)

print(f"\n3. Token Skip 统计:")
print(f"   - threshold=1: 总共跳过 {total_skip_t1} 个 token（应该是 0）")
print(f"   - threshold=0.9999: 总共跳过 {total_skip_t2} 个 token")

关键发现总结

1. 最终输出对比:
   - threshold=1 输出长度: 128 tokens
   - threshold=0.9999 输出长度: 128 tokens
   - 不同位置数: 0

2. 分歧点:
   - Block: None
   - Step: None
   - 原因: no divergence found

3. Token Skip 统计:
   - threshold=1: 总共跳过 0 个 token（应该是 0）
   - threshold=0.9999: 总共跳过 0 个 token


## 7. 手动单步调试（可选）

In [None]:
# 如果需要更细致的调试，可以在这里手动运行单个 step
# 先初始化

x_debug = torch.full((1, prompt_len + GEN_LENGTH), MASK_ID, dtype=torch.long, device=device)
x_debug[:, :prompt_len] = input_ids

# 第一个 block 的范围
s = prompt_len
e = s + BLOCK_LENGTH

print(f"Debug 初始化完成")
print(f"x_debug shape: {x_debug.shape}")
print(f"Block 0 范围: [{s}, {e})")

Debug 初始化完成
x_debug shape: torch.Size([1, 147])
Block 0 范围: [19, 51)


In [None]:
# Step 0: 完整前向（预热 KV cache）
with torch.no_grad():
    out0 = model(x_debug, use_cache=True, output_hidden_states=True)

past_kv = out0.past_key_values
prev_hidden_debug = out0.hidden_states

# 初始采样
global_mask = (x_debug == MASK_ID)
global_mask[:, e:] = False
x0_debug, transfer_debug = get_transfer_index(
    out0.logits, 0., "low_confidence", global_mask, x_debug, None, THRESHOLD
)
x_debug = torch.where(transfer_debug, x0_debug, x_debug)

print(f"Step 0 完成")
print(f"转移了 {transfer_debug.sum().item()} 个 token")
print(f"当前 block 剩余 mask: {(x_debug[:, s:e] == MASK_ID).sum().item()}")
print(f"prev_hidden 层数: {len(prev_hidden_debug)}")
print(f"prev_hidden[0] shape: {prev_hidden_debug[0].shape}")

Step 0 完成
转移了 1 个 token
当前 block 剩余 mask: 31
prev_hidden 层数: 33
prev_hidden[0] shape: torch.Size([1, 147, 4096])


In [None]:
# Step 1: 对比 threshold=1 和 threshold=0.9999
replace_pos = torch.zeros_like(x_debug, dtype=torch.bool)
replace_pos[:, s:e] = True

# threshold=1
with torch.no_grad():
    out1_t1 = model(
        x_debug[:, s:e],
        past_key_values=past_kv,
        use_cache=True,
        replace_position=replace_pos,
        output_hidden_states=True,
        skip_layer_k=SKIP_LAYER_K,
        skip_threshold=1.0,
        skip_outlier=SKIP_OUTLIER,
        prev_hidden=prev_hidden_debug,
    )

# threshold=0.9999
with torch.no_grad():
    out1_t2 = model(
        x_debug[:, s:e],
        past_key_values=past_kv,
        use_cache=True,
        replace_position=replace_pos,
        output_hidden_states=True,
        skip_layer_k=SKIP_LAYER_K,
        skip_threshold=0.9999,
        skip_outlier=SKIP_OUTLIER,
        prev_hidden=prev_hidden_debug,
    )

print(f"Step 1 对比:")
print(f"  threshold=1 logits shape: {out1_t1.logits.shape}")
print(f"  threshold=0.9999 logits shape: {out1_t2.logits.shape}")

logits_diff = (out1_t1.logits - out1_t2.logits).abs().max().item()
print(f"  logits max diff: {logits_diff}")

if logits_diff > 0:
    print(f"\n  >>> logits 不同！Token Skip 产生了影响 <<<")
    
    # 查看 hidden_states 形状
    print(f"\n  hidden_states 形状对比:")
    for i, (h1, h2) in enumerate(zip(out1_t1.hidden_states, out1_t2.hidden_states)):
        if h1.shape != h2.shape:
            print(f"    Layer {i}: t1={h1.shape}, t2={h2.shape} <- 不同!")

Step 1 对比:
  threshold=1 logits shape: torch.Size([1, 32, 126464])
  threshold=0.9999 logits shape: torch.Size([1, 32, 126464])
  logits max diff: 0.0
