-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add new flash attn features to cuDNN SDPA API and remove fused attn #21228
base: main
Are you sure you want to change the base?
Conversation
Cjkkkk
commented
May 14, 2024
- add variable sequence length: Accepts two additional tensor seqlen_q and seqlen_kv to indicate the non padded length to reduce computation.
- add MQA/GQA.
- add broadcast bias: bias can be broadcast on batch/head dim.
- add dbias calculation.
- remove fused attn and default to flash attn.
@superbobry Hi Sergei, could you help review this PR? |
@superbobry hi Sergei, any updates on this? |
No updates just yet, sorry. I will review some time tomorrow. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did my best to read through, but these large diffs are really hard to get through. Please send smaller PRs for any follow up changes.
I would also recommend to ask someone from NVidia to review for CUDNN APIs etc.
@@ -41,10 +41,42 @@ class AttentionLayout(Enum): | |||
BTNH = 0 | |||
BNTH = 1 | |||
|
|||
class MaskType(Enum): | |||
NO_MASK = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OOC why not use None
instead, when a mask is not specified?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a choice to make it more explicit
Understood, sorry for the large PR, i will create smaller one next time. I think people from Nvidia don't have access to approve and merge the PR? |
Please address the comments, and the we can merge. |
Comments addressed, sorry about the delay. |
Can you squash the PR please? |
627a064
to
403ad05
Compare