[TRTLLM-10407][perf] Enable CuteDSL indexer_top_k in model#12236
Merged
yuxianq merged 38 commits intoNVIDIA:mainfrom Mar 19, 2026
Merged
[TRTLLM-10407][perf] Enable CuteDSL indexer_top_k in model#12236yuxianq merged 38 commits intoNVIDIA:mainfrom
yuxianq merged 38 commits intoNVIDIA:mainfrom
Conversation
Add CuTE DSL implementation of filtered top-k kernel optimized for Blackwell (SM100) architecture. The kernel uses radix-based filtering algorithm for efficient top-k selection. Changes: - Add CuTE DSL kernel implementation in blackwell/top_k/: - filtered_top_k_decode_varlen.py: Main decode kernel - filtered_top_k_varlen_util.py: Base classes and utilities - block_scan.py: Block-level prefix sum operations - __init__.py: Module exports - Add torch custom op wrapper in cute_dsl_custom_ops.py: - CuteDSLTopKDecodeSingleCTARunner class with compilation caching - torch.library.custom_op registration for trtllm::cute_dsl_topk_decode_blackwell - Add unit tests in test_indexer_topk.py: - test_cute_dsl_topk_decode with parameterized configurations - Test coverage for multiple batch sizes, dtypes, and sequence lengths Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Move FilteredTopKKernelVarlenDecode import to the main import section with other Blackwell kernels, and remove redundant imports. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Clean up CuTE DSL top-k kernel code for production readiness: - Remove debug print statements that would spam logs - Fix assert statement style (use 'assert not' instead of '== False') - Update copyright year to 2026 This addresses P0 items from code review. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…top-k
Add robust input validation to prevent runtime errors with clear error messages:
- Validate top_k range [1, 2048]
- Validate next_n > 0
- Validate num_copy_bits in {128, 256}
- Validate tensor dimensions and shapes
- Validate supported dtypes {fp16, bf16, fp32}
- Add helpful error messages for unsupported configurations
This addresses P0-3 from code review.
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…runner Add detailed documentation for CuteDSLTopKDecodeSingleCTARunner class and forward method including: - Class overview and purpose - Compilation caching mechanism explanation - Detailed parameter descriptions with constraints - Return value specifications - Error conditions and exception types - Usage notes and performance considerations - Architecture requirements and limitations This addresses P0-4 from code review. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Add explanatory comments for hardcoded values: - large_occupancy threshold (148 = number of SMs in Blackwell GPU) - SMEM size thresholds (tuned values for occupancy optimization) - filtered_topk_max_k limit (2048, can be extended) Add TODOs for future improvements: - Query SM count via API instead of hardcoding - Further tuning of SMEM thresholds - Support for top_k > 2048 This addresses P0-5 from code review and completes all P0 items. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
- Fix NameError: torch_dtype not defined in CuTE DSL topk custom op - Add performance benchmark function comparing CuTE DSL vs standard indexer_topk_decode - Benchmark includes CUDA event timing, warmup, throughput calculation, and correctness verification - Support CLI interface with parameter sweep option Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…OOB bug Add CuTE DSL multi-CTA top-k decode kernel that splits each row into chunks processed by separate CTAs, then merges partial results. This enables top-k on large vocabularies (32K+) that exceed single-CTA capacity. Fix a bug in the merge kernel where large_occupancy forced num_threads_per_cta=512, making tile width (8192) exceed the actual data width (num_ctas_per_row * top_k). The OOB padding elements from _fill_oob were counted in the radix histogram and could be selected as top-k candidates with invalid indices. Cap num_threads_per_cta for merge_blocks so tile width never exceeds max_num_cols. Also fix trivial-case index lookup for merge_blocks in varlen_util, add debug parameter, and update benchmark units from ms to us. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
… multi-CTA switching Fix num_threads_per_cta overflow (2048 > CUDA 1024 limit) when large_occupancy=True + merge_blocks=True by capping to 512. Add cute_dsl_indexer_topk_decode unified op that auto-selects single-CTA or multi-CTA based on dual-condition: SM utilization < 15% AND multi-CTA waves <= 2, preventing performance regressions at high batch sizes while maintaining speedup for low-occupancy cases. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…pk_decode test Extract _run_cute_dsl_topk_test helper to deduplicate shared logic across CuTE DSL top-k tests, and add test_cute_dsl_indexer_topk_decode covering both single-CTA and multi-CTA paths via the unified op. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…pk_decode test Extract _run_cute_dsl_topk_test helper to deduplicate shared logic across CuTE DSL top-k tests, and add test_cute_dsl_indexer_topk_decode covering both single-CTA and multi-CTA paths via the unified op. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
- Replace hardcoded SM count (148) with torch.cuda.get_device_properties().multi_processor_count
- Fix typo: enable_persistent_dymamic_scheduling -> enable_persistent_dynamic_scheduling
- Remove unused .shape[1] expressions in compare_top_k_results
- Remove debug print statements in run_filtered_topk_decode
- Add OOM warnings for large intermediate buffer allocations (>1GB)
- Clean up __init__.py exports to only expose library API
- Move argparse import into __main__ block
- Fix duplicated "Row {row_idx}" in error message
- Fix copy-paste argparse description from GEMM example
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…pilation caching In the DSA path for DeepSeek v3, num_cols (KV cache length) changes per request and grows each decode step, causing a recompilation for every new cache length. Bucketing num_cols to the next power of 2 dramatically reduces recompilations while being safe because num_cols only controls compile-time config (num_threads, smem sizes, index_type) and actual data access is bounded by seq_lens. For multi-CTA kernels, num_ctas_per_row is also included in the cache key since different actual num_cols values can produce different CTA counts while bucketing to the same power of 2. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
The three top-k call sites queried torch.cuda.get_device_properties().multi_processor_count on every forward call. Extract a module-level _get_num_sms() that caches the value on first call, avoiding repeated device property queries in the hot decode path. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…r CuTE DSL top-k Refactor compilation logic out of forward() into classmethod _compile() for both SingleCTA and MultiCTA runners. Add warmup_cute_dsl_topk_kernels() that pre-compiles all (bucketed_num_cols, large_occupancy) combinations, eliminating first-request compilation latency. The warmup function is ready to be called from DSA initialization once the CuTE DSL top-k kernel is integrated into the DSA path. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…FFI env stream - Enable TVM FFI env stream (use_tvm_ffi_env_stream + --enable-tvm-ffi) to eliminate per-call CUDA stream acquisition overhead - Hoist dtype mapping dict to module level to avoid per-call recreation - Use semantic sym_int names (n_rows/n_cols/n_batch) for fake tensors - Have indexer_topk_decode call Runner.forward() directly, bypassing the inner custom op dispatch layer - Move warmup constants inside warmup function Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…ove _bucket_num_cols wrapper - Convert CuteDSLTopKDecodeSingleCTARunner and CuteDSLTopKDecodeMultiCTARunner to use @classmethod for forward(), removing empty __init__ and unnecessary object instantiation per call. - Remove _bucket_num_cols wrapper, use next_positive_power_of_2 directly. - Remove undefined skip_if_no_cute_dsl/skip_if_not_blackwell calls from test helper; use pytest.mark.skipif decorators on test functions instead. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…p-k varlen Add dynamic multi-CTA scheduling mode that assigns CTAs proportionally to each row's actual sequence length, avoiding wasted work on short rows. Key changes: - ComputeDynamicCTAOffsets: CuTE DSL kernel computing row_cta_offsets and row_output_offsets via parallel CTA counting + sequential prefix sum - FilteredTopKKernelVarlenDecode: 1D grid with binary search task mapping when enable_dynamic_multi_cta=True; varlen merge input support - CuteDSLTopKDecodeMultiCTARunner: dynamic=True/False mode with upper-bound grid allocation to avoid DtoH sync - Unit test updated for dynamic mode coverage Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…c multi-CTA top-k Eliminate the separate ComputeDynamicCTAOffsets kernel by computing the prefix sum of per-row CTA counts directly in shared memory within the main filtered_topk_kernel. This reduces the dynamic multi-CTA path from 3 CUDA kernel launches to 2. Key changes: - filtered_topk_kernel: add in-kernel block_prefix_sum to build s_row_cta_offsets in shared memory, replacing global memory lookups - Remove row_cta_offsets and row_output_offsets from kernel signatures - Remove ComputeDynamicCTAOffsets kernel compilation and invocation - Add dynamic parameter to cute_dsl_indexer_topk_decode for easy static/dynamic comparison Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…id for dynamic multi-CTA top-k Switch dynamic multi-CTA from 1D grid + shared-memory prefix sum + binary search to a simple 2D grid (num_rows, num_ctas_per_row) with per-CTA early exit. This eliminates ~80 lines of prefix-sum/binary-search logic, removes the 512-row limit, reduces shared memory usage, and simplifies the host side by reusing the merge buffer. Also updates auto-select heuristics (bf16/fp16 threshold to 131072, SM utilization check to <25%) and makes dynamic=True the default. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…with first-task direct assignment Improve the persistent dynamic scheduling by assigning the first task directly via block index (avoiding an unnecessary atomic), and expose load_balance parameter through the custom op API with test coverage. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…k kernel Add FlashInfer-style fused multi-CTA distributed radix top-k kernel that cooperatively finds the global pivot via multi-round radix select with global histogram merging. Single kernel launch, no intermediate buffer, no merge kernel. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…egion reordering in load_chunk_to_smem Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…nd SM-aware chunk heuristic for CuTE DSL top-k - Add _get_or_alloc_buffer to SingleCTA/MultiCTA/Distributed runners with grow-only strategy for CUDA Graph capture/replay compatibility - Change cute_dsl_indexer_topk_decode to in-place API (mutates output_indices) to eliminate per-call allocation overhead - Add SM-aware chunk_size heuristic in Distributed runner that balances CTA parallelism vs reduce overhead based on batch size and SM count - Rewrite warmup to enumerate all SingleCTA/MultiCTA/Distributed compile configs without requiring runtime max_seq_len - Add use_cute_dsl_topk config field with fallback to CUDA C++ kernel Signed-off-by: Li Min <limin@nvidia.com> Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…ig in DSA indexer tests Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Collaborator
Author
|
/bot run |
Collaborator
|
PR_Github #39415 [ run ] triggered by Bot. Commit: |
yuxianq
reviewed
Mar 18, 2026
yuxianq
reviewed
Mar 18, 2026
yuxianq
reviewed
Mar 18, 2026
yuxianq
reviewed
Mar 18, 2026
yuxianq
reviewed
Mar 18, 2026
yuxianq
reviewed
Mar 18, 2026
yuxianq
reviewed
Mar 18, 2026
yuxianq
reviewed
Mar 18, 2026
yuxianq
reviewed
Mar 18, 2026
yuxianq
reviewed
Mar 18, 2026
yuxianq
reviewed
Mar 18, 2026
- Remove duplicate logger.info in Indexer (already logged in attention.py) - Pass self.index_topk explicitly to indexer_topk_decode - Use cls._compile(*key) to avoid duplicating compile args and key tuple - Simplify output_indices allocation with or-expression - Condense buffer size comments to max values only - Add comment explaining smem overhead calculation Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
yuxianq
reviewed
Mar 18, 2026
yuxianq
reviewed
Mar 18, 2026
yuxianq
reviewed
Mar 18, 2026
yuxianq
reviewed
Mar 18, 2026
yuxianq
reviewed
Mar 18, 2026
Fix output_indices truthiness check, guard use_cute_dsl_topk on CUTLASS DSL availability, pass correct dtype for warmup, query smem capacity dynamically, and only warmup the active multi-CTA variant. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
e677450 to
225223b
Compare
Collaborator
Author
|
/bot run |
Collaborator
|
PR_Github #39452 [ run ] triggered by Bot. Commit: |
Collaborator
|
PR_Github #39452 [ run ] completed with state |
yuxianq
reviewed
Mar 19, 2026
yuxianq
approved these changes
Mar 19, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
1. perf for newly added single_pass_multi_cta kernel.
Note:
dist: the newly added single-pass multi-cta kernel.
mc_dyn: two-pass multi-cta kernel with dynamic workload output
mc_sta: two-pass multi-cta kernel with fix-length output for 1st kernel
fi: flashinfer multi-cta kernel
dtype=float32 batch=1 next_n=1 top_k=2048:
num_tokens dist(us) mc_dyn(us) mc_sta(us) fi(us) fi/dist fi/mcd fi/mcs
dtype=float32 batch=4 next_n=1 top_k=2048:
num_tokens dist(us) mc_dyn(us) mc_sta(us) fi(us) fi/dist fi/mcd fi/mcs
dtype=float32 batch=8 next_n=1 top_k=2048:
num_tokens dist(us) mc_dyn(us) mc_sta(us) fi(us) fi/dist fi/mcd fi/mcs
dtype=float32 batch=16 next_n=1 top_k=2048:
num_tokens dist(us) mc_dyn(us) mc_sta(us) fi(us) fi/dist fi/mcd fi/mcs
dtype=float32 batch=32 next_n=1 top_k=2048:
num_tokens dist(us) mc_dyn(us) mc_sta(us) fi(us) fi/dist fi/mcd fi/mcs
dtype=float32 batch=64 next_n=1 top_k=2048:
num_tokens dist(us) mc_dyn(us) mc_sta(us) fi(us) fi/dist fi/mcd fi/mcs
dtype=float32 batch=128 next_n=1 top_k=2048:
num_tokens dist(us) mc_dyn(us) mc_sta(us) fi(us) fi/dist fi/mcd fi/mcs
dtype=float32 batch=256 next_n=1 top_k=2048:
num_tokens dist(us) mc_dyn(us) mc_sta(us) fi(us) fi/dist fi/mcd fi/mcs
2. cute dsl top_k model accuracy check:
on mmlu/gsm8k dataset, pass the trtllm-eval check.
[03/16/2026-13:19:43] [TRT-LLM] [I] MMLU weighted average accuracy: 87.67 (14042)3. model perf with cute dsl top_k
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.