# Token Skip Evaluation

在 GSM8K 上评测 Dual Cache + Token Skip 优化。

**Token Skip 原理（新版）：**
- 比较 h_{t-1} 与 h_{t-2}（上一 step 和上上一 step 的最终 hidden state）的 cosine similarity
- 对于相似度 > threshold 的 token，**完全跳过**当前 step 的计算
- 被跳过的 token 使用上一 step 的 KV cache
- 被跳过的 token 下一步强制更新（防止误差累积）

**与原版 KV Reuse 的区别：**
- 原版：在 forward 中途判定，部分层 skip
- 新版：在 forward 之前判定，整个 forward 都 skip

**实验配置:**
1. Baseline: Dual Cache（不使用 Token Skip）
2. Token Skip + 不同 `skip_threshold` (0.90, 0.95, 0.99)
3. Token Skip + 不同 `force_full_every_k` (0, 1, 2, 3)

In [None]:
import os
import torch
import gc

# Set GPU (modify as needed)
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'

# Environment settings
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
os.environ['HF_ALLOW_CODE_EVAL'] = '1'
os.environ['HF_DATASETS_TRUST_REMOTE_CODE'] = 'true'

# Change to llada directory
os.chdir('llada')

# Create log directory
os.makedirs('nlogs', exist_ok=True)

# Clear GPU cache
torch.cuda.empty_cache()
gc.collect()

print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Total VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

## 1. 快速测试 (limit=30)

In [None]:
import subprocess
import datetime

task = "gsm8k"
fewshot = 3
limit = 30
gpu = 1

# (skip_threshold, force_full_every_k, name)
experiments = [
    (None, None, "baseline"),
    (0.95, 3, "skip_th095_k3"),
    (0.95, 1, "skip_th095_k1"),  # K=1 应该等价于 baseline
    (0.99, 3, "skip_th099_k3"),
    (0.90, 3, "skip_th090_k3"),
]

for skip_threshold, force_k, name in experiments:
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = f"nlogs/{task}_{name}_{timestamp}.log"
    
    model_args = [
        "model_path='GSAI-ML/LLaDA-8B-Instruct'",
        "gen_length=128",
        "steps=32",
        "block_length=32",
        "threshold=0.9",
        "use_cache=True",
        "dual_cache=True",
        "show_speed=True",
    ]
    
    if skip_threshold is not None:
        model_args.extend([
            "token_skip=True",
            f"skip_threshold={skip_threshold}",
            f"force_full_every_k={force_k}",
        ])
    else:
        model_args.append("token_skip=False")
    
    cmd = f"""CUDA_VISIBLE_DEVICES={gpu} accelerate launch eval_llada.py \\
        --tasks {task} --num_fewshot {fewshot} --limit {limit} \\
        --confirm_run_unsafe_code --model llada_dist \\
        --model_args {','.join(model_args)}"""
    
    print(f"\n{'='*60}")
    print(f"Running: {name}")
    print(f"Log file: {log_file}")
    print('='*60)
    
    with open(log_file, 'w') as f:
        result = subprocess.run(cmd, shell=True, stdout=f, stderr=subprocess.STDOUT, text=True)
    
    with open(log_file, 'r') as f:
        lines = f.readlines()
        print("Last 10 lines:")
        for line in lines[-10:]:
            print(line.rstrip())

## 2. 并行完整测试

In [None]:
import subprocess
import datetime

task = "gsm8k"
fewshot = 4

# (skip_threshold, force_full_every_k, gpu, name)
configs = [
    (None, None, 0, "baseline"),
    (0.95, 3, 1, "skip_th095_k3"),
    (0.95, 1, 2, "skip_th095_k1"),
    (0.99, 3, 3, "skip_th099_k3"),
]

processes = []

for skip_threshold, force_k, gpu, name in configs:
    log_file = f"nlogs/token_skip_{task}_{name}_{datetime.datetime.now():%F_%H-%M-%S}.log"
    
    skip_args = f"dual_cache=True"
    if skip_threshold is not None:
        skip_args += f",token_skip=True,skip_threshold={skip_threshold},force_full_every_k={force_k}"
    else:
        skip_args += ",token_skip=False"
    
    cmd = f"""CUDA_VISIBLE_DEVICES={gpu} accelerate launch eval_llada.py \\
        --tasks {task} --num_fewshot {fewshot} \\
        --confirm_run_unsafe_code --model llada_dist \\
        --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=256,steps=256,block_length=32,threshold=0.9,{skip_args},show_speed=True \\
        --output_path evals_results/token_skip/gsm8k-{name} --log_samples"""
    
    print(f"GPU{gpu} running: {name}")
    p = subprocess.Popen(cmd, shell=True, stdout=open(log_file, "w"), stderr=subprocess.STDOUT)
    processes.append((p, name, log_file))

print(f"\n{len(processes)} tasks launched, waiting...")

for p, name, log in processes:
    p.wait()
    print(f"{name} finished (exit code: {p.returncode})")

## 3. 查看结果

In [None]:
import glob
import re
import os

log_files = sorted(glob.glob("nlogs/token_skip_gsm8k*.log"), key=os.path.getmtime, reverse=True)

print("Token Skip GSM8K Results:")
print("=" * 90)

results = []
for log_file in log_files[:10]:
    name = os.path.basename(log_file)
    
    with open(log_file, 'r') as f:
        content = f.read()
    
    # Accuracy
    acc_match = re.search(r'exact_match[,:\|]?\s*([\d.]+)', content)
    acc = float(acc_match.group(1)) * 100 if acc_match else None
    
    # Speed
    speed_match = re.search(r'Tokens per second:\s*([\d.]+)', content)
    speed = float(speed_match.group(1)) if speed_match else None
    
    # NFE
    nfe_match = re.search(r'Total NFE is (\d+)', content)
    nfe = int(nfe_match.group(1)) if nfe_match else None
    
    # Time
    time_match = re.search(r'Total time taken:\s*([\d.]+)', content)
    time_sec = float(time_match.group(1)) if time_match else None
    
    results.append({
        'name': name,
        'accuracy': acc,
        'tokens_per_sec': speed,
        'total_nfe': nfe,
        'time_sec': time_sec,
    })

print(f"{'Name':<55} {'Acc%':<8} {'Tok/s':<10} {'NFE':<10} {'Time(s)':<10}")
print("-" * 90)
for r in results:
    acc_str = f"{r['accuracy']:.2f}" if r['accuracy'] else "N/A"
    speed_str = f"{r['tokens_per_sec']:.1f}" if r['tokens_per_sec'] else "N/A"
    nfe_str = str(r['total_nfe']) if r['total_nfe'] else "N/A"
    time_str = f"{r['time_sec']:.1f}" if r['time_sec'] else "N/A"
    print(f"{r['name']:<55} {acc_str:<8} {speed_str:<10} {nfe_str:<10} {time_str:<10}")

## 4. 手动运行（可选）

In [None]:
# Single experiment
import subprocess
import datetime

skip_threshold = 0.95
force_full_every_k = 3
limit = 100
gpu = 0

name = f"manual_th{int(skip_threshold*100)}_k{force_full_every_k}" if skip_threshold else "manual_baseline"
log_file = f"nlogs/manual_token_skip_{name}_{datetime.datetime.now():%F_%H-%M-%S}.log"

skip_args = f"dual_cache=True"
if skip_threshold is not None:
    skip_args += f",token_skip=True,skip_threshold={skip_threshold},force_full_every_k={force_full_every_k}"
else:
    skip_args += ",token_skip=False"

cmd = f"""CUDA_VISIBLE_DEVICES={gpu} accelerate launch eval_llada.py \\
    --tasks gsm8k --num_fewshot 5 --limit {limit} \\
    --confirm_run_unsafe_code --model llada_dist \\
    --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=256,steps=256,block_length=32,threshold=0.9,{skip_args},show_speed=True \\
    --output_path evals_results/token_skip/manual-{name} --log_samples"""

print(f"Running: {name}")
print(f"Log: {log_file}")
print(f"Command: {cmd[:150]}...")

# Uncomment to run:
# p = subprocess.Popen(cmd, shell=True, stdout=open(log_file, "w"), stderr=subprocess.STDOUT)
# p.wait()
# print(f"Finished (exit code: {p.returncode})")

## 5. 快速查看最近日志

In [None]:
import glob
import os

# 查看最近的日志文件
log_files = sorted(glob.glob("nlogs/*.log"), key=os.path.getmtime, reverse=True)

if log_files:
    latest_log = log_files[0]
    print(f"Latest log: {latest_log}")
    print("="*60)
    with open(latest_log, 'r') as f:
        lines = f.readlines()
        # 显示最后 20 行
        for line in lines[-20:]:
            print(line.rstrip())
else:
    print("No log files found.")