Skip to content

[TRTLLM-10407][perf] Enable CuteDSL indexer_top_k in model#12236

Merged
yuxianq merged 38 commits intoNVIDIA:mainfrom
limin2021:cute_dsl_single_pass_multi_cta_top_k
Mar 19, 2026
Merged

[TRTLLM-10407][perf] Enable CuteDSL indexer_top_k in model#12236
yuxianq merged 38 commits intoNVIDIA:mainfrom
limin2021:cute_dsl_single_pass_multi_cta_top_k

Conversation

@limin2021
Copy link
Collaborator

@limin2021 limin2021 commented Mar 16, 2026

Description

  1. add single pass multi_cta top_k kernel.
  2. add config to enable cute dsl top_k in model.

1. perf for newly added single_pass_multi_cta kernel.

Note:
dist: the newly added single-pass multi-cta kernel.
mc_dyn: two-pass multi-cta kernel with dynamic workload output
mc_sta: two-pass multi-cta kernel with fix-length output for 1st kernel
fi: flashinfer multi-cta kernel

dtype=float32 batch=1 next_n=1 top_k=2048:

num_tokens dist(us) mc_dyn(us) mc_sta(us) fi(us) fi/dist fi/mcd fi/mcs

    8192            7.95           12.27           11.82           11.80      1.48x      0.96x      1.00x
   16384           12.37           15.58           15.08           16.37      1.32x      1.05x      1.09x
   32768           19.88           20.73           20.36           25.09      1.26x      1.21x      1.23x
   65536           19.73           22.53           22.04           30.91      1.57x      1.37x      1.40x
  131072           19.63           26.76           26.26           34.84      1.77x      1.30x      1.33x
  262144           19.64           35.97           35.40           37.90      1.93x      1.05x      1.07x

dtype=float32 batch=4 next_n=1 top_k=2048:

num_tokens dist(us) mc_dyn(us) mc_sta(us) fi(us) fi/dist fi/mcd fi/mcs

    8192            8.00           12.38           11.93           11.83      1.48x      0.96x      0.99x
   16384           12.44           15.73           15.31           16.40      1.32x      1.04x      1.07x
   32768           20.05           21.35           21.04           25.11      1.25x      1.18x      1.19x
   65536           19.94           23.36           22.83           33.37      1.67x      1.43x      1.46x
  131072           19.95           27.89           27.34           37.33      1.87x      1.34x      1.37x
  262144           20.00           37.02           36.55           40.07      2.00x      1.08x      1.10x

dtype=float32 batch=8 next_n=1 top_k=2048:

num_tokens dist(us) mc_dyn(us) mc_sta(us) fi(us) fi/dist fi/mcd fi/mcs

    8192            7.98           12.40           11.98           11.83      1.48x      0.95x      0.99x
   16384           12.44           16.35           15.99           16.41      1.32x      1.00x      1.03x
   32768           20.28           21.38           21.06           25.14      1.24x      1.18x      1.19x
   65536           20.47           23.25           22.72           33.57      1.64x      1.44x      1.48x
  131072           20.25           28.15           27.57           37.45      1.85x      1.33x      1.36x
  262144           23.36           37.19           36.75           40.45      1.73x      1.09x      1.10x

dtype=float32 batch=16 next_n=1 top_k=2048:

num_tokens dist(us) mc_dyn(us) mc_sta(us) fi(us) fi/dist fi/mcd fi/mcs

    8192            8.19           12.76           12.37           12.06      1.47x      0.95x      0.98x
   16384           12.69           16.68           16.30           16.67      1.31x      1.00x      1.02x
   32768           20.64           22.40           21.96           25.35      1.23x      1.13x      1.15x
   65536           20.79           23.89           23.35           33.75      1.62x      1.41x      1.45x
  131072           23.90           28.41           27.88           37.58      1.57x      1.32x      1.35x
  262144           30.23           49.08           48.47           40.75      1.35x      0.83x      0.84x

dtype=float32 batch=32 next_n=1 top_k=2048:

num_tokens dist(us) mc_dyn(us) mc_sta(us) fi(us) fi/dist fi/mcd fi/mcs

    8192            8.23           12.88           12.47           12.09      1.47x      0.94x      0.97x
   16384           12.72           16.75           16.29           16.72      1.31x      1.00x      1.03x
   32768           20.98           22.82           22.36           25.48      1.21x      1.12x      1.14x
   65536           24.49           24.76           24.22           33.84      1.38x      1.37x      1.40x
  131072           30.99           40.25           39.54           37.72      1.22x      0.94x      0.95x
  262144           58.05           72.71           71.60           78.48      1.35x      1.08x      1.10x

dtype=float32 batch=64 next_n=1 top_k=2048:

num_tokens dist(us) mc_dyn(us) mc_sta(us) fi(us) fi/dist fi/mcd fi/mcs

    8192            8.28           12.99           12.54           12.15      1.47x      0.94x      0.97x
   16384           12.78           17.12           16.72           16.79      1.31x      0.98x      1.00x
   32768           20.18           23.11           22.61           25.58      1.27x      1.11x      1.13x
   65536           31.96           33.09           32.63           34.19      1.07x      1.03x      1.05x
  131072           59.42           63.59           62.46           72.53      1.22x      1.14x      1.16x
  262144          113.94          110.46          108.60          116.13      1.02x      1.05x      1.07x

dtype=float32 batch=128 next_n=1 top_k=2048:

num_tokens dist(us) mc_dyn(us) mc_sta(us) fi(us) fi/dist fi/mcd fi/mcs

    8192            8.42           13.87           13.43           12.29      1.46x      0.89x      0.92x
   16384           12.91           17.35           16.94           16.99      1.32x      0.98x      1.00x
   32768           20.37           31.52           30.96           25.80      1.27x      0.82x      0.83x
   65536           61.64           51.37           50.81           65.66      1.07x      1.28x      1.29x
  131072          103.79          101.03           99.19          108.08      1.04x      1.07x      1.09x
  262144          195.64          203.19          199.99          208.18      1.06x      1.02x      1.04x

dtype=float32 batch=256 next_n=1 top_k=2048:

num_tokens dist(us) mc_dyn(us) mc_sta(us) fi(us) fi/dist fi/mcd fi/mcs

    8192           15.29           18.21           17.80           22.75      1.49x      1.25x      1.28x
   16384           24.20           25.84           25.42           31.95      1.32x      1.24x      1.26x
   32768           38.95           50.93           50.43           49.42      1.27x      0.97x      0.98x
   65536          119.26           79.23           78.63          127.70      1.07x      1.61x      1.62x
  131072          210.83          148.73          148.38          225.89      1.07x      1.52x      1.52x
  262144          350.05          286.21          285.58          372.78      1.06x      1.30x      1.31x

2. cute dsl top_k model accuracy check:

on mmlu/gsm8k dataset, pass the trtllm-eval check.

[03/16/2026-13:19:43] [TRT-LLM] [I] MMLU weighted average accuracy: 87.67 (14042)

[03/16/2026-13:33:06] [TRT-LLM] [I] lm-eval gsm8k results (scores normalized to range 0~100):
|Tasks|Version|     Filter     |n-shot|  Metric   |   | Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|------:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |95.1478|±  |0.5918|
|     |       |strict-match    |     5|exact_match|↑  |95.1478|±  |0.5918|

[03/16/2026-13:33:06] [TRT-LLM] [I] lm-eval gsm8k average accuracy: 95.15

3. model perf with cute dsl top_k

max_batch_size=32
max_num_tokens=8192            # chunked prefill: each chunk <= this size
# max_seq_len=32768              # 130K context window (KV cache capacity)
isl=32768
osl=1024
max_seq_len=isl+osl
max_num_requests=256
TP=8
EP=8

# python benchmarks/cpp/prepare_dataset.py \
#     --tokenizer=$model_card \
#     --stdout token-norm-dist \
#     --num-requests=$max_num_requests \
#     --input-mean=$isl \
#     --output-mean=$osl \
#     --input-stdev=0 \
#     --output-stdev=0 > ${DATASET}
image

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

limin2021 and others added 26 commits March 4, 2026 00:07
Add CuTE DSL implementation of filtered top-k kernel optimized for
Blackwell (SM100) architecture. The kernel uses radix-based filtering
algorithm for efficient top-k selection.

Changes:
- Add CuTE DSL kernel implementation in blackwell/top_k/:
  - filtered_top_k_decode_varlen.py: Main decode kernel
  - filtered_top_k_varlen_util.py: Base classes and utilities
  - block_scan.py: Block-level prefix sum operations
  - __init__.py: Module exports
- Add torch custom op wrapper in cute_dsl_custom_ops.py:
  - CuteDSLTopKDecodeSingleCTARunner class with compilation caching
  - torch.library.custom_op registration for trtllm::cute_dsl_topk_decode_blackwell
- Add unit tests in test_indexer_topk.py:
  - test_cute_dsl_topk_decode with parameterized configurations
  - Test coverage for multiple batch sizes, dtypes, and sequence lengths

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Move FilteredTopKKernelVarlenDecode import to the main import section
with other Blackwell kernels, and remove redundant imports.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Clean up CuTE DSL top-k kernel code for production readiness:
- Remove debug print statements that would spam logs
- Fix assert statement style (use 'assert not' instead of '== False')
- Update copyright year to 2026

This addresses P0 items from code review.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…top-k

Add robust input validation to prevent runtime errors with clear error messages:
- Validate top_k range [1, 2048]
- Validate next_n > 0
- Validate num_copy_bits in {128, 256}
- Validate tensor dimensions and shapes
- Validate supported dtypes {fp16, bf16, fp32}
- Add helpful error messages for unsupported configurations

This addresses P0-3 from code review.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…runner

Add detailed documentation for CuteDSLTopKDecodeSingleCTARunner class and
forward method including:
- Class overview and purpose
- Compilation caching mechanism explanation
- Detailed parameter descriptions with constraints
- Return value specifications
- Error conditions and exception types
- Usage notes and performance considerations
- Architecture requirements and limitations

This addresses P0-4 from code review.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Add explanatory comments for hardcoded values:
- large_occupancy threshold (148 = number of SMs in Blackwell GPU)
- SMEM size thresholds (tuned values for occupancy optimization)
- filtered_topk_max_k limit (2048, can be extended)

Add TODOs for future improvements:
- Query SM count via API instead of hardcoding
- Further tuning of SMEM thresholds
- Support for top_k > 2048

This addresses P0-5 from code review and completes all P0 items.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
- Fix NameError: torch_dtype not defined in CuTE DSL topk custom op
- Add performance benchmark function comparing CuTE DSL vs standard indexer_topk_decode
- Benchmark includes CUDA event timing, warmup, throughput calculation, and correctness verification
- Support CLI interface with parameter sweep option

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…OOB bug

Add CuTE DSL multi-CTA top-k decode kernel that splits each row into
chunks processed by separate CTAs, then merges partial results. This
enables top-k on large vocabularies (32K+) that exceed single-CTA capacity.

Fix a bug in the merge kernel where large_occupancy forced
num_threads_per_cta=512, making tile width (8192) exceed the actual data
width (num_ctas_per_row * top_k). The OOB padding elements from _fill_oob
were counted in the radix histogram and could be selected as top-k
candidates with invalid indices. Cap num_threads_per_cta for merge_blocks
so tile width never exceeds max_num_cols.

Also fix trivial-case index lookup for merge_blocks in varlen_util, add
debug parameter, and update benchmark units from ms to us.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
… multi-CTA switching

Fix num_threads_per_cta overflow (2048 > CUDA 1024 limit) when
large_occupancy=True + merge_blocks=True by capping to 512.

Add cute_dsl_indexer_topk_decode unified op that auto-selects
single-CTA or multi-CTA based on dual-condition: SM utilization
< 15% AND multi-CTA waves <= 2, preventing performance regressions
at high batch sizes while maintaining speedup for low-occupancy cases.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…pk_decode test

Extract _run_cute_dsl_topk_test helper to deduplicate shared logic across
CuTE DSL top-k tests, and add test_cute_dsl_indexer_topk_decode covering
both single-CTA and multi-CTA paths via the unified op.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…pk_decode test

Extract _run_cute_dsl_topk_test helper to deduplicate shared logic across
CuTE DSL top-k tests, and add test_cute_dsl_indexer_topk_decode covering
both single-CTA and multi-CTA paths via the unified op.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
- Replace hardcoded SM count (148) with torch.cuda.get_device_properties().multi_processor_count
- Fix typo: enable_persistent_dymamic_scheduling -> enable_persistent_dynamic_scheduling
- Remove unused .shape[1] expressions in compare_top_k_results
- Remove debug print statements in run_filtered_topk_decode
- Add OOM warnings for large intermediate buffer allocations (>1GB)
- Clean up __init__.py exports to only expose library API
- Move argparse import into __main__ block
- Fix duplicated "Row {row_idx}" in error message
- Fix copy-paste argparse description from GEMM example

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…pilation caching

In the DSA path for DeepSeek v3, num_cols (KV cache length) changes per
request and grows each decode step, causing a recompilation for every new
cache length. Bucketing num_cols to the next power of 2 dramatically
reduces recompilations while being safe because num_cols only controls
compile-time config (num_threads, smem sizes, index_type) and actual
data access is bounded by seq_lens.

For multi-CTA kernels, num_ctas_per_row is also included in the cache
key since different actual num_cols values can produce different CTA
counts while bucketing to the same power of 2.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
The three top-k call sites queried
torch.cuda.get_device_properties().multi_processor_count on every
forward call. Extract a module-level _get_num_sms() that caches the
value on first call, avoiding repeated device property queries in the
hot decode path.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…r CuTE DSL top-k

Refactor compilation logic out of forward() into classmethod _compile()
for both SingleCTA and MultiCTA runners. Add warmup_cute_dsl_topk_kernels()
that pre-compiles all (bucketed_num_cols, large_occupancy) combinations,
eliminating first-request compilation latency.

The warmup function is ready to be called from DSA initialization once
the CuTE DSL top-k kernel is integrated into the DSA path.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…FFI env stream

- Enable TVM FFI env stream (use_tvm_ffi_env_stream + --enable-tvm-ffi) to
  eliminate per-call CUDA stream acquisition overhead
- Hoist dtype mapping dict to module level to avoid per-call recreation
- Use semantic sym_int names (n_rows/n_cols/n_batch) for fake tensors
- Have indexer_topk_decode call Runner.forward() directly, bypassing the
  inner custom op dispatch layer
- Move warmup constants inside warmup function

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…ove _bucket_num_cols wrapper

- Convert CuteDSLTopKDecodeSingleCTARunner and CuteDSLTopKDecodeMultiCTARunner
  to use @classmethod for forward(), removing empty __init__ and unnecessary
  object instantiation per call.
- Remove _bucket_num_cols wrapper, use next_positive_power_of_2 directly.
- Remove undefined skip_if_no_cute_dsl/skip_if_not_blackwell calls from test
  helper; use pytest.mark.skipif decorators on test functions instead.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…p-k varlen

Add dynamic multi-CTA scheduling mode that assigns CTAs proportionally
to each row's actual sequence length, avoiding wasted work on short rows.

Key changes:
- ComputeDynamicCTAOffsets: CuTE DSL kernel computing row_cta_offsets and
  row_output_offsets via parallel CTA counting + sequential prefix sum
- FilteredTopKKernelVarlenDecode: 1D grid with binary search task mapping
  when enable_dynamic_multi_cta=True; varlen merge input support
- CuteDSLTopKDecodeMultiCTARunner: dynamic=True/False mode with upper-bound
  grid allocation to avoid DtoH sync
- Unit test updated for dynamic mode coverage

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…c multi-CTA top-k

Eliminate the separate ComputeDynamicCTAOffsets kernel by computing the
prefix sum of per-row CTA counts directly in shared memory within the
main filtered_topk_kernel. This reduces the dynamic multi-CTA path from
3 CUDA kernel launches to 2.

Key changes:
- filtered_topk_kernel: add in-kernel block_prefix_sum to build
  s_row_cta_offsets in shared memory, replacing global memory lookups
- Remove row_cta_offsets and row_output_offsets from kernel signatures
- Remove ComputeDynamicCTAOffsets kernel compilation and invocation
- Add dynamic parameter to cute_dsl_indexer_topk_decode for easy
  static/dynamic comparison

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…id for dynamic multi-CTA top-k

Switch dynamic multi-CTA from 1D grid + shared-memory prefix sum + binary
search to a simple 2D grid (num_rows, num_ctas_per_row) with per-CTA early
exit. This eliminates ~80 lines of prefix-sum/binary-search logic, removes
the 512-row limit, reduces shared memory usage, and simplifies the host side
by reusing the merge buffer. Also updates auto-select heuristics (bf16/fp16
threshold to 131072, SM utilization check to <25%) and makes dynamic=True
the default.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…with first-task direct assignment

Improve the persistent dynamic scheduling by assigning the first task
directly via block index (avoiding an unnecessary atomic), and expose
load_balance parameter through the custom op API with test coverage.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…k kernel

Add FlashInfer-style fused multi-CTA distributed radix top-k kernel that
cooperatively finds the global pivot via multi-round radix select with
global histogram merging. Single kernel launch, no intermediate buffer,
no merge kernel.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…egion reordering in load_chunk_to_smem

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…nd SM-aware chunk heuristic for CuTE DSL top-k

- Add _get_or_alloc_buffer to SingleCTA/MultiCTA/Distributed runners with
  grow-only strategy for CUDA Graph capture/replay compatibility
- Change cute_dsl_indexer_topk_decode to in-place API (mutates output_indices)
  to eliminate per-call allocation overhead
- Add SM-aware chunk_size heuristic in Distributed runner that balances
  CTA parallelism vs reduce overhead based on batch size and SM count
- Rewrite warmup to enumerate all SingleCTA/MultiCTA/Distributed compile
  configs without requiring runtime max_seq_len
- Add use_cute_dsl_topk config field with fallback to CUDA C++ kernel

Signed-off-by: Li Min <limin@nvidia.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
@limin2021 limin2021 requested review from a team as code owners March 16, 2026 05:29
…ig in DSA indexer tests

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
@limin2021
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39415 [ run ] triggered by Bot. Commit: 85e07c8 Link to invocation

- Remove duplicate logger.info in Indexer (already logged in attention.py)
- Pass self.index_topk explicitly to indexer_topk_decode
- Use cls._compile(*key) to avoid duplicating compile args and key tuple
- Simplify output_indices allocation with or-expression
- Condense buffer size comments to max values only
- Add comment explaining smem overhead calculation

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Fix output_indices truthiness check, guard use_cute_dsl_topk on
CUTLASS DSL availability, pass correct dtype for warmup, query smem
capacity dynamically, and only warmup the active multi-CTA variant.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
@limin2021 limin2021 force-pushed the cute_dsl_single_pass_multi_cta_top_k branch from e677450 to 225223b Compare March 18, 2026 11:42
@limin2021
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39452 [ run ] triggered by Bot. Commit: 225223b Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39452 [ run ] completed with state SUCCESS. Commit: 225223b
/LLM/main/L0_MergeRequest_PR pipeline #30678 completed with status: 'SUCCESS'

CI Report

Link to invocation

@yuxianq yuxianq merged commit e940e58 into NVIDIA:main Mar 19, 2026
5 checks passed
limin2021 added a commit to limin2021/TensorRT-LLM that referenced this pull request Mar 19, 2026
)

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants