Skip to content

Upgrade expected attention with support for more models#126

Merged
alessiodevoto merged 4 commits intomainfrom
aledev/improved_ea
Aug 26, 2025
Merged

Upgrade expected attention with support for more models#126
alessiodevoto merged 4 commits intomainfrom
aledev/improved_ea

Conversation

@alessiodevoto
Copy link
Copy Markdown
Collaborator

@alessiodevoto alessiodevoto commented Aug 25, 2025

PR description

Right now expected attention does not support QK norm models (Gemma, Qwen). This PR:

  • changes EA code slightly to support them, by computing mean and covariance in query states.
  • makes the code of EA more standard wrt to other presses (naming conventions for Q layers and tensor dimensions)

To make sure the two methods are equivalent, I tested them with Llama 3.1-8B on Ruler 4096 with 0.5 compression, obtaining the same results with a delta < .5%, probably due to precision.

Also, I benchmarked the compression speed across 10 runs to make sure working in the query space does not slow down compression. Turns out it does not, here are the results:

Current implementation

Use covariance: True
   Seq len |    Time (ms) |   Std (ms)
--------------------------------------
       500 |      62.2459 |     2.1639
      1024 |      71.5522 |     1.0194
      2048 |     102.9493 |     1.1546
      4096 |     180.9449 |     1.4609
      8192 |     357.3696 |     1.9820
Use covariance: False
   Seq len |    Time (ms) |   Std (ms)
--------------------------------------
       500 |      63.4290 |     1.9213
      1024 |      73.5406 |     2.2720
      2048 |     102.5547 |     0.6407
      4096 |     180.7852 |     1.6742
      8192 |     357.9888 |     1.6556

New implementation

Use covariance: True
   Seq len |    Time (ms) |   Std (ms)
--------------------------------------
       500 |      60.9933 |     2.7061
      1024 |      70.5250 |     0.6948
      2048 |     102.1577 |     1.0129
      4096 |     179.3266 |     1.1186
      8192 |     357.6134 |     1.7375
Use covariance: False
   Seq len |    Time (ms) |   Std (ms)
--------------------------------------
       500 |      60.7389 |     1.8204
      1024 |      70.3875 |     0.5291
      2048 |     101.0313 |     0.7018
      4096 |     180.4022 |     1.1061
      8192 |     357.9737 |     0.9381

⚠️ EDIT: I also added a fix for #129

Checklist

  • Tests are working (make test)
  • Code is formatted correctly (make style, on errors try fix with make format)
  • Copyright header is included
  • All commits are signed-off using git commit -s
  • (new press) mypress_press.py is in the presses directory
  • (new press) MyPress is in __init__.py
  • (new press) README.md is updated with a 1 liner about the new press in the Available presses section
  • (new press) New press is in the default_presses list in tests/default_presses.py
  • (new press) A docstring is provided that follows the same structure as the existing ones

Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Copy link
Copy Markdown
Collaborator

@maxjeblick maxjeblick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the changes!
Overall, it looks good. I left some code comments. Most are minor.
Do you think one can extract get_query_states function and resue it within the codebase?

Comment thread kvpress/presses/base_press.py Outdated
Comment thread kvpress/presses/expected_attention_press.py Outdated
Comment thread kvpress/presses/expected_attention_press.py Outdated
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Copy link
Copy Markdown
Collaborator

@maxjeblick maxjeblick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks!

@alessiodevoto alessiodevoto merged commit 3d2a1e1 into main Aug 26, 2025
2 checks passed
@alessiodevoto alessiodevoto deleted the aledev/improved_ea branch August 26, 2025 12:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants