Skip to content

AtomGradient/H2OAttnScore

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

26 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

H₂O Attention Score KV Cache Eviction for On-Device LLM Inference

Work in Progress — This research is ongoing. Results and conclusions may be updated as additional optimizations are implemented and validated.

By AtomGradient

Zero-overhead attention score export from fused Metal kernels, enabling H₂O (Heavy-Hitter Oracle) KV cache eviction in Apple's MLX ecosystem. 2–3× better quality retention than sliding-window eviction at identical memory budgets.

Key Results

Model Baseline PPL Rotating (+Δ%) H₂O-heavy (+Δ%)
Qwen3-4B-4bit 7.70 7.94 (+3.2%) 7.81 (+1.4%)
Qwen3-8B-4bit 6.26 6.43 (+2.7%) 6.32 (+0.9%)
Qwen3-8B-3bit 7.62 7.89 (+3.5%) 7.73 (+1.5%)

max_kv_size=256, 10 samples × 512 tokens, Apple M2 Ultra

  • Speed overhead: <0.3% (score export adds 1 conditional float store per KV position)
  • Memory overhead: ~0.1 GB (scores buffer, fixed regardless of sequence length)
  • Best configuration: H₂O-heavy (h=max/2, r=max/2-sink) achieves +0.4% PPL in 20-sample eval

How It Works

H₂O tracks cumulative attention scores and retains three categories of KV entries:

  1. Sink tokens — first few positions (attention sinks)
  2. Heavy hitters — highest cumulative attention score (critical context)
  3. Recent window — most recent tokens

Everything else is evicted when the cache exceeds max_size.

Architecture

Standard SDPA:           H₂O SDPA (+1 store):
  Q, K, V                  Q, K, V
    │                        │
    ▼                        ▼
  sdpa_vector             sdpa_vector + scores_out
    │                        │         │
    ▼                        ▼         ▼
  Output                  Output    Scores
  (scores discarded)                  │
                                      ▼
                                  H₂O Eviction

The key insight: the fused Metal kernel already computes score = scale * Q · K + mask internally. We add a single conditional store to export it — zero extra FLOPs.

Repository Structure

H2OAttnScore/
├── mlx/                  ← MLX core: Metal kernel + C++ dispatch + Python binding
├── mlx-lm/               ← Python: H2OKVCache + model integration
├── mlx-swift/            ← Swift: C API binding + MLXFast wrapper
├── mlx-swift-lm/         ← Swift: H2OKVCache + AttentionUtils routing
├── benchmark_h2o.py      ← Speed benchmark script
├── benchmark_comprehensive.py  ← Multi-model benchmark
├── eval_h2o_ppl.py       ← PPL evaluation script
├── profile_layer_attention.py  ← Per-layer attention profiling
├── adaptive_h2o.py       ← Adaptive budget allocation algorithm
├── benchmark_results.json ← Raw benchmark data
├── profiling_results/    ← Per-layer attention profiles + plots
└── docs/                 ← Paper (LaTeX + PDF) + GitHub Pages site

Quick Start

Python (mlx-lm)

# Activate environment
source ~/Documents/mlx-community/3-11-mlx-community-env/bin/activate

# Install modified mlx-lm
cd mlx-lm && pip install -e .

# Generate with H₂O KV cache
python -m mlx_lm generate \
  --model mlx-community/Qwen3-4B-4bit \
  --prompt "Explain quantum computing" \
  --h2o-max-size 1024 \
  --h2o-heavy-budget 256 \
  --h2o-recent-budget 512 \
  --h2o-sink-size 4

# Run benchmark
python benchmark_comprehensive.py

# Run PPL evaluation
python eval_h2o_ppl.py --num-samples 10 --max-kv-size 256

Python API

from mlx_lm import load, generate

model, tokenizer = load("mlx-community/Qwen3-4B-4bit")

response = generate(
    model, tokenizer,
    "Explain quantum computing",
    max_tokens=500,
    h2o_config={
        "max_size": 1024,
        "heavy_budget": 256,
        "recent_budget": 512,
        "sink_size": 4,
    },
)

Swift (mlx-swift-lm)

import MLXLMCommon

// Create H₂O cache for each layer
let cache = (0..<numLayers).map { _ in
    H2OKVCache(maxSize: 1024, heavyBudget: 256, recentBudget: 512, sinkSize: 4)
}

// Use with any model — transparent routing via attentionWithCacheUpdate
let output = attentionWithCacheUpdate(
    queries: queries, keys: keys, values: values,
    cache: cache[layerIdx], scale: scale, mask: mask
)

Modified Files

MLX Core (Metal Kernel)

File Change
mlx/backend/metal/kernels/sdpa_vector.h output_scores function constant + scores_out buffer
mlx/backend/metal/scaled_dot_product_attention.cpp C++ dispatch with scores parameter
mlx/fast_primitives.h output_scores_ member
mlx/fast.h + mlx/fast.cpp scaled_dot_product_attention_with_scores() API
python/src/fast.cpp Python binding

mlx-lm (Python)

File Change
mlx_lm/models/cache.py H2OKVCache class (~180 lines)
mlx_lm/models/base.py Kernel score routing + GQA handling
mlx_lm/generate.py h2o_config parameter + CLI args

mlx-swift + mlx-swift-lm (Swift)

File Change
Cmlx/mlx-c/mlx/c/fast.cpp + .h C API binding
MLX/MLXFast.swift scaledDotProductAttentionWithScores()
MLXLMCommon/KVCache.swift H2OKVCache class
MLXLMCommon/AttentionUtils.swift H₂O routing

Technical Notes

RoPE Offset Bug (Fixed)

A critical bug was discovered during development: after KV cache eviction, self.offset was reset to the physical cache length, breaking RoPE positional encoding. The fix separates:

  • offset — total tokens processed (monotonically increasing, for RoPE)
  • _cache_len — physical cache entries (decreases after eviction)

EMA Score Accumulation

Cumulative scores use exponential moving average with α=0.95 decay:

C_t = 0.95 × C_{t-1} + 0.05 × |s_t|

Absolute value is used since pre-softmax scores can be negative.

GQA Support

For Grouped Query Attention, kernel-exported scores have shape [B, n_q_heads, L_q, L_k]. These are averaged across query head groups to produce per-KV-head scores before accumulation.

Publication

Citation

@article{atomgradient2026h2o,
  title={H2O Attention Score KV Cache Eviction for On-Device LLM Inference},
  author={AtomGradient},
  year={2026},
  url={https://github.com/AtomGradient/H2OAttnScore}
}

Acknowledgements

This work builds upon the MLX framework by Apple and the H₂O algorithm by Zhang et al.

About

An explore on H2O Attention Score on Apple MLX

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors