Skip to content

[TRTLLM-35882][feat] Add cute dsl gvr top-k decode kernel#14602

Open
limin2021 wants to merge 26 commits into
NVIDIA:mainfrom
limin2021:cute-dsl-gvr-topk
Open

[TRTLLM-35882][feat] Add cute dsl gvr top-k decode kernel#14602
limin2021 wants to merge 26 commits into
NVIDIA:mainfrom
limin2021:cute-dsl-gvr-topk

Conversation

@limin2021
Copy link
Copy Markdown
Collaborator

@limin2021 limin2021 commented May 27, 2026

Review Change Stack

Description

Add cute dsl gvr top-k decode kernel.
(1) port cuda gvr kernel to cute dsl. Thx for Long's help, who gives the 1st version porting code.
(2) add some extra optimizations, e.g., unroll, 256bits vectorization, num_threads tune, to further improve the perf.

Test Coverage

# ut running in CI, used for production.
python -m pytest tests/unittest/_torch/attention/sparse/test_cute_dsl_gvr_topk_decode.py
# standard alone ut, which is not dependent on trtllm env. It don't run in CI.
python tests/scripts/cute_dsl_kernels/top_k/run_gvr_topk.py
python -m pytest tests/scripts/cute_dsl_kernels/top_k/run_gvr_topk.py

Performance

image

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • If PR introduces API changes, an appropriate PR label is added - either api-compatible or api-breaking. For api-breaking, include BREAKING in the PR title.

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

limin2021 added 11 commits May 20, 2026 08:29
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…invariant

- Replace from_dlpack/static-shape compile with make_fake_compact_tensor +
  sym_int for batch/num_tokens dims, keeping (dtype, top_k, next_n) as the
  cache key. Reduces unique compile entries from 810 to 27 across the bench
  sweep; correctness verified (no OOB writes from cache reuse with wrong
  shape) via 288-config pytest + cross-impl A/B match.

- Fix test_gvr_topk_decode: (1) pre_idx_count now uses top_k (matches CUDA
  dispatch precondition preIdxCount == topK at heuristic_topk.cuh:810);
  (2) tie-aware reference now masks logits to per-row effective_len
  = seq_len - next_n + 1, avoiding false negatives when next_n > 1 makes
  the kernel skip the last next_n-1 columns.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…imit / mask guard)

Four small mechanical alignments — each isolated, removes only redundant
work the CUDA reference does not do. Correctness verified: 288/288 in
test_gvr_topk_decode.py. Perf delta within measurement noise (~0.15us
estimated, 21us baseline DSL — under the ~0.5us spread floor) but the
changes match heuristic_topk.cuh semantics 1:1 and pave the way for
later batches.

- block_count_ge: drop the trailing barrier (gvr_topk_decode.py:422 ->
  removed). CUDA blockCountGE (heuristic_topk.cuh:441) returns without
  a sync because callers already insert their own __syncthreads after
  their tid==0 post-processing. The previous DSL trailing barrier was
  redundant (tid==0 reads its own write in-thread, no sync needed).

- Phase 4 snap_limit: change from cand_count>128 ? cand_count/4 : 32
  to cand_count (matches heuristic_topk.cuh:985). The older bound
  silently accepted a non-converged threshold in ~0.09 % of adversarial
  distributions; correctness improvement only, common case still
  converges in 1-3 iters.

- Phase 4 block_min/max: every thread now recomputes block_min/max from
  the warp-staged smem slots into local registers (matches
  heuristic_topk.cuh:891-898). Replaces the prior `tid==0 writes
  s_thr[1]/s_thr[2] then broadcast via __syncthreads` pattern, saving
  one block barrier in Phase 4.

- Phase 4 Pass 1/Pass 2 writeback: wrap popc + atomicAdd + shuffle in
  `if mask != 0` warp-uniform guard (mirrors heuristic_topk.cuh:1020,
  1045). Skips the atomic round-trip when no lane in the warp emits,
  most impactful for Pass 2 where only K-th-rank ties emit.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…le + early break)

Convert the Phase 2 secant refinement loop and the Phase 3 retry-shrink
loop from Python-unrolled `for in range(N)` (every body wrapped in an
`if not done:` guard) to runtime `while` with the convergence condition
in the loop predicate. This matches CUDA's pattern at heuristic_topk.cuh:
683 (Phase 2) and :769 (Phase 3 retry).

Previously, after the kernel converged at iteration k, the remaining
N-k unrolled bodies still each issued an LDS+ICMP+branch guard. With
secant typically converging at iter 3 of 15 and retry-shrink usually 0
of 10, this saved ~12 + ~10 = ~22 wasted guard sites per kernel call.

Tradeoff: lose Python-time const-fold of `if it == 0: f = min(f, 0.5)`,
which now becomes a runtime compare. CUDA does the same runtime compare
(heuristic_topk.cuh:698-699), so this is alignment not regression.

Measured impact (median config bf16 K=1024 N=32768 BS=1 next_n=2,
same-process A/B vs CUDA GVR, 5 reps alternating order):
  DSL_us  21.01 -> 20.28  (-0.73 us, -3.5%)
  C/G     0.869 -> 0.903  (+3.4 percentage points)
Above the ~1.5% bench_kineto spread floor. 288/288 tests pass.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Two SASS-alignment changes verified against the CUDA reference at the
median config (bf16 K=1024 N=32768 BS=1 next_n=2):

1. cute.make_ptr(..., cute.AddressSpace.gmem, ...) at the two 128-bit
   vec-load sites in block_count_ge and phase3_collect_candidates.
   Default AddressSpace.generic lowered to SASS LD.E.128; explicit gmem
   hint flips to LDG.E.128 (matches CUDA __ldg path, minus .CONSTANT
   which still requires CopyG2ROp+invariant).

2. phase1_preidx_stats: replace the runtime `while i < pre_idx_count`
   strided loop with `range_constexpr(pre_idx_count // num_threads)`.
   pre_idx.shape[1] is a compile-time constant (top_k baked into JIT
   cache key); supported top_k in {512, 1024, 2048} are all multiples
   of num_threads (512), so n_iters ∈ {1, 2, 4} unrolls cleanly. cute
   emits straight-line code (no BRA / ISETP / counter update) and
   issues both preIdx LDG.E and input LDG.E.U16 back-to-back, enabling
   LSU pipelining (in flight ILP). Mirrors what nvcc/ptxas does for
   the equivalent CUDA loop via auto-partial-unroll.

Bench (same-process A/B, 5 repeats × 100 iters, kineto + L2 flush):
   Before: C/G = 0.903 (DSL 10.7% slow)  -- post-Batch 2 baseline
   After:  C/G = 0.922 (DSL  8.5% slow)
   Δ = +1.9pp

Resource use after changes:
   regs/thread:  34 -> 39  (still 3 blocks/SM, occupancy unchanged 75%)
   dynamic smem: unchanged (~44 KB)
   total SASS instructions: 2935 -> 2944 (codegen ripple, mostly
   FMNMX3 +6; loop overhead ISETP/BRA -6/-2/-3 offset by +25 IMAD)

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…tail)

Replaces the runtime `while i + (vec_w - 1) < N` vec loop in
block_count_ge with a 4-way unrolled fast path + 1-way tail. The fast
path issues 4 independent LDG.E.128 per round (separate fragments so
cute schedules them concurrently), mirroring what nvcc/ptxas does for
the equivalent CUDA loop via auto-partial-unroll.

SASS verification at median config (bf16 K=1024 N=32768 BS=1 nn=2):
- 4 LDG.E.128 per inline at addresses base / base+0x2000 / +0x4000 /
  +0x6000 — exact match to CUDA's LDG.E.128.CONSTANT pattern (minus
  the CONSTANT cache hint, which still requires CopyG2ROp+invariant).
- Total LDG.E.128 count: 5 -> 21 (4 inlines * 4 + 4 tails + 1 phase3).
- Cute software-pipelines: 3 LDGs issued back-to-back, then consume of
  iter 0 starts while iter 3's LDG is issued in parallel. All 4 are
  in flight before HBM responds (latency ~600 cy >> 23 inst slots).

Resource impact:
- Regs/thread: 39 -> 39  (cute reuses fragment regs across loop body;
  Phase 4 likely remains the kernel-wide peak)
- Dynamic smem: unchanged (~44 KB)
- Static SASS size: 2944 -> 3672 inst (+25%)  -- code bloat acceptable,
  well within icache; Block Limit Reg = 3 unchanged at occupancy=75%.

Bench results (kineto, L2 flush, n_iters=30):

Median config (bf16 K=1024 N=32768 BS=1 nn=2), same-process A/B:
   Before this commit:  C/G = 0.922  (DSL  8.5% slow)
   After this commit:   C/G = 0.976  (DSL  2.4% slow)
   Delta:               +5.4pp

Full sweep (804 configs = 3 dtype * 3 top_k * 6 N * 5 BS * 3 next_n):
   Median  C/G: 0.860 (baseline post-Batch-2) -> 0.988 (now)
   Geomean C/G:           0.869 -> 0.999  (parity with CUDA)
   DSL faster:             17% ->  46%
   Within 5%:              13% ->  34%
   Within 10%:             28% ->  55%

By dtype: bf16 1.000, fp16 1.022, fp32 0.951 (fp32 has slightly less
runway since vec_w=4 vs 8 for bf16/fp16).
By N: gap remains at large N (>=64K: median ~0.87-0.90), where the
LSU-pipelining win is already saturated and other phases dominate.

The single-config worst slowdowns observed (C/G ~0.4) are concentrated
in nn=3 + small-mid N (4-32K) + BS>=64 configs whose CUDA-side numbers
also moved 5-15x between runs -- short-runtime measurement noise, not
real regressions.

This commit completes the SASS-alignment campaign objective (gap < 5%
on median config). Remaining ~10-13% at very large N is deferred.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Three new switches gate the block_count_ge vec-load fast path:

  enable_unroll_4 (default True): 4-way unrolled fast path
  enable_unroll_2 (default by dtype): 2-way cascade between fast and tail
  use_strided_layout (default by dtype): True → single make_ptr +
      (UNROLL, vec_w) strided layout (cute emits 4 LDG.E.128 sharing
      base reg with +0x2000/+0x4000/+0x6000 imm offsets, matching the
      CUDA SASS pattern). False → 4 separate make_ptr calls (matches
      the prior b459a8f commit style with 4 independent base regs).

Dtype-aware defaults (validated via per-config A/B testing on B200):

  bf16 / fp16: enable_unroll_2=True, use_strided_layout=True
      Strided cascade gives clean wins: cascade flips DSL from CUDA
      parity to consistently faster on small-N where the 4-way fast
      path doesn't fully cover N, and the medium 2-way path keeps two
      LDG.E.128 in flight. Strided layout keeps the SASS shared-base
      pattern that nvcc/ptxas auto-partial-unroll also produces.

  fp32: enable_unroll_2=False, use_strided_layout=False
      For fp32 (vec_w=4) the strided layout pushes regs 38 → 40 and
      regresses fp32 large-grid configs by 30-60pp (worst observed:
      K=1024 BS=128 nn=2 → 0.753 vs 1.364 with separate-ptrs). The
      cascade similarly hurts in 12% of fp32 configs. Separate-ptrs
      4-way unroll alone is the sweet spot.

Cache key includes the three switches so different settings produce
separate compiled kernels.

Full sweep results (804 configs, n_iters=30 kineto, L2 flush):

                       baseline   cascade-all   dtype-policy
  Median C/G:           0.988      1.011        1.006
  Geomean C/G:          0.999      1.047        1.038
  DSL faster %:           46%        54%          52%
  Within 10%:             55%        65%          65%

By dtype:
  bf16:  1.000 -> 1.043 (cascade wins preserved)
  fp16:  1.022 -> 1.038
  fp32:  0.951 -> 0.960 (anom K=1024 BS=128 fixed: 0.610 -> 1.038)

By N (the original "large-N gap"):
  N=8192:    1.097 -> 1.172 (+7pp, cascade hides medium-path remainder)
  N=65536:   0.897 -> 0.928 (+3pp)
  N=131072:  0.866 -> 0.923 (+6pp)

Remaining slow configs (fp32 K=2048 + BS>=64) were already <0.7 in
the baseline -- this commit doesn't introduce new regressions there.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Adds two new switches to the DSL GVR kernel:

  enable_phase3_unroll (default True): master gate for phase3_collect
      unrolling. When ON, the inner enable_unroll_4 / enable_unroll_2
      switches independently control 4-way fast and 2-way medium paths
      in phase3 (same semantics as block_count_ge). When OFF, only the
      tail 1-way loop runs.

  use_constant_hint (default False): True → CopyG2ROp(invariant=True)
      → SASS LDG.E.*.CONSTANT (read-only data cache, matches CUDA
      __ldg). Default False because cute's invariant lowering triggers
      aggressive rematerialization in LLVM/NVPTX (+272 inst, 4 spills,
      net -7pp geomean), outweighing the cache hint benefit.

Phase3_collect is now a 3-tier cascade (4-way fast + 2-way medium +
1-way tail) mirroring block_count_ge. The cascade gives:
  N>=65K:  +5-7%  (large-N main path, LSU pipelining wins)
  N<=32K:  -1-3%  (unroll setup overhead exceeds benefit at small N)
  Median geomean: +2.2pp from phase3 unroll alone

Resource analysis (bf16/fp16/fp32 x phase3 ON/OFF):
  REG/thread:
    bf16: 39 -> 39 (no change, cute reuses fragments)
    fp16: 39 -> 39 (no change)
    fp32: 38 -> 40 (+2, separate-ptrs path)
  Static SASS:
    bf16: 3936 -> 4368 (+11%)
    fp16: 4096 -> 4512 (+10%)
    fp32: 3368 -> 3480 (+3%)
  Theoretical Occupancy: 75% all configs (smem-limited to 3 blocks/SM,
    binding limit unaffected by phase3 unroll). Phase3 unroll has
    *zero* occupancy cost.

Wrapper signature gains both switches; cache key includes them so
different settings produce separate compiled kernels. A small helper
method _make_load_copy_atom() factors out the CopyG2ROp/Universal
selection to avoid Python if-else NameError inside @cute.jit scope.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…uristic

Phase 2 (block_count_ge) and Phase 3 (phase3_collect_candidates) replace
the manual `while + range_constexpr(UNROLL)` fast/medium-cascade unrolling
with a single `for k in cutlass.range(big_iters, unroll=4)` loop. LLVM's
loop unroll pass + GVN/CSE folds the 4 derived vec loads into the
CUDA-style shared base + immediate offsets pattern, emitting 4 back-to-back
LDG.E.128 [base+0x2000/0x4000/0x6000] instructions.

Add `min_blocks_per_mp` field on `GvrTopKKernel` and a 3-tier shape-aware
heuristic in the host wrapper:
  * n_vec_iters < 4         -> 0 (no launch_bounds, natural ptxas allocation)
  * num_rows <= 148 (B200 SMs) -> 1 (allow many regs, 4xLDG fold survives)
  * else                       -> 3 (keep 3 CTA/SM occupancy, ~42 reg cap)

The heuristic lifts fp32 K=512 large-N out of its regression zone (worst
case C/G 0.62 -> 1.07 at K=512 N=131072 BS=16 nn=2). Cache key extended so
each min_blocks value gets its own compiled kernel.

Random sweep vs phase3_unroll baseline (804 configs):
  geomean 1.060 -> 1.149, faster%-than-CUDA 62% -> 92%, losses 304 -> 63.

CUDA Graph: heuristic reads `logits.shape` (host int, no GPU sync) so
capture is safe; per-graph capture selects the right kernel per shape.
For dynamic-shape single-graph use, caller can pin `min_blocks_per_mp=3`
to disable the heuristic.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Three new kernel knobs on GvrTopKKernel + gvr_topk_decode host wrapper:

  * use_256bit_load (default False): emit LDG.E.256 (8 fp32 / 16 bf16-fp16
    elements per LDG) instead of LDG.E.128. Address alignment hint is
    raised from 16 to 32 bytes. Phase 2/3 unroll factor is dtype-aware:
    fp32 keeps unroll=4 (no cvt-to-fp32 overhead); bf16/fp16 drops to
    unroll=2 to limit the cvt register pressure that otherwise spills
    under min_blocks=3.

  * num_threads_per_block (default 512): configurable per-instance.
    BLOCK_SIZE / WARP_SIZE / NUM_WARPS are moved from module-level to
    GvrTopKKernel instance attrs (self.WARP_SIZE, self.num_threads,
    self.num_warps). Phase 1 preIdx loop gains an else branch for the
    K < num_threads case (e.g. num_threads=1024 with K=512): only the
    first K threads load a preIdx, others keep reduction-identity
    values which the warp/block reduces naturally absorb.

  * vec_bits / vec_align_bytes derived from use_256bit_load; cache key
    extended with use_256bit_load + num_threads_per_block.

Heuristic uses the resolved num_threads_per_block (not a hardcoded 512)
when computing n_vec_iters.

Tests parametrize use_256bit_load and num_threads_per_block; pytest
runs 288/288 PASS at use_256bit_load=True and at num_threads=1024.

Synth bench on BS<=128:
  - 128-bit + heuristic baseline: gm=1.131, faster%=99%, lose=9
  - 256-bit + heuristic        : gm=1.121 (fp32 wins +3pp; bf16/fp16
    flat-to-negative due to cvt-to-fp32 reg pressure spills under mb=3)

Random sweep on BS up to 128: 256-bit shows niche win on
(fp32, num_rows<=148, large N); should be opt-in.

Synth data generator (multi-BS) and bench script env vars
(DSL_USE_256BIT/DSL_MIN_BLOCKS/DSL_NUM_THREADS) live in the gvr-topk-opt
workspace and are not part of this commit.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…nobs

Add a fourth perf knob `enable_warp_parallel_reduce` to GvrTopKKernel +
gvr_topk_decode and replace the four `tid==0` serial loops over
num_warps slots with warp-parallel reduce/scan in warp 0:

  * Phase 1 block aggregate (4-way reduce):
        min/max/sum_f32/sum_i32 -> 4x warp_reduce in warp 0.
  * Phase 2 / blockCountGE total (1-way reduce):
        sum_i32 -> warp_reduce_sum_i32.
  * Phase 3 collect block prefix sum (exclusive scan):
        Hillis-Steele inclusive scan via block_scan.warp_scan, then
        exclusive = inclusive - val; total = inclusive at last lane.
  * Phase 2 secant aggregate (3-way reduce):
        packed sum_i32 + min_f32 + max_f32, with bound update on lane 0.

Default is False since at num_threads=512 (num_warps=16) the per-warp
ILP loss exceeds the serial-loop savings (~2pp regression on synth).
At num_threads=1024 (num_warps=32) the switch is essential -- without
it 1024 regresses vs baseline (gm 1.131 -> 1.123); with it 1024 wins
(gm -> 1.154 on synth BS<=128). Pair as
`enable_warp_parallel_reduce = (num_threads_per_block >= 1024)`.

Phase 1 also gains an `active_preidx_warps` optimization: when
`pre_idx_count < num_threads` (e.g. K=512 with num_threads=1024) only
the first ceil(K/32) warps have real data, so the warp_reduce + smem
write step is now gated to those warps. Saves ~30 cy/dummy-warp; the
full barrier afterwards still keeps all 1024 threads aligned for
Phase 2. The constexpr is clamped to num_warps so the K>num_threads
case (K=2048 with num_threads=512) doesn't index past the smem
buffers, and the same value drives both the warp_reduce gate and the
Site-1 block aggregate's smem read range.

Remove two now-dead switches:
  * `enable_unroll_2` -- only referenced in the commented-out manual
    2-way medium path that the `cutlass.range(unroll=4)` rewrite
    replaced.
  * `use_strided_layout` -- only referenced in the commented-out manual
    4-way strided-layout path, also replaced.

Cache key drops the two dead entries and gains
`enable_warp_parallel_reduce`. The cleanup is a no-op functionally
(the dead values were ignored by the active code paths) but removes
two cache-bucket dimensions.

Test parametrize expanded to 4-way matrix:
  next_n in {1, 2} (was {1, 2, 3, 4} -- trimmed to keep walltime)
  use_256bit_load in {False, True}
  num_threads_per_block in {512, 1024}
  enable_warp_parallel_reduce in {False, True}
1152 / 1152 PASS in 20:22.

Synth bench (BS<=128, threads=512 baseline -> threads=1024+wpON):
  geomean 1.131 -> 1.154  (+2.3pp)
  fp32 geomean 1.127 -> 1.177  (+5.0pp; up to +28pp at fp32 K=2048
                                N=131072 -- 1.50x vs CUDA)

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
@limin2021 limin2021 requested a review from a team as a code owner May 27, 2026 00:39
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 27, 2026

📝 Walkthrough

Walkthrough

This PR adds test coverage and public API exposure for the Blackwell CuTE gvr_topk_decode kernel. It exports kernel components through the public top_k module and provides a comprehensive parameterized CUDA correctness test with tie-aware validation logic.

Changes

GVR Top-K Kernel Export and Testing

Layer / File(s) Summary
Public API export
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/top_k/__init__.py
GvrParams, GvrTopKKernel, and gvr_topk_decode are now imported and added to __all__, making these symbols available to external consumers of the top_k package.
Correctness test with tie-aware validation
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/top_k/test_gvr_topk_decode.py
Added _make_inputs helper to generate deterministic single-row test tensors, _tie_aware_correct helper to validate kernel output using masked torch.topk reference, and parameterized test_gvr_topk_decode_correctness covering multiple dtypes, vocabulary sizes, K values, seeds, next_n configurations, and kernel launch flags, with CUDA synchronization and detailed failure reporting.

🎯 2 (Simple) | ⏱️ ~10 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ⚠️ Warning PR description provides implementation details and test coverage but lacks a properly formatted title following the template requirements. Add a PR title following the template: [JIRA ticket/NVBugs ID/GitHub issue/None][type] Summary. For example: [None][feat] Add cute dsl gvr top-k decode kernel.
✅ Passed checks (3 passed)
Check name Status Explanation
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Title check ✅ Passed The title clearly and specifically describes the main change: adding a CuTE DSL GVR Top-K decode kernel implementation with associated test suite.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In
`@tensorrt_llm/_torch/cute_dsl_kernels/blackwell/top_k/test_gvr_topk_decode.py`:
- Around line 29-45: The test picks argmax_idx over the full N, which breaks the
pre_idx[...,0] invariant when the kernel only scans the first (N - next_n + 1)
columns; update _make_inputs to compute an effective_scan = N - next_n + 1 (or
accept next_n as a parameter) and compute argmax_idx = int(logits[0,
:effective_scan].argmax().item()) so pre_idx_list[0] is in-range, then build
pre_idx as before; apply the same change to the other occurrence (the second
_make_inputs usage around the later test).
- Line 21: Remove the unused "from typing import Tuple" import and replace the
two occurrences of typing.Tuple[...] return annotations in this test module with
the native Python 3.10+ generic syntax tuple[...] (e.g., change "Tuple[int,
str]" to "tuple[int, str]") in the two functions in
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/top_k/test_gvr_topk_decode.py so
the file uses built-in tuple typing and no longer imports typing.Tuple.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 031d6484-acc6-4f3e-b7b3-e6defb77c52c

📥 Commits

Reviewing files that changed from the base of the PR and between c7e7fc5 and 882c767.

📒 Files selected for processing (3)
  • tensorrt_llm/_torch/cute_dsl_kernels/blackwell/top_k/__init__.py
  • tensorrt_llm/_torch/cute_dsl_kernels/blackwell/top_k/gvr_topk_decode.py
  • tensorrt_llm/_torch/cute_dsl_kernels/blackwell/top_k/test_gvr_topk_decode.py

Comment thread tensorrt_llm/_torch/cute_dsl_kernels/blackwell/top_k/test_gvr_topk_decode.py Outdated
Comment thread tensorrt_llm/_torch/cute_dsl_kernels/blackwell/top_k/test_gvr_topk_decode.py Outdated
limin2021 added 9 commits May 27, 2026 01:24
Mirrors the kFTarget=kK alignment for K=512/1024 (all dtypes) from CUDA
PR NVIDIA#14413 on the DSL GVR Top-K kernel so the DSL Phase-2 secant
behavior matches the new CUDA reference. Old pre-NVIDIA#14413 values kept as
inline comments for easy rollback.

Verified: 768/768 pytest configs pass for K=512/1024 across all dtypes,
N, next_n, use_256bit_load, num_threads_per_block, and
enable_warp_parallel_reduce.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Move module-level MAX_REFINE_ITERS / FLT_MAX / NEG_FLT_MAX into
instance attributes so all kernel-wide knobs live in one place.
Inline NUM_BINS_DEFAULT (2048) directly into the GvrParams table
since it was only used in three K=2048 entries. Drop dead
MAX_CANDIDATES.

Pure refactor — values, control flow, and DSL IR are unchanged.
Also removes the previously-commented-out A/B layout/unroll dead
code in block_count_ge.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Mirrors heuristicTopKDecode.cu PR NVIDIA#14219 cr-aware branch in the DSL
GVR Top-K kernel. compress_ratio=1 (default) preserves DSv3.2 behavior
exactly; compress_ratio=4 enables the DSv4 (overlap-compressor) indexer
path:
  * pre_idx_offset = 0 (vs (row % next_n) + 1 for cr=1) — in compressed-
    index space, new entries append at the end so prev-step indices
    remain valid as-is.
  * N = actual_kv_len / cr — logits/preIdx live in compressed-token-
    index space when cr > 1.

GvrParams TABLE is also keyed by (dtype, K, cr) so V3.2 and V4 use
their respectively tuned kFTarget values:
  cr=1 (V3.2): kFTarget = 384 (K=512) / 2560 (K=1024), pre-NVIDIA#14413.
  cr=4 (V4):   kFTarget = kK   = 512 (K=512) / 1024 (K=1024), PR NVIDIA#14413.
  K=2048: identical across cr (V4 doesn't natively use K=2048).

Cache key includes compress_ratio so different cr settings compile
separate kernels. assert restricts compress_ratio in {1, 4}.

Verified: 1152/1152 pytest configs pass on cr=1 default path.
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…3 mb

Two paired host-wrapper heuristic refinements:

1. enable_warp_parallel_reduce: bool → Optional[bool] = None, default
   auto-coupled to num_threads_per_block: enabled iff threads == 1024 (32
   warps, where serial tid==0 cost dominates). At threads == 512 (16
   warps) the warp-parallel path measured a ~2pp synth regression so it
   stays off. Cache key sees the concrete bool. Explicit True/False still
   overrides for A/B testing.

2. tier-3 (large grid + large N) min_blocks_per_mp hardcoded "= 3"
   replaced by a (T, dtype) lookup:
     T == 1024 or dtype == fp32 → mb=2
     T == 512  and dtype in (bf16, fp16) → mb=3
   Derived from BS{256,384,512} × N{16K,32K,65K} × all 9 (dtype, K) sweep
   (gvr-topk-opt/sweep_tv_mb_kineto/mb_sweep.png). Old mb=3 default
   regressed by 25-37% on (T=512 + fp32 + large N/BS) configs because
   cap=42 starves the 4-LDG-inflight ILP (fp32 vec_w=4 × unroll-4 needs
   50+ regs). bf16/fp16 keep mb=3 since cvt-to-fp32 ILP fits in 40 regs
   and the extra CTA/SM (3 vs 2) hides cvt latency.

Pure default-policy change — no behavioral effect when caller passes
explicit values. Verified: pytest smoke 4/4 on cr=1 default path.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Two more host-wrapper Optional[*]=None defaults so callers no longer
need to pick threads/vec-bits per shape:

  num_threads_per_block (None default):
    1024 iff num_rows <= num_sms (1 CTA/SM bound) AND N >= 65536
    (so each of the 1024 threads has meaningful vec-loop work).
    Otherwise 512.

  use_256bit_load (None default):
    True iff dtype == fp32 AND N >= 16384.
    Half-prec (bf16/fp16) cvt-to-fp32 doubles fragment reg footprint
    and regresses 5-11% at K=512/1024; LDG already saturates at 128b
    anyway. fp32 N=8K dips 5-8% with 256b at small grid so the N
    threshold excludes that single tier.

Cache key sees concrete values; (None, X) and (None, Y) hash apart.
Explicit values still override for A/B testing.

Derivation: sweep BS{1,4,16,64,128,256,384,512} x N{4K..131K} x all 9
(dtype, K), gvr-topk-opt/sweep_tv_kineto/auto_speedup.csv. Net vs
baseline (T=512, V=128):
  - median speedup vs CUDA  1.09x -> 1.10x
  - mean   speedup vs CUDA  1.11x -> 1.15x  (+3.8pp)
  - max    speedup vs CUDA  1.45x -> 1.52x
  - 21 of 22 sp<1 configs were already sp<1 in baseline (BS=384 grid
    quirk, unrelated to this change). 1 new config introduces a 0.8pp
    sp<1 dip (within bench noise).

Pure default-policy change. Verified: 4-case auto-path smoke + pytest
smoke 4/4 on cr=1 default.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Decode runs under CUDA graph, where the (T, V) heuristic baked in at
capture time is reused across all replays. The capture-time
logits.shape[1] is typically much smaller than peak runtime N, so the
captured kernel misses the large-N (T=1024, V=256) path. Add an
optional max_seq_len hint so the caller (e.g. dsa.py) can pass the
peak compressed-N for the model; the heuristic then tunes the captured
kernel for the peak.

Usage guidance baked into the docstring + inline comment:
  * CUDA Graph mode: CALLER MUST PASS max_seq_len.
  * Eager mode: leave max_seq_len=None (heuristic adapts per call).

Rules with max_seq_len:
  * T=1024 threshold becomes dtype-aware to avoid half-prec K=512/1024
    small-N replay regression (14-16% when forced T=1024 at small N):
      fp32 -> 65536 (small-N replay 1-9% loss, net win)
      half -> 131072 (only forced at very large peak)
  * V=256 still gated by fp32 + N >= 16384.

Without max_seq_len, dtype-split is NOT applied because per-call
adaptive decisions never force T=1024 onto small N — heuristic only
fires for N >= 65536 by definition — so the half-prec N=65K-128K
+4-6% T=1024 win is preserved.

Cache key sees concrete (T, V), so different max_seq_len hints compile
distinct kernels. Pure default-policy extension. Verified with 4-case
auto smoke (no hint / fp32+131K / bf16+131K / bf16+200K).

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
The min_blocks_per_mp tier heuristic was still computing n_vec_iters
from logits.shape[1] (capture-time N). In graph mode with max_seq_len
hint, this would stick small-capture-N calls in tier-0 (mb=0) and
miss the tier-3 occupancy choice for large-N replays — same pitfall
the (T, V) heuristic was fixed against in the previous commit.

Switch to N_dec (= max_seq_len if provided, else logits.shape[1]) so
the tier classification is consistent with how T/V are picked.

Smoke 4/4.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Add a wave-fit branch in the fp32 tier-3 path: when num_rows ∈ (296, 444]
(i.e. fits 1 wave at 3 CTAs/SM but needs partial 2nd wave at 2 CTAs/SM
with num_sms=148), pick mb=3 instead of mb=2. This recovers ~15% perf
on fp32 BS=384 across (K, N) — verified against CUDA which already uses
__launch_bounds__(BS, 3) for this exact reason.

Math:
  mb=2 cap → 2 CTAs/SM × 148 SMs = 296 CTAs in 1 wave.
  mb=3 cap → 3 CTAs/SM × 148 SMs = 444 CTAs in 1 wave.
For BS=384 (× next_n=1):
  mb=2: 384 / 296 = 1.30 waves → tail wave wastes ~70% SMs.
  mb=3: 384 / 444 = 0.86 waves → 1 wave fits, max SM utilization.

Verified perf gains (fp32 T=512, both V=128 and V=256 default paths):
  fp32 K=512  N=4K-32K BS=384: +11-23%
  fp32 K=1024 N=4K-32K BS=384: +16-19%
  fp32 K=2048 N=8K-32K BS=384: +5-9%

Other BS unaffected:
  BS ≤ 296 (192, 256): mb=2 already fits 1 wave → rule keeps mb=2 (no change)
  BS > 444 (512): both need >1 wave → rule keeps mb=2 (ILP > occupancy)

Half-prec heuristic unchanged (already uses mb=3 in tier-3 via the
dtype-split path from a prior commit).

Bench artifacts: gvr-topk-opt/auto_full_bench/fp32_bs384_cluster/
(mb sweep CSV + NCU reports + drivers). Smoke: pytest 4/4 + spot tests
across BS={256, 384, 512}.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Functional change: limit wave-fit mb=3 branch to N <= 32768. Beyond
that threshold the kernel becomes bandwidth-bound and mb=3's 3-way L2
sharing causes contention; mb=2's lower occupancy gives each CTA more
bandwidth and wins +21-30% at fp32 K=512 N=65K BS=384.

The full wave-fit rule for fp32 tier-3 is now:
  if 2*num_sms < num_rows <= 3*num_sms and N_dec <= 32768:
      mb = 3
  else:
      mb = 2

Also cleans up file comments: remove obsolete TODO list, trim refs to
specific CUDA line numbers, simplify class docstring.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
@limin2021 limin2021 changed the title Cute dsl gvr topk [TRTLLM-35237][feat] Add cute dsl gvr top-k decode kernel May 28, 2026
@limin2021 limin2021 changed the title [TRTLLM-35237][feat] Add cute dsl gvr top-k decode kernel [TRTLLM-35882][feat] Add cute dsl gvr top-k decode kernel May 28, 2026
Registers torch.ops.trtllm.cute_dsl_gvr_topk_decode as the production
entry point for the cuTe DSL GVR Top-K decode kernel (Blackwell SM100).
Op writes values + indices into caller-allocated buffers (mutates_args
style), matching the existing cute_dsl_indexer_topk_decode pattern so
the DSA indexer pipeline can drop it in.

CuteDSLGvrTopKDecodeRunner takes ownership of the JIT compile cache and
the auto-heuristic for T (threads/block), V (vec-load width),
min_blocks_per_mp and enable_warp_parallel_reduce. The previous
module-level wrapper in gvr_topk_decode.py is removed; standalone
bench / A-B testing with the full tuning knob set lives in
tests/scripts/cute_dsl_kernels/top_k/run_gvr_topk.py.

Tests:
- tests/unittest/.../test_cute_dsl_gvr_topk_decode.py: production
  correctness sweep via the op (dtype x K x N x next_n x batch_size x
  compress_ratio) with vectorized tie-aware + strict sort+allclose check.
- tests/scripts/.../run_gvr_topk.py: dual-mode driver -- pytest sweep
  over T/V/wp knobs and standalone CLI for single-case verification.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
@limin2021 limin2021 requested a review from a team as a code owner May 28, 2026 11:21
@limin2021 limin2021 requested a review from yizhang-nv May 28, 2026 11:21
@limin2021
Copy link
Copy Markdown
Collaborator Author

/bot run

@limin2021 limin2021 requested review from hyukn, longcheng-nv and yuxianq and removed request for HuiGao-NV, leslie-fang25 and yizhang-nv May 28, 2026 11:22
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #50762 [ run ] triggered by Bot. Commit: d474eb9 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #50762 [ run ] completed with state SUCCESS. Commit: d474eb9
/LLM/main/L0_MergeRequest_PR pipeline #40240 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@limin2021
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #50917 [ run ] triggered by Bot. Commit: d474eb9 Link to invocation

limin2021 added 5 commits May 29, 2026 02:55
The DSA indexer pipeline (dsa.py) and CUDA indexer_topk_decode op both
only read top-K indices from the kernel output; the value buffer is
caller-allocated scratch that's never consumed. Add a kernel-level
return_output_values switch so the cuTe DSL kernel can elide all
STG.value stores when the caller doesn't need them.

Kernel (gvr_topk_decode.py):
- GvrTopKKernel gains return_output_values: bool = True.
- All 9 STG.value sites + the output_values_row slice are gated under
  cutlass.const_expr(self.return_output_values), letting cute.compile
  eliminate the dead writes when False.

Op (cute_dsl_custom_ops.py):
- CuteDSLGvrTopKDecodeRunner adds return_output_values to the compile
  cache key + _compile signature; forward() hardcodes False, drops the
  output_values arg, and passes None for the value-output slot at
  launch (mirrors the optional-fake-tensor pattern at
  CuteDSLTopKDecodeMultiCTARunner._compile).
- trtllm::cute_dsl_gvr_topk_decode op signature drops output_values;
  mutates_args is now ("output_indices",), aligning with CUDA's
  indexer_topk_decode which also only exposes indices.

Tests:
- tests/unittest/.../test_cute_dsl_gvr_topk_decode.py drops the
  output_values buffer alloc + op kwarg (all 144 cases pass with the
  sort+allclose strict check).
- tests/scripts/.../run_gvr_topk.py wrapper exposes
  return_output_values as a knob so the standalone driver can still
  capture written values; the _compile cache + cute.compile
  out_values_fake placeholder are conditional on the flag.

SASS verification at bf16 K=1024 N=8K BS=384 confirms 88 STG.E.U16
writes are eliminated (kernel cubin -6KB, total SASS -416 lines).
On B200 SXM5 + synth_data, v5 (return_output_values=False) gives a
median 1.2% latency improvement over v4 with sp<1 configs nearly
halved (52 -> 27).

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
* Modernize type hints: replace ``typing.Tuple[...]`` with the built-in
  ``tuple[...]`` (Python 3.10+) and drop the ``from typing import Tuple``
  import in both ``_make_inputs`` / ``_tie_aware_correct`` helpers.
* Fix the ``pre_idx[..., 0]`` argmax invariant for next_n > 1: argmax
  must come from the kernel's effective scan range
  ``[0, N - next_n + 1)``, not full ``[0, N)``. With the prior full-N
  argmax, an index landing in the ``[N_eff, N)`` tail could violate the
  CUDA-side ``preIdxCount == topK`` dispatch precondition (kernel still
  produced correct top-K because pre_idx is only a Phase-1 hint, but
  the test was technically exercising the kernel under invalid input).

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
The kernel derives batch_size as logits.shape[0] / next_n implicitly,
then sizes pre_idx / seq_lens / output_indices accordingly. If the
divisibility breaks, the failure modes are either an OOB write or a
ZeroDivisionError raised from deep inside the JIT-compiled kernel —
neither is actionable for callers. Add an upfront check in the op
body so the contract violation surfaces with a clear message.

Other invariants (top_k in {512,1024,2048}, compress_ratio in {1,4},
logits dtype) are already enforced by GvrTopKKernel.__init__ via
GvrParams.get and the dtype switch.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
* Drop the single-use N_cols local in CuteDSLGvrTopKDecodeRunner.forward;
  fold it directly into the N_dec ternary for less noise.
* Add the input shape / dtype / knob signature to the info_once dedup
  key. Without the signature the first call's log message hid every
  subsequent shape from production diagnostics; now each new shape
  emits a single log line on its first run.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Both _tie_aware_check (unittest) and _tie_aware_correct (run_gvr_topk
standalone driver) previously assumed the reference scan range was
``N - next_n + ofs + 1`` per row. That matches the kernel for cr=1
and for cr>=2 with next_n in {1, 2}, but breaks for cr>=2 with
next_n>=3 because floor-division by cr makes per-row N_eff vary
within a group in ways the simple closed form can't express.

Switch both reference helpers to mirror the kernel's exact formula:

    actual_kv_len = seq_lens[row // next_n] - next_n + (row % next_n) + 1
    N_eff = actual_kv_len // compress_ratio  # cr=1 is identity

This requires the helpers to take ``seq_lens`` (and compress_ratio for
the standalone driver) so the reference can compute per-row N_eff
exactly as the kernel does. With this, any (next_n, cr) combo is
testable without the floor-division mismatch.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #50917 [ run ] completed with state SUCCESS. Commit: d474eb9
/LLM/main/L0_MergeRequest_PR pipeline #40379 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@limin2021
Copy link
Copy Markdown
Collaborator Author

/bot run

@limin2021 limin2021 requested a review from Kefeng-Duan May 29, 2026 05:34
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #50978 [ run ] triggered by Bot. Commit: c055ad6 Link to invocation

Copy link
Copy Markdown
Collaborator

@longcheng-nv longcheng-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review — comment (approve in principle; please address the points below)

Thanks @limin2021 for the cuTe DSL port. This PR is well-scoped — single
operator, no production wiring, in-kernel auto-heuristic confined to the new
kernel itself (i.e. doesn't touch any cross-algorithm dispatcher). The
structural design is clean and I'm happy to see this land once the items
below are addressed.

I'm leaving this as COMMENT rather than REQUEST_CHANGES because the
substance of the kernel + tests is good — the asks below are mostly about
PR-body evidence and a few small footguns.


✅ What's right

Functional correctness — strong test design.
test_cute_dsl_gvr_topk_decode.py covers a 144-cell parametric sweep
(dtype × top_k × N × next_n × batch_size × cr) and uses a tie-aware
multi-row reference check (out-of-range / duplicates / n_below /
sort+allclose vs torch.topk). The per-row N_eff formula mirrors the
kernel's exact contract (seq_lens - next_n + ofs + 1) // cr, including
next_n>=3 + cr>=2 floor-division cases. This is significantly more
careful than many CuTe DSL kernel PRs that rely on torch.allclose on
sorted values alone.

Code quality — knobs and rationale are documented.
The forward() auto-heuristic in cute_dsl_custom_ops.py is layered
with concrete rationale per knob:

  • num_threads_per_block (T) — 1024 iff num_rows ≤ num_sms AND N_dec ≥ 65K
    (or 131K for graph-capture half-prec)
  • use_256bit_load (V) — fp32 AND N_dec ≥ 16K; half-prec excluded
    because cvt-to-fp32 doubles fragment reg footprint (5-11% regression
    on K=512/1024)
  • enable_warp_parallel_reduce — only at T=1024; T=512 costs ~2pp
  • min_blocks_per_mp — 3-tier with dtype-conditional ordering (fp32
    prioritizes LDG-ILP via mb=2 cap=64; half-prec prioritizes cvt-ILP)

The max_seq_len graph-capture hint is a real correctness fix —
without it, graph captured at small N would replay with T=512 even
when N=256K at replay time. Easy to miss; good catch.

Algorithm — GvrParams table justifies its choices.
18 entries (3 dtypes × 3 K × 2 cr), with explicit per-cell kFTarget
values. The docstring distinguishes cr=1 (V3.2) vs cr=4 (V4) tuning
origins and cites evidence: "cross-prompt swe-bench shows 1.5-2.2×
P2-iter reduction vs V3.2's kFTarget=384/2560"
.

Architecture — auto-heuristic is INSIDE the new kernel.
Unlike a cross-algorithm dispatcher, this heuristic only chooses between
knob combinations for THIS kernel. When gvr_topk_decode is re-optimized
later, the heuristic updates with it — no main-level constant to rot.
This is the right abstraction layer for tuning of this kind.


⚠️ Asks before merge

1. Replace the perf PNG with a numerical table

The PR body has one performance image but no numerical detail.
Reviewers cannot validate, future readers cannot reproduce. Please
publish a markdown table covering at least:

dtype K N BS next_n cute-dsl μs CUDA prod GVR μs speedup
fp32 2048 64K 1 1 ? ? ?
fp32 2048 64K 32 1 ? ? ?
fp32 512 4K 1 1 ? ? ?
bf16 512 64K 32 2 ? ? ?
...

(Your _fmin_f32_inline docstring already implies there's a measured
gap — "~8-10 us of the cuTe vs prod GVR gap at fp32 K=2048 BS=1" —
make the rest of those measurements visible.)

Also include 2-3 knob-decision validation rows: e.g., one row
showing T=1024 beats T=512 at the heuristic's switch point N=65K
(num_rows ≤ num_sms), and one showing V=256-bit beats V=128-bit at
fp32 N=16K. That justifies the auto-heuristic's specific boundaries.

2. Add an alignment runtime check when use_256bit_load=True

cute_dsl_custom_ops.py justifies the 32-byte alignment as:

"PyTorch CUDA allocations are 256-byte aligned; Phase 2/3 offsets
are multiples of vec_w * elem_bytes = 32 bytes."

This holds for torch.empty(...) outputs but breaks for views into
larger allocations with non-aligned offsets
. If a caller eventually
wires this op to take a view (likely once it lands in the V4 indexer
pipeline), misalignment will trigger silent corruption or a fault.

Suggested addition to forward():

if use_256bit_load:
    assert logits.data_ptr() % 32 == 0, (
        f"256-bit vec load requires 32B-aligned logits.data_ptr(), "
        f"got {logits.data_ptr()} % 32 = {logits.data_ptr() % 32}"
    )

Cheap, catches a real footgun.

3. Clarify return_output_values policy

The op-level forward hardcodes return_output_values = False with the
comment "DSA indexer pipeline only consumes indices, mirroring CUDA's
indexer_topk_decode"
. But the kernel supports True. Two options:

  • If False is permanent for the trtllm op: remove the True branch
    from the kernel (or sentinel-tag it with # TODO(...): remove when production needs values)
  • If True will be re-enabled later: document the eventual caller in
    a TODO so the dead-looking code path is intentional

Currently this reads ambiguously to a future reader.

4. Add 1-2 tests exercising realistic preIdx hit-rate

_make_inputs() builds pre_idx[:, 0] = argmax(logits) and
pre_idx[:, 1:] = arange(1, top_k). That's a worst-case "only slot 0
is meaningful" scenario, exercising kernel robustness to junk pre_idx.

But it does not exercise the realistic case where
|preIdx ∩ topK| / top_k is 0.3-0.8 (matching production
V3.2 / V4 captures). The kernel's whole purpose — the "Guess" phase —
saves work proportional to hit-rate; the current tests don't have any
cell where Guess actually carries weight.

Suggested cell:

# Realistic preIdx: ~50% of top_k slots actually match torch.topk
ref_topk = logits.topk(top_k, dim=-1).indices
keep_mask = torch.rand(ref_topk.shape, device=device) < 0.5
random_fill = torch.randint(0, N, ref_topk.shape, device=device, dtype=torch.int32)
pre_idx = torch.where(keep_mask, ref_topk.int(), random_fill)

This catches "Guess phase short-circuits and returns wrong result"
bugs that the current argmax-only setup cannot.


Minor (optional follow-up)

  • The _fmin_f32_inline PTX-asm workaround (cute DSL has cute.arch.fmax
    but not cute.arch.fmin) — please file an upstream cute-dsl issue
    and link it in the comment, so future maintainers know when the
    workaround can be removed
  • GvrParams is an 18-entry table; future K additions (256? 4096?)
    need 6 new rows. Worth a follow-up to consider whether kFTarget
    could be computed (e.g. kFTarget = K for cr=4) rather than
    tabulated — not blocking for this PR
  • tests/scripts/cute_dsl_kernels/top_k/run_gvr_topk.py is 500 lines
    and not in CI per its docstring — please add a brief top-of-file
    comment explaining why it's standalone (no trtllm env dep) and how
    it relates to the CI'd unittest

Happy to re-review and approve as soon as 1-4 are addressed.

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #50978 [ run ] completed with state ABORTED. Commit: c055ad6

Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants