[None][feat] Add support for bidirectional sliding window attention mask to fmha_v2#11212
Conversation
📝 WalkthroughWalkthroughThis change introduces bidirectional sliding window attention support to the FMHA v2 kernel stack by adding a new attention mask type (BIDIRECTIONAL_SLIDING_WINDOW) that propagates through kernel traits, mask implementations, and computation stages. The feature enables symmetric windowing around query positions, complementing existing unidirectional and causal masking modes. Associated test cases validate the new functionality with configurable sliding window sizes and mask types. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (11)
cpp/kernels/fmha_v2/src/fused_multihead_attention.h (1)
1-3:⚠️ Potential issue | 🟡 MinorUpdate the NVIDIA copyright year to 2026.
This file is modified in 2026, so the SPDX header should reflect the latest modification year.
🗓️ Suggested update
- * SPDX-FileCopyrightText: Copyright (c) 2011-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2011-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.As per coding guidelines, "All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification".
cpp/kernels/fmha_v2/src/fmha/hopper/kernel_traits.h (1)
1-3:⚠️ Potential issue | 🟡 MinorUpdate the NVIDIA copyright year to 2026.
This file now has 2026 changes and the SPDX header should reflect that.
🗓️ Suggested update
- * SPDX-FileCopyrightText: Copyright (c) 2011-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2011-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.As per coding guidelines, "All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification".
cpp/kernels/fmha_v2/fmha_test.py (1)
1-1:⚠️ Potential issue | 🟡 MinorAdd the NVIDIA copyright header for this Python source.
This file is modified in 2026 and should include the SPDX header at the top.
As per coding guidelines, "All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification".
cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp (1)
1-3:⚠️ Potential issue | 🟡 MinorUpdate the NVIDIA copyright year to 2026.
This file now includes 2026 changes and the SPDX year should match.
🗓️ Suggested update
- * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.As per coding guidelines, "All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification".
cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h (1)
1-3:⚠️ Potential issue | 🟡 MinorUpdate the NVIDIA copyright year to 2026.
This file was modified in 2026 and the SPDX header should reflect that.
🗓️ Suggested update
- * SPDX-FileCopyrightText: Copyright (c) 2011-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2011-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.As per coding guidelines, "All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification".
cpp/kernels/fmha_v2/src/fused_multihead_flash_attention_kernel_noloop.h (1)
1-3:⚠️ Potential issue | 🟡 MinorUpdate the NVIDIA copyright year to 2026.
This file includes 2026 modifications and the SPDX year should be updated.
🗓️ Suggested update
- * SPDX-FileCopyrightText: Copyright (c) 2011-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2011-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.As per coding guidelines, "All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification".
cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h (1)
1-3:⚠️ Potential issue | 🟡 MinorUpdate the NVIDIA copyright year to 2026.
This file now has 2026 changes and the SPDX year should be updated accordingly.
🗓️ Suggested update
- * SPDX-FileCopyrightText: Copyright (c) 2011-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2011-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.As per coding guidelines, "All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification".
cpp/kernels/fmha_v2/setup.py (3)
1-2:⚠️ Potential issue | 🟡 MinorUpdate SPDX copyright year to 2026.
This file has meaningful 2026 changes but the header still ends at 2025.
As per coding guidelines, all TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification.
3079-3109:⚠️ Potential issue | 🟠 MajorFix kernel-traits printing for the new bidirectional mask.
selected_mask_typesnow returns 5 values, but this block still treats index 3 ascustom_maskand never emits bidirectional kernel specs. That will drop custom/bidirectional variants fromprint_kernel_traitsoutput and the generated metadata. Update the tuple unpacking and add bidirectional snippet variants (mirror_sliding_or_chunked_causal).💡 Suggested fix (tuple unpacking)
- padding_mask = int(selected_types[0]) - causal_mask = int(selected_types[1]) - sliding_or_chunked_causal_mask = int(selected_types[2]) - custom_mask = int(selected_types[3]) + (padding_mask, causal_mask, sliding_or_chunked_causal_mask, + bidirectional_sliding_window_mask, custom_mask) = map(int, selected_types)
3205-3217:⚠️ Potential issue | 🟠 MajorAdd bidirectional mask handling in cubin metadata + launcher parsing.
get_cubin_headerdoesn’t strip_bidirectional_sliding_windowfrom names or map it toAttentionMaskType, which will mislabel these kernels as padding and generate non-existent launcher names. Extend the normalization and mask mapping for the new suffix.💡 Suggested fix (normalization + mapping)
- tname = (kname.replace('flash_attention_', '').replace( + tname = (kname.replace('flash_attention_', '').replace( '_scale_max', '').replace('_nl', '').replace('_tiled', '').replace( 'tma_', '').replace('ldgsts_', '').replace('causal_', '').replace( - 'alibi_', '').replace('softmax_', '').replace( - 'sliding_or_chunked_', '').replace( + 'alibi_', '').replace('softmax_', '').replace( + 'sliding_or_chunked_', '').replace( + 'bidirectional_sliding_window_', '').replace( 'custom_mask_', '').replace('qkv_', '').replace( 'q_kv_', '').replace('q_paged_kv_', '').replace( 'q_k_v_', '').replace('ws_', '').replace( 'softcapping_', '').replace('sage_', '').replace( 'skipSoftmax_', '').replace('output_', '')) @@ - # padding (0), causal_mask (1), sliding_or_chunked_causal_mask (2), custom_mask (3). + # padding (0), causal (1), sliding_or_chunked (2), bidirectional (3), custom (4). if '_custom_mask' in kname: attention_mask_type = AttentionMaskType.CUSTOM_MASK + elif '_bidirectional_sliding_window' in kname: + attention_mask_type = AttentionMaskType.BIDIRECTIONAL_SLIDING_WINDOW elif '_sliding_or_chunked_causal' in kname: attention_mask_type = AttentionMaskType.SLIDING_OR_CHUNKED_CAUSAL @@ - mask_types = [ - '_sliding_or_chunked_causal', '_custom_mask', '_causal' - ] + mask_types = [ + '_sliding_or_chunked_causal', + '_bidirectional_sliding_window', + '_custom_mask', + '_causal' + ]Also applies to: 3284-3292, 3346-3351
cpp/kernels/fmha_v2/src/fused_multihead_flash_attention_kernel_noloop_tiled.h (1)
1-3:⚠️ Potential issue | 🟡 MinorUpdate SPDX copyright year to 2026.
This header now has meaningful 2026 changes.
As per coding guidelines, all TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification.
🤖 Fix all issues with AI agents
In `@cpp/kernels/fmha_v2/fmha_test.py`:
- Around line 288-298: Replace the two subprocess.run invocations in
fmha_test.py so they pass an argv list instead of a single shell string (remove
shell=True) and remove the duplicated "-bf16" flag from each argument list;
specifically locate the two subprocess.run(...) calls that invoke "bin/fmha.exe"
and convert their command strings into lists like ["bin/fmha.exe", "-d", "128",
..., "-sliding-window-size", str(sliding_window_size), mask_type] (do the same
for the "-d 64" call) and set check=True without shell=True; also add the
required NVIDIA copyright header with the latest modification year at the top of
the file.
In
`@cpp/kernels/fmha_v2/src/fused_multihead_flash_attention_kernel_noloop_tiled.h`:
- Around line 191-203: The bidirectional sliding-window boundary calculations
using Kernel_traits::BIDIRECTIONAL_SLIDING_WINDOW_ATTENTION (kv_loop_start,
sliding_window_mask_left, kv_loop_end, sliding_window_mask_right) are marked
with a TODO and need validation; update the math to correctly compute left/right
KV loop bounds using params.sliding_window_size, q_sequence_start and tile sizes
Cta_tile_p::N/Cta_tile_p::M, add focused unit/integration tests that cover odd
window sizes, near-start and near-end q_sequence_start, and multi-row tile
alignment to assert correct clamping and tile-aligned results, then remove the
TODO (or if you cannot fully validate now, replace the TODO with a short comment
referencing a newly created follow-up issue tracking the exact edge-case tests
and expected invariants).
In `@cpp/kernels/fmha_v2/src/fused_multihead_flash_attention_kernel_noloop.h`:
- Around line 174-211: The bidirectional sliding-window math in the mask
computation is incorrect/uncertain: reconcile the kv_loop_end and
sliding_window_mask_right formulas inside the mask_sliding_window branch (when
Kernel_traits::BIDIRECTIONAL_SLIDING_WINDOW_ATTENTION) with the simpler warpspec
implementations (compute.h, dma.h, epilogue.h). Replace the current tile-rounded
expressions that use fmha::div_up and integer /2 with logic that (1) computes
the exact window half-width handling both even and odd
params.sliding_window_size (use explicit floor/ceil semantics), (2) then rounds
the resulting start/end to Cta_tile_p::N tile boundaries consistently (use
div_up for end, floor for start), and (3) set sliding_window_mask_right and
kv_loop_end from those canonical window boundaries; update/tests for even and
odd sliding_window_size to cover off-by-one cases. Ensure references:
Kernel_traits, q_sequence_start, params.sliding_window_size, Cta_tile_p::N when
making changes.
🧹 Nitpick comments (4)
cpp/kernels/fmha_v2/fmha_test.py (1)
285-287: Clarify bidirectional window-size semantics in the test.Doubling
sliding_window_sizeonly for bidirectional masks implies different parameter meaning across modes; consider normalizing this infmha.exeor documenting why tests must adjust it.cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h (1)
119-123: Clamp bidirectional mask boundaries beforediv_up.In the bidirectional branch,
tile_offset_end - sliding_window_size / 2can go negative; clamp to 0 beforediv_upto avoid negativekv_left_mask_endand align with the max(0, row - w/2) logic in epilogue/DMA.🔧 Suggested fix
- kv_left_mask_end = div_up(tile_offset_end - params.sliding_window_size / 2, STEP_KV); + kv_left_mask_end = div_up( + max(0, tile_offset_end - params.sliding_window_size / 2), + STEP_KV);Also applies to: 297-317
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h (1)
584-586: Update comment to match new mask type semantics.The comment for the
Kernel_traits_Hopper_qgmma_e4m3_fp32specialization still references the old mask type numbering ("sliding_window_causal (2)") rather than the updated semantics that includes bidirectional_sliding_window_attention (3) and custom_mask (4).📝 Suggested comment update
- // The attention mask type: padding (0), causal (1), sliding_window_causal (2). - // See fused_multihead_attention_kernel.h for description. + // The attention mask type: padding (0), causal (1), sliding_or_chunked_attention (2), + // bidirectional_sliding_window_attention (3), custom_mask (4). See fused_multihead_attention_kernel.h for + // description.cpp/kernels/fmha_v2/src/fmha/mask.h (1)
507-544: Consider documenting the window size semantics for bidirectional mode.The bidirectional window uses
sliding_window_size_ / 2on each side of the query position. This means asliding_window_size_of 5 would give a total window of 4 tokens (2 on each side plus the center), whereas the unidirectional V4 mask with the same size would give 5 tokens. This asymmetry may be intentional but could cause confusion.📝 Suggested comment clarification
// Is a given position valid? inline __device__ bool is_valid(int row, int col) const { - // Is it a valid position in the sequence, i.e. are we in the lower triangle? + // Is it a valid position in the sequence? + // Window is centered on row: [row - window/2, row + window/2] + // Note: integer division means total window size is 2*(window/2)+1 for odd window sizes. return (col >= max(0, row - Base::sliding_window_size_ / 2)) && (col <= min(seqlen_ - 1, row + Base::sliding_window_size_ / 2)); }
cpp/kernels/fmha_v2/src/fused_multihead_flash_attention_kernel_noloop_tiled.h
Show resolved
Hide resolved
37708ea to
65f4405
Compare
|
/bot run |
|
PR_Github #34895 [ run ] triggered by Bot. Commit: |
|
PR_Github #34895 [ run ] completed with state
|
bd0df57 to
62b296a
Compare
|
/bot run |
|
PR_Github #36064 [ run ] triggered by Bot. Commit: |
|
PR_Github #36064 [ run ] completed with state |
c240ad2 to
fc04591
Compare
|
/bot run |
|
PR_Github #36437 [ run ] triggered by Bot. Commit: |
|
PR_Github #36437 [ run ] completed with state |
nvrohanv
left a comment
There was a problem hiding this comment.
Used these changes to generate cubins and they passed my integration tests for a TRT attention plugin. I tested, padding, causal, causal sliding window, and bidirectional sliding window attention.
|
/bot run |
|
PR_Github #36438 [ run ] triggered by Bot. Commit: |
|
PR_Github #36438 [ run ] completed with state |
|
/bot run |
|
PR_Github #36451 [ run ] triggered by Bot. Commit: |
|
PR_Github #36451 [ run ] completed with state
|
bf39aaa to
cf09dc6
Compare
|
/bot run |
|
PR_Github #36581 [ run ] triggered by Bot. Commit: |
|
PR_Github #36581 [ run ] completed with state
|
cf09dc6 to
178e3bf
Compare
|
/bot run |
178e3bf to
ae19659
Compare
|
/bot run |
|
PR_Github #37259 [ run ] triggered by Bot. Commit: |
|
PR_Github #37259 [ run ] completed with state
|
|
/bot run |
|
PR_Github #37414 [ run ] triggered by Bot. Commit: |
|
PR_Github #37414 [ run ] completed with state
|
…fmha_v2 Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
ee4108c to
393b80b
Compare
|
/bot run |
|
PR_Github #37766 [ run ] triggered by Bot. Commit: |
|
PR_Github #37766 [ run ] completed with state
|
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
|
/bot run --reuse-test |
|
PR_Github #37773 [ run ] triggered by Bot. Commit: |
|
PR_Github #37773 [ run ] completed with state |
…ask to fmha_v2 (NVIDIA#11212) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Summary by CodeRabbit
New Features
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.