[PyTorch] Add pad_between_seqs support for non-CP and CP (A2A and P2P) with FA3 + THD (varlen)#2596
Conversation
Greptile SummaryThis PR extends
Confidence Score: 3/5The production attention kernels are wired correctly for the supported P2P and A2A paths, but the new tests do not actually exercise the new code path due to a bool/string type mismatch, and a secondary ordering problem in the test runner would cause those tests to fail even after the type mismatch is fixed. Multiple issues compound in the test layer: every pad_between_seqs=True CP test silently falls back to pad_between_seqs=False because run_dpa_with_cp compares the bool True against the string "True". Fixing that comparison would then expose the NaN-check-before-zeroing ordering problem, where FA3 tile-spillover NaN values in reference tensors trigger the assertion at line 562 before the zeroing code at line 277 can run. The net effect is that the feature ships without any working test coverage of the new code path. Additionally, all_gather CP mode silently ignores pad_between_seqs, and a2a+p2p lacks a production-time guard. tests/pytorch/attention/test_attention_with_cp.py and tests/pytorch/attention/run_attention_with_cp.py need attention for the bool/string type mismatch and NaN assertion ordering; transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py needs a runtime guard or error for the all_gather+pad_between_seqs and a2a+p2p+pad_between_seqs combinations. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["DotProductAttention.forward()"] --> B{cp_group?}
B -- No --> C["FlashAttention.forward()"]
B -- Yes --> D["attn_forward_func_with_cp()"]
C --> E{pad_between_seqs + FA3?}
E -- Yes --> F["Append cu_seqlens_q_padded as layout descriptor + set seqused_q/k kwargs"]
E -- No --> G["Append cu_seqlens_q as layout descriptor"]
D --> H{cp_comm_type}
H -- p2p / a2a+p2p --> I["AttnFuncWithCPAndKVP2P (pad_between_seqs ✓)"]
H -- a2a --> J["AttnFuncWithCPAndQKVOA2A (pad_between_seqs ✓)"]
H -- all_gather --> K["AttnFuncWithCPAndKVAllGather (pad_between_seqs ✗ not forwarded)"]
I --> L["cp_p2p_fwd_flash_attn(): seqused=diff(cu_seqlens_per_step), cu_seqlens→padded"]
J --> M["get_fa_args(seqused_q, seqused_k): cu_seqlens→cu_seqlens_padded"]
style K fill:#ffcccc
Reviews (50): Last reviewed commit: "Merge branch 'main' into flash_attn_pad_..." | Re-trigger Greptile |
ea51821 to
e338049
Compare
|
/te-ci pytorch L2 |
|
/te-ci pytorch L1 |
|
/te-ci pytorch L3 |
b0a3c64 to
057f406
Compare
|
/te-ci pytorch L3 |
1 similar comment
|
/te-ci pytorch L3 |
00bdc92 to
0f48ebc
Compare
| if not FlashAttentionUtils.v3_is_installed: | ||
| pytest.skip("pad_between_seqs with CP requires Flash Attention v3!") | ||
| if cp_comm_type == "a2a+p2p": | ||
| pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!") |
| if pad_between_seqs: | ||
| dq, dk, dv = [torch.zeros_like(x) for x in [q_part, k_part, v_part]] | ||
| else: | ||
| dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] |
There was a problem hiding this comment.
Just to confirm, we can't do this for fwd, right? Because fwd output is not allocated by us.
There was a problem hiding this comment.
It's a limitation in Flash Attention code - forward never mutates out (so pre-zeroing is overwritten), backward treats dq/dk/dv as in-place mutable (so pre-zeroing sticks). Also this zeroing out works only for CP code where we can provide the args.
None of the zeroing works for non-CP path because we only have the forward call in TE.
FA3 / Hopper (hopper/flash_attn_interface.py)
- Forward: mutates_args=() _ namespace flash_attn_3::_flash_attn_forward
- Backward: mutates_args=("dq", "dk", "dv") _ namespace flash_attn_3::_flash_attn_backward
|
/te-ci pytorch L3 |
Add support for padding between sequences (pad_between_seqs) in the FlashAttention 3 backend when used with context parallelism (CP). Key changes: - backends.py: Pass fa_pad_between_seqs through to FA3 forward/backward - context_parallel.py: Handle pad_between_seqs in A2A and P2P CP paths, zero FA3 padding garbage in CP forward, fix a2a backward alignment - dot_product_attention.py: Auto-detect pad_between_seqs from cu_seqlens - utils.py: Gate FA3 deterministic backward for hdim>=256, fix flash_attn_supported override for cross-attention and large head_dim, disable UnfusedDotProductAttention for pad_between_seqs, add SM100+ FA3 skip Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Add test parametrization for pad_between_seqs in flash attention tests. Update run_attention_with_cp.py to support the new parameter and fix batch boundary alignment in the non-CP FA3 path. Run tests in parallel when multiple GPUs are available. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Add deterministic CP test runs to L3 FA versions test. Support TE_PATH positional arg and fix GPU threshold for parallel test execution. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…raint The previous check disabled FA3 for deterministic mode whenever head_dim_qk > 128, which was overly conservative — FA3 forward supports deterministic execution at any head dim. The actual constraint from flash_api.cpp is that the backward pass does not support deterministic mode when max(head_size, head_size_v) >= 256. Narrow the gate to only disable FA3 during training (backward) and raise the threshold to >= 256, checking both head_dim_qk and head_dim_v to handle MLA configs with asymmetric head dimensions. Ref: https://github.com/Dao-AILab/flash-attention/blob/ac6f2eb5/hopper/flash_api.cpp#L1370 Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
9c01601 to
4745f98
Compare
The pad_between_seqs gate in get_attention_backend only disabled FlashAttention 2, letting FA4 leak through to the test-time fused-vs-flash comparison. On B200 runners that install flash-attn-4, this caused test_dpa_qkv_layout_thd to compare FusedAttention against an FA4 output whose padded positions contain garbage, producing 48 numerics failures in L3_pytorch_FA_versions_test--B200_1GPU. The log message already claimed FA4 would be disabled — this change makes the code match the message: set use_flash_attention_4 = False alongside use_flash_attention_2 when pad_between_seqs is True. FA3 continues to support pad_between_seqs via seqused_q/seqused_k. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
pad_between_seqs support for non-CP and CP (A2A and P2P) with FA3 + THD (varlen)
…_attn_pad_bw_seqs
…_attn_pad_bw_seqs
Each (world_size) is served by one long-lived torchrun running run_attention_with_cp_pool.py. Tests submit work over rank-0 stdin as JSON and read results from rank-0 stdout, replacing the per-test torchrun launch path. NCCL init/destroy happens once per pool, not once per case, eliminating ~9s overhead per test and fixing L3 timeouts. Why two pool sizes: cp_comm_type="a2a+p2p" needs world_size=4; everything else uses world_size=2. We can't resize an active PG, so one pool per world_size, routed by num_gpus. Pools spawn lazily on first use so a session that only exercises 2-GPU cases never pays the 4-GPU init cost. Includes: - PoolWorker class with sentinel-prefixed JSON protocol over rank-0 stdio (sentinel filters out torchrun status / library prints that share the stdout fd) - Stderr ring buffer (200 lines / ~4 KB tail) attached to crash-path AssertionErrors so CI JUnit XML shows the real failure cause - POOL_SUBMIT_TIMEOUT_SEC defaulting to 90 s (~6x p50 case wall on H100); override via NVTE_CP_POOL_TIMEOUT_SEC - Stream race fix on max_logit_per_step in all-gather CP forward: wait_stream(flash_attn_streams[i-1]) before torch.maximum, so the read on the default stream doesn't race with the write on cp_stream in iteration i=2. The pool's persistent process exposed this latent race; per-process subprocess design happened to schedule it safely. - Deep-copy of model_configs_flash_attn[model] to prevent in-place attn_mask_type mutation from leaking across pool cases - Deterministic-mode skips for FusedAttention configs that OOM on sm90 under NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 Preserves PR NVIDIA#2596 pad_between_seqs additions (fa_pad_between_seqs parameter through generate_input_shapes and run_dpa_with_cp, THD padding cleanup for FA3 tile-spillover comparison). Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
bcb717e to
bce64ff
Compare
Re-applying the formatting fixes that pre-commit.ci posted on PR NVIDIA#2596 after the previous push (commit bcb717e, overwritten by the cleanup rebase). for more information, see https://pre-commit.ci Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
|
/te-ci pytorch L3 |
2 similar comments
|
/te-ci pytorch L3 |
|
/te-ci pytorch L3 |
…to flash_attn_pad_bw_seqs Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> # Conflicts: # tests/pytorch/attention/test_attention_with_cp.py
287403f to
d8e8ba4
Compare
|
/te-ci pytorch L3 |
Address two pending review comments: 1. The "auto-set when RUN_L3_TESTS=1" annotation on the base-image FA3 preinstall is no longer accurate; drop it so readers don't grep for a coupling that doesn't exist. 2. `flash_attn_interface` reads like a generic FA API even though the top-level shim is only created by the FA3 install. Switching to `import flash_attn_3` makes the FA3-specific intent unambiguous and matches the FA3 package layout produced by the source build. Local validation on H100 (sm90) with FA3 active, TE worktree resolving to the editable install (verified via three-layer import check from /tmp): test_attention_with_cp.py parallel det+nondet — 45 passed / 0 failed nondet (3:52), 33 passed / 0 failed det (2:55). 33 pad-True nondet passes + 21 pad-True det passes confirm the FA3+THD+CP path is exercised; 5 det OOM cases skip cleanly via the existing inline guard. Same test scope is exercised by L1_pytorch_distributed_unittest (parallel det+nondet) and the FA3 iteration of L3_pytorch_FA_versions_test; the changes here are L3-only documentation/detection tweaks and do not alter the Python test code, but the L1+L3 CP execution was re-run on the cleaned PR head end-to-end as proof. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
| if qkv_format == "thd" and config.num_heads >= 20 and get_device_compute_capability() == (9, 0): | ||
| pytest.skip( | ||
| "Deterministic FusedAttention backward with THD format OOMs on sm90" | ||
| " for this particular test config since cuDNN reserves memory" | ||
| " proportional to bHSS (known cuDNN issue)." | ||
| ) |
There was a problem hiding this comment.
The motivation for this makes sense to me but seems like the way we are skipping the test is viewing it from a slightly narrow lens. What I mean by that is the main issue is total memory (bhSS) but we seem to be guarding on head dims only
This skip guard would not be correct if tomorrow someone were to add a test with small b,S and H>20 (IIUC) - it almost makes it seem that the issue is the num_heads rather than the total memory
Is there a better way to do this ?
There was a problem hiding this comment.
Good catch — gated on the actual b*H*S*S product instead of num_heads in d3bd4e4. Threshold of 1e9 empirically matches the existing 5-case skip set on the test_essential fused subset (cp_2_0, cp_2_2, cp_3_1, cp_4_2, cp_4_3 — bHSS 1.07B–4.29B) and lets the smaller configs (cp_1_0/cp_2_1/cp_2_4/cp_3_2/cp_3_4, all ~0.40B) keep running. Local det+nondet still 33/0 + 45/0 with 5 OOM skips fired by the new gate.
1. Det FusedAttention backward THD/sm90 OOM skip: gate on the actual memory pressure (b*H*S*S) instead of num_heads >= 20. The cuDNN workspace is proportional to bHSS, so a future config with H >= 20 but small b or S would be needlessly skipped under the old guard, while a config with H < 20 but large b*S that hit the same OOM wouldn't be caught. Threshold 1e9 empirically matches the existing 5-case skip set on the test_essential fused subset (cp_2_0, cp_2_2, cp_3_1, cp_4_2, cp_4_3 — bHSS in 1.07B–4.29B) and lets cp_1_0/ cp_2_1/cp_2_4/cp_3_2/cp_3_4 (bHSS ~0.40B) keep running. 2. L3 FA3 install comment: drop the "Dockerfile.base INSTALL_FA3=1" reference. The detection check is the contract; mentioning a specific image variable couples this script to an out-of-tree provisioning detail that may evolve independently. Local validation on H100 (sm90) with FA3 active and TE worktree resolving to editable (verified via /tmp-cwd three-layer import check after reinstall — the /usr/local TE shadow had reappeared between sessions): test_attention_with_cp.py parallel det+nondet — 45 passed / 0 failed nondet (4:09), 33 passed / 0 failed det (3:14). 33 pad-True nondet passes + 21 pad-True det passes; 5 det OOM cases skip via the new bHSS gate — same cases as the old num_heads-only gate. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
|
/te-ci pytorch L3 |
…ation Address review nits on the deterministic THD-backward OOM guard: 1. Replace the magic number 1_000_000_000 with the named constant SM90_DET_FUSED_THD_BWD_MAX_BHSS = 1 << 30, so the value is searchable and labeled. 2. Replace the prefatory comment with a short note tying the number to cuDNN's actual workspace request (~128 * bHSS bytes, measured on cuDNN 9.21.0 sm90 — see local sweep). At bHSS = 1<<30 the request is 128 GiB, which doesn't fit on H100's 80 GB. 3. Flag the b>=3 caveat for future readers: cuDNN rounds the batch up internally so workspace grows super-linearly past b=2 (b=4 asks for 4x the b=2 workspace, not 2x). The current fused-essential matrix is all b=2, so the threshold stays correct for what the test exercises; the note is there so the next person doesn't have to rediscover it. Skip set is unchanged — cp_2_0, cp_2_1, cp_3_1, cp_4_2, cp_4_3. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
We measured the workspace request from outside cuDNN, so the comment should say "observed" rather than asserting what cuDNN does. Reframes the ~128 * bHSS bytes formula and the super-linear b>=3 behavior as empirical observations from our sweep. No code change. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Description
TLDR
Enable
pad_between_seqs=Truefor FlashAttention 3 with THD format — both for context parallelism (A2A and P2P comm types) and non-CP paths. Previouslypad_between_seqswas only supported with FusedAttention.Problem
When using THD format with variable-length sequences, sequences are padded for divisibility across CP ranks. With
pad_between_seqs=True, the attention kernel needs to know actual (unpadded) token counts so it doesn't compute attention over padding tokens. FusedAttention already handled this viacu_seqlens_q_padded, but FlashAttention (both FA2 and FA3) hadpad_between_seqshardcoded toFalsein the CP path, and FA2 was entirely disabled forpad_between_seqs + thd. FA3 can natively handle this via itsseqused_q/seqused_kmechanism.Solution
Use FA3's
seqused_q/seqused_ktensors to communicate actual token counts per batch element. Passcu_seqlens_q_paddedfor tensor memory layout while derivingseqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]from the realcu_seqlens. This applies to both the CP path (A2A and P2P) and the non-CP path.Fixes #2399
Type of change
Changes
Please list the changes introduced in this PR:
context_parallel.py
get_fa_args(): Addseqused_q/seqused_kparameters, pass through to FA3 forward and backward positional arg lists (replacing hardcodedNones).cp_p2p_fwd_flash_attn()/cp_p2p_bwd_flash_attn(): Acceptpad_between_seqs,cu_seqlens_q_padded,cu_seqlens_kv_padded. When enabled, derivesequsedtensors and overridecu_seqlensto padded versions (with half-padding for lower-triangle/upper-triangle sections).AttnFuncWithCPAndKVP2P: Threadpad_between_seqsand padded cu_seqlens through all forward/backwardcp_p2p_fwd/bwd_flash_attncall sites. Savectx.pad_between_seqsfor backward.AttnFuncWithCPAndQKVOA2A.forward(): Addpad_between_seqsparameter. When enabled with FA3+THD, derivesequsedand swapcu_seqlensfor padded versions before callingget_fa_args().AttnFuncWithCPAndQKVOA2A.backward(): Same seqused/cu_seqlens override. Usezeros_like(notempty_like) for gradient init whenpad_between_seqssince FA3 skips padding positions. Add extraNonein return tuple for the newpad_between_seqsgradient slot.attn_forward_func_with_cp(): Passpad_between_seqsin A2A args list.backends.py
FlashAttention.forward(): Acceptcu_seqlens_q_padded/cu_seqlens_kv_padded. Detectpad_between_seqsby comparing padded vs actual cu_seqlens. Pass padded cu_seqlens to CP path. For non-CP FA3 path, derive and passseqused_q/seqused_k.dot_product_attention.py
cu_seqlens_q_padded/cu_seqlens_kv_paddedthrough toFlashAttention.utils.py
pad_between_seqs + thd. FA3 handles this natively viaseqused.test_attention_with_cp.py
@pytest.mark.parametrize("pad_between_seqs", [False, True])to flash attention CP tests.pad_between_seqs=Truefor non-THD formats, when FA3 is not installed, and fora2a+p2pcomm type (not yet supported).run_attention_with_cp.py
pad_between_seqsthroughgenerate_input_shapes()andrun_dpa_with_cp().pad_between_seqs, setcu_seqlens_qto actual lengths (not just for FusedAttention).nan_to_num(nan=0.0).test_attention.py
_run_dot_product_attention()(previously FlashAttention used original unpadded inputs).cu_seqlens_q_padded/cu_seqlens_kv_paddedandpad_between_seqsto DPA call for FlashAttention backend.pad_between_seqs=Trueto parametrize with skip for non-THD formats.New Tests
CP tests (
test_attention_with_cp.py)Added
@pytest.mark.parametrize("pad_between_seqs", [False, True])totest_cp_with_flash_attention. Skip conditions: non-THD formats, FA3 not installed,a2a+p2pcomm type.5 new tests that run (all
pad_between_seqs=True, thd, bf16):True-p2p-thd-cp_1_0-bf16True-p2p-thd-cp_2_1-bf16True-a2a-thd-cp_1_0-bf16True-a2a-thd-cp_1_2-bf16True-a2a-thd-cp_2_1-bf16Non-CP tests (
test_attention.py)Added
Trueto@pytest.mark.parametrize("pad_between_seqs", [False, True])ontest_dot_product_attention, with skip for non-THD. Also changed_run_dot_product_attentionso FlashAttention uses padded inputs/cu_seqlens and receivespad_between_seqs=True.48 new test IDs collected, but all are skipped because the main parametrize uses
qkv_layout=None(defaults to sbhd, not thd). The non-CPpad_between_seqs+ FA3 code path is exercised indirectly when other test functions calltest_dot_product_attentionwithqkv_layout="thd_thd_thd"(e.g.,test_dpa_softmax_thd).Checklist: