Skip to content

[PyTorch] Add FA4 Support#2432

Merged
vcherepanov-nv merged 5 commits intoNVIDIA:mainfrom
yaox12:xiny/fa4
Apr 17, 2026
Merged

[PyTorch] Add FA4 Support#2432
vcherepanov-nv merged 5 commits intoNVIDIA:mainfrom
yaox12:xiny/fa4

Conversation

@yaox12
Copy link
Copy Markdown
Member

@yaox12 yaox12 commented Nov 28, 2025

Description

  • Add FA4 support
  • Add tests

Need help to install flash-attn-4 in the CI container to enable FA4 tests.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

CI test time impact

  • Time cost of L3_pytorch_FA_versions_test--B200_1GPU increases from ~20 mins to ~40 mins.
  • Time cost of L3_pytorch_FA_versions_test--H100_1GPU increases from ~101 mins to ~127 mins.

Comment thread transformer_engine/pytorch/attention/dot_product_attention/backends.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/utils.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/utils.py
Comment thread transformer_engine/pytorch/attention/dot_product_attention/utils.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/backends.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/utils.py
@yaox12 yaox12 marked this pull request as ready for review March 19, 2026 22:15
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Mar 19, 2026

Greptile Summary

This PR introduces Flash Attention 4 (FA4) support for PyTorch — importing flash_attn.cute.interface, adding get_attention_backend filters across SM80/90/100/120, wiring the FA4 forward dispatch in FlashAttention.forward, and adding five new test groups with appropriate skip guards.

  • P1 – SM90 FA4 fully disabled when FA3 absent (utils.py:467): the SM90 preference check fires on use_flash_attention_3=1 (the env-var default) without checking FlashAttentionUtils.v3_is_installed; the cleanup at line 1289 then zeroes FA3 too, leaving no flash attention backend for users who install only FA4 on Hopper.
  • P1 – use_flash_attn_3 wrongly True for FA4 beta (backends.py:943-946): PEP 440 orders 4.0.0b8 < 4.0.0, so the range guard < PkgVersion("4.0.0") includes all 4.0.0bN versions; the upper bound should mirror the FA4 lower bound (< PkgVersion("4.0.0b")).

Confidence Score: 3/5

  • Two P1 logic bugs need to be fixed before merging: SM90 + FA4-only users get no flash attention backend, and FA4 beta versions incorrectly set use_flash_attn_3=True.
  • Two confirmed P1 defects on the changed code path: (1) the SM90 preference guard disables FA4 even when FA3 is not installed, and (2) the PEP 440 upper-bound fix for use_flash_attn_3 is incomplete for beta FA4 versions (4.0.0bN < 4.0.0 in PEP 440). Both are reproducible with the exact FA4 version pinned in this PR (4.0.0b8).
  • transformer_engine/pytorch/attention/dot_product_attention/utils.py (line 467) and transformer_engine/pytorch/attention/dot_product_attention/backends.py (lines 943-946)

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/backends.py Adds FA4 import block, forward dispatch (flash_attn_func_v4 / flash_attn_varlen_func_v4), and use_flash_attn_4 detection flag; the upper-bound for the FA3 range check (< 4.0.0) incorrectly includes FA4 beta versions (4.0.0b8 < 4.0.0 in PEP 440), causing both use_flash_attn_3 and use_flash_attn_4 to be True simultaneously.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Adds FlashAttentionUtils FA4 fields, set_flash_attention_4_params(), and get_attention_backend filters for FA4; the SM90 preference check (line 467) uses the env-var flag use_flash_attention_3 without gating on v3_is_installed, so FA4 is incorrectly disabled on SM90 when FA3 is not actually installed, leaving no flash attention backend.
tests/pytorch/attention/test_attention.py Adds five FA4-specific test groups (base, MLA, sliding window, varlen, mask types) with appropriate skip guards; sys.path prepend fix is a minor correctness improvement; test coverage is comprehensive.
qa/L3_pytorch_FA_versions_test/test.sh Adds FA4 to version matrices and installs flash-attn-4 with nvidia-cutlass-dsl[cu13]; on SM90, FA3 remains installed when FA4 tests run, so the FA4-specific tests silently execute under FA3 backend preference.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[get_attention_backend] --> B{SM90?}
    B -- yes --> C{FA3 flag AND\nFA4 flag?}
    C -- yes --> D[Disable FA4\n⚠️ Missing v3_is_installed check]
    C -- no --> E[Keep FA4]
    B -- no --> E

    E --> F{flash_attention_backend set}
    F --> G[backends.py: use_flash_attn_4\nif backend > 4.0.0b]
    F --> H[backends.py: use_flash_attn_3\nif 3.0.0b < backend < 4.0.0\n⚠️ 4.0.0b8 satisfies this too]

    G --> I{use_flash_attn_4?}
    H --> J{use_flash_attn_3?}

    I -- yes --> K[flash_attn_func_v4 /\nflash_attn_varlen_func_v4]
    J -- yes, elif --> L[flash_attn_func_v3 /\nflash_attn_varlen_func_v3]
    I -- no, J no --> M[flash_attn_func v2]
Loading

Comments Outside Diff (2)

  1. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 467-470 (link)

    FA4 incorrectly disabled on SM90 when FA3 is not installed

    use_flash_attention_3 is set to 1 from the env-var default, regardless of whether FA3 is actually installed. On an SM90 machine where only FA4 is present, this check fires, disables FA4, and then the cleanup at line 1289 disables FA3 too (if use_flash_attention_3 and not FlashAttentionUtils.v3_is_installed). The net result is no flash attention backend selected at all.

  2. transformer_engine/pytorch/attention/dot_product_attention/backends.py, line 943-946 (link)

    use_flash_attn_3 incorrectly True for FA4 beta versions

    In PEP 440, pre-release versions sort below the corresponding release: 4.0.0b8 < 4.0.0. This means PkgVersion("3.0.0b") < PkgVersion("4.0.0b8") < PkgVersion("4.0.0") evaluates to True, so when the installed FA4 version is 4.0.0b8, both use_flash_attn_4 and use_flash_attn_3 are set to True. The intended upper bound for the FA3 range should align with the lower bound used for the FA4 range (4.0.0b/4.0.0b0):

Reviews (10): Last reviewed commit: "fix sm90" | Re-trigger Greptile

Comment thread transformer_engine/pytorch/attention/dot_product_attention/backends.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/utils.py
Comment thread transformer_engine/pytorch/attention/dot_product_attention/backends.py Outdated
Comment on lines +1037 to +1042
output = func(
query_layer,
key_layer,
value_layer,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 causal_bottom_right treated identically to causal for FA4

causal="causal" in attn_mask_type evaluates to True for both "causal" and "causal_bottom_right". If FA4's flash_attn_func supports a separate bottom-right diagonal alignment flag (similar to how cuDNN fused attention distinguishes the two), passing only causal=True would produce incorrect results for causal_bottom_right configs.

This is consistent with the existing FA2 path, but since fa4_mask_causal_br is explicitly added as a test case, it is worth verifying that the FA4 causal parameter correctly implements both variants, or adding a dedicated causal_bottom_right kwarg if the FA4 API exposes one.

@yaox12 yaox12 force-pushed the xiny/fa4 branch 4 times, most recently from 0708391 to 4760264 Compare March 19, 2026 23:01
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented Mar 19, 2026

/te-ci pytorch

@KshitijLakhani KshitijLakhani requested a review from mk-61 March 19, 2026 23:28
vcherepanov-nv
vcherepanov-nv previously approved these changes Mar 20, 2026
Copy link
Copy Markdown
Collaborator

@vcherepanov-nv vcherepanov-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, please consider (as also noted by AI):

  • if you need tighter version version check, i.e. minimum required version for FA4
  • if it's cleaner to remove the currently dead / unreachable CP-related code

Approval is contingent upon prior corresponding changes on te_ci branch and green pipeline with the new tests.

@umiswing
Copy link
Copy Markdown

umiswing commented Apr 9, 2026

Hello! I was wondering if FA4 support on TE is still a work in progress? Thanks!

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 10, 2026

Tip:

Greploop — Automatically fix all review issues by running /greploops in Claude Code. It iterates: fix, push, re-review, repeat until 5/5 confidence.

Use the Greptile plugin for Claude Code to query reviews, search comments, and manage custom context directly from your terminal.

@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented Apr 10, 2026

/te-ci pytorch L0 L3

@bbuschkaemper
Copy link
Copy Markdown
Contributor

Is there an ETA for this? FA4 support in TE would be a requirement for downstream FA4 support in megatron-core with TEDotProductAttention

@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented Apr 14, 2026

/te-ci pytorch L0 L3

yaox12 added 5 commits April 14, 2026 09:15
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented Apr 14, 2026

/te-ci pytorch L0 L3

@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented Apr 16, 2026

@umiswing @bbuschkaemper Thanks for your attention. This PR is ready to be merged, likely in today or tmr.

@vcherepanov-nv vcherepanov-nv merged commit c5a4fd5 into NVIDIA:main Apr 17, 2026
23 of 26 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants