Skip to content

[None][feat] Add support for bidirectional sliding window attention mask to fmha_v2#11212

Merged
djns99 merged 13 commits intoNVIDIA:mainfrom
djns99:dastokes/fmha_v2_bidirection_sliding_window_attention
Mar 6, 2026
Merged

[None][feat] Add support for bidirectional sliding window attention mask to fmha_v2#11212
djns99 merged 13 commits intoNVIDIA:mainfrom
djns99:dastokes/fmha_v2_bidirection_sliding_window_attention

Conversation

@djns99
Copy link
Collaborator

@djns99 djns99 commented Feb 3, 2026

Summary by CodeRabbit

  • New Features

    • Added support for bidirectional sliding window attention, extending available masking options for attention computation with configurable window sizes and mask types.
  • Tests

    • Introduced new parametrized test suite for bidirectional sliding window attention with variable window sizes and mask configurations across multiple hardware setups.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@djns99 djns99 requested review from PerkzZheng and nvrohanv February 3, 2026 04:16
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 3, 2026

📝 Walkthrough

Walkthrough

This change introduces bidirectional sliding window attention support to the FMHA v2 kernel stack by adding a new attention mask type (BIDIRECTIONAL_SLIDING_WINDOW) that propagates through kernel traits, mask implementations, and computation stages. The feature enables symmetric windowing around query positions, complementing existing unidirectional and causal masking modes. Associated test cases validate the new functionality with configurable sliding window sizes and mask types.

Changes

Cohort / File(s) Summary
Test Suite
cpp/kernels/fmha_v2/fmha_test.py
New parametric test test_trtllm_sliding_window_attention with sliding_window_size and mask_type dimensions; executes fmha.exe subprocess with distinct hardware configs (d=128, d=64).
Setup & Code Generation
cpp/kernels/fmha_v2/setup.py
Added BIDIRECTIONAL_SLIDING_WINDOW = 3 to AttentionMaskType enum; shifted CUSTOM_MASK to value 4; introduced kernel trait variants (Kernel_traits_nl_bidirectional_sliding_window, Kernel_traits_nl_tiled_bidirectional_sliding_window); updated mask-type selection logic to return five values.
Kernel Traits (Core)
cpp/kernels/fmha_v2/src/fmha/kernel_traits.h
Added BIDIRECTIONAL_SLIDING_WINDOW_ATTENTION enum (MASK_VERSION == 5) across primary and variant Kernel_traits structs; shifted CUSTOM_MASK to MASK_VERSION == 6; refined CAUSAL_MASK to explicitly match versions 3 or 4.
Kernel Traits (Hopper)
cpp/kernels/fmha_v2/src/fmha/hopper/kernel_traits.h
Introduced BIDIRECTIONAL_SLIDING_WINDOW_ATTENTION flag (MASK_VERSION == 5); adjusted CAUSAL_MASK logic to match versions 3–4 only; updated documentation for extended MASK_VERSION range.
Kernel Traits (Warpspec)
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
Expanded mask-type handling: SLIDING_OR_CHUNKED_ATTENTION now includes type 3; added BIDIRECTIONAL_SLIDING_WINDOW_ATTENTION for type 3; USE_CUSTOM_MASK shifted to type 4; applied changes to both base and Hopper_qgmma_e4m3_fp32 specializations.
Mask Implementation
cpp/kernels/fmha_v2/src/fmha/mask.h
New template specializations: Mask<Traits, Cta_tile, 5> (bidirectional sliding window with seqlen_ member and dual is_valid overloads); moved prior V5 mask logic to V6; added Mask_hopper<Traits, Cta_tile, 5> variant.
Public API (Host)
cpp/kernels/fmha_v2/src/fused_multihead_attention.h
Added BIDIRECTIONAL_SLIDING_WINDOW enum value to Attention_mask_type; extended mask_type_to_string dispatch with new case.
Host Kernel Control
cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp
New command-line option -bidirectional-sliding-window-mask; default mask type set to SLIDING_OR_CHUNKED_CAUSAL when sliding_window_size is active without explicit mask type; constraint assertion preventing bidirectional mode with chunked attention; added bidirectional window bounds validation (si within [so − floor(window/2), so + floor(window/2)]).
Compute Logic
cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h
Exposed BIDIRECTIONAL_SLIDING_WINDOW_ATTENTION enum; extended sliding/chunked masking branching: bidirectional path computes kv_left_mask_end from tile_offset_end minus half sliding window and updates kv_right_mask_start using window center.
DMA (Data Movement)
cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
Exposed BIDIRECTIONAL_SLIDING_WINDOW_ATTENTION enum; enhanced compute_kv_tile_idx with bidirectional branch: symmetric KV window computed as max(0, q_step_offset − window/2) to min(kv_steps\*STEP_KV−1, q_step_end + window/2).
Epilogue (Softmax/Masking)
cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h
Added BIDIRECTIONAL_SLIDING_WINDOW_ATTENTION enum; broadened mask application logic to include bidirectional branch with assertion against chunked attention; computes sliding_window_end = min(seqlen−1, row + window/2) for constraint masking; adjusted quad_row_ initialization.
Kernel Loop Logic
cpp/kernels/fmha_v2/src/fused_multihead_flash_attention_kernel_noloop.h, cpp/kernels/fmha_v2/src/fused_multihead_flash_attention_kernel_noloop_tiled.h
Reworked sliding window masking to support bidirectional mode; introduced kv_loop_start, kv_loop_end, sliding_window_mask_left/right variables; bidirectional branch computes symmetric boundaries around query position; updated apply_sliding_window_mask to check range against [left, right] bounds; expanded CHECK_NEG_INF condition.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.93% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ⚠️ Warning The PR description is entirely template boilerplate with no substantive content explaining the feature, rationale, or implementation details. Provide a clear description of the bidirectional sliding window attention feature, explain the problem it solves, describe the implementation approach, list relevant test coverage, and ensure the PR title follows the format [ticket][type] Summary.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main feature being added: bidirectional sliding window attention mask support to fmha_v2, directly matching the substantial changes across test and kernel files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (11)
cpp/kernels/fmha_v2/src/fused_multihead_attention.h (1)

1-3: ⚠️ Potential issue | 🟡 Minor

Update the NVIDIA copyright year to 2026.

This file is modified in 2026, so the SPDX header should reflect the latest modification year.

🗓️ Suggested update
- * SPDX-FileCopyrightText: Copyright (c) 2011-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-FileCopyrightText: Copyright (c) 2011-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

As per coding guidelines, "All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification".

cpp/kernels/fmha_v2/src/fmha/hopper/kernel_traits.h (1)

1-3: ⚠️ Potential issue | 🟡 Minor

Update the NVIDIA copyright year to 2026.

This file now has 2026 changes and the SPDX header should reflect that.

🗓️ Suggested update
- * SPDX-FileCopyrightText: Copyright (c) 2011-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-FileCopyrightText: Copyright (c) 2011-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

As per coding guidelines, "All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification".

cpp/kernels/fmha_v2/fmha_test.py (1)

1-1: ⚠️ Potential issue | 🟡 Minor

Add the NVIDIA copyright header for this Python source.

This file is modified in 2026 and should include the SPDX header at the top.

As per coding guidelines, "All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification".

cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp (1)

1-3: ⚠️ Potential issue | 🟡 Minor

Update the NVIDIA copyright year to 2026.

This file now includes 2026 changes and the SPDX year should match.

🗓️ Suggested update
- * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

As per coding guidelines, "All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification".

cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h (1)

1-3: ⚠️ Potential issue | 🟡 Minor

Update the NVIDIA copyright year to 2026.

This file was modified in 2026 and the SPDX header should reflect that.

🗓️ Suggested update
- * SPDX-FileCopyrightText: Copyright (c) 2011-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-FileCopyrightText: Copyright (c) 2011-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

As per coding guidelines, "All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification".

cpp/kernels/fmha_v2/src/fused_multihead_flash_attention_kernel_noloop.h (1)

1-3: ⚠️ Potential issue | 🟡 Minor

Update the NVIDIA copyright year to 2026.

This file includes 2026 modifications and the SPDX year should be updated.

🗓️ Suggested update
- * SPDX-FileCopyrightText: Copyright (c) 2011-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-FileCopyrightText: Copyright (c) 2011-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

As per coding guidelines, "All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification".

cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h (1)

1-3: ⚠️ Potential issue | 🟡 Minor

Update the NVIDIA copyright year to 2026.

This file now has 2026 changes and the SPDX year should be updated accordingly.

🗓️ Suggested update
- * SPDX-FileCopyrightText: Copyright (c) 2011-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-FileCopyrightText: Copyright (c) 2011-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

As per coding guidelines, "All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification".

cpp/kernels/fmha_v2/setup.py (3)

1-2: ⚠️ Potential issue | 🟡 Minor

Update SPDX copyright year to 2026.

This file has meaningful 2026 changes but the header still ends at 2025.

As per coding guidelines, all TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification.


3079-3109: ⚠️ Potential issue | 🟠 Major

Fix kernel-traits printing for the new bidirectional mask.

selected_mask_types now returns 5 values, but this block still treats index 3 as custom_mask and never emits bidirectional kernel specs. That will drop custom/bidirectional variants from print_kernel_traits output and the generated metadata. Update the tuple unpacking and add bidirectional snippet variants (mirror _sliding_or_chunked_causal).

💡 Suggested fix (tuple unpacking)
-        padding_mask = int(selected_types[0])
-        causal_mask = int(selected_types[1])
-        sliding_or_chunked_causal_mask = int(selected_types[2])
-        custom_mask = int(selected_types[3])
+        (padding_mask, causal_mask, sliding_or_chunked_causal_mask,
+         bidirectional_sliding_window_mask, custom_mask) = map(int, selected_types)

3205-3217: ⚠️ Potential issue | 🟠 Major

Add bidirectional mask handling in cubin metadata + launcher parsing.

get_cubin_header doesn’t strip _bidirectional_sliding_window from names or map it to AttentionMaskType, which will mislabel these kernels as padding and generate non-existent launcher names. Extend the normalization and mask mapping for the new suffix.

💡 Suggested fix (normalization + mapping)
-        tname = (kname.replace('flash_attention_', '').replace(
+        tname = (kname.replace('flash_attention_', '').replace(
             '_scale_max', '').replace('_nl', '').replace('_tiled', '').replace(
                 'tma_',
                 '').replace('ldgsts_', '').replace('causal_', '').replace(
-                    'alibi_', '').replace('softmax_', '').replace(
-                        'sliding_or_chunked_', '').replace(
+                    'alibi_', '').replace('softmax_', '').replace(
+                        'sliding_or_chunked_', '').replace(
+                            'bidirectional_sliding_window_', '').replace(
                             'custom_mask_', '').replace('qkv_', '').replace(
                                 'q_kv_', '').replace('q_paged_kv_', '').replace(
                                     'q_k_v_', '').replace('ws_', '').replace(
                                         'softcapping_',
                                         '').replace('sage_', '').replace(
                                             'skipSoftmax_',
                                             '').replace('output_', ''))
@@
-        # padding (0), causal_mask (1), sliding_or_chunked_causal_mask (2), custom_mask (3).
+        # padding (0), causal (1), sliding_or_chunked (2), bidirectional (3), custom (4).
         if '_custom_mask' in kname:
             attention_mask_type = AttentionMaskType.CUSTOM_MASK
+        elif '_bidirectional_sliding_window' in kname:
+            attention_mask_type = AttentionMaskType.BIDIRECTIONAL_SLIDING_WINDOW
         elif '_sliding_or_chunked_causal' in kname:
             attention_mask_type = AttentionMaskType.SLIDING_OR_CHUNKED_CAUSAL
@@
-                    mask_types = [
-                        '_sliding_or_chunked_causal', '_custom_mask', '_causal'
-                    ]
+                    mask_types = [
+                        '_sliding_or_chunked_causal',
+                        '_bidirectional_sliding_window',
+                        '_custom_mask',
+                        '_causal'
+                    ]

Also applies to: 3284-3292, 3346-3351

cpp/kernels/fmha_v2/src/fused_multihead_flash_attention_kernel_noloop_tiled.h (1)

1-3: ⚠️ Potential issue | 🟡 Minor

Update SPDX copyright year to 2026.

This header now has meaningful 2026 changes.

As per coding guidelines, all TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification.

🤖 Fix all issues with AI agents
In `@cpp/kernels/fmha_v2/fmha_test.py`:
- Around line 288-298: Replace the two subprocess.run invocations in
fmha_test.py so they pass an argv list instead of a single shell string (remove
shell=True) and remove the duplicated "-bf16" flag from each argument list;
specifically locate the two subprocess.run(...) calls that invoke "bin/fmha.exe"
and convert their command strings into lists like ["bin/fmha.exe", "-d", "128",
..., "-sliding-window-size", str(sliding_window_size), mask_type] (do the same
for the "-d 64" call) and set check=True without shell=True; also add the
required NVIDIA copyright header with the latest modification year at the top of
the file.

In
`@cpp/kernels/fmha_v2/src/fused_multihead_flash_attention_kernel_noloop_tiled.h`:
- Around line 191-203: The bidirectional sliding-window boundary calculations
using Kernel_traits::BIDIRECTIONAL_SLIDING_WINDOW_ATTENTION (kv_loop_start,
sliding_window_mask_left, kv_loop_end, sliding_window_mask_right) are marked
with a TODO and need validation; update the math to correctly compute left/right
KV loop bounds using params.sliding_window_size, q_sequence_start and tile sizes
Cta_tile_p::N/Cta_tile_p::M, add focused unit/integration tests that cover odd
window sizes, near-start and near-end q_sequence_start, and multi-row tile
alignment to assert correct clamping and tile-aligned results, then remove the
TODO (or if you cannot fully validate now, replace the TODO with a short comment
referencing a newly created follow-up issue tracking the exact edge-case tests
and expected invariants).

In `@cpp/kernels/fmha_v2/src/fused_multihead_flash_attention_kernel_noloop.h`:
- Around line 174-211: The bidirectional sliding-window math in the mask
computation is incorrect/uncertain: reconcile the kv_loop_end and
sliding_window_mask_right formulas inside the mask_sliding_window branch (when
Kernel_traits::BIDIRECTIONAL_SLIDING_WINDOW_ATTENTION) with the simpler warpspec
implementations (compute.h, dma.h, epilogue.h). Replace the current tile-rounded
expressions that use fmha::div_up and integer /2 with logic that (1) computes
the exact window half-width handling both even and odd
params.sliding_window_size (use explicit floor/ceil semantics), (2) then rounds
the resulting start/end to Cta_tile_p::N tile boundaries consistently (use
div_up for end, floor for start), and (3) set sliding_window_mask_right and
kv_loop_end from those canonical window boundaries; update/tests for even and
odd sliding_window_size to cover off-by-one cases. Ensure references:
Kernel_traits, q_sequence_start, params.sliding_window_size, Cta_tile_p::N when
making changes.
🧹 Nitpick comments (4)
cpp/kernels/fmha_v2/fmha_test.py (1)

285-287: Clarify bidirectional window-size semantics in the test.

Doubling sliding_window_size only for bidirectional masks implies different parameter meaning across modes; consider normalizing this in fmha.exe or documenting why tests must adjust it.

cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h (1)

119-123: Clamp bidirectional mask boundaries before div_up.

In the bidirectional branch, tile_offset_end - sliding_window_size / 2 can go negative; clamp to 0 before div_up to avoid negative kv_left_mask_end and align with the max(0, row - w/2) logic in epilogue/DMA.

🔧 Suggested fix
-                kv_left_mask_end = div_up(tile_offset_end - params.sliding_window_size / 2, STEP_KV);
+                kv_left_mask_end = div_up(
+                    max(0, tile_offset_end - params.sliding_window_size / 2),
+                    STEP_KV);

Also applies to: 297-317

cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h (1)

584-586: Update comment to match new mask type semantics.

The comment for the Kernel_traits_Hopper_qgmma_e4m3_fp32 specialization still references the old mask type numbering ("sliding_window_causal (2)") rather than the updated semantics that includes bidirectional_sliding_window_attention (3) and custom_mask (4).

📝 Suggested comment update
-    // The attention mask type: padding (0), causal (1), sliding_window_causal (2).
-    // See fused_multihead_attention_kernel.h for description.
+    // The attention mask type: padding (0), causal (1), sliding_or_chunked_attention (2),
+    // bidirectional_sliding_window_attention (3), custom_mask (4). See fused_multihead_attention_kernel.h for
+    // description.
cpp/kernels/fmha_v2/src/fmha/mask.h (1)

507-544: Consider documenting the window size semantics for bidirectional mode.

The bidirectional window uses sliding_window_size_ / 2 on each side of the query position. This means a sliding_window_size_ of 5 would give a total window of 4 tokens (2 on each side plus the center), whereas the unidirectional V4 mask with the same size would give 5 tokens. This asymmetry may be intentional but could cause confusion.

📝 Suggested comment clarification
     // Is a given position valid?
     inline __device__ bool is_valid(int row, int col) const
     {
-        // Is it a valid position in the sequence, i.e. are we in the lower triangle?
+        // Is it a valid position in the sequence?
+        // Window is centered on row: [row - window/2, row + window/2]
+        // Note: integer division means total window size is 2*(window/2)+1 for odd window sizes.
         return (col >= max(0, row - Base::sliding_window_size_ / 2))
             && (col <= min(seqlen_ - 1, row + Base::sliding_window_size_ / 2));
     }

@djns99 djns99 changed the title [None][feat] Add support for bidirectional sliding window attention mask to … [None][feat] Add support for bidirectional sliding window attention mask to fmha_v2 Feb 5, 2026
@djns99 djns99 force-pushed the dastokes/fmha_v2_bidirection_sliding_window_attention branch from 37708ea to 65f4405 Compare February 5, 2026 03:20
@djns99
Copy link
Collaborator Author

djns99 commented Feb 5, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #34895 [ run ] triggered by Bot. Commit: 65f4405

@tensorrt-cicd
Copy link
Collaborator

PR_Github #34895 [ run ] completed with state SUCCESS. Commit: 65f4405
/LLM/main/L0_MergeRequest_PR pipeline #26917 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@djns99 djns99 force-pushed the dastokes/fmha_v2_bidirection_sliding_window_attention branch from bd0df57 to 62b296a Compare February 17, 2026 01:55
@djns99
Copy link
Collaborator Author

djns99 commented Feb 17, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #36064 [ run ] triggered by Bot. Commit: bbc1111

@tensorrt-cicd
Copy link
Collaborator

PR_Github #36064 [ run ] completed with state DISABLED
CI server is currently disabled for scheduled maintenance. Estimated completion time: 12 AM PST on 2/17.

@djns99 djns99 force-pushed the dastokes/fmha_v2_bidirection_sliding_window_attention branch from c240ad2 to fc04591 Compare February 22, 2026 21:57
@djns99
Copy link
Collaborator Author

djns99 commented Feb 22, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #36437 [ run ] triggered by Bot. Commit: fc04591 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #36437 [ run ] completed with state DISABLED
CI server is currently disabled for scheduled maintenance. Estimated completion time: 6 PM PST on 2/22.

Link to invocation

Copy link
Collaborator

@nvrohanv nvrohanv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used these changes to generate cubins and they passed my integration tests for a TRT attention plugin. I tested, padding, causal, causal sliding window, and bidirectional sliding window attention.

@djns99
Copy link
Collaborator Author

djns99 commented Feb 22, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #36438 [ run ] triggered by Bot. Commit: fc04591 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #36438 [ run ] completed with state DISABLED
CI server is currently disabled for scheduled maintenance. Estimated completion time: 6 PM PST on 2/22.

Link to invocation

@djns99
Copy link
Collaborator Author

djns99 commented Feb 23, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #36451 [ run ] triggered by Bot. Commit: fc04591 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #36451 [ run ] completed with state SUCCESS. Commit: fc04591
/LLM/main/L0_MergeRequest_PR pipeline #28198 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@djns99 djns99 force-pushed the dastokes/fmha_v2_bidirection_sliding_window_attention branch from bf39aaa to cf09dc6 Compare February 24, 2026 02:19
@djns99
Copy link
Collaborator Author

djns99 commented Feb 24, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #36581 [ run ] triggered by Bot. Commit: cf09dc6 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #36581 [ run ] completed with state FAILURE. Commit: cf09dc6
/LLM/main/L0_MergeRequest_PR pipeline #28311 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@djns99 djns99 force-pushed the dastokes/fmha_v2_bidirection_sliding_window_attention branch from cf09dc6 to 178e3bf Compare February 24, 2026 23:53
@djns99
Copy link
Collaborator Author

djns99 commented Feb 24, 2026

/bot run

@djns99 djns99 force-pushed the dastokes/fmha_v2_bidirection_sliding_window_attention branch from 178e3bf to ae19659 Compare March 2, 2026 02:42
@djns99
Copy link
Collaborator Author

djns99 commented Mar 2, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #37259 [ run ] triggered by Bot. Commit: ae19659 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #37259 [ run ] completed with state SUCCESS. Commit: ae19659
/LLM/main/L0_MergeRequest_PR pipeline #28838 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@djns99
Copy link
Collaborator Author

djns99 commented Mar 3, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #37414 [ run ] triggered by Bot. Commit: ae19659 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #37414 [ run ] completed with state SUCCESS. Commit: ae19659
/LLM/main/L0_MergeRequest_PR pipeline #28960 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

djns99 added 12 commits March 5, 2026 12:05
…fmha_v2

Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
@djns99 djns99 force-pushed the dastokes/fmha_v2_bidirection_sliding_window_attention branch from ee4108c to 393b80b Compare March 4, 2026 23:06
@djns99
Copy link
Collaborator Author

djns99 commented Mar 4, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #37766 [ run ] triggered by Bot. Commit: 393b80b Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #37766 [ run ] completed with state SUCCESS. Commit: 393b80b
/LLM/main/L0_MergeRequest_PR pipeline #29234 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
@djns99
Copy link
Collaborator Author

djns99 commented Mar 5, 2026

/bot run --reuse-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #37773 [ run ] triggered by Bot. Commit: 5085503 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #37773 [ run ] completed with state SUCCESS. Commit: 5085503
/LLM/main/L0_MergeRequest_PR pipeline #29240 completed with status: 'SUCCESS'

Link to invocation

@djns99 djns99 merged commit e699f23 into NVIDIA:main Mar 6, 2026
5 checks passed
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Mar 9, 2026
…ask to fmha_v2 (NVIDIA#11212)

Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants