[PyTorch] Add FA4 Support#2432
Conversation
Greptile SummaryThis PR introduces Flash Attention 4 (FA4) support for PyTorch — importing
Confidence Score: 3/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[get_attention_backend] --> B{SM90?}
B -- yes --> C{FA3 flag AND\nFA4 flag?}
C -- yes --> D[Disable FA4\n⚠️ Missing v3_is_installed check]
C -- no --> E[Keep FA4]
B -- no --> E
E --> F{flash_attention_backend set}
F --> G[backends.py: use_flash_attn_4\nif backend > 4.0.0b]
F --> H[backends.py: use_flash_attn_3\nif 3.0.0b < backend < 4.0.0\n⚠️ 4.0.0b8 satisfies this too]
G --> I{use_flash_attn_4?}
H --> J{use_flash_attn_3?}
I -- yes --> K[flash_attn_func_v4 /\nflash_attn_varlen_func_v4]
J -- yes, elif --> L[flash_attn_func_v3 /\nflash_attn_varlen_func_v3]
I -- no, J no --> M[flash_attn_func v2]
|
| output = func( | ||
| query_layer, | ||
| key_layer, | ||
| value_layer, | ||
| softmax_scale=self.softmax_scale, | ||
| causal="causal" in attn_mask_type, |
There was a problem hiding this comment.
causal_bottom_right treated identically to causal for FA4
causal="causal" in attn_mask_type evaluates to True for both "causal" and "causal_bottom_right". If FA4's flash_attn_func supports a separate bottom-right diagonal alignment flag (similar to how cuDNN fused attention distinguishes the two), passing only causal=True would produce incorrect results for causal_bottom_right configs.
This is consistent with the existing FA2 path, but since fa4_mask_causal_br is explicitly added as a test case, it is worth verifying that the FA4 causal parameter correctly implements both variants, or adding a dedicated causal_bottom_right kwarg if the FA4 API exposes one.
0708391 to
4760264
Compare
|
/te-ci pytorch |
vcherepanov-nv
left a comment
There was a problem hiding this comment.
Looks good overall, please consider (as also noted by AI):
- if you need tighter version version check, i.e. minimum required version for FA4
- if it's cleaner to remove the currently dead / unreachable CP-related code
Approval is contingent upon prior corresponding changes on te_ci branch and green pipeline with the new tests.
|
Hello! I was wondering if FA4 support on TE is still a work in progress? Thanks! |
|
Tip: Greploop — Automatically fix all review issues by running Use the Greptile plugin for Claude Code to query reviews, search comments, and manage custom context directly from your terminal. |
|
/te-ci pytorch L0 L3 |
|
Is there an ETA for this? FA4 support in TE would be a requirement for downstream FA4 support in megatron-core with TEDotProductAttention |
|
/te-ci pytorch L0 L3 |
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
|
/te-ci pytorch L0 L3 |
|
@umiswing @bbuschkaemper Thanks for your attention. This PR is ready to be merged, likely in today or tmr. |
Description
Need help to install
flash-attn-4in the CI container to enable FA4 tests.Type of change
Checklist:
CI test time impact
L3_pytorch_FA_versions_test--B200_1GPUincreases from ~20 mins to ~40 mins.L3_pytorch_FA_versions_test--H100_1GPUincreases from ~101 mins to ~127 mins.