在 Pythia-70M 上复现 StreamingLLM、SnapKV 和 TreeKV,并提出三种改进方案,在 wikitext-2 和 pg-19 数据集上进行 PPL 和速度评测。
This project is my individual implementation of AI2801 NLP.
| 文件 | 方法 | 核心思想 |
|---|---|---|
baseline.py |
Full KV Cache | 标准自回归推理,全量保留 KV Cache |
streaming_llm.py |
StreamingLLM | 保留前 k_sink 个 Sink token + 最近 window_size 个 token,丢弃中间 |
snapkv.py |
SnapKV | Prefill 阶段用末尾 obs_window 个 query 的 attention 选出重要 KV 位置 |
src/treekv.py |
TreeKV | 几何预算分配:历史区按块分配 top-k 预算,近块多远块少(树形结构) |
improved.py |
三种改进 | 在 SnapKV 基础上改进选取策略(见下方分析) |
conda create -n llm_accel python=3.10
conda activate llm_accel
pip install torch --index-url https://download.pytorch.org/whl/cu118
pip install transformers datasets numpycd utils/
python prepare_data.py自动下载并缓存到 ./cache/:
- Pythia-70M 模型权重
- wikitext-2-raw-v1 测试集
- pg-19 测试集前 5 个样本(保存为
cache/pg19_test_5samples.json)
cd src/
# Baseline(Full KV)
python baseline.py
# StreamingLLM(k_sink=4, window=512)
python streaming_llm.py --k_sink 4 --window_size 512 --max_eval_tokens 2048
# SnapKV(k_ratio=0.5, 保留 50% KV)
python snapkv.py --k_ratio 0.5 --obs_window 32 --local_window 32
# TreeKV(几何预算 top-k 选取,n_levels=4 → 预算比例 1:2:4:8)
python src/treekv.py --k_ratio 0.5 --obs_window 32 --local_window 32 --n_levels 4
# 改进方法(三种对比)
python improved.py --k_ratio 0.5 --k_sink 4 --obs_window 32 --local_window 32| 方法 | wikitext-2 PPL ↓ | pg-19 PPL ↓ |
|---|---|---|
| Baseline (Full KV) | 39.85 | 13.75 |
| StreamingLLM (k_sink=4, window=512) | 302.41 | 167.15 |
| SnapKV (k_ratio=0.5) | 42.24 | 31.30 |
| TreeKV (k_ratio=0.5, n_levels=4) | 41.90 | 31.16 |
| + Sink 保护 (snapkv_sink) | 42.23 | 31.30 |
| + Adaptive Per-Head (snapkv_adaptive) | 42.24 | 31.30 |
| + Sink + Adaptive (snapkv_sink_adaptive) | 42.23 | 31.30 |
| 方法 | decode tps ↑ | 峰值显存 (MB) ↓ | KV 压缩比 |
|---|---|---|---|
| Baseline (Full KV) | 277 | 148 | 1× |
| StreamingLLM (k_sink=4, window=512) | 306 | 148 | ~0.43× |
| SnapKV (k_ratio=0.5) | 231 | 284 | 0.5× |
| TreeKV (k_ratio=0.5, n_levels=4) | 45 | 286 | 0.5× |
| + Sink 保护 | 226 | 289 | 0.5× |
| + Adaptive Per-Head | 223 | 289 | 0.5× |
| + Sink + Adaptive | 223 | 289 | 0.5× |
测试配置:Pythia-70M,CUDA,fp32(eager)/ fp16(sdpa),prompt=21 tokens,生成 200 tokens。
PPL 评估:wikitext-2 全量 test set(288k tokens),pg-19 前 8192 tokens,滑动窗口 max_length=2048 / stride=512。
StreamingLLM 的 PPL 大幅高于 Baseline(302 vs 40),这是预期内的结果,不意味着实现有误。其设计目标是无限长序列的稳定生成,而非最小化 PPL:
- token-by-token 推理时,中间大量上下文被丢弃,模型看不到足够的历史信息;
k_sink=4, window=512配置下,有效上下文始终只有 ~516 token,而 wikitext-2 滑动窗口评估需要跨越更长距离的依赖;- 优势在于显存恒定(148 MB,与 Baseline 相同)且 decode 速度最快(306 tps)。
SnapKV 在 50% KV 压缩率下 PPL 仅损失约 6%(39.85 → 42.24),是有效的压缩方案。
显存反而更高(284 MB vs 148 MB)的原因:SnapKV 使用 attn_implementation="eager" + fp32 以获取 attention weights,而 Baseline 使用 sdpa + fp16,前者的中间 attention tensor 占用更大。若在相同精度下比较,SnapKV 的 KV Cache 本身确实更小。
TreeKV 在 50% 压缩率下 wikitext-2 PPL=41.90,优于 SnapKV(42.24),是本实验中压缩方法里 PPL 最低的方案。
核心优势在于几何预算分配:将历史区均分为 n_levels=4 块,各块预算比例为 1:2:4:8(最旧块最少,最新历史块最多)。这一分配符合语言建模的局部性原理——越近的 token 对下一词预测贡献越大,给它们更多 KV 预算自然带来更低的 PPL。
decode tps 仅 45(远低于 SnapKV 的 231),原因:同样使用 eager + fp32,且基准 prompt 只有 21 tokens(不触发压缩),decode 阶段 KV 无限增长,失去压缩方法应有的速度优势。在真实长 prompt 场景下,压缩后 KV 较小,速度应与 SnapKV 相近。
三种改进(Sink 保护、Adaptive Per-Head、组合)对 PPL 的提升均在 0.01 量级以内,未能拉开明显差距。
分析原因:
- Pythia-70M 头数少(8 heads/layer):各 head 的 attention 熵差异有限,entropy-weighted 聚合与等权平均在小模型上几乎等价;
- Sink token 已被自然选中:在 50% 压缩率下,前 4 个 token 的 attention score 本身就足以进入 top-k,强制保留与否差别微乎其微;
- 固定 k_ratio 限制了上限:在总 KV 预算不变的前提下,改变选取策略的收益空间本就有限。
这一结论本身也有参考价值:在小规模模型(< 1B)上,KV 选取策略的精细化带来的收益极为有限;这类改进的价值更多体现在头数多(≥ 32)、上下文更长的大模型场景中。
- Xiao et al., Efficient Streaming Language Models with Attention Sinks, 2023. [arxiv] [code]
- Li et al., SnapKV: LLM Knows What You are Looking for Before Generation, 2024. [arxiv] [code]
- Lian et al., TreeKV: Smooth Key-Value Cache Compression with Tree Structures, IJCAI 2025. [arxiv]
- Biderman et al., Pythia: A Suite for Analyzing Large Language Models Across Training and Scaling, 2023. [model]
