# KV Reuse Evaluation

在 GSM8K 上评测 Dual Cache + KV Reuse 优化。

**KV Reuse 原理：**
- 比较当前 step 与上一 step 的 hidden states 的 cosine similarity
- 对于相似度 >= threshold 的 token，复用上一 step 的 K/V
- Q/Attention/FFN/logits 仍然重新计算
- 被复用的 token 下一步强制更新（防止误差累积）

**实验配置:**
1. Baseline: Dual Cache（不使用 KV reuse）
2. Dual Cache + 不同 `similarity_threshold` (0.60, 0.70, 0.80)

In [1]:
import os
import torch
import gc

# Set GPU (modify as needed)
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'

# Environment settings
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
os.environ['HF_ALLOW_CODE_EVAL'] = '1'
os.environ['HF_DATASETS_TRUST_REMOTE_CODE'] = 'true'

# Change to llada directory
os.chdir('llada')

# Create log directory
os.makedirs('nlogs', exist_ok=True)

# Clear GPU cache
torch.cuda.empty_cache()
gc.collect()

print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Total VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

CUDA available: True
GPU: NVIDIA RTX A5000
Total VRAM: 22.06 GB


## 1. 快速测试 (limit=50)

In [3]:
import subprocess
import datetime

task = "gsm8k"
fewshot = 3
limit = 30
gpu = 1

# (threshold, name)
experiments = [
    (None, "baseline"),
    (0.70, "kv_reuse_th070"),
    (0.80, "kv_reuse_th080"),
]

for threshold, name in experiments:
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = f"nlogs/{task}_{name}_{timestamp}.log"
    
    model_args = [
        "model_path='GSAI-ML/LLaDA-8B-Instruct'",
        "gen_length=128",
        "steps=32",
        "block_length=32",
        "threshold=0.9",
        "use_cache=True",
        "dual_cache=True",
        "show_speed=True",
    ]
    
    if threshold is not None:
        model_args.extend([
            "mid_layer_skip=True",
            f"early_exit_threshold={threshold}",  # Maps to similarity_threshold internally
        ])
    else:
        model_args.append("mid_layer_skip=False")
    
    cmd = f"""CUDA_VISIBLE_DEVICES={gpu} accelerate launch eval_llada.py \\
        --tasks {task} --num_fewshot {fewshot} --limit {limit} \\
        --confirm_run_unsafe_code --model llada_dist \\
        --model_args {','.join(model_args)}"""
    
    print(f"\n{'='*60}")
    print(f"Running: {name}")
    print(f"Log file: {log_file}")
    print('='*60)
    
    with open(log_file, 'w') as f:
        result = subprocess.run(cmd, shell=True, stdout=f, stderr=subprocess.STDOUT, text=True)
    
    with open(log_file, 'r') as f:
        lines = f.readlines()
        print("Last 10 lines:")
        for line in lines[-10:]:
            print(line.rstrip())


Running: baseline
Log file: nlogs/gsm8k_baseline_20260123_020654.log
Last 10 lines:
Total number of tokens generated: 3695
Total time taken: 68.66974472999573 seconds
Tokens per second: 53.808265686035156
Total NFE is 923
llada_dist (model_path=GSAI-ML/LLaDA-8B-Instruct,gen_length=128,steps=32,block_length=32,threshold=0.9,use_cache=True,dual_cache=True,show_speed=True,mid_layer_skip=False), gen_kwargs: (None), limit: 30.0, num_fewshot: 3, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     3|exact_match|↑  |0.5333|±  |0.0926|
|     |       |strict-match    |     3|exact_match|↑  |0.4000|±  |0.0910|


Running: kv_reuse_th070
Log file: nlogs/gsm8k_kv_reuse_th070_20260123_020842.log
Last 10 lines:
Total number of tokens generated: 3666
Total time taken: 74.38229727745056 seconds
Tokens per second: 49.28592300415039
Total NFE is 920
llada_dist

## 2. 并行完整测试

In [None]:
import subprocess
import datetime

task = "gsm8k"
fewshot = 4

# (threshold, gpu, name)
configs = [
    (None, 0, "baseline"),
    (0.60, 1, "kv_reuse_th060"),
    (0.70, 2, "kv_reuse_th070"),
    (0.80, 3, "kv_reuse_th080"),
]

processes = []

for threshold, gpu, name in configs:
    log_file = f"nlogs/kv_reuse_{task}_{name}_{datetime.datetime.now():%F_%H-%M-%S}.log"
    
    skip_args = f"dual_cache=True"
    if threshold is not None:
        skip_args += f",mid_layer_skip=True,early_exit_threshold={threshold}"
    else:
        skip_args += ",mid_layer_skip=False"
    
    cmd = f"""CUDA_VISIBLE_DEVICES={gpu} accelerate launch eval_llada.py \\
        --tasks {task} --num_fewshot {fewshot} \\
        --confirm_run_unsafe_code --model llada_dist \\
        --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=256,steps=256,block_length=32,threshold=0.9,{skip_args},show_speed=True \\
        --output_path evals_results/kv_reuse/gsm8k-{name} --log_samples"""
    
    print(f"GPU{gpu} running: {name}")
    p = subprocess.Popen(cmd, shell=True, stdout=open(log_file, "w"), stderr=subprocess.STDOUT)
    processes.append((p, name, log_file))

print(f"\n{len(processes)} tasks launched, waiting...")

for p, name, log in processes:
    p.wait()
    print(f"{name} finished (exit code: {p.returncode})")

## 3. 查看结果

In [None]:
import glob
import re
import os

log_files = sorted(glob.glob("nlogs/kv_reuse_gsm8k*.log"), key=os.path.getmtime, reverse=True)

print("KV Reuse GSM8K Results:")
print("=" * 90)

results = []
for log_file in log_files[:10]:
    name = os.path.basename(log_file)
    
    with open(log_file, 'r') as f:
        content = f.read()
    
    # Accuracy
    acc_match = re.search(r'exact_match[,:\|]?\s*([\d.]+)', content)
    acc = float(acc_match.group(1)) * 100 if acc_match else None
    
    # Speed
    speed_match = re.search(r'Tokens per second:\s*([\d.]+)', content)
    speed = float(speed_match.group(1)) if speed_match else None
    
    # NFE
    nfe_match = re.search(r'Total NFE is (\d+)', content)
    nfe = int(nfe_match.group(1)) if nfe_match else None
    
    # Time
    time_match = re.search(r'Total time taken:\s*([\d.]+)', content)
    time_sec = float(time_match.group(1)) if time_match else None
    
    # KV Reuse rate
    kv_match = re.search(r'kv_reuse_rate=([\d.]+)%', content)
    kv_rate = float(kv_match.group(1)) if kv_match else None
    
    # Token KV reuse rate
    tkv_match = re.search(r'token_kv_reuse_rate=([\d.]+)%', content)
    tkv_rate = float(tkv_match.group(1)) if tkv_match else None
    
    results.append({
        'name': name,
        'accuracy': acc,
        'tokens_per_sec': speed,
        'total_nfe': nfe,
        'time_sec': time_sec,
        'kv_reuse_rate': kv_rate,
        'token_kv_rate': tkv_rate,
    })

print(f"{'Name':<50} {'Acc%':<8} {'Tok/s':<10} {'NFE':<10} {'Time(s)':<10} {'KV%':<8} {'TKV%':<8}")
print("-" * 90)
for r in results:
    acc_str = f"{r['accuracy']:.2f}" if r['accuracy'] else "N/A"
    speed_str = f"{r['tokens_per_sec']:.1f}" if r['tokens_per_sec'] else "N/A"
    nfe_str = str(r['total_nfe']) if r['total_nfe'] else "N/A"
    time_str = f"{r['time_sec']:.1f}" if r['time_sec'] else "N/A"
    kv_str = f"{r['kv_reuse_rate']:.1f}" if r['kv_reuse_rate'] else "N/A"
    tkv_str = f"{r['token_kv_rate']:.1f}" if r['token_kv_rate'] else "N/A"
    print(f"{r['name']:<50} {acc_str:<8} {speed_str:<10} {nfe_str:<10} {time_str:<10} {kv_str:<8} {tkv_str:<8}")

## 4. 手动运行（可选）

In [None]:
# Single experiment
import subprocess
import datetime

threshold = 0.70
limit = 100
gpu = 0

name = f"manual_th{int(threshold*100)}" if threshold else "manual_baseline"
log_file = f"nlogs/manual_kv_reuse_{name}_{datetime.datetime.now():%F_%H-%M-%S}.log"

skip_args = f"dual_cache=True"
if threshold is not None:
    skip_args += f",mid_layer_skip=True,early_exit_threshold={threshold}"
else:
    skip_args += ",mid_layer_skip=False"

cmd = f"""CUDA_VISIBLE_DEVICES={gpu} accelerate launch eval_llada.py \\
    --tasks gsm8k --num_fewshot 5 --limit {limit} \\
    --confirm_run_unsafe_code --model llada_dist \\
    --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=256,steps=256,block_length=32,threshold=0.9,{skip_args},show_speed=True \\
    --output_path evals_results/kv_reuse/manual-{name} --log_samples"""

print(f"Running: {name}")
print(f"Log: {log_file}")
print(f"Command: {cmd[:150]}...")

# Uncomment to run:
# p = subprocess.Popen(cmd, shell=True, stdout=open(log_file, "w"), stderr=subprocess.STDOUT)
# p.wait()
# print(f"Finished (exit code: {p.returncode})")