# Token Skip 数据流追踪

这个 notebook 用于追踪 Token Skip 过程中的所有关键张量维度和数值。

In [1]:
import torch
import time
import sys
sys.path.insert(0, '.')

from transformers import AutoTokenizer
from model.modeling_llada import LLaDAModelLM
from generate import generate_with_dual_cache, generate_with_dual_cache_tokenskip
from trace_dataflow import Tracer, set_tracer, get_tracer

print("Imports done.")

  from .autonotebook import tqdm as notebook_tqdm


Imports done.


In [2]:
# 加载模型
device = 'cuda'
model = LLaDAModelLM.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', torch_dtype=torch.bfloat16).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct')
print("Model loaded.")

Loading checkpoint shards: 100%|██████████| 6/6 [00:00<00:00,  9.66it/s]


Model loaded.


In [3]:
# 准备输入
prompt = "Who is Newton, physics?"
m = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
input_ids = torch.tensor(tokenizer(text)['input_ids']).to(device).unsqueeze(0)
print(f"Input shape: {input_ids.shape}")
print(f"Input tokens: {tokenizer.decode(input_ids[0])}")

Input shape: torch.Size([1, 19])
Input tokens: <|startoftext|><|start_header_id|>user<|end_header_id|>

Who is Newton, physics?<|eot_id|><|start_header_id|>assistant<|end_header_id|>




## 1. Baseline 追踪 (threshold=1, 无 skip)

In [4]:
# 创建追踪器并启用
tracer_baseline = Tracer(enabled=True, max_entries=10000)
set_tracer(tracer_baseline)

# 测试 baseline (threshold=1 意味着不会 skip)
SKIP_LAYER_K = 16
SKIP_THRESHOLD = 1.0  # 不会触发 skip
SKIP_OUTLIER = 0.7

print("Running baseline with tracing...")
start = time.time()
out1, nfe1 = generate_with_dual_cache_tokenskip(
    model, input_ids, steps=32, gen_length=32, block_length=32, threshold=0.9,
    skip_layer_k=SKIP_LAYER_K, skip_threshold=SKIP_THRESHOLD, skip_outlier=SKIP_OUTLIER
)
t1 = time.time() - start
ans1 = tokenizer.decode(out1[0, input_ids.shape[1]:], skip_special_tokens=True)
print(f"Baseline: {t1:.2f}s, NFE={nfe1}")
print(f"Output: {ans1}")
print(f"Total trace entries: {len(tracer_baseline.entries)}")

Running baseline with tracing...
Baseline: 12.49s, NFE=30
Output: 
Total trace entries: 9786


In [5]:
# 查看追踪摘要
print(tracer_baseline.summary())

Total entries: 9786
Events:
  ATTN_ENTRY: 960
  ATTN_AFTER_RESHAPE: 960
  ROPE_FWD_ENTRY: 960
  ROPE_BASELINE_PARAMS: 960
  ROPE_BASELINE_RESULT: 960
  KV_CACHE_BEFORE: 928
  KV_CACHE_UPDATE: 928
  KV_CACHE_AFTER_REPLACE: 928
  ROPE_REPLACE_BEFORE: 928
  ROPE_REPLACE_AFTER: 928
  SKIP_COS_SIM_DETAIL: 168
  ROPE_BASELINE_BEFORE: 32
  ROPE_BASELINE_AFTER: 32
  MODEL_FWD_ENTRY: 30
  MODEL_PREV_HIDDEN_INFO: 28
  SKIP_JUDGE_START: 28
  SKIP_JUDGE_RESULT: 28


In [6]:
# 转换为 DataFrame
df_baseline = tracer_baseline.to_dataframe()
print(f"DataFrame shape: {df_baseline.shape}")
df_baseline.head(20)

DataFrame shape: (9786, 140)


Unnamed: 0,timestamp,event,step,layer,block_idx,extra,shape_x,val_x_dtype,val_x_min,val_x_max,...,val_h2_dtype,val_h2_min,val_h2_max,val_h2_mean,val_cos_sim,shape_active_mask,val_num_stable,val_num_active,val_active_mask_dtype,val_cos_sims_all
0,2026-01-23T14:38:21.956634,MODEL_FWD_ENTRY,-1,-1,-1,,"[1, 51, 4096]",torch.bfloat16,-2.5,4.4375,...,,,,,,,,,,
1,2026-01-23T14:38:22.077663,ATTN_ENTRY,-1,-1,-1,,,,,,...,,,,,,,,,,
2,2026-01-23T14:38:22.078330,ATTN_AFTER_RESHAPE,-1,-1,-1,,,,,,...,,,,,,,,,,
3,2026-01-23T14:38:22.078709,ROPE_BASELINE_BEFORE,-1,-1,-1,,,,,,...,,,,,,,,,,
4,2026-01-23T14:38:22.080530,ROPE_FWD_ENTRY,-1,-1,-1,,,,,,...,,,,,,,,,,
5,2026-01-23T14:38:22.081367,ROPE_BASELINE_PARAMS,-1,-1,-1,,,,,,...,,,,,,,,,,
6,2026-01-23T14:38:22.093421,ROPE_BASELINE_RESULT,-1,-1,-1,,,,,,...,,,,,,,,,,
7,2026-01-23T14:38:22.093840,ROPE_BASELINE_AFTER,-1,-1,-1,,,,,,...,,,,,,,,,,
8,2026-01-23T14:38:22.660691,ATTN_ENTRY,-1,-1,-1,,,,,,...,,,,,,,,,,
9,2026-01-23T14:38:22.661841,ATTN_AFTER_RESHAPE,-1,-1,-1,,,,,,...,,,,,,,,,,


In [7]:
# 查看所有事件类型
print("Event types:")
print(df_baseline['event'].value_counts())

Event types:
event
ATTN_ENTRY                960
ROPE_BASELINE_RESULT      960
ATTN_AFTER_RESHAPE        960
ROPE_FWD_ENTRY            960
ROPE_BASELINE_PARAMS      960
ROPE_REPLACE_BEFORE       928
KV_CACHE_BEFORE           928
KV_CACHE_AFTER_REPLACE    928
KV_CACHE_UPDATE           928
ROPE_REPLACE_AFTER        928
SKIP_COS_SIM_DETAIL       168
ROPE_BASELINE_AFTER        32
ROPE_BASELINE_BEFORE       32
MODEL_FWD_ENTRY            30
MODEL_PREV_HIDDEN_INFO     28
SKIP_JUDGE_START           28
SKIP_JUDGE_RESULT          28
Name: count, dtype: int64


In [8]:
# 查看 SKIP_JUDGE 相关事件
df_skip_judge = df_baseline[df_baseline['event'].str.contains('SKIP_JUDGE')]
print(f"SKIP_JUDGE events: {len(df_skip_judge)}")
df_skip_judge

SKIP_JUDGE events: 56


Unnamed: 0,timestamp,event,step,layer,block_idx,extra,shape_x,val_x_dtype,val_x_min,val_x_max,...,val_h2_dtype,val_h2_min,val_h2_max,val_h2_mean,val_cos_sim,shape_active_mask,val_num_stable,val_num_active,val_active_mask_dtype,val_cos_sims_all
708,2026-01-23T14:38:23.319765,SKIP_JUDGE_START,-1,-1,-1,,,,,,...,,,,,,,,,,
715,2026-01-23T14:38:23.444939,SKIP_JUDGE_RESULT,-1,-1,-1,,,,,,...,,,,,,[32],0.0,32.0,torch.bool,"[{'j': 0, 'avg': 0.999, 'min': 0.9922}, {'j': ..."
1038,2026-01-23T14:38:23.699300,SKIP_JUDGE_START,-1,-1,-1,,,,,,...,,,,,,,,,,
1045,2026-01-23T14:38:23.804598,SKIP_JUDGE_RESULT,-1,-1,-1,,,,,,...,,,,,,[32],0.0,32.0,torch.bool,"[{'j': 0, 'avg': 0.9982, 'min': 0.9922}, {'j':..."
1368,2026-01-23T14:38:24.053313,SKIP_JUDGE_START,-1,-1,-1,,,,,,...,,,,,,,,,,
1375,2026-01-23T14:38:24.137656,SKIP_JUDGE_RESULT,-1,-1,-1,,,,,,...,,,,,,[32],0.0,32.0,torch.bool,"[{'j': 0, 'avg': 0.9997, 'min': 0.9961}, {'j':..."
1698,2026-01-23T14:38:24.376896,SKIP_JUDGE_START,-1,-1,-1,,,,,,...,,,,,,,,,,
1705,2026-01-23T14:38:24.477406,SKIP_JUDGE_RESULT,-1,-1,-1,,,,,,...,,,,,,[32],0.0,32.0,torch.bool,"[{'j': 0, 'avg': 0.9995, 'min': 0.9961}, {'j':..."
2028,2026-01-23T14:38:24.721489,SKIP_JUDGE_START,-1,-1,-1,,,,,,...,,,,,,,,,,
2035,2026-01-23T14:38:24.797991,SKIP_JUDGE_RESULT,-1,-1,-1,,,,,,...,,,,,,[32],0.0,32.0,torch.bool,"[{'j': 0, 'avg': 0.9987, 'min': 0.9961}, {'j':..."


## 2. Token Skip 追踪 (threshold=0.98, 触发 skip)

In [9]:
# 创建新的追踪器
tracer_skip = Tracer(enabled=True, max_entries=50000)
set_tracer(tracer_skip)

# 测试 token skip (threshold=0.98 应该触发 skip)
SKIP_LAYER_K = 16
SKIP_THRESHOLD = 0.98
SKIP_OUTLIER = 0.7

print("Running token skip with tracing...")
start = time.time()
try:
    out2, nfe2 = generate_with_dual_cache_tokenskip(
        model, input_ids, steps=32, gen_length=32, block_length=32, threshold=0.9,
        skip_layer_k=SKIP_LAYER_K, skip_threshold=SKIP_THRESHOLD, skip_outlier=SKIP_OUTLIER
    )
    t2 = time.time() - start
    ans2 = tokenizer.decode(out2[0, input_ids.shape[1]:], skip_special_tokens=True)
    print(f"TokenSkip: {t2:.2f}s, NFE={nfe2}")
    print(f"Output: {ans2}")
except Exception as e:
    print(f"ERROR: {e}")
    import traceback
    traceback.print_exc()

print(f"Total trace entries: {len(tracer_skip.entries)}")

Running token skip with tracing...
TokenSkip: 4.91s, NFE=16
Output: awaited**“blockListblockList
Total trace entries: 5866


In [10]:
# 查看追踪摘要
print(tracer_skip.summary())

Total entries: 5866
Events:
  ATTN_ENTRY: 512
  ATTN_AFTER_RESHAPE: 512
  ROPE_FWD_ENTRY: 512
  KV_CACHE_BEFORE: 480
  KV_CACHE_UPDATE: 480
  KV_CACHE_AFTER_REPLACE: 480
  ROPE_BASELINE_PARAMS: 288
  ROPE_BASELINE_RESULT: 288
  ROPE_REPLACE_BEFORE: 256
  ROPE_REPLACE_AFTER: 256
  ATTN_POSITION_IDS: 224
  ROPE_BEFORE_TOKENSKIP: 224
  ROPE_TOKENSKIP_PARAMS: 224
  ROPE_INDEX_SELECT: 224
  ROPE_TOKENSKIP_RESULT: 224
  ROPE_AFTER_TOKENSKIP: 224
  ROPE_SKIP: 224
  SKIP_COS_SIM_DETAIL: 84
  ROPE_BASELINE_BEFORE: 32
  ROPE_BASELINE_AFTER: 32
  MODEL_FWD_ENTRY: 16
  MODEL_PREV_HIDDEN_INFO: 14
  SKIP_JUDGE_START: 14
  SKIP_JUDGE_RESULT: 14
  SKIP_PARTIAL: 14
  SKIP_POSITION_UPDATE: 14


In [11]:
# 转换为 DataFrame
df_skip = tracer_skip.to_dataframe()
print(f"DataFrame shape: {df_skip.shape}")
df_skip.head(30)

DataFrame shape: (5866, 200)


Unnamed: 0,timestamp,event,step,layer,block_idx,extra,shape_x,val_x_dtype,val_x_min,val_x_max,...,val_q_after_dtype,val_q_after_min,val_q_after_max,val_q_after_mean,val_k_after_dtype,val_k_after_min,val_k_after_max,val_k_after_mean,val_replace_indices,val_reason
0,2026-01-23T14:38:34.484404,MODEL_FWD_ENTRY,-1,-1,-1,,"[1, 51, 4096]",torch.bfloat16,-2.5,4.4375,...,,,,,,,,,,
1,2026-01-23T14:38:34.485399,ATTN_ENTRY,-1,-1,-1,,,,,,...,,,,,,,,,,
2,2026-01-23T14:38:34.485980,ATTN_AFTER_RESHAPE,-1,-1,-1,,,,,,...,,,,,,,,,,
3,2026-01-23T14:38:34.486330,ROPE_BASELINE_BEFORE,-1,-1,-1,,,,,,...,,,,,,,,,,
4,2026-01-23T14:38:34.486802,ROPE_FWD_ENTRY,-1,-1,-1,,,,,,...,,,,,,,,,,
5,2026-01-23T14:38:34.487080,ROPE_BASELINE_PARAMS,-1,-1,-1,,,,,,...,,,,,,,,,,
6,2026-01-23T14:38:34.488518,ROPE_BASELINE_RESULT,-1,-1,-1,,,,,,...,,,,,,,,,,
7,2026-01-23T14:38:34.488866,ROPE_BASELINE_AFTER,-1,-1,-1,,,,,,...,,,,,,,,,,
8,2026-01-23T14:38:34.490467,ATTN_ENTRY,-1,-1,-1,,,,,,...,,,,,,,,,,
9,2026-01-23T14:38:34.491016,ATTN_AFTER_RESHAPE,-1,-1,-1,,,,,,...,,,,,,,,,,


In [12]:
# 查看所有事件类型
print("Event types:")
print(df_skip['event'].value_counts())

Event types:
event
ATTN_ENTRY                512
ATTN_AFTER_RESHAPE        512
ROPE_FWD_ENTRY            512
KV_CACHE_BEFORE           480
KV_CACHE_UPDATE           480
KV_CACHE_AFTER_REPLACE    480
ROPE_BASELINE_PARAMS      288
ROPE_BASELINE_RESULT      288
ROPE_REPLACE_BEFORE       256
ROPE_REPLACE_AFTER        256
ROPE_SKIP                 224
ROPE_TOKENSKIP_RESULT     224
ROPE_BEFORE_TOKENSKIP     224
ROPE_AFTER_TOKENSKIP      224
ROPE_INDEX_SELECT         224
ROPE_TOKENSKIP_PARAMS     224
ATTN_POSITION_IDS         224
SKIP_COS_SIM_DETAIL        84
ROPE_BASELINE_BEFORE       32
ROPE_BASELINE_AFTER        32
MODEL_FWD_ENTRY            16
MODEL_PREV_HIDDEN_INFO     14
SKIP_PARTIAL               14
SKIP_JUDGE_RESULT          14
SKIP_JUDGE_START           14
SKIP_POSITION_UPDATE       14
Name: count, dtype: int64


## 3. 关键事件分析

In [13]:
# 查看 SKIP_JUDGE_RESULT 事件（Token Skip 判定结果）
df_judge = df_skip[df_skip['event'] == 'SKIP_JUDGE_RESULT']
print(f"SKIP_JUDGE_RESULT events: {len(df_judge)}")
df_judge[['event', 'step', 'val_num_stable', 'val_num_active']].head(20)

SKIP_JUDGE_RESULT events: 14


Unnamed: 0,event,step,val_num_stable,val_num_active
715,SKIP_JUDGE_RESULT,-1,22.0,10.0
1095,SKIP_JUDGE_RESULT,-1,21.0,11.0
1475,SKIP_JUDGE_RESULT,-1,22.0,10.0
1855,SKIP_JUDGE_RESULT,-1,26.0,6.0
2235,SKIP_JUDGE_RESULT,-1,20.0,12.0
2615,SKIP_JUDGE_RESULT,-1,20.0,12.0
2995,SKIP_JUDGE_RESULT,-1,25.0,7.0
3375,SKIP_JUDGE_RESULT,-1,28.0,4.0
3755,SKIP_JUDGE_RESULT,-1,25.0,7.0
4135,SKIP_JUDGE_RESULT,-1,22.0,10.0


In [14]:
# 查看 SKIP_PARTIAL 事件（部分 skip）
df_partial = df_skip[df_skip['event'] == 'SKIP_PARTIAL']
print(f"SKIP_PARTIAL events: {len(df_partial)}")
if len(df_partial) > 0:
    print(df_partial[['event', 'step', 'shape_active_indices', 'shape_x_after_select']].head(10))

SKIP_PARTIAL events: 14
             event  step shape_active_indices shape_x_after_select
716   SKIP_PARTIAL    -1                 [10]        [1, 10, 4096]
1096  SKIP_PARTIAL    -1                 [11]        [1, 11, 4096]
1476  SKIP_PARTIAL    -1                 [10]        [1, 10, 4096]
1856  SKIP_PARTIAL    -1                  [6]         [1, 6, 4096]
2236  SKIP_PARTIAL    -1                 [12]        [1, 12, 4096]
2616  SKIP_PARTIAL    -1                 [12]        [1, 12, 4096]
2996  SKIP_PARTIAL    -1                  [7]         [1, 7, 4096]
3376  SKIP_PARTIAL    -1                  [4]         [1, 4, 4096]
3756  SKIP_PARTIAL    -1                  [7]         [1, 7, 4096]
4136  SKIP_PARTIAL    -1                 [10]        [1, 10, 4096]


In [15]:
# 查看 ROPE 相关事件
df_rope = df_skip[df_skip['event'].str.contains('ROPE')]
print(f"ROPE events: {len(df_rope)}")
print(df_rope['event'].value_counts())

ROPE events: 3008
event
ROPE_FWD_ENTRY           512
ROPE_BASELINE_RESULT     288
ROPE_BASELINE_PARAMS     288
ROPE_REPLACE_AFTER       256
ROPE_REPLACE_BEFORE      256
ROPE_INDEX_SELECT        224
ROPE_TOKENSKIP_RESULT    224
ROPE_BEFORE_TOKENSKIP    224
ROPE_TOKENSKIP_PARAMS    224
ROPE_AFTER_TOKENSKIP     224
ROPE_SKIP                224
ROPE_BASELINE_AFTER       32
ROPE_BASELINE_BEFORE      32
Name: count, dtype: int64


In [16]:
# 查看 ROPE_TOKENSKIP 事件
df_rope_ts = df_skip[df_skip['event'].str.contains('ROPE_TOKENSKIP')]
print(f"ROPE_TOKENSKIP events: {len(df_rope_ts)}")
if len(df_rope_ts) > 0:
    display_cols = [c for c in df_rope_ts.columns if 'shape' in c or c in ['event', 'step', 'layer']]
    print(df_rope_ts[display_cols].head(10))

ROPE_TOKENSKIP events: 448
                     event  step  layer shape_x shape_q_in shape_k_in  \
724  ROPE_TOKENSKIP_PARAMS    -1     -1     NaN        NaN        NaN   
726  ROPE_TOKENSKIP_RESULT    -1     -1     NaN        NaN        NaN   
737  ROPE_TOKENSKIP_PARAMS    -1     -1     NaN        NaN        NaN   
739  ROPE_TOKENSKIP_RESULT    -1     -1     NaN        NaN        NaN   
750  ROPE_TOKENSKIP_PARAMS    -1     -1     NaN        NaN        NaN   
752  ROPE_TOKENSKIP_RESULT    -1     -1     NaN        NaN        NaN   
763  ROPE_TOKENSKIP_PARAMS    -1     -1     NaN        NaN        NaN   
765  ROPE_TOKENSKIP_RESULT    -1     -1     NaN        NaN        NaN   
776  ROPE_TOKENSKIP_PARAMS    -1     -1     NaN        NaN        NaN   
778  ROPE_TOKENSKIP_RESULT    -1     -1     NaN        NaN        NaN   

    shape_v_in shape_q shape_k shape_v  ... shape_active_global  \
724        NaN     NaN     NaN     NaN  ...                 NaN   
726        NaN     NaN     NaN     

In [17]:
# 查看 KV_CACHE 相关事件
df_cache = df_skip[df_skip['event'].str.contains('KV_CACHE')]
print(f"KV_CACHE events: {len(df_cache)}")
print(df_cache['event'].value_counts())

KV_CACHE events: 1440
event
KV_CACHE_BEFORE           480
KV_CACHE_UPDATE           480
KV_CACHE_AFTER_REPLACE    480
Name: count, dtype: int64


In [18]:
# 查看 KV_CACHE_UPDATE 事件
df_cache_update = df_skip[df_skip['event'] == 'KV_CACHE_UPDATE']
print(f"KV_CACHE_UPDATE events: {len(df_cache_update)}")
if len(df_cache_update) > 0:
    display_cols = [c for c in df_cache_update.columns if 'shape' in c or 'replace' in c.lower() or c in ['event', 'step', 'layer']]
    print(df_cache_update[display_cols].head(10))

KV_CACHE_UPDATE events: 480
               event  step  layer shape_x val_has_replace_position shape_q_in  \
229  KV_CACHE_UPDATE    -1     -1     NaN                      NaN        NaN   
239  KV_CACHE_UPDATE    -1     -1     NaN                      NaN        NaN   
249  KV_CACHE_UPDATE    -1     -1     NaN                      NaN        NaN   
259  KV_CACHE_UPDATE    -1     -1     NaN                      NaN        NaN   
269  KV_CACHE_UPDATE    -1     -1     NaN                      NaN        NaN   
279  KV_CACHE_UPDATE    -1     -1     NaN                      NaN        NaN   
289  KV_CACHE_UPDATE    -1     -1     NaN                      NaN        NaN   
299  KV_CACHE_UPDATE    -1     -1     NaN                      NaN        NaN   
309  KV_CACHE_UPDATE    -1     -1     NaN                      NaN        NaN   
319  KV_CACHE_UPDATE    -1     -1     NaN                      NaN        NaN   

    shape_k_in shape_v_in shape_q shape_k  ... shape_position_ids  \
229        

## 4. 对比 baseline 和 token skip 的数据流

In [19]:
# 对比两者的 ATTN_ENTRY 事件
print("=== Baseline ATTN_ENTRY (first 5) ===")
df_attn_base = df_baseline[df_baseline['event'] == 'ATTN_ENTRY']
shape_cols = [c for c in df_attn_base.columns if 'shape' in c]
print(df_attn_base[['event', 'step', 'layer'] + shape_cols].head(5))

print("\n=== TokenSkip ATTN_ENTRY (first 10) ===")
df_attn_skip = df_skip[df_skip['event'] == 'ATTN_ENTRY']
shape_cols = [c for c in df_attn_skip.columns if 'shape' in c]
print(df_attn_skip[['event', 'step', 'layer'] + shape_cols].head(10))

=== Baseline ATTN_ENTRY (first 5) ===
         event  step  layer shape_x     shape_q_in     shape_k_in  \
1   ATTN_ENTRY    -1     -1     NaN  [1, 51, 4096]  [1, 51, 4096]   
8   ATTN_ENTRY    -1     -1     NaN  [1, 51, 4096]  [1, 51, 4096]   
15  ATTN_ENTRY    -1     -1     NaN  [1, 51, 4096]  [1, 51, 4096]   
22  ATTN_ENTRY    -1     -1     NaN  [1, 51, 4096]  [1, 51, 4096]   
29  ATTN_ENTRY    -1     -1     NaN  [1, 51, 4096]  [1, 51, 4096]   

       shape_v_in shape_q shape_k shape_v  ... shape_k_to_insert  \
1   [1, 51, 4096]     NaN     NaN     NaN  ...               NaN   
8   [1, 51, 4096]     NaN     NaN     NaN  ...               NaN   
15  [1, 51, 4096]     NaN     NaN     NaN  ...               NaN   
22  [1, 51, 4096]     NaN     NaN     NaN  ...               NaN   
29  [1, 51, 4096]     NaN     NaN     NaN  ...               NaN   

   shape_k_final shape_v_final shape_max_replace_pos shape_block_end_index  \
1            NaN           NaN                   NaN        

In [20]:
# 查看 position_ids 相关事件
df_pos = df_skip[df_skip['event'].str.contains('POSITION')]
print(f"POSITION events: {len(df_pos)}")
if len(df_pos) > 0:
    print(df_pos[['event', 'step', 'layer'] + [c for c in df_pos.columns if 'position' in c.lower()]].head(10))

POSITION events: 238
                    event  step  layer val_has_replace_position  \
717  SKIP_POSITION_UPDATE    -1     -1                      NaN   
719     ATTN_POSITION_IDS    -1     -1                      NaN   
732     ATTN_POSITION_IDS    -1     -1                      NaN   
745     ATTN_POSITION_IDS    -1     -1                      NaN   
758     ATTN_POSITION_IDS    -1     -1                      NaN   
771     ATTN_POSITION_IDS    -1     -1                      NaN   
784     ATTN_POSITION_IDS    -1     -1                      NaN   
797     ATTN_POSITION_IDS    -1     -1                      NaN   
810     ATTN_POSITION_IDS    -1     -1                      NaN   
823     ATTN_POSITION_IDS    -1     -1                      NaN   

    val_has_position_ids shape_new_replace_position shape_position_ids  \
717                  NaN                    [1, 51]               [10]   
719                  NaN                        NaN               [10]   
732                

## 5. 保存追踪数据

In [21]:
# 保存为 JSON
tracer_baseline.save_json('trace_baseline.json')
tracer_skip.save_json('trace_tokenskip.json')
print("Traces saved to trace_baseline.json and trace_tokenskip.json")

Traces saved to trace_baseline.json and trace_tokenskip.json


In [22]:
# 禁用追踪器
set_tracer(None)
print("Tracer disabled.")

Tracer disabled.


## 6. 手动检查关键张量

In [23]:
# 如果有错误，打印最后几条追踪
print("Last 20 trace entries:")
for e in tracer_skip.entries[-20:]:
    print(f"[{e.event}] step={e.step} layer={e.layer}")
    if e.shapes:
        print(f"  shapes: {e.shapes}")
    if e.values and len(str(e.values)) < 200:
        print(f"  values: {e.values}")

Last 20 trace entries:
[ROPE_TOKENSKIP_PARAMS] step=-1 layer=-1
  shapes: {'position_ids': [2], 'pos_sin': [1, 1, 21, 128], 'pos_cos': [1, 1, 21, 128]}
[ROPE_INDEX_SELECT] step=-1 layer=-1
  shapes: {'idx': [2], 'pos_sin_slice': [1, 1, 2, 128], 'pos_cos_slice': [1, 1, 2, 128]}
[ROPE_TOKENSKIP_RESULT] step=-1 layer=-1
  shapes: {'q_out': [1, 32, 2, 128], 'k_out': [1, 32, 2, 128]}
[ROPE_AFTER_TOKENSKIP] step=-1 layer=-1
  shapes: {'q_after': [1, 32, 2, 128], 'k_after': [1, 32, 2, 128]}
[KV_CACHE_UPDATE] step=-1 layer=-1
  shapes: {'replace_indices': [2], 'k_to_insert': [32, 2, 128]}
[KV_CACHE_AFTER_REPLACE] step=-1 layer=-1
  shapes: {'k_final': [1, 32, 51, 128], 'v_final': [1, 32, 51, 128]}
[ROPE_SKIP] step=-1 layer=-1
  values: {'reason': 'already_applied_above'}
[ATTN_ENTRY] step=-1 layer=-1
  shapes: {'q_in': [1, 2, 4096], 'k_in': [1, 2, 4096], 'v_in': [1, 2, 4096]}
[ATTN_POSITION_IDS] step=-1 layer=-1
  shapes: {'position_ids': [2]}
  values: {'position_ids': [19, 20], 'position_ids