# Layer Skip Evaluation

在 GSM8K 上评测 Layer Skip 优化。

**Layer Skip 原理：**
- 基于 dual_cache，如果前一层的 input-output 相似度 > threshold，跳过当前层
- 跳过时 output = input (identity mapping)
- 上一步跳过的层，这一步必须重算（防止误差累积）

**实验配置:**
- 测试 7 种不同的 `layer_skip_threshold`: 0.950, 0.960, 0.970, 0.980, 0.990, 0.995, 0.998

In [None]:
import os
import torch
import gc

# Set GPU (modify as needed)
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'

# Environment settings
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
os.environ['HF_ALLOW_CODE_EVAL'] = '1'
os.environ['HF_DATASETS_TRUST_REMOTE_CODE'] = 'true'

# Change to llada directory
os.chdir('llada')

# Create log directory
os.makedirs('nlogs', exist_ok=True)

# Clear GPU cache
torch.cuda.empty_cache()
gc.collect()

print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Total VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

## 1. 快速测试 (limit=30)

In [None]:
import subprocess
import datetime

task = "gsm8k"
fewshot = 3
limit = 30
gpu = 0

# (threshold, name)
experiments = [
    (0.985, "layer_skip_th950"),
    (0.990, "layer_skip_th970"),
    (0.990, "layer_skip_th990"),
]

for threshold, name in experiments:
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = f"nlogs/{task}_{name}_{timestamp}.log"
    
    model_args = [
        "model_path='GSAI-ML/LLaDA-8B-Instruct'",
        "gen_length=128",
        "steps=32",
        "block_length=32",
        "threshold=0.9",
        "use_cache=True",
        "layer_skip=True",
        f"layer_skip_threshold={threshold}",
        "show_speed=True",
    ]
    
    cmd = f"""CUDA_VISIBLE_DEVICES={gpu} accelerate launch eval_llada.py \\
        --tasks {task} --num_fewshot {fewshot} --limit {limit} \\
        --confirm_run_unsafe_code --model llada_dist \\
        --model_args {','.join(model_args)}"""
    
    print(f"\n{'='*60}")
    print(f"Running: {name}")
    print(f"Log file: {log_file}")
    print('='*60)
    
    with open(log_file, 'w') as f:
        result = subprocess.run(cmd, shell=True, stdout=f, stderr=subprocess.STDOUT, text=True)
    
    with open(log_file, 'r') as f:
        lines = f.readlines()
        print("Last 15 lines:")
        for line in lines[-15:]:
            print(line.rstrip())

## 2. 并行完整测试（7个 threshold）

In [None]:
import subprocess
import datetime

task = "gsm8k"
fewshot = 4

# (threshold, gpu, name)
# 7个不同的 layer_skip_threshold 值
configs = [
    (0.950, 0, "layer_skip_th950"),
    (0.960, 1, "layer_skip_th960"),
    (0.970, 2, "layer_skip_th970"),
    (0.980, 3, "layer_skip_th980"),
    (0.990, 4, "layer_skip_th990"),
    (0.995, 5, "layer_skip_th995"),
    (0.998, 6, "layer_skip_th998"),
]

processes = []

for threshold, gpu, name in configs:
    log_file = f"nlogs/layer_skip_{task}_{name}_{datetime.datetime.now():%F_%H-%M-%S}.log"
    
    model_args = [
        "model_path='GSAI-ML/LLaDA-8B-Instruct'",
        "gen_length=256",
        "steps=256",
        "block_length=32",
        "threshold=0.9",
        "use_cache=True",
        "layer_skip=True",
        f"layer_skip_threshold={threshold}",
        "show_speed=True",
    ]
    
    cmd = f"""CUDA_VISIBLE_DEVICES={gpu} accelerate launch eval_llada.py \\
        --tasks {task} --num_fewshot {fewshot} \\
        --confirm_run_unsafe_code --model llada_dist \\
        --model_args {','.join(model_args)} \\
        --output_path evals_results/layer_skip/gsm8k-{name} --log_samples"""
    
    print(f"GPU{gpu} running: {name} (threshold={threshold})")
    p = subprocess.Popen(cmd, shell=True, stdout=open(log_file, "w"), stderr=subprocess.STDOUT)
    processes.append((p, name, log_file))

print(f"\n{len(processes)} tasks launched, waiting...")

for p, name, log in processes:
    p.wait()
    print(f"{name} finished (exit code: {p.returncode})")

## 3. 查看结果

In [None]:
import glob
import re
import os

log_files = sorted(glob.glob("nlogs/layer_skip_gsm8k*.log"), key=os.path.getmtime, reverse=True)

print("Layer Skip GSM8K Results:")
print("=" * 110)

results = []
for log_file in log_files[:14]:  # 最近14个（7个threshold × 可能多次运行）
    name = os.path.basename(log_file)
    
    with open(log_file, 'r') as f:
        content = f.read()
    
    # Accuracy
    acc_match = re.search(r'exact_match[,:\|]?\s*([\d.]+)', content)
    acc = float(acc_match.group(1)) * 100 if acc_match else None
    
    # Speed
    speed_match = re.search(r'Tokens per second:\s*([\d.]+)', content)
    speed = float(speed_match.group(1)) if speed_match else None
    
    # NFE
    nfe_match = re.search(r'Total NFE is (\d+)', content)
    nfe = int(nfe_match.group(1)) if nfe_match else None
    
    # Time
    time_match = re.search(r'Total time taken:\s*([\d.]+)', content)
    time_sec = float(time_match.group(1)) if time_match else None
    
    # Layer Skip Rate (从最终汇总中提取)
    skip_match = re.search(r'Layer Skip Stats:.*skip_rate=([\d.]+)%', content)
    skip_rate = float(skip_match.group(1)) if skip_match else None
    
    # Layers computed/skipped
    computed_match = re.search(r'Layer Skip Stats: computed=(\d+)', content)
    skipped_match = re.search(r'skipped=(\d+)', content)
    computed = int(computed_match.group(1)) if computed_match else None
    skipped = int(skipped_match.group(1)) if skipped_match else None
    
    results.append({
        'name': name,
        'accuracy': acc,
        'tokens_per_sec': speed,
        'total_nfe': nfe,
        'time_sec': time_sec,
        'skip_rate': skip_rate,
        'layers_computed': computed,
        'layers_skipped': skipped,
    })

print(f"{'Name':<55} {'Acc%':<8} {'Tok/s':<10} {'NFE':<10} {'Time(s)':<10} {'Skip%':<8} {'Computed':<10} {'Skipped':<10}")
print("-" * 110)
for r in results:
    acc_str = f"{r['accuracy']:.2f}" if r['accuracy'] else "N/A"
    speed_str = f"{r['tokens_per_sec']:.1f}" if r['tokens_per_sec'] else "N/A"
    nfe_str = str(r['total_nfe']) if r['total_nfe'] else "N/A"
    time_str = f"{r['time_sec']:.1f}" if r['time_sec'] else "N/A"
    skip_str = f"{r['skip_rate']:.2f}" if r['skip_rate'] else "N/A"
    computed_str = str(r['layers_computed']) if r['layers_computed'] else "N/A"
    skipped_str = str(r['layers_skipped']) if r['layers_skipped'] else "N/A"
    print(f"{r['name']:<55} {acc_str:<8} {speed_str:<10} {nfe_str:<10} {time_str:<10} {skip_str:<8} {computed_str:<10} {skipped_str:<10}")

## 4. 手动运行（可选）

In [None]:
# Single experiment
import subprocess
import datetime

threshold = 0.990  # 修改此值测试不同阈值
limit = 100
gpu = 0

name = f"manual_layer_skip_th{int(threshold*1000)}"
log_file = f"nlogs/manual_layer_skip_{name}_{datetime.datetime.now():%F_%H-%M-%S}.log"

model_args = [
    "model_path='GSAI-ML/LLaDA-8B-Instruct'",
    "gen_length=256",
    "steps=256",
    "block_length=32",
    "threshold=0.9",
    "use_cache=True",
    "layer_skip=True",
    f"layer_skip_threshold={threshold}",
    "show_speed=True",
]

cmd = f"""CUDA_VISIBLE_DEVICES={gpu} accelerate launch eval_llada.py \\
    --tasks gsm8k --num_fewshot 5 --limit {limit} \\
    --confirm_run_unsafe_code --model llada_dist \\
    --model_args {','.join(model_args)} \\
    --output_path evals_results/layer_skip/manual-{name} --log_samples"""

print(f"Running: {name}")
print(f"Log: {log_file}")
print(f"Command: {cmd[:200]}...")

# Uncomment to run:
# p = subprocess.Popen(cmd, shell=True, stdout=open(log_file, "w"), stderr=subprocess.STDOUT)
# p.wait()
# print(f"Finished (exit code: {p.returncode})")

## 5. 对比分析

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# 从上面的 results 中提取数据进行可视化
# 需要先运行 Section 3 获取 results

if results:
    # 按 threshold 排序（从文件名提取）
    valid_results = [r for r in results if r['accuracy'] is not None and r['skip_rate'] is not None]
    
    if valid_results:
        # 提取 threshold 值
        for r in valid_results:
            th_match = re.search(r'th(\d+)', r['name'])
            r['threshold'] = int(th_match.group(1)) / 1000 if th_match else 0
        
        valid_results.sort(key=lambda x: x['threshold'])
        
        thresholds = [r['threshold'] for r in valid_results]
        accuracies = [r['accuracy'] for r in valid_results]
        skip_rates = [r['skip_rate'] for r in valid_results]
        speeds = [r['tokens_per_sec'] for r in valid_results]
        
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # Accuracy vs Threshold
        axes[0].plot(thresholds, accuracies, 'bo-', linewidth=2, markersize=8)
        axes[0].set_xlabel('Layer Skip Threshold')
        axes[0].set_ylabel('Accuracy (%)')
        axes[0].set_title('Accuracy vs Threshold')
        axes[0].grid(True, alpha=0.3)
        
        # Skip Rate vs Threshold
        axes[1].plot(thresholds, skip_rates, 'ro-', linewidth=2, markersize=8)
        axes[1].set_xlabel('Layer Skip Threshold')
        axes[1].set_ylabel('Skip Rate (%)')
        axes[1].set_title('Skip Rate vs Threshold')
        axes[1].grid(True, alpha=0.3)
        
        # Speed vs Threshold
        axes[2].plot(thresholds, speeds, 'go-', linewidth=2, markersize=8)
        axes[2].set_xlabel('Layer Skip Threshold')
        axes[2].set_ylabel('Tokens/second')
        axes[2].set_title('Speed vs Threshold')
        axes[2].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig('nlogs/layer_skip_analysis.png', dpi=150)
        plt.show()
        
        print("\n图表已保存至 nlogs/layer_skip_analysis.png")
    else:
        print("没有找到有效的结果数据")
else:
    print("请先运行 Section 3 获取结果数据")