[None][fix] DSv4 indexer: resize radix aux scratch in update_spec_dec_param#14443
Open
longcheng-nv wants to merge 3 commits into
Open
Conversation
…_param PR NVIDIA#14297 added persistent radix_aux_{indices,logits} scratch buffers in DSAtrtllmAttentionMetadata.__init__ sized to max_num_sequences * (1 + max_draft_tokens), and added a kernel-side TORCH_CHECK in IndexerTopKOp.cpp that the buffers' numel >= num_rows * blocks_per_row * index_topk. It also patched update_spec_dec_param to resize kv_lens_expanded_host (via create_expanded_buffers) and heuristic_scratch_values when max_draft_tokens changes at runtime, but missed the parallel radix buffers. When the framework reconfigures max_draft_tokens (e.g. spec decoding warmup -> real run, or disagg gen server picking up a different draft length), num_rows starts reflecting the new bound while the radix aux buffers stay at their construction-time size, triggering RuntimeError: radix_aux_{indices,logits} must hold at least num_rows*blocks_per_row*index_topk elements (got 10240 / 10240, need 16384) inside torch.ops.trtllm.indexer_topk_decode on the next forward step. This patch mirrors the existing heuristic_scratch_values resize block for the radix buffers, allocated unconditionally to match the __init__ path (the radix dispatcher can still run when enable_heuristic_topk=True falls back for small numColumns). Made-with: claude-code (https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
98bf3f2 to
f779a46
Compare
Collaborator
Author
|
/bot run --disable-fail-fast |
Five previously-uncovered files added to the single-B200 DS pre-merge list, covering kernel paths exercised by DeepSeek-V4 modeling: - unittest/_torch/thop/parallel/test_indexer_topk.py - unittest/_torch/attention/sparse/dsa/test_dsa_indexer.py - unittest/_torch/attention/sparse/dsa/test_dsa_sparse_mla.py - unittest/_torch/thop/parallel/test_dsv3_fused_a_gemm.py - unittest/_torch/thop/parallel/test_dsv3_router_gemm.py The first three guard the radix_aux scratch + update_spec_dec_param resize path fixed in the preceding commit; the last two cover torch.ops.trtllm.dsv3_fused_a_gemm_op and dsv3_router_gemm_op which modeling_deepseekv4.py invokes directly. All files declare world_size=1 / tp_size=1 and use skip_pre_blackwell or skip_pre_hopper gating, so they fit the single-B200 pre_merge condition. Broader cleanup of the remaining ~25 uncovered files under tests/unittest/_torch/thop/parallel/ (FP4 / FP8 / W4A* generic quant kernels) is tracked as a follow-up. Made-with: claude-code (https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
f779a46 to
975ef28
Compare
Collaborator
|
PR_Github #49873 [ run ] triggered by Bot. Commit: |
Collaborator
Author
|
/bot run --disable-fail-fast |
lfr-0531
approved these changes
May 22, 2026
Collaborator
|
PR_Github #49876 [ run ] triggered by Bot. Commit: |
Collaborator
|
PR_Github #49873 [ run ] completed with state |
…PRoundingMode
Migrate `cute.arch.fma_packed_f32x2(..., rnd=nvvm.RoundingModeKind.RN)` to
`rnd='rn'` (8 callsites in fp8_paged_mqa_logits.py). The cute DSL
FPRoundingMode parameter now accepts only string literals; the enum form
raises:
TypeError: Expected a string literal for FPRoundingMode, but got enum
'RoundingModeKind.RN'. Please pass a string instead (e.g., 'rn' instead
of RoundingModeKind.RN).
on the DSL kernel path used by torch.ops.trtllm.cute_dsl_fp8_paged_mqa_logits.
Drop now-unused `nvvm` from the cutlass._mlir.dialects import.
Verified locally on B200:
pytest tests/unittest/_torch/attention/sparse/dsa/test_dsa_indexer.py \
-k "test_indexer_decode_with_paged_kv_cache and dsl"
-> 8 passed in 11.85s (previously 8 failed with FPRoundingMode TypeError).
Made-with: Claude Code
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
Collaborator
Author
|
/bot run --disable-fail-fast |
Collaborator
|
PR_Github #49910 [ run ] triggered by Bot. Commit: |
Collaborator
|
PR_Github #49876 [ run ] completed with state |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fix a buffer-sizing regression introduced together with the persistent
radix aux scratch in PR #14297, and backfill missing CI coverage for
DSv4-relevant unit tests in
tests/unittest/_torch/thop/parallel/andtests/unittest/_torch/attention/sparse/dsa/.Root cause
DSAtrtllmAttentionMetadata.__init__allocatesradix_aux_indices/radix_aux_logitssized tomax_num_sequences * (1 + max_draft_tokens) * kMaxBlocksPerRowDecode(=10) * num_sparse_topk,and
cpp/tensorrt_llm/thop/IndexerTopKOp.cppenforcesnumel >= num_rows * blocks_per_row * index_topkat everyindexer_topk_decodecall. The same PR taughtupdate_spec_dec_paramto resizekv_lens_expanded_hostandheuristic_scratch_valueswhenevermax_draft_tokenschanges atruntime — but the parallel radix buffers were not added to that
resize path.
When the framework reconfigures
max_draft_tokenspost-construction(spec decoding warmup → real run, MTP depth change, disagg gen server
picking up a larger draft length, etc.),
num_rowsreflects the newbound while the radix buffers stay at their construction-time size,
producing:
inside
torch.ops.trtllm.indexer_topk_decode. The math matches atypical reconfig:
10240 = max_num_sequences*(1+max_draft_tokens_init) * 10 * num_sparse_topk(
init=1Flash /init=0Pro),16384 = num_rows_new * blocks_per_row * index_topk.Fix
Mirror the existing
heuristic_scratch_valuesresize block forradix_aux_indices/radix_aux_logits. Unlike heuristic scratch theradix buffers are allocated unconditionally (the radix dispatcher can
fire when
enable_heuristic_topk=Truefalls back for smallnumColumns), so the resize is placed outside theif self.enable_heuristic_topk:guard but inside the existingif self.max_num_sequences * (1 + self.max_draft_tokens) != init_shape:gate.
CI coverage backfill
Investigation revealed that this regression slipped past pre-merge CI
because the relevant unit tests were never listed in any
tests/integration/test_lists/test-db/*.yml:tests/unittest/_torch/thop/parallel/test_indexer_topk.py— directlyexercises the persistent radix_aux scratch and CUDA Graph capture/replay
added in [None][fix] DSv4 indexer: stable radix aux scratch for CUDA Graph safety #14297 (its own docstring documents the very bug being fixed).
tests/unittest/_torch/attention/sparse/dsa/test_dsa_indexer.py—contains
TestPrepareRestoreAttnMetadataForDraftReplay, the naturalregression site for the resize path.
tests/unittest/_torch/attention/sparse/dsa/test_dsa_sparse_mla.py—DSA sparse-MLA forward, previously uncovered.
tests/unittest/_torch/thop/parallel/test_dsv3_fused_a_gemm.pyandtest_dsv3_router_gemm.py— the only two ops fromthop/parallel/that
modeling_deepseekv4.pyinvokes directly.All five files added to
l0_b200_ds.yml(single-B200 pre-merge, matchesworld_size=1declarations in the test files andskip_pre_blackwellgating).
Test Coverage
transition: the runtime
TORCH_CHECKno longer fires after the fix.CI going forward (
TestPrepareRestoreAttnMetadataForDraftReplay,test_indexer_topk_decode_radix_aux_cuda_graph_replay, and the_v4_cr4/_distTop-K parametric sweeps).IndexerTopKOp.cpp/indexerTopK.cubehaviour preserved.Notes
heuristic_scratch_values, so no new convention is introduced.tests/unittest/_torch/thop/parallel/(FP4 / FP8 / W4A* genericquant kernels) is tracked as a dedicated follow-up to keep this
fix-PR scope bounded.
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.🤖 Generated with Claude Code