[None][feat] Support sparse mqa/gqa attention#12470
Conversation
📝 WalkthroughWalkthroughThis PR refactors sparse attention parameter naming and control flow across C++ kernel dispatchers and Python attention backends. It separates paged-sparse-attention from sparse-MLA modes, renames MLA-specific fields to generic sparse-attention equivalents, and extends the Triton sparse-index conversion kernel to support per-KV-head and KV-factor dimensions. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 10
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h (1)
1-2:⚠️ Potential issue | 🟡 MinorUpdate copyright year.
The copyright year shows
2020-2023, but this file is being modified in 2025/2026. Per coding guidelines, the copyright header should reflect the year of the latest meaningful modification.Suggested fix
/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h` around lines 1 - 2, Update the file header copyright years to include the current modification year (e.g., change "2020-2023" to "2020-2026" or "2020-2025-2026" per project convention) so the top-of-file comment in decoderXQAImplCommon.h reflects the latest modification year; edit the existing comment block at the top of the file to replace the old year range and keep the rest of the header intact.cpp/tensorrt_llm/kernels/xqaDispatcher.h (1)
1-2:⚠️ Potential issue | 🟡 MinorUpdate copyright year.
The copyright year shows
2020-2024, but this file is being modified in 2025/2026. Per coding guidelines, the copyright header should reflect the year of the latest meaningful modification.Suggested fix
/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/kernels/xqaDispatcher.h` around lines 1 - 2, Update the copyright header in xqaDispatcher.h to reflect the current year of modification (replace "2020-2024" with the appropriate latest year, e.g., "2020-2026") so the file header aligns with project guidelines; modify the top-of-file comment block containing the copyright notice accordingly.
🧹 Nitpick comments (5)
cpp/tensorrt_llm/nanobind/thop/bindings.cpp (1)
69-69: Consider a temporary compatibility alias forsparse_mla_topk.This renames a Python keyword on a bound API, so any downstream caller still passing
sparse_mla_topk=will fail immediately. If this binding is used outside the in-tree wrappers, keeping both names for one release—or explicitly calling out the break—would make the migration safer.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/nanobind/thop/bindings.cpp` at line 69, The binding changed the Python keyword from sparse_mla_topk to num_sparse_topk, which will break callers using sparse_mla_topk; restore a temporary compatibility alias by allowing both keywords in the nanobind signature (e.g., accept both sparse_mla_topk and num_sparse_topk as nb::arg entries) so callers using sparse_mla_topk continue to work for one release before removal—update the binding in bindings.cpp around the nb::arg(...) list where num_sparse_topk is defined to add a second nb::arg("sparse_mla_topk") = std::nullopt (or equivalent) mapping to the same parameter handling code and add a short deprecation comment.cpp/tensorrt_llm/kernels/xqaDispatcher.cpp (3)
455-463: Use consistent type cast forkvPageIdxPtr.Line 461 casts to
int const*, while lines 438 and 451-452 usereinterpret_cast<KVCacheIndex::UnderlyingType const*>for the same field. Use the consistent type for clarity and type safety.// Sparse MQA/GQA attention: use trtllm-gen sparse kernel. else if (mFixedParams.useTllmGenSparseAttention) { tllmRunnerParams.mSparseAttention = true; tllmRunnerParams.mSparseTopK = params.sparse_params.sparse_topk; tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Causal; - tllmRunnerParams.kvPageIdxPtr = reinterpret_cast<int const*>(params.sparse_params.sparse_attn_indices); + tllmRunnerParams.kvPageIdxPtr = reinterpret_cast<KVCacheIndex::UnderlyingType const*>(params.sparse_params.sparse_attn_indices); tllmRunnerParams.kvPtr = params.sparse_params.sparse_kv_cache_pool; }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/kernels/xqaDispatcher.cpp` around lines 455 - 463, The cast for kvPageIdxPtr is inconsistent: replace the reinterpret_cast<int const*> used when setting tllmRunnerParams.kvPageIdxPtr with the same type used elsewhere, i.e. reinterpret_cast<KVCacheIndex::UnderlyingType const*>, so update the assignment in the block that sets tllmRunnerParams.mSparseAttention to use KVCacheIndex::UnderlyingType const* for kvPageIdxPtr to ensure consistent typing with the other assignments.
300-306: Redundant assignment ofmKernelType.Line 304 sets
mKernelType = FmhaKernelType::Generation, but this is already set unconditionally at line 285. Consider removing the redundant assignment.// Sparse MQA/GQA uses trtllm-gen sparse kernel. if (mFixedParams.useTllmGenSparseAttention) { tllmRunnerParams.mSparseAttention = true; - tllmRunnerParams.mKernelType = FmhaKernelType::Generation; tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Causal; }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/kernels/xqaDispatcher.cpp` around lines 300 - 306, The assignment of tllmRunnerParams.mKernelType to FmhaKernelType::Generation inside the if-block is redundant because mKernelType is already set unconditionally earlier; remove the duplicate assignment in the block that checks mFixedParams.useTllmGenSparseAttention (leave mSparseAttention and mMaskType assignments intact) so only the earlier initialization of mKernelType remains in the code.
1-2: Update copyright year to include 2025/2026.The file has been modified but the copyright header still shows 2020-2024. As per coding guidelines, the copyright year should reflect the year of the latest meaningful modification.
-/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved. +/* + * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/kernels/xqaDispatcher.cpp` around lines 1 - 2, Update the file header copyright range in the top-of-file comment that begins with "/*" and the line containing "Copyright (c) 2020-2024, NVIDIA CORPORATION." to reflect the latest modification year (e.g., change "2020-2024" to "2020-2026"); edit the Copyright line in xqaDispatcher.cpp's leading comment block accordingly so the range includes 2025/2026.tests/unittest/_torch/attention/sparse/test_sparse_attention.py (1)
6-21: Switch these direct imports to module-qualified imports.This new module pulls several symbols straight into local scope. The repo’s Python import convention is to import modules and use qualified names instead.
As per coding guidelines, "When importing in Python, always maintain the namespace. Import the module, not individual classes or functions."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/attention/sparse/test_sparse_attention.py` around lines 6 - 21, Replace the direct symbol imports with module-qualified imports: instead of "from tensorrt_llm._torch.attention_backend.sparse.kernel import triton_convert_req_index_to_global_index" and the other from-imports, import their modules (e.g., import tensorrt_llm._torch.attention_backend.sparse.kernel as sparse_kernel; import tensorrt_llm._torch.attention_backend.trtllm as trtllm; import tensorrt_llm._torch.metadata as metadata; import tensorrt_llm._torch.pyexecutor.resource_manager as resource_manager; import tensorrt_llm._utils as utils; import tensorrt_llm.bindings.executor as bindings_executor; import tensorrt_llm.mapping as mapping) and then update all uses of triton_convert_req_index_to_global_index, TrtllmAttention, TrtllmAttentionMetadata, KVCacheParams, KVCacheManager, str_dtype_to_binding, torch_dtype_to_str, KvCacheConfig, and Mapping to use their module-qualified names (e.g., sparse_kernel.triton_convert_req_index_to_global_index, trtllm.TrtllmAttention, metadata.KVCacheParams, resource_manager.KVCacheManager, utils.str_dtype_to_binding, utils.torch_dtype_to_str, bindings_executor.KvCacheConfig, mapping.Mapping).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h`:
- Around line 431-436: The code incorrectly treats params.mSparseTopK == 0 as
"attend to zero tokens" by setting maxAttentionWindow to
min(params.mMaxSeqLenKv, params.mSparseTopK); change the logic in the
params.mSparseAttention branch to treat mSparseTopK == 0 as "unknown" and fall
back to params.mMaxSeqLenKv unless params.mSparseTopK > 0; i.e., only apply the
min(...) reduction when params.mSparseTopK > 0 so maxAttentionWindow remains
params.mMaxSeqLenKv and downstream values like numCtasPerSeqKv cannot become
zero.
In `@cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h`:
- Around line 772-780: The sparse-attention branch sets 2-D KV layouts
(shapeK/shapeV, strideK/strideV, tileShapeKv) but later the DATA_TYPE_E2M1 path
still constructs tmaKSf_ and tmaVSf_ from options.kvSfPtr using the
dense/paged-KV shape/stride, which mixes dense scale factors with sparse values;
fix by detecting options.mSparseAttention inside the DATA_TYPE_E2M1 handling and
either (a) reject/throw when FP4 (DATA_TYPE_E2M1) + mSparseAttention are
combined, or (b) build tmaKSf_ and tmaVSf_ using the same sparse 2-D
shape/stride used for shapeK/shapeV (i.e., use shapeK/strideK and
shapeV/strideV) so scales are read with the correct layout; update code paths
that construct tmaKSf_/tmaVSf_ (and any usage of options.kvSfPtr) accordingly
and ensure tileShapeKv is consistently applied.
In `@cpp/tensorrt_llm/thop/attentionOp.cpp`:
- Around line 403-407: The assignment to
op.mRuntimeSparseAttentionParams.sparse_kv_cache_pool is wrong—don’t index into
host_kv_cache_pool_pointers or call .item<int64_t>(), instead reuse the
already-computed host_primary_pool_pointer (the same pointer used for dense KV
cache) and apply the intra_pool_offset computed earlier so sparse attention
points at the current layer’s start; update the branch guarded by
op.mUseSparseAttention && use_kv_cache to set sparse_kv_cache_pool =
reinterpret_cast<char*>(host_primary_pool_pointer) plus the intra_pool_offset
(using the same offset logic/variable from lines 332–358), and remove the
host_kv_cache_pool_pointers.item<int64_t>() usage.
In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py`:
- Around line 647-649: The long RHS expression assigned to
self.skip_indexer_for_gen_reqs is causing an E123 continuation lint error;
modify the assignment for kv_lens[self.num_contexts:self.num_seqs].max().item()
<= self.num_sparse_topk - num_extra_kv_tokens by wrapping the entire comparison
expression in parentheses (or otherwise grouping it across lines) so the
continuation is explicit and the behavior of skip_indexer_for_gen_reqs remains
identical; target the assignment to self.skip_indexer_for_gen_reqs and the
variables kv_lens, num_contexts, num_seqs, num_sparse_topk, and
num_extra_kv_tokens when making the change.
In `@tensorrt_llm/_torch/attention_backend/sparse/kernel.py`:
- Around line 1887-1920: The fallback computation for stride_factor in this
function is missing num_layers (currently: stride_factor = kv_factor *
num_kv_heads * BLOCK_SIZE), which will produce overlapping indices for
multi-layer caches; update the None branch that sets stride_factor to include
num_layers (compute stride_factor = num_layers * kv_factor * num_kv_heads *
BLOCK_SIZE) or, if num_layers may be unspecified/0, validate num_layers > 0 and
raise a clear error asking the caller to pass stride_factor explicitly; modify
the code around the existing "if stride_factor is None:" block and reference
BLOCK_SIZE, num_layers, kv_factor, and num_kv_heads when making the change.
- Around line 1852-1858: The code dereferences block_table_ptr using page_idx
computed from tok which can be -1 for padding; because page_idx <
max_num_blocks_per_req still evaluates true for -1 you must mask padded tokens
before computing or using page_idx. Fix by clamping or gating tok (e.g., create
tok_valid = tok >= 0 or tok_clamped = where(tok < 0, 0, tok)), compute page_idx
and token_in_page from the clamped value, and compute valid_page = (tok >= 0) &
(page_idx < max_num_blocks_per_req); then use that valid_page as the mask for
tl.load when reading bt_ptr into mem_pool_idx (references: tok, BLOCK_SIZE,
page_idx, token_in_page, valid_page, block_table_ptr, bt_ptr, bt_stride0,
bt_stride1, mem_pool_idx, max_num_blocks_per_req).
In `@tensorrt_llm/_torch/attention_backend/trtllm_gen.py`:
- Around line 237-253: The check_sparse_attention function currently claims
support for "skip_softmax" but trtllm_gen_attention never consumes skip-softmax
inputs; change check_sparse_attention (class TrtllmGenAttentionConfig usage) to
NOT return True for algorithm == "skip_softmax" — instead return False with a
clear "not supported yet" message. Also harden access to the algorithm by using
getattr(config.sparse_attention_config, "algorithm", None) (or equivalent) to
avoid AttributeError on unexpected types and update the returned messages for
"skip_softmax" and unknown algorithms accordingly.
In `@tests/unittest/_torch/attention/sparse/test_dsa_indexer.py`:
- Around line 589-593: The test block intended to model extra-KV gating for
multi-token decoding currently contains a no-op expression
`self.max_draft_tokens + 1`; replace this with an actual mutation so the mock
matches runtime behavior—e.g., change that line to increment the draft-token
limit (`self.max_draft_tokens += 1`) inside the `if self.num_generations > 0 and
self.enable_indexer_skip:` block so `skip_indexer_for_gen_reqs` is computed
against the correct adjusted value and the Flake8 E123/behavior mismatch is
resolved.
In `@tests/unittest/_torch/attention/sparse/test_sparse_attention.py`:
- Around line 1-3: This new test file lacks the required NVIDIA/Apache license
header; add the repository’s standard NVIDIA/Apache copyright header (with the
current modification year) at the very top of
tests/unittest/_torch/attention/sparse/test_sparse_attention.py, above the
module docstring (the triple-quoted string) so the file begins with the standard
header block before any code or comments.
- Around line 297-315: The helper misaligns request-row mapping when there are
non-decoding contexts because page_indices and request_ids are computed from all
metadata.request_ids while num_requests and seq_lens only cover active decode
requests; fix by computing page_indices and request_ids from only the active
requests (use metadata.request_ids[:num_requests] or an equivalent slice) before
calling kv_cache_manager.get_batch_cache_indices and before building
host_block_table, so req_idx_per_token (built from num_requests) and block_table
rows correspond to the same active-request subset; keep the same dtype/device
conversions for tensors (req_idx_per_token, host_block_table -> block_table).
---
Outside diff comments:
In
`@cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h`:
- Around line 1-2: Update the file header copyright years to include the current
modification year (e.g., change "2020-2023" to "2020-2026" or "2020-2025-2026"
per project convention) so the top-of-file comment in decoderXQAImplCommon.h
reflects the latest modification year; edit the existing comment block at the
top of the file to replace the old year range and keep the rest of the header
intact.
In `@cpp/tensorrt_llm/kernels/xqaDispatcher.h`:
- Around line 1-2: Update the copyright header in xqaDispatcher.h to reflect the
current year of modification (replace "2020-2024" with the appropriate latest
year, e.g., "2020-2026") so the file header aligns with project guidelines;
modify the top-of-file comment block containing the copyright notice
accordingly.
---
Nitpick comments:
In `@cpp/tensorrt_llm/kernels/xqaDispatcher.cpp`:
- Around line 455-463: The cast for kvPageIdxPtr is inconsistent: replace the
reinterpret_cast<int const*> used when setting tllmRunnerParams.kvPageIdxPtr
with the same type used elsewhere, i.e.
reinterpret_cast<KVCacheIndex::UnderlyingType const*>, so update the assignment
in the block that sets tllmRunnerParams.mSparseAttention to use
KVCacheIndex::UnderlyingType const* for kvPageIdxPtr to ensure consistent typing
with the other assignments.
- Around line 300-306: The assignment of tllmRunnerParams.mKernelType to
FmhaKernelType::Generation inside the if-block is redundant because mKernelType
is already set unconditionally earlier; remove the duplicate assignment in the
block that checks mFixedParams.useTllmGenSparseAttention (leave mSparseAttention
and mMaskType assignments intact) so only the earlier initialization of
mKernelType remains in the code.
- Around line 1-2: Update the file header copyright range in the top-of-file
comment that begins with "/*" and the line containing "Copyright (c) 2020-2024,
NVIDIA CORPORATION." to reflect the latest modification year (e.g., change
"2020-2024" to "2020-2026"); edit the Copyright line in xqaDispatcher.cpp's
leading comment block accordingly so the range includes 2025/2026.
In `@cpp/tensorrt_llm/nanobind/thop/bindings.cpp`:
- Line 69: The binding changed the Python keyword from sparse_mla_topk to
num_sparse_topk, which will break callers using sparse_mla_topk; restore a
temporary compatibility alias by allowing both keywords in the nanobind
signature (e.g., accept both sparse_mla_topk and num_sparse_topk as nb::arg
entries) so callers using sparse_mla_topk continue to work for one release
before removal—update the binding in bindings.cpp around the nb::arg(...) list
where num_sparse_topk is defined to add a second nb::arg("sparse_mla_topk") =
std::nullopt (or equivalent) mapping to the same parameter handling code and add
a short deprecation comment.
In `@tests/unittest/_torch/attention/sparse/test_sparse_attention.py`:
- Around line 6-21: Replace the direct symbol imports with module-qualified
imports: instead of "from tensorrt_llm._torch.attention_backend.sparse.kernel
import triton_convert_req_index_to_global_index" and the other from-imports,
import their modules (e.g., import
tensorrt_llm._torch.attention_backend.sparse.kernel as sparse_kernel; import
tensorrt_llm._torch.attention_backend.trtllm as trtllm; import
tensorrt_llm._torch.metadata as metadata; import
tensorrt_llm._torch.pyexecutor.resource_manager as resource_manager; import
tensorrt_llm._utils as utils; import tensorrt_llm.bindings.executor as
bindings_executor; import tensorrt_llm.mapping as mapping) and then update all
uses of triton_convert_req_index_to_global_index, TrtllmAttention,
TrtllmAttentionMetadata, KVCacheParams, KVCacheManager, str_dtype_to_binding,
torch_dtype_to_str, KvCacheConfig, and Mapping to use their module-qualified
names (e.g., sparse_kernel.triton_convert_req_index_to_global_index,
trtllm.TrtllmAttention, metadata.KVCacheParams, resource_manager.KVCacheManager,
utils.str_dtype_to_binding, utils.torch_dtype_to_str,
bindings_executor.KvCacheConfig, mapping.Mapping).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: dbb780a2-4bea-468d-9268-919bce67d025
📒 Files selected for processing (24)
cpp/tensorrt_llm/common/attentionOp.cppcpp/tensorrt_llm/common/attentionOp.hcpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.hcpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.hcpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.hcpp/tensorrt_llm/kernels/fmhaDispatcher.cppcpp/tensorrt_llm/kernels/indexerTopK.cucpp/tensorrt_llm/kernels/sparseAttentionKernels.hcpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.hcpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaReduction.cucpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.hcpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.hcpp/tensorrt_llm/kernels/xqaDispatcher.cppcpp/tensorrt_llm/kernels/xqaDispatcher.hcpp/tensorrt_llm/nanobind/thop/bindings.cppcpp/tensorrt_llm/thop/attentionOp.cppcpp/tensorrt_llm/thop/attentionOp.htensorrt_llm/_torch/attention_backend/sparse/dsa.pytensorrt_llm/_torch/attention_backend/sparse/kernel.pytensorrt_llm/_torch/attention_backend/trtllm.pytensorrt_llm/_torch/attention_backend/trtllm_gen.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.pytests/unittest/_torch/attention/sparse/test_dsa_indexer.pytests/unittest/_torch/attention/sparse/test_sparse_attention.py
2de8183 to
1357f03
Compare
b6afa91 to
5c7319f
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #42327 [ run ] triggered by Bot. Commit: |
5c7319f to
f88cfb8
Compare
|
Hi @lfr-0531 @PerkzZheng @QiJune , could you please take a look again about this PR? I've rebased it and checked the correctness locally but pushed only some upper files for the review. After review finished i'll push the remaining files and run CI. Thanks! |
|
PR_Github #42327 [ run ] completed with state
|
f88cfb8 to
3c8c16d
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #42974 [ run ] triggered by Bot. Commit: |
|
PR_Github #42974 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #43025 [ run ] triggered by Bot. Commit: |
|
/bot run --disable-fail-fast |
|
PR_Github #43066 [ run ] triggered by Bot. Commit: |
|
PR_Github #43798 [ run ] triggered by Bot. Commit: |
|
PR_Github #43798 [ run ] completed with state
|
|
/bot run |
|
PR_Github #43853 [ run ] triggered by Bot. Commit: |
|
PR_Github #43853 [ run ] completed with state
|
|
/bot run |
|
PR_Github #43939 [ run ] triggered by Bot. Commit: |
|
PR_Github #43939 [ run ] completed with state
|
|
/bot run |
|
/bot run --disable-fail-fast |
|
PR_Github #44043 [ run ] triggered by Bot. Commit: |
|
/bot kill |
|
/bot run --disable-fail-fast |
|
PR_Github #44084 [ kill ] triggered by Bot. Commit: |
|
PR_Github #44043 [ run ] completed with state |
|
PR_Github #44084 [ kill ] completed with state |
|
PR_Github #44085 [ run ] triggered by Bot. Commit: |
|
PR_Github #44085 [ run ] completed with state
|
|
/bot run |
|
PR_Github #44124 [ run ] triggered by Bot. Commit: |
|
PR_Github #44124 [ run ] completed with state |
Upstream commit cf9963f (NVIDIA#12470) renamed the wrapper.plan() sparse topk parameter from sparse_mla_topk to num_sparse_topk while also updating all explicit kwarg call sites. This branch's refactor (04f5012) switched the call site from explicit kwargs to **sparse_params dict unpack, but built the dict with the old key sparse_mla_topk. The mismatch silently routed the topk value into wrapper.plan()'s **kwargs, leaving self.num_sparse_topk at its default 0. The kernel then received populated sparse_attn_indices with topk=0, producing NaN/Inf (Mode A of test_forward_sparse_mla_unified) and ~2.0 numerical divergence (Mode B + test_agrees_with_absorption_path). Rename SparseParams.sparse_mla_topk to num_sparse_topk to match the Python/C++ API boundary. The DSAtrtllmAttentionMetadata semantic field sparse_mla_topk (sourced from sparse_attention_config.index_topk) is preserved; DSATrtllmAttention.sparse_params() still reads it and writes into params.num_sparse_topk for the plan() call. Also sync the MockMetadata field name in test_dsa_indexer.py to match DSAtrtllmAttentionMetadata (used by prepare_dense_topk_indices when enable_indexer_skip triggers skip_indexer_for_ctx/gen_reqs). Fixes test_forward_sparse_mla_unified (20 variants) and test_short_seq_mha.py::test_agrees_with_absorption_path on B200. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
Upstream commit cf9963f (NVIDIA#12470) renamed the wrapper.plan() sparse topk parameter from sparse_mla_topk to num_sparse_topk while also updating all explicit kwarg call sites. This branch's refactor (04f5012) switched the call site from explicit kwargs to **sparse_params dict unpack, but built the dict with the old key sparse_mla_topk. The mismatch silently routed the topk value into wrapper.plan()'s **kwargs, leaving self.num_sparse_topk at its default 0. The kernel then received populated sparse_attn_indices with topk=0, producing NaN/Inf (Mode A of test_forward_sparse_mla_unified) and ~2.0 numerical divergence (Mode B + test_agrees_with_absorption_path). Rename SparseParams.sparse_mla_topk to num_sparse_topk to match the Python/C++ API boundary. The DSAtrtllmAttentionMetadata semantic field sparse_mla_topk (sourced from sparse_attention_config.index_topk) is preserved; DSATrtllmAttention.sparse_params() still reads it and writes into params.num_sparse_topk for the plan() call. Also sync the MockMetadata field name in test_dsa_indexer.py to match DSAtrtllmAttentionMetadata (used by prepare_dense_topk_indices when enable_indexer_skip triggers skip_indexer_for_ctx/gen_reqs). Fixes test_forward_sparse_mla_unified (20 variants) and test_short_seq_mha.py::test_agrees_with_absorption_path on B200. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
Summary by CodeRabbit
Refactor
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.