Skip to content

[None][feat] Support sparse mqa/gqa attention#12470

Merged
lfr-0531 merged 12 commits intoNVIDIA:mainfrom
heyuhhh:user/yuhangh/add_sparse_attention_support
Apr 19, 2026
Merged

[None][feat] Support sparse mqa/gqa attention#12470
lfr-0531 merged 12 commits intoNVIDIA:mainfrom
heyuhhh:user/yuhangh/add_sparse_attention_support

Conversation

@heyuhhh
Copy link
Copy Markdown
Collaborator

@heyuhhh heyuhhh commented Mar 24, 2026

Summary by CodeRabbit

  • Refactor

    • Reorganized sparse attention parameter naming and configuration wiring across generation and context attention paths for improved clarity and consistency.
    • Updated sparse attention kernel dispatching logic to support both sparse MLA and sparse attention modes through unified parameter handling.
  • Tests

    • Added comprehensive test coverage for sparse attention functionality in both context and generation phases.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@heyuhhh heyuhhh requested review from a team as code owners March 24, 2026 01:34
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 24, 2026

📝 Walkthrough

Walkthrough

This PR refactors sparse attention parameter naming and control flow across C++ kernel dispatchers and Python attention backends. It separates paged-sparse-attention from sparse-MLA modes, renames MLA-specific fields to generic sparse-attention equivalents, and extends the Triton sparse-index conversion kernel to support per-KV-head and KV-factor dimensions.

Changes

Cohort / File(s) Summary
Sparse Attention Flag Management
cpp/tensorrt_llm/common/attentionOp.h, cpp/tensorrt_llm/common/attentionOp.cpp
Introduced new member flag mUseTllmGenSparseAttentionPaged and split useTllmGenSparseAttention() into a paged-only variant (useTllmGenSparseAttentionPaged()) and a new general-purpose variant supporting both sparse-MLA and sparse-attention paths. Updated plugin serialization data tuple.
FMHA/XQA Fixed Parameters
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h, cpp/tensorrt_llm/kernels/xqaDispatcher.h
Added useTllmGenSparseAttention boolean field to MHARunnerFixedParams and XqaFixedParams for enabling sparse attention in generation kernels.
XQA Kernel Parameters
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h, cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h
Replaced use_sparse_attention with use_sparse_attention_gen_paged to denote paged generation sparse-attention mode; updated workspace allocation condition accordingly.
Sparse Attention Parameter Structures
cpp/tensorrt_llm/kernels/sparseAttentionKernels.h, cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h
Renamed sparse-MLA-specific fields to generic sparse-attention equivalents: sparse_mla_topksparse_topk, sparse_mla_kv_cache_poolsparse_kv_cache_pool, mSparseMlamSparseAttention, mSparseMlaTopKmSparseTopK.
FMHA Dispatcher & Runner
cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp, cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h, cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h
Updated sparse path from useSparseMLA() to useTllmGenSparseAttention(), switched kernel hashing/selection from sparseMla to sparseAttention, and replaced MLA-specific topK heuristics with generic sparse-attention logic. Extended KV TMA shape/stride override to both K and V.
XQA & FMHA Dispatchers
cpp/tensorrt_llm/kernels/xqaDispatcher.cpp, cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp
Added/updated sparse-attention gating to use useTllmGenSparseAttention flag and route sparse top-k / sparse KV-cache pool from renamed sparse_params fields. Introduced paged sparse-attention (use_sparse_attention_gen_paged) branch alongside MLA sparse-attention branch.
FMHA Reduction Kernel
cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaReduction.cu
Renamed kernel parameter from sparseMla to sparseAttention and updated sparse topK lookup to use mNumSparseTopk instead of mSparseMlaTopK.
Kernel Include & C++ Wiring
cpp/tensorrt_llm/kernels/indexerTopK.cu
Added #include <cfloat> for floating-point constants support.
Python Bindings
cpp/tensorrt_llm/nanobind/thop/bindings.cpp, cpp/tensorrt_llm/thop/attentionOp.h, cpp/tensorrt_llm/thop/attentionOp.cpp
Renamed Python/C\+\+ interface parameter from sparse_mla_topk to num_sparse_topk; updated runner virtual/override signatures; refactored sparse backend selection to dispatch between paged and MLA sparse-attention modes based on presence of sparse_attn_offsets.
Python Sparse DSA Backend
tensorrt_llm/_torch/attention_backend/sparse/dsa.py, tensorrt_llm/_torch/attention_backend/sparse/kernel.py
Renamed metadata field sparse_mla_topk to num_sparse_topk. Extended Triton kernel _convert_req_index_to_global_index_kernel_with_stride_factor from 2D token-local layout to 4D head-aware layout supporting per-KV-head and per-KV-factor index conversion; added num_kv_heads and kv_factor parameters and updated block-table indexing to 3D format.
Python TrtllmAttention Wrapper
tensorrt_llm/_torch/attention_backend/trtllm.py, tensorrt_llm/_torch/attention_backend/trtllm_gen.py
Renamed sparse_mla_topk to num_sparse_topk in TrtllmAttentionWrapper.plan() and added sparse_attention_config parameter. Introduced TrtllmGenSupportChecker.check_sparse_attention() validation; added logic to force paged context FMHA for MQA/GQA sparse algorithm.
Python Custom Ops
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py
Updated thop.attention call argument from sparse_mla_topk=0 to num_sparse_topk=0.
Sparse Attention Tests
tests/unittest/_torch/attention/sparse/test_dsa_indexer.py, tests/unittest/_torch/attention/sparse/test_sparse_attention.py
Updated existing test metadata field reference. Added comprehensive new test module covering sparse attention for context and generation phases, including scenario-based test fixtures, per-head sparse index conversion validation, and reference implementations for context/generation sparse attention correctness verification.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 52.08% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ⚠️ Warning The pull request description is incomplete and contains only the template structure with empty sections for Description, Test Coverage, and PR Checklist. Fill in the Description section with a clear explanation of what changes are being made and why. Document the Test Coverage section with specific test cases that validate the sparse MQA/GQA attention support. Complete the PR Checklist items.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main change: adding support for sparse MQA/GQA attention, which aligns with the extensive kernel-level changes throughout the codebase.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h (1)

1-2: ⚠️ Potential issue | 🟡 Minor

Update copyright year.

The copyright year shows 2020-2023, but this file is being modified in 2025/2026. Per coding guidelines, the copyright header should reflect the year of the latest meaningful modification.

Suggested fix
 /*
- * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
+ * Copyright (c) 2020-2025, NVIDIA CORPORATION.  All rights reserved.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h`
around lines 1 - 2, Update the file header copyright years to include the
current modification year (e.g., change "2020-2023" to "2020-2026" or
"2020-2025-2026" per project convention) so the top-of-file comment in
decoderXQAImplCommon.h reflects the latest modification year; edit the existing
comment block at the top of the file to replace the old year range and keep the
rest of the header intact.
cpp/tensorrt_llm/kernels/xqaDispatcher.h (1)

1-2: ⚠️ Potential issue | 🟡 Minor

Update copyright year.

The copyright year shows 2020-2024, but this file is being modified in 2025/2026. Per coding guidelines, the copyright header should reflect the year of the latest meaningful modification.

Suggested fix
 /*
- * Copyright (c) 2020-2024, NVIDIA CORPORATION.  All rights reserved.
+ * Copyright (c) 2020-2025, NVIDIA CORPORATION.  All rights reserved.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/xqaDispatcher.h` around lines 1 - 2, Update the
copyright header in xqaDispatcher.h to reflect the current year of modification
(replace "2020-2024" with the appropriate latest year, e.g., "2020-2026") so the
file header aligns with project guidelines; modify the top-of-file comment block
containing the copyright notice accordingly.
🧹 Nitpick comments (5)
cpp/tensorrt_llm/nanobind/thop/bindings.cpp (1)

69-69: Consider a temporary compatibility alias for sparse_mla_topk.

This renames a Python keyword on a bound API, so any downstream caller still passing sparse_mla_topk= will fail immediately. If this binding is used outside the in-tree wrappers, keeping both names for one release—or explicitly calling out the break—would make the migration safer.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/nanobind/thop/bindings.cpp` at line 69, The binding changed
the Python keyword from sparse_mla_topk to num_sparse_topk, which will break
callers using sparse_mla_topk; restore a temporary compatibility alias by
allowing both keywords in the nanobind signature (e.g., accept both
sparse_mla_topk and num_sparse_topk as nb::arg entries) so callers using
sparse_mla_topk continue to work for one release before removal—update the
binding in bindings.cpp around the nb::arg(...) list where num_sparse_topk is
defined to add a second nb::arg("sparse_mla_topk") = std::nullopt (or
equivalent) mapping to the same parameter handling code and add a short
deprecation comment.
cpp/tensorrt_llm/kernels/xqaDispatcher.cpp (3)

455-463: Use consistent type cast for kvPageIdxPtr.

Line 461 casts to int const*, while lines 438 and 451-452 use reinterpret_cast<KVCacheIndex::UnderlyingType const*> for the same field. Use the consistent type for clarity and type safety.

             // Sparse MQA/GQA attention: use trtllm-gen sparse kernel.
             else if (mFixedParams.useTllmGenSparseAttention)
             {
                 tllmRunnerParams.mSparseAttention = true;
                 tllmRunnerParams.mSparseTopK = params.sparse_params.sparse_topk;
                 tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Causal;
-                tllmRunnerParams.kvPageIdxPtr = reinterpret_cast<int const*>(params.sparse_params.sparse_attn_indices);
+                tllmRunnerParams.kvPageIdxPtr = reinterpret_cast<KVCacheIndex::UnderlyingType const*>(params.sparse_params.sparse_attn_indices);
                 tllmRunnerParams.kvPtr = params.sparse_params.sparse_kv_cache_pool;
             }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/xqaDispatcher.cpp` around lines 455 - 463, The cast
for kvPageIdxPtr is inconsistent: replace the reinterpret_cast<int const*> used
when setting tllmRunnerParams.kvPageIdxPtr with the same type used elsewhere,
i.e. reinterpret_cast<KVCacheIndex::UnderlyingType const*>, so update the
assignment in the block that sets tllmRunnerParams.mSparseAttention to use
KVCacheIndex::UnderlyingType const* for kvPageIdxPtr to ensure consistent typing
with the other assignments.

300-306: Redundant assignment of mKernelType.

Line 304 sets mKernelType = FmhaKernelType::Generation, but this is already set unconditionally at line 285. Consider removing the redundant assignment.

         // Sparse MQA/GQA uses trtllm-gen sparse kernel.
         if (mFixedParams.useTllmGenSparseAttention)
         {
             tllmRunnerParams.mSparseAttention = true;
-            tllmRunnerParams.mKernelType = FmhaKernelType::Generation;
             tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Causal;
         }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/xqaDispatcher.cpp` around lines 300 - 306, The
assignment of tllmRunnerParams.mKernelType to FmhaKernelType::Generation inside
the if-block is redundant because mKernelType is already set unconditionally
earlier; remove the duplicate assignment in the block that checks
mFixedParams.useTllmGenSparseAttention (leave mSparseAttention and mMaskType
assignments intact) so only the earlier initialization of mKernelType remains in
the code.

1-2: Update copyright year to include 2025/2026.

The file has been modified but the copyright header still shows 2020-2024. As per coding guidelines, the copyright year should reflect the year of the latest meaningful modification.

-/*
- * Copyright (c) 2020-2024, NVIDIA CORPORATION.  All rights reserved.
+/*
+ * Copyright (c) 2020-2025, NVIDIA CORPORATION.  All rights reserved.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/xqaDispatcher.cpp` around lines 1 - 2, Update the
file header copyright range in the top-of-file comment that begins with "/*" and
the line containing "Copyright (c) 2020-2024, NVIDIA CORPORATION." to reflect
the latest modification year (e.g., change "2020-2024" to "2020-2026"); edit the
Copyright line in xqaDispatcher.cpp's leading comment block accordingly so the
range includes 2025/2026.
tests/unittest/_torch/attention/sparse/test_sparse_attention.py (1)

6-21: Switch these direct imports to module-qualified imports.

This new module pulls several symbols straight into local scope. The repo’s Python import convention is to import modules and use qualified names instead.

As per coding guidelines, "When importing in Python, always maintain the namespace. Import the module, not individual classes or functions."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/attention/sparse/test_sparse_attention.py` around lines
6 - 21, Replace the direct symbol imports with module-qualified imports: instead
of "from tensorrt_llm._torch.attention_backend.sparse.kernel import
triton_convert_req_index_to_global_index" and the other from-imports, import
their modules (e.g., import tensorrt_llm._torch.attention_backend.sparse.kernel
as sparse_kernel; import tensorrt_llm._torch.attention_backend.trtllm as trtllm;
import tensorrt_llm._torch.metadata as metadata; import
tensorrt_llm._torch.pyexecutor.resource_manager as resource_manager; import
tensorrt_llm._utils as utils; import tensorrt_llm.bindings.executor as
bindings_executor; import tensorrt_llm.mapping as mapping) and then update all
uses of triton_convert_req_index_to_global_index, TrtllmAttention,
TrtllmAttentionMetadata, KVCacheParams, KVCacheManager, str_dtype_to_binding,
torch_dtype_to_str, KvCacheConfig, and Mapping to use their module-qualified
names (e.g., sparse_kernel.triton_convert_req_index_to_global_index,
trtllm.TrtllmAttention, metadata.KVCacheParams, resource_manager.KVCacheManager,
utils.str_dtype_to_binding, utils.torch_dtype_to_str,
bindings_executor.KvCacheConfig, mapping.Mapping).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h`:
- Around line 431-436: The code incorrectly treats params.mSparseTopK == 0 as
"attend to zero tokens" by setting maxAttentionWindow to
min(params.mMaxSeqLenKv, params.mSparseTopK); change the logic in the
params.mSparseAttention branch to treat mSparseTopK == 0 as "unknown" and fall
back to params.mMaxSeqLenKv unless params.mSparseTopK > 0; i.e., only apply the
min(...) reduction when params.mSparseTopK > 0 so maxAttentionWindow remains
params.mMaxSeqLenKv and downstream values like numCtasPerSeqKv cannot become
zero.

In `@cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h`:
- Around line 772-780: The sparse-attention branch sets 2-D KV layouts
(shapeK/shapeV, strideK/strideV, tileShapeKv) but later the DATA_TYPE_E2M1 path
still constructs tmaKSf_ and tmaVSf_ from options.kvSfPtr using the
dense/paged-KV shape/stride, which mixes dense scale factors with sparse values;
fix by detecting options.mSparseAttention inside the DATA_TYPE_E2M1 handling and
either (a) reject/throw when FP4 (DATA_TYPE_E2M1) + mSparseAttention are
combined, or (b) build tmaKSf_ and tmaVSf_ using the same sparse 2-D
shape/stride used for shapeK/shapeV (i.e., use shapeK/strideK and
shapeV/strideV) so scales are read with the correct layout; update code paths
that construct tmaKSf_/tmaVSf_ (and any usage of options.kvSfPtr) accordingly
and ensure tileShapeKv is consistently applied.

In `@cpp/tensorrt_llm/thop/attentionOp.cpp`:
- Around line 403-407: The assignment to
op.mRuntimeSparseAttentionParams.sparse_kv_cache_pool is wrong—don’t index into
host_kv_cache_pool_pointers or call .item<int64_t>(), instead reuse the
already-computed host_primary_pool_pointer (the same pointer used for dense KV
cache) and apply the intra_pool_offset computed earlier so sparse attention
points at the current layer’s start; update the branch guarded by
op.mUseSparseAttention && use_kv_cache to set sparse_kv_cache_pool =
reinterpret_cast<char*>(host_primary_pool_pointer) plus the intra_pool_offset
(using the same offset logic/variable from lines 332–358), and remove the
host_kv_cache_pool_pointers.item<int64_t>() usage.

In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py`:
- Around line 647-649: The long RHS expression assigned to
self.skip_indexer_for_gen_reqs is causing an E123 continuation lint error;
modify the assignment for kv_lens[self.num_contexts:self.num_seqs].max().item()
<= self.num_sparse_topk - num_extra_kv_tokens by wrapping the entire comparison
expression in parentheses (or otherwise grouping it across lines) so the
continuation is explicit and the behavior of skip_indexer_for_gen_reqs remains
identical; target the assignment to self.skip_indexer_for_gen_reqs and the
variables kv_lens, num_contexts, num_seqs, num_sparse_topk, and
num_extra_kv_tokens when making the change.

In `@tensorrt_llm/_torch/attention_backend/sparse/kernel.py`:
- Around line 1887-1920: The fallback computation for stride_factor in this
function is missing num_layers (currently: stride_factor = kv_factor *
num_kv_heads * BLOCK_SIZE), which will produce overlapping indices for
multi-layer caches; update the None branch that sets stride_factor to include
num_layers (compute stride_factor = num_layers * kv_factor * num_kv_heads *
BLOCK_SIZE) or, if num_layers may be unspecified/0, validate num_layers > 0 and
raise a clear error asking the caller to pass stride_factor explicitly; modify
the code around the existing "if stride_factor is None:" block and reference
BLOCK_SIZE, num_layers, kv_factor, and num_kv_heads when making the change.
- Around line 1852-1858: The code dereferences block_table_ptr using page_idx
computed from tok which can be -1 for padding; because page_idx <
max_num_blocks_per_req still evaluates true for -1 you must mask padded tokens
before computing or using page_idx. Fix by clamping or gating tok (e.g., create
tok_valid = tok >= 0 or tok_clamped = where(tok < 0, 0, tok)), compute page_idx
and token_in_page from the clamped value, and compute valid_page = (tok >= 0) &
(page_idx < max_num_blocks_per_req); then use that valid_page as the mask for
tl.load when reading bt_ptr into mem_pool_idx (references: tok, BLOCK_SIZE,
page_idx, token_in_page, valid_page, block_table_ptr, bt_ptr, bt_stride0,
bt_stride1, mem_pool_idx, max_num_blocks_per_req).

In `@tensorrt_llm/_torch/attention_backend/trtllm_gen.py`:
- Around line 237-253: The check_sparse_attention function currently claims
support for "skip_softmax" but trtllm_gen_attention never consumes skip-softmax
inputs; change check_sparse_attention (class TrtllmGenAttentionConfig usage) to
NOT return True for algorithm == "skip_softmax" — instead return False with a
clear "not supported yet" message. Also harden access to the algorithm by using
getattr(config.sparse_attention_config, "algorithm", None) (or equivalent) to
avoid AttributeError on unexpected types and update the returned messages for
"skip_softmax" and unknown algorithms accordingly.

In `@tests/unittest/_torch/attention/sparse/test_dsa_indexer.py`:
- Around line 589-593: The test block intended to model extra-KV gating for
multi-token decoding currently contains a no-op expression
`self.max_draft_tokens + 1`; replace this with an actual mutation so the mock
matches runtime behavior—e.g., change that line to increment the draft-token
limit (`self.max_draft_tokens += 1`) inside the `if self.num_generations > 0 and
self.enable_indexer_skip:` block so `skip_indexer_for_gen_reqs` is computed
against the correct adjusted value and the Flake8 E123/behavior mismatch is
resolved.

In `@tests/unittest/_torch/attention/sparse/test_sparse_attention.py`:
- Around line 1-3: This new test file lacks the required NVIDIA/Apache license
header; add the repository’s standard NVIDIA/Apache copyright header (with the
current modification year) at the very top of
tests/unittest/_torch/attention/sparse/test_sparse_attention.py, above the
module docstring (the triple-quoted string) so the file begins with the standard
header block before any code or comments.
- Around line 297-315: The helper misaligns request-row mapping when there are
non-decoding contexts because page_indices and request_ids are computed from all
metadata.request_ids while num_requests and seq_lens only cover active decode
requests; fix by computing page_indices and request_ids from only the active
requests (use metadata.request_ids[:num_requests] or an equivalent slice) before
calling kv_cache_manager.get_batch_cache_indices and before building
host_block_table, so req_idx_per_token (built from num_requests) and block_table
rows correspond to the same active-request subset; keep the same dtype/device
conversions for tensors (req_idx_per_token, host_block_table -> block_table).

---

Outside diff comments:
In
`@cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h`:
- Around line 1-2: Update the file header copyright years to include the current
modification year (e.g., change "2020-2023" to "2020-2026" or "2020-2025-2026"
per project convention) so the top-of-file comment in decoderXQAImplCommon.h
reflects the latest modification year; edit the existing comment block at the
top of the file to replace the old year range and keep the rest of the header
intact.

In `@cpp/tensorrt_llm/kernels/xqaDispatcher.h`:
- Around line 1-2: Update the copyright header in xqaDispatcher.h to reflect the
current year of modification (replace "2020-2024" with the appropriate latest
year, e.g., "2020-2026") so the file header aligns with project guidelines;
modify the top-of-file comment block containing the copyright notice
accordingly.

---

Nitpick comments:
In `@cpp/tensorrt_llm/kernels/xqaDispatcher.cpp`:
- Around line 455-463: The cast for kvPageIdxPtr is inconsistent: replace the
reinterpret_cast<int const*> used when setting tllmRunnerParams.kvPageIdxPtr
with the same type used elsewhere, i.e.
reinterpret_cast<KVCacheIndex::UnderlyingType const*>, so update the assignment
in the block that sets tllmRunnerParams.mSparseAttention to use
KVCacheIndex::UnderlyingType const* for kvPageIdxPtr to ensure consistent typing
with the other assignments.
- Around line 300-306: The assignment of tllmRunnerParams.mKernelType to
FmhaKernelType::Generation inside the if-block is redundant because mKernelType
is already set unconditionally earlier; remove the duplicate assignment in the
block that checks mFixedParams.useTllmGenSparseAttention (leave mSparseAttention
and mMaskType assignments intact) so only the earlier initialization of
mKernelType remains in the code.
- Around line 1-2: Update the file header copyright range in the top-of-file
comment that begins with "/*" and the line containing "Copyright (c) 2020-2024,
NVIDIA CORPORATION." to reflect the latest modification year (e.g., change
"2020-2024" to "2020-2026"); edit the Copyright line in xqaDispatcher.cpp's
leading comment block accordingly so the range includes 2025/2026.

In `@cpp/tensorrt_llm/nanobind/thop/bindings.cpp`:
- Line 69: The binding changed the Python keyword from sparse_mla_topk to
num_sparse_topk, which will break callers using sparse_mla_topk; restore a
temporary compatibility alias by allowing both keywords in the nanobind
signature (e.g., accept both sparse_mla_topk and num_sparse_topk as nb::arg
entries) so callers using sparse_mla_topk continue to work for one release
before removal—update the binding in bindings.cpp around the nb::arg(...) list
where num_sparse_topk is defined to add a second nb::arg("sparse_mla_topk") =
std::nullopt (or equivalent) mapping to the same parameter handling code and add
a short deprecation comment.

In `@tests/unittest/_torch/attention/sparse/test_sparse_attention.py`:
- Around line 6-21: Replace the direct symbol imports with module-qualified
imports: instead of "from tensorrt_llm._torch.attention_backend.sparse.kernel
import triton_convert_req_index_to_global_index" and the other from-imports,
import their modules (e.g., import
tensorrt_llm._torch.attention_backend.sparse.kernel as sparse_kernel; import
tensorrt_llm._torch.attention_backend.trtllm as trtllm; import
tensorrt_llm._torch.metadata as metadata; import
tensorrt_llm._torch.pyexecutor.resource_manager as resource_manager; import
tensorrt_llm._utils as utils; import tensorrt_llm.bindings.executor as
bindings_executor; import tensorrt_llm.mapping as mapping) and then update all
uses of triton_convert_req_index_to_global_index, TrtllmAttention,
TrtllmAttentionMetadata, KVCacheParams, KVCacheManager, str_dtype_to_binding,
torch_dtype_to_str, KvCacheConfig, and Mapping to use their module-qualified
names (e.g., sparse_kernel.triton_convert_req_index_to_global_index,
trtllm.TrtllmAttention, metadata.KVCacheParams, resource_manager.KVCacheManager,
utils.str_dtype_to_binding, utils.torch_dtype_to_str,
bindings_executor.KvCacheConfig, mapping.Mapping).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: dbb780a2-4bea-468d-9268-919bce67d025

📥 Commits

Reviewing files that changed from the base of the PR and between ffb1fed and 2de8183.

📒 Files selected for processing (24)
  • cpp/tensorrt_llm/common/attentionOp.cpp
  • cpp/tensorrt_llm/common/attentionOp.h
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h
  • cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp
  • cpp/tensorrt_llm/kernels/indexerTopK.cu
  • cpp/tensorrt_llm/kernels/sparseAttentionKernels.h
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaReduction.cu
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h
  • cpp/tensorrt_llm/kernels/xqaDispatcher.cpp
  • cpp/tensorrt_llm/kernels/xqaDispatcher.h
  • cpp/tensorrt_llm/nanobind/thop/bindings.cpp
  • cpp/tensorrt_llm/thop/attentionOp.cpp
  • cpp/tensorrt_llm/thop/attentionOp.h
  • tensorrt_llm/_torch/attention_backend/sparse/dsa.py
  • tensorrt_llm/_torch/attention_backend/sparse/kernel.py
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • tensorrt_llm/_torch/attention_backend/trtllm_gen.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py
  • tests/unittest/_torch/attention/sparse/test_dsa_indexer.py
  • tests/unittest/_torch/attention/sparse/test_sparse_attention.py

Comment thread cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h Outdated
Comment thread cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h Outdated
Comment thread cpp/tensorrt_llm/thop/attentionOp.cpp
Comment thread tensorrt_llm/_torch/attention_backend/sparse/dsa.py
Comment thread tensorrt_llm/_torch/attention_backend/sparse/kernel.py
Comment thread tensorrt_llm/_torch/attention_backend/sparse/kernel.py
Comment thread tensorrt_llm/_torch/attention_backend/trtllm_gen.py Outdated
Comment thread tests/unittest/_torch/attention/sparse/test_dsa_indexer.py
Comment thread tests/unittest/_torch/attention/sparse/test_sparse_attention.py
Comment thread tests/unittest/_torch/attention/sparse/test_sparse_attention.py
@heyuhhh heyuhhh requested a review from lfr-0531 March 24, 2026 02:29
@heyuhhh heyuhhh force-pushed the user/yuhangh/add_sparse_attention_support branch from 2de8183 to 1357f03 Compare March 24, 2026 13:30
Comment thread cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h
Comment thread cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h
Comment thread cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h Outdated
Comment thread cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaReduction.cu
Comment thread cpp/tensorrt_llm/thop/attentionOp.cpp
Comment thread cpp/tensorrt_llm/thop/attentionOp.cpp
Comment thread tensorrt_llm/_torch/attention_backend/trtllm.py
Comment thread cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h Outdated
@heyuhhh heyuhhh force-pushed the user/yuhangh/add_sparse_attention_support branch 2 times, most recently from b6afa91 to 5c7319f Compare April 8, 2026 10:17
@lfr-0531
Copy link
Copy Markdown
Collaborator

lfr-0531 commented Apr 8, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42327 [ run ] triggered by Bot. Commit: 5c7319f Link to invocation

@heyuhhh heyuhhh force-pushed the user/yuhangh/add_sparse_attention_support branch from 5c7319f to f88cfb8 Compare April 8, 2026 11:15
@heyuhhh
Copy link
Copy Markdown
Collaborator Author

heyuhhh commented Apr 8, 2026

Hi @lfr-0531 @PerkzZheng @QiJune , could you please take a look again about this PR? I've rebased it and checked the correctness locally but pushed only some upper files for the review. After review finished i'll push the remaining files and run CI. Thanks!

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42327 [ run ] completed with state FAILURE. Commit: 5c7319f
/LLM/main/L0_MergeRequest_PR pipeline #33117 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@heyuhhh heyuhhh requested a review from QiJune April 8, 2026 14:43
Copy link
Copy Markdown
Collaborator

@QiJune QiJune left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@heyuhhh heyuhhh force-pushed the user/yuhangh/add_sparse_attention_support branch from f88cfb8 to 3c8c16d Compare April 13, 2026 06:27
@heyuhhh
Copy link
Copy Markdown
Collaborator Author

heyuhhh commented Apr 13, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42974 [ run ] triggered by Bot. Commit: 3c8c16d Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42974 [ run ] completed with state FAILURE. Commit: 3c8c16d
/LLM/main/L0_MergeRequest_PR pipeline #33629 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@heyuhhh
Copy link
Copy Markdown
Collaborator Author

heyuhhh commented Apr 13, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43025 [ run ] triggered by Bot. Commit: 24e42f5 Link to invocation

@heyuhhh
Copy link
Copy Markdown
Collaborator Author

heyuhhh commented Apr 13, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43066 [ run ] triggered by Bot. Commit: 742ba21 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43798 [ run ] triggered by Bot. Commit: 6892e22 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43798 [ run ] completed with state FAILURE. Commit: 6892e22
/LLM/main/L0_MergeRequest_PR pipeline #34275 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@heyuhhh
Copy link
Copy Markdown
Collaborator Author

heyuhhh commented Apr 17, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43853 [ run ] triggered by Bot. Commit: 19e9eb5 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43853 [ run ] completed with state FAILURE. Commit: 19e9eb5
/LLM/main/L0_MergeRequest_PR pipeline #34314 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@heyuhhh
Copy link
Copy Markdown
Collaborator Author

heyuhhh commented Apr 17, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43939 [ run ] triggered by Bot. Commit: 19e9eb5 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43939 [ run ] completed with state SUCCESS. Commit: 19e9eb5
/LLM/main/L0_MergeRequest_PR pipeline #34387 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@heyuhhh
Copy link
Copy Markdown
Collaborator Author

heyuhhh commented Apr 17, 2026

/bot run

@heyuhhh
Copy link
Copy Markdown
Collaborator Author

heyuhhh commented Apr 17, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44043 [ run ] triggered by Bot. Commit: 53bc41a Link to invocation

@heyuhhh
Copy link
Copy Markdown
Collaborator Author

heyuhhh commented Apr 18, 2026

/bot kill

@heyuhhh
Copy link
Copy Markdown
Collaborator Author

heyuhhh commented Apr 18, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44084 [ kill ] triggered by Bot. Commit: 95036b7 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44043 [ run ] completed with state ABORTED. Commit: 53bc41a

Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44084 [ kill ] completed with state SUCCESS. Commit: 95036b7
Successfully killed previous jobs for commit 95036b7

Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44085 [ run ] triggered by Bot. Commit: 95036b7 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44085 [ run ] completed with state SUCCESS. Commit: 95036b7
/LLM/main/L0_MergeRequest_PR pipeline #34515 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@heyuhhh
Copy link
Copy Markdown
Collaborator Author

heyuhhh commented Apr 18, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44124 [ run ] triggered by Bot. Commit: 95036b7 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44124 [ run ] completed with state SUCCESS. Commit: 95036b7
/LLM/main/L0_MergeRequest_PR pipeline #34551 completed with status: 'SUCCESS'

CI Report

Link to invocation

@lfr-0531 lfr-0531 enabled auto-merge (squash) April 19, 2026 03:32
@lfr-0531 lfr-0531 merged commit cf9963f into NVIDIA:main Apr 19, 2026
5 checks passed
lfr-0531 added a commit to lfr-0531/TensorRT-LLM that referenced this pull request Apr 22, 2026
Upstream commit cf9963f (NVIDIA#12470) renamed the wrapper.plan() sparse
topk parameter from sparse_mla_topk to num_sparse_topk while also
updating all explicit kwarg call sites. This branch's refactor
(04f5012) switched the call site from explicit kwargs to **sparse_params
dict unpack, but built the dict with the old key sparse_mla_topk. The
mismatch silently routed the topk value into wrapper.plan()'s **kwargs,
leaving self.num_sparse_topk at its default 0. The kernel then received
populated sparse_attn_indices with topk=0, producing NaN/Inf (Mode A of
test_forward_sparse_mla_unified) and ~2.0 numerical divergence
(Mode B + test_agrees_with_absorption_path).

Rename SparseParams.sparse_mla_topk to num_sparse_topk to match the
Python/C++ API boundary. The DSAtrtllmAttentionMetadata semantic field
sparse_mla_topk (sourced from sparse_attention_config.index_topk) is
preserved; DSATrtllmAttention.sparse_params() still reads it and writes
into params.num_sparse_topk for the plan() call.

Also sync the MockMetadata field name in test_dsa_indexer.py to match
DSAtrtllmAttentionMetadata (used by prepare_dense_topk_indices when
enable_indexer_skip triggers skip_indexer_for_ctx/gen_reqs).

Fixes test_forward_sparse_mla_unified (20 variants) and
test_short_seq_mha.py::test_agrees_with_absorption_path on B200.

Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
lfr-0531 added a commit to lfr-0531/TensorRT-LLM that referenced this pull request Apr 22, 2026
Upstream commit cf9963f (NVIDIA#12470) renamed the wrapper.plan() sparse
topk parameter from sparse_mla_topk to num_sparse_topk while also
updating all explicit kwarg call sites. This branch's refactor
(04f5012) switched the call site from explicit kwargs to **sparse_params
dict unpack, but built the dict with the old key sparse_mla_topk. The
mismatch silently routed the topk value into wrapper.plan()'s **kwargs,
leaving self.num_sparse_topk at its default 0. The kernel then received
populated sparse_attn_indices with topk=0, producing NaN/Inf (Mode A of
test_forward_sparse_mla_unified) and ~2.0 numerical divergence
(Mode B + test_agrees_with_absorption_path).

Rename SparseParams.sparse_mla_topk to num_sparse_topk to match the
Python/C++ API boundary. The DSAtrtllmAttentionMetadata semantic field
sparse_mla_topk (sourced from sparse_attention_config.index_topk) is
preserved; DSATrtllmAttention.sparse_params() still reads it and writes
into params.num_sparse_topk for the plan() call.

Also sync the MockMetadata field name in test_dsa_indexer.py to match
DSAtrtllmAttentionMetadata (used by prepare_dense_topk_indices when
enable_indexer_skip triggers skip_indexer_for_ctx/gen_reqs).

Fixes test_forward_sparse_mla_unified (20 variants) and
test_short_seq_mha.py::test_agrees_with_absorption_path on B200.

Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants