Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a new SdpaFwdOp IR node for Flash Attention #2294

Merged
merged 19 commits into from
Jun 10, 2024
Merged

Add a new SdpaFwdOp IR node for Flash Attention #2294

merged 19 commits into from
Jun 10, 2024

Conversation

Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented May 23, 2024

Issue #2278.
This PR adds the new node with same functionality as torch.nn.functional.scaled_dot_product_attention, and enables scheduling it through ExprEvalScheduler.

Based on the PR discussions, this PR is repurposed to introduce a new IR node SdpaFwdOp for scaled dot product flash attention forward (see #2278 for details).
This PR does not include changes to the scheduler.

The next PRs will:

  1. Add mapping for producer-consumer in root_domain_map and enable this op in ExprEvalScheduler.
  2. Python API
  3. Add a node for backward pass similar to this forward node.

After the completion of these tasks, we also aim at introducing Memory Efficient Attention.

@Priya2698 Priya2698 changed the title [WIP] Add a new SdpaOp IR node Add a new SdpaOp IR node May 23, 2024
@Priya2698 Priya2698 marked this pull request as ready for review May 23, 2024 23:49
csrc/ir/internal_nodes.h Outdated Show resolved Hide resolved
csrc/ir/internal_nodes.h Outdated Show resolved Hide resolved
tests/cpp/utils.h Outdated Show resolved Hide resolved
csrc/ir/internal_nodes.h Outdated Show resolved Hide resolved
@Priya2698 Priya2698 changed the title Add a new SdpaOp IR node [WIP] Add a new SdpaOp IR node May 29, 2024
@Priya2698 Priya2698 marked this pull request as draft May 29, 2024 22:48
@Priya2698 Priya2698 changed the title [WIP] Add a new SdpaOp IR node Add a new SdpaFwdOp IR node for Flash Attention Jun 3, 2024
@Priya2698
Copy link
Collaborator Author

!build

@Priya2698 Priya2698 marked this pull request as ready for review June 3, 2024 22:33
Copy link
Collaborator

@jacobhinkle jacobhinkle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it correct to say that before accepting this in ExprEval scheduler, we need to handle the other execution modes?

csrc/ir/nodes.cpp Outdated Show resolved Hide resolved
csrc/ir/nodes.cpp Show resolved Hide resolved
@Priya2698
Copy link
Collaborator Author

Is it correct to say that before accepting this in ExprEval scheduler, we need to handle the other execution modes?

No. I mainly separated them to handle the mapping, and the ID graphs workarounds separately than the node to reduce the scope of this PR.

As we discussed in today's meeting, at the moment we only plan on supporting Flash Attention to support multi-GPU development. Once we support Flash Attention, we can revisit, if we need to add Memory Efficient Attention as well. There could be a few ways:

  1. Plumbing down the backend info from Thunder and using that within our nodes: While the two implementations have different function signatures, there are overlaps and hence, one possibility is to use a superset of the inputs and outputs. The other design here would be distinct nodes for each implementation.
  2. We make the decision about the backend within nvFuser using the same logic as ATen/Thunder. See: https://github.com/Lightning-AI/lightning-thunder/blob/9f0c50cc6df187cf5fd2e31240690fe2b5e9ccc1/thunder/executors/sdpaex.py#L618-L680

@jacobhinkle
Copy link
Collaborator

Is it correct to say that before accepting this in ExprEval scheduler, we need to handle the other execution modes?

No. I mainly separated them to handle the mapping, and the ID graphs workarounds separately than the node to reduce the scope of this PR.

As we discussed in today's meeting, at the moment we only plan on supporting Flash Attention to support multi-GPU development. Once we support Flash Attention, we can revisit, if we need to add Memory Efficient Attention as well. There could be a few ways:

  1. Plumbing down the backend info from Thunder and using that within our nodes: While the two implementations have different function signatures, there are overlaps and hence, one possibility is to use a superset of the inputs and outputs. The other design here would be distinct nodes for each implementation.
  2. We make the decision about the backend within nvFuser using the same logic as ATen/Thunder. See: https://github.com/Lightning-AI/lightning-thunder/blob/9f0c50cc6df187cf5fd2e31240690fe2b5e9ccc1/thunder/executors/sdpaex.py#L618-L680

I dont understand how to do partial support. If we are given inputs that we cannot evaluate with flashattention, what will happen?

@Priya2698
Copy link
Collaborator Author

Is it correct to say that before accepting this in ExprEval scheduler, we need to handle the other execution modes?

No. I mainly separated them to handle the mapping, and the ID graphs workarounds separately than the node to reduce the scope of this PR.
As we discussed in today's meeting, at the moment we only plan on supporting Flash Attention to support multi-GPU development. Once we support Flash Attention, we can revisit, if we need to add Memory Efficient Attention as well. There could be a few ways:

  1. Plumbing down the backend info from Thunder and using that within our nodes: While the two implementations have different function signatures, there are overlaps and hence, one possibility is to use a superset of the inputs and outputs. The other design here would be distinct nodes for each implementation.
  2. We make the decision about the backend within nvFuser using the same logic as ATen/Thunder. See: https://github.com/Lightning-AI/lightning-thunder/blob/9f0c50cc6df187cf5fd2e31240690fe2b5e9ccc1/thunder/executors/sdpaex.py#L618-L680

I dont understand how to do partial support. If we are given inputs that we cannot evaluate with flashattention, what will happen?

We will only accept the op if the backend identified is Flash Attention in Thunder: https://github.com/Lightning-AI/lightning-thunder/blob/9f0c50cc6df187cf5fd2e31240690fe2b5e9ccc1/thunder/executors/sdpaex.py#L618-L680.

Do you think this will not be sufficient?

@jacobhinkle
Copy link
Collaborator

Do you think this will not be sufficient?

Makes sense to me. The logic is taking place in thunder to dispatch to flash attention, which seems fine.

@naoyam
Copy link
Collaborator

naoyam commented Jun 6, 2024

Is anyone still reviewing this PR? @jacobhinkle?

Copy link
Collaborator

@jacobhinkle jacobhinkle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once this is addressed: #2294 (comment)

@Priya2698
Copy link
Collaborator Author

!build

@Priya2698
Copy link
Collaborator Author

!build

2 similar comments
@naoyam
Copy link
Collaborator

naoyam commented Jun 7, 2024

!build

@Priya2698
Copy link
Collaborator Author

!build

@Priya2698
Copy link
Collaborator Author

!build

@Priya2698
Copy link
Collaborator Author

The failing tests look unrelated.

@Priya2698 Priya2698 merged commit 23ee81d into main Jun 10, 2024
35 of 37 checks passed
@Priya2698 Priya2698 deleted the pm/sdpa branch June 10, 2024 20:37
Priya2698 added a commit that referenced this pull request Jun 11, 2024
Stacked on #2294.

1. Adds the producer-consumer mapping to root domain map.
2. Adds `SDPAOp` to `ExprEvalScheduler`.
3. Modifies `ExprEvalSched::canSchedule` to skip computeAt checks and
only use the compile time check since expression evaluator scheduler
will only accept segments with a single expression of type MatmulOp /
LinearOp / SdpaOp.,

---------

Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants