# Mid-Layer Early-Exit 评测 Notebook

在 GSM8K 上评测 **Dual Cache + Mid-Layer Early-Exit** 的效果。

**工作原理：**
- 在每个 refinement step，先计算到第 k 层
- 比较当前 hidden state 与上一步的 cosine similarity
- 如果相似度高于阈值，跳过后续层计算

**实验内容:**
1. Baseline: Dual Cache（不使用 early-exit）
2. Dual Cache + 不同 early_exit_layer (16, 20, 24, 26)

In [None]:
import os
import torch
import gc

# Set GPU (modify as needed)
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'

# Environment settings
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
os.environ['HF_ALLOW_CODE_EVAL'] = '1'
os.environ['HF_DATASETS_TRUST_REMOTE_CODE'] = 'true'

# Change to llada directory
os.chdir('llada')

# Create log directory
os.makedirs('nlogs', exist_ok=True)

# Clear GPU cache
torch.cuda.empty_cache()
gc.collect()

print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Total VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

## 1. 快速验证: Baseline vs Mid-Layer Skip (limit=50)

In [None]:
# Quick test: only evaluate 50 samples
import subprocess
import datetime

task = "gsm8k"
fewshot = 5
limit = 50  # Only test 50 samples
gpu = 0  # GPU ID

# Configuration: (early_exit_layer, name)
# None means no early-exit (baseline)
experiments = [
    (None, "baseline"),           # Baseline
    (24, "midskip_L24"),          # Early-exit at Layer 24
]

threshold = 0.9  # Fixed threshold

for exit_layer, name in experiments:
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = f"nlogs/{task}_{name}_{timestamp}.log"
    
    # Build model_args string (correct format for lm-eval)
    # IMPORTANT: steps should be >= gen_length for good quality!
    # steps_per_block = steps / num_blocks = steps / (gen_length / block_length)
    # For gen_length=256, block_length=32: num_blocks=8
    # If steps=256: steps_per_block=32 (good)
    # If steps=32:  steps_per_block=4  (too few, bad quality!)
    model_args = [
        "model_path='GSAI-ML/LLaDA-8B-Instruct'",
        "gen_length=256",
        "steps=256",  # FIXED: was 32, should be 256 for good quality
        "block_length=32",
        "threshold=0.9",
        "use_cache=True",
        "dual_cache=True",
        "show_speed=True",
    ]
    
    if exit_layer is not None:
        model_args.extend([
            "mid_layer_skip=True",
            f"early_exit_layer={exit_layer}",
            f"early_exit_threshold={threshold}",
        ])
    else:
        model_args.append("mid_layer_skip=False")
    
    # Correct command format using accelerate launch and --model_args
    cmd = f"""CUDA_VISIBLE_DEVICES={gpu} accelerate launch eval_llada.py \
        --tasks {task} --num_fewshot {fewshot} --limit {limit} \
        --confirm_run_unsafe_code --model llada_dist \
        --model_args {','.join(model_args)}"""
    
    print(f"\n{'='*60}")
    print(f"Running: {name}")
    print(f"Command: {cmd[:200]}...")
    print(f"Log file: {log_file}")
    print('='*60)
    
    with open(log_file, 'w') as f:
        result = subprocess.run(cmd, shell=True, stdout=f, stderr=subprocess.STDOUT, text=True)
    
    # Show last few lines of log
    with open(log_file, 'r') as f:
        lines = f.readlines()
        print("Last 10 lines of output:")
        for line in lines[-10:]:
            print(line.rstrip())

## 2. 完整测试: 不同 Early-Exit Layer 并行运行

In [None]:
import subprocess
import datetime

task = "gsm8k"
fewshot = 4
threshold = 0.9  # Fixed threshold

# Configuration: (early_exit_layer, gpu, name)
# None means baseline
configs = [
    (None, 0, "baseline"),
    (16, 1, "midskip_L16"),
    (20, 2, "midskip_L20"),
    (24, 3, "midskip_L24"),
    (26, 4, "midskip_L26"),
]

processes = []

for early_exit_layer, gpu, name in configs:
    log_file = f"nlogs/midskip_{task}_{name}_{datetime.datetime.now():%F_%H-%M-%S}.log"
    
    # Build arguments
    skip_args = f"dual_cache=True"
    if early_exit_layer is not None:
        skip_args += f",mid_layer_skip=True,early_exit_layer={early_exit_layer},early_exit_threshold={threshold}"
    else:
        skip_args += ",mid_layer_skip=False"
    
    # FIXED: steps should be >= gen_length for good quality
    # gen_length=256, steps=256, block_length=32 -> steps_per_block=32 (good)
    cmd = f"""CUDA_VISIBLE_DEVICES={gpu} accelerate launch eval_llada.py \
        --tasks {task} --num_fewshot {fewshot} \
        --confirm_run_unsafe_code --model llada_dist \
        --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=256,steps=256,block_length=32,threshold=0.9,{skip_args},show_speed=True \
        --output_path evals_results/midskip/gsm8k-{name} --log_samples"""
    
    print(f"GPU{gpu} running: {name}")
    p = subprocess.Popen(cmd, shell=True, stdout=open(log_file, "w"), stderr=subprocess.STDOUT)
    processes.append((p, name, log_file))

print(f"\n{len(processes)} tasks launched in parallel, waiting...")

for p, name, log in processes:
    p.wait()
    print(f"{name} finished (exit code: {p.returncode})")

## 3. 查看评测结果

In [None]:
import glob
import re
import os

# Find recent log files and extract key metrics
log_files = sorted(glob.glob("nlogs/midskip_gsm8k*.log"), key=os.path.getmtime, reverse=True)

print("Mid-Layer Skip GSM8K Results:")
print("=" * 80)

results = []
for log_file in log_files[:10]:  # Most recent 10
    name = os.path.basename(log_file)
    
    with open(log_file, 'r') as f:
        content = f.read()
        
    # Extract accuracy
    acc_match = re.search(r'exact_match[,:\|]?\s*([\d.]+)', content)
    acc = float(acc_match.group(1)) * 100 if acc_match else None
    
    # Extract speed info
    speed_match = re.search(r'Tokens per second:\s*([\d.]+)', content)
    speed = float(speed_match.group(1)) if speed_match else None
    
    # Extract NFE
    nfe_match = re.search(r'Total NFE is (\d+)', content)
    nfe = int(nfe_match.group(1)) if nfe_match else None
    
    # Extract time
    time_match = re.search(r'Total time taken:\s*([\d.]+)', content)
    time_sec = float(time_match.group(1)) if time_match else None
    
    # Extract early-exit info
    ee_match = re.search(r'Early-Exit rate:\s*([\d.]+)%', content)
    ee_rate = float(ee_match.group(1)) if ee_match else None
    
    results.append({
        'name': name,
        'accuracy': acc,
        'tokens_per_sec': speed,
        'total_nfe': nfe,
        'time_sec': time_sec,
        'early_exit_rate': ee_rate,
    })

# Formatted output
print(f"{'Name':<55} {'Acc%':<8} {'Tok/s':<10} {'NFE':<10} {'Time(s)':<10} {'EE%':<8}")
print("-" * 80)
for r in results:
    acc_str = f"{r['accuracy']:.2f}" if r['accuracy'] else "N/A"
    speed_str = f"{r['tokens_per_sec']:.1f}" if r['tokens_per_sec'] else "N/A"
    nfe_str = str(r['total_nfe']) if r['total_nfe'] else "N/A"
    time_str = f"{r['time_sec']:.1f}" if r['time_sec'] else "N/A"
    ee_str = f"{r['early_exit_rate']:.1f}" if r['early_exit_rate'] else "N/A"
    print(f"{r['name']:<55} {acc_str:<8} {speed_str:<10} {nfe_str:<10} {time_str:<10} {ee_str:<8}")

## 4. 手动运行单个实验（可选）

In [None]:
# For running a single experiment manually
import subprocess
import datetime

# Configuration
early_exit_layer = 24  # Set to None for baseline
threshold = 0.9
limit = 100  # Number of samples
gpu = 0

name = f"manual_L{early_exit_layer}" if early_exit_layer else "manual_baseline"
log_file = f"nlogs/manual_midskip_{name}_{datetime.datetime.now():%F_%H-%M-%S}.log"

skip_args = f"dual_cache=True"
if early_exit_layer is not None:
    skip_args += f",mid_layer_skip=True,early_exit_layer={early_exit_layer},early_exit_threshold={threshold}"
else:
    skip_args += ",mid_layer_skip=False"

cmd = f"""CUDA_VISIBLE_DEVICES={gpu} accelerate launch eval_llada.py     --tasks gsm8k --num_fewshot 5 --limit {limit}     --confirm_run_unsafe_code --model llada_dist     --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=64,steps=16,block_length=8,threshold=0.9,{skip_args},show_speed=True     --output_path evals_results/midskip/manual-{name} --log_samples"""

print(f"Running: {name}")
print(f"Log: {log_file}")
print(f"Command: {cmd[:200]}...")

# Uncomment below to run
# p = subprocess.Popen(cmd, shell=True, stdout=open(log_file, "w"), stderr=subprocess.STDOUT)
# p.wait()
# print(f"Finished (exit code: {p.returncode})")