-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a new SdpaFwdOp
IR node for Flash Attention
#2294
Conversation
SdpaOp
IR nodeSdpaFwdOp
IR node for Flash Attention
!build |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it correct to say that before accepting this in ExprEval scheduler, we need to handle the other execution modes?
No. I mainly separated them to handle the mapping, and the ID graphs workarounds separately than the node to reduce the scope of this PR. As we discussed in today's meeting, at the moment we only plan on supporting Flash Attention to support multi-GPU development. Once we support Flash Attention, we can revisit, if we need to add Memory Efficient Attention as well. There could be a few ways:
|
I dont understand how to do partial support. If we are given inputs that we cannot evaluate with flashattention, what will happen? |
We will only accept the op if the backend identified is Flash Attention in Thunder: https://github.com/Lightning-AI/lightning-thunder/blob/9f0c50cc6df187cf5fd2e31240690fe2b5e9ccc1/thunder/executors/sdpaex.py#L618-L680. Do you think this will not be sufficient? |
Makes sense to me. The logic is taking place in thunder to dispatch to flash attention, which seems fine. |
Is anyone still reviewing this PR? @jacobhinkle? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM once this is addressed: #2294 (comment)
!build |
Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com>
!build |
2 similar comments
!build |
!build |
!build |
The failing tests look unrelated. |
Stacked on #2294. 1. Adds the producer-consumer mapping to root domain map. 2. Adds `SDPAOp` to `ExprEvalScheduler`. 3. Modifies `ExprEvalSched::canSchedule` to skip computeAt checks and only use the compile time check since expression evaluator scheduler will only accept segments with a single expression of type MatmulOp / LinearOp / SdpaOp., --------- Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com>
Issue #2278.
This PR adds the new node with same functionality astorch.nn.functional.scaled_dot_product_attention
, and enables scheduling it throughExprEvalScheduler
.Based on the PR discussions, this PR is repurposed to introduce a new IR node
SdpaFwdOp
for scaled dot product flash attention forward (see #2278 for details).This PR does not include changes to the scheduler.
The next PRs will:
ExprEvalScheduler
.After the completion of these tasks, we also aim at introducing Memory Efficient Attention.