Skip to content

Conversation

@benraha
Copy link
Contributor

@benraha benraha commented Sep 7, 2025

Motivation and Context

This PR has two changes:

  1. Replace costly einsum with simple matrix multiplication - this version is faster in both PyTorch and ONNX. IMO, we do miss some of the clarity, but since this happens in every forward pass, it is worth it.
  2. Add a fast-path for the broadcast_kv_across_heads function - since the use case where we don't have to change anything happens 50% of the time, this simple fix has great performance benefits.

Public API Changes

  • No Public API changes
  • Yes, Public API changes (Details below)

How Has This Been Tested?

Locally.


Checklist

  • The changes have been tested locally.
  • Documentation has been updated (if the public API or usage changes).
  • A entry has been added to CHANGELOG.md (if relevant for users).
  • The code follows the project's style guidelines.
  • I have considered the impact of these changes on the public API.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces two performance optimizations to the attention mechanism. The first replaces a costly einsum operation with a more efficient matrix multiplication using torch.mm, and the second adds a fast-path to broadcast_kv_across_heads for a common use case. Both changes are logical and aim to improve performance. My review includes a suggestion to further refine the matrix multiplication implementation for better readability and to use a more idiomatic PyTorch function.

@benraha
Copy link
Contributor Author

benraha commented Sep 10, 2025

@priorphil can you please review? :)

@priorphil
Copy link
Contributor

Thanks! Could you share some of the benchmarks (including hardware, datasets shapes, dtype, timings) you ran for each of these changes so I can get a feeling for the magnitude of the speedup? :)

@benraha
Copy link
Contributor Author

benraha commented Sep 10, 2025

Sure! I wrap the model in ONNX, and I measure a 15% reduction in the number of nodes in the graph and ~10% faster inference.

I work on CPU inference, with a fit of 1000 rows, running inference on a few lines at a time.

Copy link
Contributor

@priorphil priorphil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the changes.

@priorphil priorphil merged commit f0ad402 into PriorLabs:main Sep 10, 2025
10 checks passed
oscarkey pushed a commit that referenced this pull request Nov 12, 2025
#141)

* Record copied public PR 488

* QKV calculation improvements in attention mechanism (#488)

(cherry picked from commit f0ad402)

---------

Co-authored-by: mirror-bot <mirror-bot@users.noreply.github.com>
Co-authored-by: benraha <benraha@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants