[TRTLLM-35882][feat] Add cute dsl gvr top-k decode kernel#14602
[TRTLLM-35882][feat] Add cute dsl gvr top-k decode kernel#14602limin2021 wants to merge 26 commits into
Conversation
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…invariant - Replace from_dlpack/static-shape compile with make_fake_compact_tensor + sym_int for batch/num_tokens dims, keeping (dtype, top_k, next_n) as the cache key. Reduces unique compile entries from 810 to 27 across the bench sweep; correctness verified (no OOB writes from cache reuse with wrong shape) via 288-config pytest + cross-impl A/B match. - Fix test_gvr_topk_decode: (1) pre_idx_count now uses top_k (matches CUDA dispatch precondition preIdxCount == topK at heuristic_topk.cuh:810); (2) tie-aware reference now masks logits to per-row effective_len = seq_len - next_n + 1, avoiding false negatives when next_n > 1 makes the kernel skip the last next_n-1 columns. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…imit / mask guard) Four small mechanical alignments — each isolated, removes only redundant work the CUDA reference does not do. Correctness verified: 288/288 in test_gvr_topk_decode.py. Perf delta within measurement noise (~0.15us estimated, 21us baseline DSL — under the ~0.5us spread floor) but the changes match heuristic_topk.cuh semantics 1:1 and pave the way for later batches. - block_count_ge: drop the trailing barrier (gvr_topk_decode.py:422 -> removed). CUDA blockCountGE (heuristic_topk.cuh:441) returns without a sync because callers already insert their own __syncthreads after their tid==0 post-processing. The previous DSL trailing barrier was redundant (tid==0 reads its own write in-thread, no sync needed). - Phase 4 snap_limit: change from cand_count>128 ? cand_count/4 : 32 to cand_count (matches heuristic_topk.cuh:985). The older bound silently accepted a non-converged threshold in ~0.09 % of adversarial distributions; correctness improvement only, common case still converges in 1-3 iters. - Phase 4 block_min/max: every thread now recomputes block_min/max from the warp-staged smem slots into local registers (matches heuristic_topk.cuh:891-898). Replaces the prior `tid==0 writes s_thr[1]/s_thr[2] then broadcast via __syncthreads` pattern, saving one block barrier in Phase 4. - Phase 4 Pass 1/Pass 2 writeback: wrap popc + atomicAdd + shuffle in `if mask != 0` warp-uniform guard (mirrors heuristic_topk.cuh:1020, 1045). Skips the atomic round-trip when no lane in the warp emits, most impactful for Pass 2 where only K-th-rank ties emit. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…le + early break) Convert the Phase 2 secant refinement loop and the Phase 3 retry-shrink loop from Python-unrolled `for in range(N)` (every body wrapped in an `if not done:` guard) to runtime `while` with the convergence condition in the loop predicate. This matches CUDA's pattern at heuristic_topk.cuh: 683 (Phase 2) and :769 (Phase 3 retry). Previously, after the kernel converged at iteration k, the remaining N-k unrolled bodies still each issued an LDS+ICMP+branch guard. With secant typically converging at iter 3 of 15 and retry-shrink usually 0 of 10, this saved ~12 + ~10 = ~22 wasted guard sites per kernel call. Tradeoff: lose Python-time const-fold of `if it == 0: f = min(f, 0.5)`, which now becomes a runtime compare. CUDA does the same runtime compare (heuristic_topk.cuh:698-699), so this is alignment not regression. Measured impact (median config bf16 K=1024 N=32768 BS=1 next_n=2, same-process A/B vs CUDA GVR, 5 reps alternating order): DSL_us 21.01 -> 20.28 (-0.73 us, -3.5%) C/G 0.869 -> 0.903 (+3.4 percentage points) Above the ~1.5% bench_kineto spread floor. 288/288 tests pass. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Two SASS-alignment changes verified against the CUDA reference at the
median config (bf16 K=1024 N=32768 BS=1 next_n=2):
1. cute.make_ptr(..., cute.AddressSpace.gmem, ...) at the two 128-bit
vec-load sites in block_count_ge and phase3_collect_candidates.
Default AddressSpace.generic lowered to SASS LD.E.128; explicit gmem
hint flips to LDG.E.128 (matches CUDA __ldg path, minus .CONSTANT
which still requires CopyG2ROp+invariant).
2. phase1_preidx_stats: replace the runtime `while i < pre_idx_count`
strided loop with `range_constexpr(pre_idx_count // num_threads)`.
pre_idx.shape[1] is a compile-time constant (top_k baked into JIT
cache key); supported top_k in {512, 1024, 2048} are all multiples
of num_threads (512), so n_iters ∈ {1, 2, 4} unrolls cleanly. cute
emits straight-line code (no BRA / ISETP / counter update) and
issues both preIdx LDG.E and input LDG.E.U16 back-to-back, enabling
LSU pipelining (in flight ILP). Mirrors what nvcc/ptxas does for
the equivalent CUDA loop via auto-partial-unroll.
Bench (same-process A/B, 5 repeats × 100 iters, kineto + L2 flush):
Before: C/G = 0.903 (DSL 10.7% slow) -- post-Batch 2 baseline
After: C/G = 0.922 (DSL 8.5% slow)
Δ = +1.9pp
Resource use after changes:
regs/thread: 34 -> 39 (still 3 blocks/SM, occupancy unchanged 75%)
dynamic smem: unchanged (~44 KB)
total SASS instructions: 2935 -> 2944 (codegen ripple, mostly
FMNMX3 +6; loop overhead ISETP/BRA -6/-2/-3 offset by +25 IMAD)
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…tail) Replaces the runtime `while i + (vec_w - 1) < N` vec loop in block_count_ge with a 4-way unrolled fast path + 1-way tail. The fast path issues 4 independent LDG.E.128 per round (separate fragments so cute schedules them concurrently), mirroring what nvcc/ptxas does for the equivalent CUDA loop via auto-partial-unroll. SASS verification at median config (bf16 K=1024 N=32768 BS=1 nn=2): - 4 LDG.E.128 per inline at addresses base / base+0x2000 / +0x4000 / +0x6000 — exact match to CUDA's LDG.E.128.CONSTANT pattern (minus the CONSTANT cache hint, which still requires CopyG2ROp+invariant). - Total LDG.E.128 count: 5 -> 21 (4 inlines * 4 + 4 tails + 1 phase3). - Cute software-pipelines: 3 LDGs issued back-to-back, then consume of iter 0 starts while iter 3's LDG is issued in parallel. All 4 are in flight before HBM responds (latency ~600 cy >> 23 inst slots). Resource impact: - Regs/thread: 39 -> 39 (cute reuses fragment regs across loop body; Phase 4 likely remains the kernel-wide peak) - Dynamic smem: unchanged (~44 KB) - Static SASS size: 2944 -> 3672 inst (+25%) -- code bloat acceptable, well within icache; Block Limit Reg = 3 unchanged at occupancy=75%. Bench results (kineto, L2 flush, n_iters=30): Median config (bf16 K=1024 N=32768 BS=1 nn=2), same-process A/B: Before this commit: C/G = 0.922 (DSL 8.5% slow) After this commit: C/G = 0.976 (DSL 2.4% slow) Delta: +5.4pp Full sweep (804 configs = 3 dtype * 3 top_k * 6 N * 5 BS * 3 next_n): Median C/G: 0.860 (baseline post-Batch-2) -> 0.988 (now) Geomean C/G: 0.869 -> 0.999 (parity with CUDA) DSL faster: 17% -> 46% Within 5%: 13% -> 34% Within 10%: 28% -> 55% By dtype: bf16 1.000, fp16 1.022, fp32 0.951 (fp32 has slightly less runway since vec_w=4 vs 8 for bf16/fp16). By N: gap remains at large N (>=64K: median ~0.87-0.90), where the LSU-pipelining win is already saturated and other phases dominate. The single-config worst slowdowns observed (C/G ~0.4) are concentrated in nn=3 + small-mid N (4-32K) + BS>=64 configs whose CUDA-side numbers also moved 5-15x between runs -- short-runtime measurement noise, not real regressions. This commit completes the SASS-alignment campaign objective (gap < 5% on median config). Remaining ~10-13% at very large N is deferred. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Three new switches gate the block_count_ge vec-load fast path:
enable_unroll_4 (default True): 4-way unrolled fast path
enable_unroll_2 (default by dtype): 2-way cascade between fast and tail
use_strided_layout (default by dtype): True → single make_ptr +
(UNROLL, vec_w) strided layout (cute emits 4 LDG.E.128 sharing
base reg with +0x2000/+0x4000/+0x6000 imm offsets, matching the
CUDA SASS pattern). False → 4 separate make_ptr calls (matches
the prior b459a8f commit style with 4 independent base regs).
Dtype-aware defaults (validated via per-config A/B testing on B200):
bf16 / fp16: enable_unroll_2=True, use_strided_layout=True
Strided cascade gives clean wins: cascade flips DSL from CUDA
parity to consistently faster on small-N where the 4-way fast
path doesn't fully cover N, and the medium 2-way path keeps two
LDG.E.128 in flight. Strided layout keeps the SASS shared-base
pattern that nvcc/ptxas auto-partial-unroll also produces.
fp32: enable_unroll_2=False, use_strided_layout=False
For fp32 (vec_w=4) the strided layout pushes regs 38 → 40 and
regresses fp32 large-grid configs by 30-60pp (worst observed:
K=1024 BS=128 nn=2 → 0.753 vs 1.364 with separate-ptrs). The
cascade similarly hurts in 12% of fp32 configs. Separate-ptrs
4-way unroll alone is the sweet spot.
Cache key includes the three switches so different settings produce
separate compiled kernels.
Full sweep results (804 configs, n_iters=30 kineto, L2 flush):
baseline cascade-all dtype-policy
Median C/G: 0.988 1.011 1.006
Geomean C/G: 0.999 1.047 1.038
DSL faster %: 46% 54% 52%
Within 10%: 55% 65% 65%
By dtype:
bf16: 1.000 -> 1.043 (cascade wins preserved)
fp16: 1.022 -> 1.038
fp32: 0.951 -> 0.960 (anom K=1024 BS=128 fixed: 0.610 -> 1.038)
By N (the original "large-N gap"):
N=8192: 1.097 -> 1.172 (+7pp, cascade hides medium-path remainder)
N=65536: 0.897 -> 0.928 (+3pp)
N=131072: 0.866 -> 0.923 (+6pp)
Remaining slow configs (fp32 K=2048 + BS>=64) were already <0.7 in
the baseline -- this commit doesn't introduce new regressions there.
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Adds two new switches to the DSL GVR kernel:
enable_phase3_unroll (default True): master gate for phase3_collect
unrolling. When ON, the inner enable_unroll_4 / enable_unroll_2
switches independently control 4-way fast and 2-way medium paths
in phase3 (same semantics as block_count_ge). When OFF, only the
tail 1-way loop runs.
use_constant_hint (default False): True → CopyG2ROp(invariant=True)
→ SASS LDG.E.*.CONSTANT (read-only data cache, matches CUDA
__ldg). Default False because cute's invariant lowering triggers
aggressive rematerialization in LLVM/NVPTX (+272 inst, 4 spills,
net -7pp geomean), outweighing the cache hint benefit.
Phase3_collect is now a 3-tier cascade (4-way fast + 2-way medium +
1-way tail) mirroring block_count_ge. The cascade gives:
N>=65K: +5-7% (large-N main path, LSU pipelining wins)
N<=32K: -1-3% (unroll setup overhead exceeds benefit at small N)
Median geomean: +2.2pp from phase3 unroll alone
Resource analysis (bf16/fp16/fp32 x phase3 ON/OFF):
REG/thread:
bf16: 39 -> 39 (no change, cute reuses fragments)
fp16: 39 -> 39 (no change)
fp32: 38 -> 40 (+2, separate-ptrs path)
Static SASS:
bf16: 3936 -> 4368 (+11%)
fp16: 4096 -> 4512 (+10%)
fp32: 3368 -> 3480 (+3%)
Theoretical Occupancy: 75% all configs (smem-limited to 3 blocks/SM,
binding limit unaffected by phase3 unroll). Phase3 unroll has
*zero* occupancy cost.
Wrapper signature gains both switches; cache key includes them so
different settings produce separate compiled kernels. A small helper
method _make_load_copy_atom() factors out the CopyG2ROp/Universal
selection to avoid Python if-else NameError inside @cute.jit scope.
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…uristic Phase 2 (block_count_ge) and Phase 3 (phase3_collect_candidates) replace the manual `while + range_constexpr(UNROLL)` fast/medium-cascade unrolling with a single `for k in cutlass.range(big_iters, unroll=4)` loop. LLVM's loop unroll pass + GVN/CSE folds the 4 derived vec loads into the CUDA-style shared base + immediate offsets pattern, emitting 4 back-to-back LDG.E.128 [base+0x2000/0x4000/0x6000] instructions. Add `min_blocks_per_mp` field on `GvrTopKKernel` and a 3-tier shape-aware heuristic in the host wrapper: * n_vec_iters < 4 -> 0 (no launch_bounds, natural ptxas allocation) * num_rows <= 148 (B200 SMs) -> 1 (allow many regs, 4xLDG fold survives) * else -> 3 (keep 3 CTA/SM occupancy, ~42 reg cap) The heuristic lifts fp32 K=512 large-N out of its regression zone (worst case C/G 0.62 -> 1.07 at K=512 N=131072 BS=16 nn=2). Cache key extended so each min_blocks value gets its own compiled kernel. Random sweep vs phase3_unroll baseline (804 configs): geomean 1.060 -> 1.149, faster%-than-CUDA 62% -> 92%, losses 304 -> 63. CUDA Graph: heuristic reads `logits.shape` (host int, no GPU sync) so capture is safe; per-graph capture selects the right kernel per shape. For dynamic-shape single-graph use, caller can pin `min_blocks_per_mp=3` to disable the heuristic. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Three new kernel knobs on GvrTopKKernel + gvr_topk_decode host wrapper:
* use_256bit_load (default False): emit LDG.E.256 (8 fp32 / 16 bf16-fp16
elements per LDG) instead of LDG.E.128. Address alignment hint is
raised from 16 to 32 bytes. Phase 2/3 unroll factor is dtype-aware:
fp32 keeps unroll=4 (no cvt-to-fp32 overhead); bf16/fp16 drops to
unroll=2 to limit the cvt register pressure that otherwise spills
under min_blocks=3.
* num_threads_per_block (default 512): configurable per-instance.
BLOCK_SIZE / WARP_SIZE / NUM_WARPS are moved from module-level to
GvrTopKKernel instance attrs (self.WARP_SIZE, self.num_threads,
self.num_warps). Phase 1 preIdx loop gains an else branch for the
K < num_threads case (e.g. num_threads=1024 with K=512): only the
first K threads load a preIdx, others keep reduction-identity
values which the warp/block reduces naturally absorb.
* vec_bits / vec_align_bytes derived from use_256bit_load; cache key
extended with use_256bit_load + num_threads_per_block.
Heuristic uses the resolved num_threads_per_block (not a hardcoded 512)
when computing n_vec_iters.
Tests parametrize use_256bit_load and num_threads_per_block; pytest
runs 288/288 PASS at use_256bit_load=True and at num_threads=1024.
Synth bench on BS<=128:
- 128-bit + heuristic baseline: gm=1.131, faster%=99%, lose=9
- 256-bit + heuristic : gm=1.121 (fp32 wins +3pp; bf16/fp16
flat-to-negative due to cvt-to-fp32 reg pressure spills under mb=3)
Random sweep on BS up to 128: 256-bit shows niche win on
(fp32, num_rows<=148, large N); should be opt-in.
Synth data generator (multi-BS) and bench script env vars
(DSL_USE_256BIT/DSL_MIN_BLOCKS/DSL_NUM_THREADS) live in the gvr-topk-opt
workspace and are not part of this commit.
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…nobs
Add a fourth perf knob `enable_warp_parallel_reduce` to GvrTopKKernel +
gvr_topk_decode and replace the four `tid==0` serial loops over
num_warps slots with warp-parallel reduce/scan in warp 0:
* Phase 1 block aggregate (4-way reduce):
min/max/sum_f32/sum_i32 -> 4x warp_reduce in warp 0.
* Phase 2 / blockCountGE total (1-way reduce):
sum_i32 -> warp_reduce_sum_i32.
* Phase 3 collect block prefix sum (exclusive scan):
Hillis-Steele inclusive scan via block_scan.warp_scan, then
exclusive = inclusive - val; total = inclusive at last lane.
* Phase 2 secant aggregate (3-way reduce):
packed sum_i32 + min_f32 + max_f32, with bound update on lane 0.
Default is False since at num_threads=512 (num_warps=16) the per-warp
ILP loss exceeds the serial-loop savings (~2pp regression on synth).
At num_threads=1024 (num_warps=32) the switch is essential -- without
it 1024 regresses vs baseline (gm 1.131 -> 1.123); with it 1024 wins
(gm -> 1.154 on synth BS<=128). Pair as
`enable_warp_parallel_reduce = (num_threads_per_block >= 1024)`.
Phase 1 also gains an `active_preidx_warps` optimization: when
`pre_idx_count < num_threads` (e.g. K=512 with num_threads=1024) only
the first ceil(K/32) warps have real data, so the warp_reduce + smem
write step is now gated to those warps. Saves ~30 cy/dummy-warp; the
full barrier afterwards still keeps all 1024 threads aligned for
Phase 2. The constexpr is clamped to num_warps so the K>num_threads
case (K=2048 with num_threads=512) doesn't index past the smem
buffers, and the same value drives both the warp_reduce gate and the
Site-1 block aggregate's smem read range.
Remove two now-dead switches:
* `enable_unroll_2` -- only referenced in the commented-out manual
2-way medium path that the `cutlass.range(unroll=4)` rewrite
replaced.
* `use_strided_layout` -- only referenced in the commented-out manual
4-way strided-layout path, also replaced.
Cache key drops the two dead entries and gains
`enable_warp_parallel_reduce`. The cleanup is a no-op functionally
(the dead values were ignored by the active code paths) but removes
two cache-bucket dimensions.
Test parametrize expanded to 4-way matrix:
next_n in {1, 2} (was {1, 2, 3, 4} -- trimmed to keep walltime)
use_256bit_load in {False, True}
num_threads_per_block in {512, 1024}
enable_warp_parallel_reduce in {False, True}
1152 / 1152 PASS in 20:22.
Synth bench (BS<=128, threads=512 baseline -> threads=1024+wpON):
geomean 1.131 -> 1.154 (+2.3pp)
fp32 geomean 1.127 -> 1.177 (+5.0pp; up to +28pp at fp32 K=2048
N=131072 -- 1.50x vs CUDA)
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
📝 WalkthroughWalkthroughThis PR adds test coverage and public API exposure for the Blackwell CuTE ChangesGVR Top-K Kernel Export and Testing
🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In
`@tensorrt_llm/_torch/cute_dsl_kernels/blackwell/top_k/test_gvr_topk_decode.py`:
- Around line 29-45: The test picks argmax_idx over the full N, which breaks the
pre_idx[...,0] invariant when the kernel only scans the first (N - next_n + 1)
columns; update _make_inputs to compute an effective_scan = N - next_n + 1 (or
accept next_n as a parameter) and compute argmax_idx = int(logits[0,
:effective_scan].argmax().item()) so pre_idx_list[0] is in-range, then build
pre_idx as before; apply the same change to the other occurrence (the second
_make_inputs usage around the later test).
- Line 21: Remove the unused "from typing import Tuple" import and replace the
two occurrences of typing.Tuple[...] return annotations in this test module with
the native Python 3.10+ generic syntax tuple[...] (e.g., change "Tuple[int,
str]" to "tuple[int, str]") in the two functions in
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/top_k/test_gvr_topk_decode.py so
the file uses built-in tuple typing and no longer imports typing.Tuple.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 031d6484-acc6-4f3e-b7b3-e6defb77c52c
📒 Files selected for processing (3)
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/top_k/__init__.pytensorrt_llm/_torch/cute_dsl_kernels/blackwell/top_k/gvr_topk_decode.pytensorrt_llm/_torch/cute_dsl_kernels/blackwell/top_k/test_gvr_topk_decode.py
Mirrors the kFTarget=kK alignment for K=512/1024 (all dtypes) from CUDA PR NVIDIA#14413 on the DSL GVR Top-K kernel so the DSL Phase-2 secant behavior matches the new CUDA reference. Old pre-NVIDIA#14413 values kept as inline comments for easy rollback. Verified: 768/768 pytest configs pass for K=512/1024 across all dtypes, N, next_n, use_256bit_load, num_threads_per_block, and enable_warp_parallel_reduce. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Move module-level MAX_REFINE_ITERS / FLT_MAX / NEG_FLT_MAX into instance attributes so all kernel-wide knobs live in one place. Inline NUM_BINS_DEFAULT (2048) directly into the GvrParams table since it was only used in three K=2048 entries. Drop dead MAX_CANDIDATES. Pure refactor — values, control flow, and DSL IR are unchanged. Also removes the previously-commented-out A/B layout/unroll dead code in block_count_ge. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Mirrors heuristicTopKDecode.cu PR NVIDIA#14219 cr-aware branch in the DSL GVR Top-K kernel. compress_ratio=1 (default) preserves DSv3.2 behavior exactly; compress_ratio=4 enables the DSv4 (overlap-compressor) indexer path: * pre_idx_offset = 0 (vs (row % next_n) + 1 for cr=1) — in compressed- index space, new entries append at the end so prev-step indices remain valid as-is. * N = actual_kv_len / cr — logits/preIdx live in compressed-token- index space when cr > 1. GvrParams TABLE is also keyed by (dtype, K, cr) so V3.2 and V4 use their respectively tuned kFTarget values: cr=1 (V3.2): kFTarget = 384 (K=512) / 2560 (K=1024), pre-NVIDIA#14413. cr=4 (V4): kFTarget = kK = 512 (K=512) / 1024 (K=1024), PR NVIDIA#14413. K=2048: identical across cr (V4 doesn't natively use K=2048). Cache key includes compress_ratio so different cr settings compile separate kernels. assert restricts compress_ratio in {1, 4}. Verified: 1152/1152 pytest configs pass on cr=1 default path. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…3 mb
Two paired host-wrapper heuristic refinements:
1. enable_warp_parallel_reduce: bool → Optional[bool] = None, default
auto-coupled to num_threads_per_block: enabled iff threads == 1024 (32
warps, where serial tid==0 cost dominates). At threads == 512 (16
warps) the warp-parallel path measured a ~2pp synth regression so it
stays off. Cache key sees the concrete bool. Explicit True/False still
overrides for A/B testing.
2. tier-3 (large grid + large N) min_blocks_per_mp hardcoded "= 3"
replaced by a (T, dtype) lookup:
T == 1024 or dtype == fp32 → mb=2
T == 512 and dtype in (bf16, fp16) → mb=3
Derived from BS{256,384,512} × N{16K,32K,65K} × all 9 (dtype, K) sweep
(gvr-topk-opt/sweep_tv_mb_kineto/mb_sweep.png). Old mb=3 default
regressed by 25-37% on (T=512 + fp32 + large N/BS) configs because
cap=42 starves the 4-LDG-inflight ILP (fp32 vec_w=4 × unroll-4 needs
50+ regs). bf16/fp16 keep mb=3 since cvt-to-fp32 ILP fits in 40 regs
and the extra CTA/SM (3 vs 2) hides cvt latency.
Pure default-policy change — no behavioral effect when caller passes
explicit values. Verified: pytest smoke 4/4 on cr=1 default path.
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Two more host-wrapper Optional[*]=None defaults so callers no longer
need to pick threads/vec-bits per shape:
num_threads_per_block (None default):
1024 iff num_rows <= num_sms (1 CTA/SM bound) AND N >= 65536
(so each of the 1024 threads has meaningful vec-loop work).
Otherwise 512.
use_256bit_load (None default):
True iff dtype == fp32 AND N >= 16384.
Half-prec (bf16/fp16) cvt-to-fp32 doubles fragment reg footprint
and regresses 5-11% at K=512/1024; LDG already saturates at 128b
anyway. fp32 N=8K dips 5-8% with 256b at small grid so the N
threshold excludes that single tier.
Cache key sees concrete values; (None, X) and (None, Y) hash apart.
Explicit values still override for A/B testing.
Derivation: sweep BS{1,4,16,64,128,256,384,512} x N{4K..131K} x all 9
(dtype, K), gvr-topk-opt/sweep_tv_kineto/auto_speedup.csv. Net vs
baseline (T=512, V=128):
- median speedup vs CUDA 1.09x -> 1.10x
- mean speedup vs CUDA 1.11x -> 1.15x (+3.8pp)
- max speedup vs CUDA 1.45x -> 1.52x
- 21 of 22 sp<1 configs were already sp<1 in baseline (BS=384 grid
quirk, unrelated to this change). 1 new config introduces a 0.8pp
sp<1 dip (within bench noise).
Pure default-policy change. Verified: 4-case auto-path smoke + pytest
smoke 4/4 on cr=1 default.
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Decode runs under CUDA graph, where the (T, V) heuristic baked in at
capture time is reused across all replays. The capture-time
logits.shape[1] is typically much smaller than peak runtime N, so the
captured kernel misses the large-N (T=1024, V=256) path. Add an
optional max_seq_len hint so the caller (e.g. dsa.py) can pass the
peak compressed-N for the model; the heuristic then tunes the captured
kernel for the peak.
Usage guidance baked into the docstring + inline comment:
* CUDA Graph mode: CALLER MUST PASS max_seq_len.
* Eager mode: leave max_seq_len=None (heuristic adapts per call).
Rules with max_seq_len:
* T=1024 threshold becomes dtype-aware to avoid half-prec K=512/1024
small-N replay regression (14-16% when forced T=1024 at small N):
fp32 -> 65536 (small-N replay 1-9% loss, net win)
half -> 131072 (only forced at very large peak)
* V=256 still gated by fp32 + N >= 16384.
Without max_seq_len, dtype-split is NOT applied because per-call
adaptive decisions never force T=1024 onto small N — heuristic only
fires for N >= 65536 by definition — so the half-prec N=65K-128K
+4-6% T=1024 win is preserved.
Cache key sees concrete (T, V), so different max_seq_len hints compile
distinct kernels. Pure default-policy extension. Verified with 4-case
auto smoke (no hint / fp32+131K / bf16+131K / bf16+200K).
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
The min_blocks_per_mp tier heuristic was still computing n_vec_iters from logits.shape[1] (capture-time N). In graph mode with max_seq_len hint, this would stick small-capture-N calls in tier-0 (mb=0) and miss the tier-3 occupancy choice for large-N replays — same pitfall the (T, V) heuristic was fixed against in the previous commit. Switch to N_dec (= max_seq_len if provided, else logits.shape[1]) so the tier classification is consistent with how T/V are picked. Smoke 4/4. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Add a wave-fit branch in the fp32 tier-3 path: when num_rows ∈ (296, 444]
(i.e. fits 1 wave at 3 CTAs/SM but needs partial 2nd wave at 2 CTAs/SM
with num_sms=148), pick mb=3 instead of mb=2. This recovers ~15% perf
on fp32 BS=384 across (K, N) — verified against CUDA which already uses
__launch_bounds__(BS, 3) for this exact reason.
Math:
mb=2 cap → 2 CTAs/SM × 148 SMs = 296 CTAs in 1 wave.
mb=3 cap → 3 CTAs/SM × 148 SMs = 444 CTAs in 1 wave.
For BS=384 (× next_n=1):
mb=2: 384 / 296 = 1.30 waves → tail wave wastes ~70% SMs.
mb=3: 384 / 444 = 0.86 waves → 1 wave fits, max SM utilization.
Verified perf gains (fp32 T=512, both V=128 and V=256 default paths):
fp32 K=512 N=4K-32K BS=384: +11-23%
fp32 K=1024 N=4K-32K BS=384: +16-19%
fp32 K=2048 N=8K-32K BS=384: +5-9%
Other BS unaffected:
BS ≤ 296 (192, 256): mb=2 already fits 1 wave → rule keeps mb=2 (no change)
BS > 444 (512): both need >1 wave → rule keeps mb=2 (ILP > occupancy)
Half-prec heuristic unchanged (already uses mb=3 in tier-3 via the
dtype-split path from a prior commit).
Bench artifacts: gvr-topk-opt/auto_full_bench/fp32_bs384_cluster/
(mb sweep CSV + NCU reports + drivers). Smoke: pytest 4/4 + spot tests
across BS={256, 384, 512}.
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Functional change: limit wave-fit mb=3 branch to N <= 32768. Beyond
that threshold the kernel becomes bandwidth-bound and mb=3's 3-way L2
sharing causes contention; mb=2's lower occupancy gives each CTA more
bandwidth and wins +21-30% at fp32 K=512 N=65K BS=384.
The full wave-fit rule for fp32 tier-3 is now:
if 2*num_sms < num_rows <= 3*num_sms and N_dec <= 32768:
mb = 3
else:
mb = 2
Also cleans up file comments: remove obsolete TODO list, trim refs to
specific CUDA line numbers, simplify class docstring.
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Registers torch.ops.trtllm.cute_dsl_gvr_topk_decode as the production entry point for the cuTe DSL GVR Top-K decode kernel (Blackwell SM100). Op writes values + indices into caller-allocated buffers (mutates_args style), matching the existing cute_dsl_indexer_topk_decode pattern so the DSA indexer pipeline can drop it in. CuteDSLGvrTopKDecodeRunner takes ownership of the JIT compile cache and the auto-heuristic for T (threads/block), V (vec-load width), min_blocks_per_mp and enable_warp_parallel_reduce. The previous module-level wrapper in gvr_topk_decode.py is removed; standalone bench / A-B testing with the full tuning knob set lives in tests/scripts/cute_dsl_kernels/top_k/run_gvr_topk.py. Tests: - tests/unittest/.../test_cute_dsl_gvr_topk_decode.py: production correctness sweep via the op (dtype x K x N x next_n x batch_size x compress_ratio) with vectorized tie-aware + strict sort+allclose check. - tests/scripts/.../run_gvr_topk.py: dual-mode driver -- pytest sweep over T/V/wp knobs and standalone CLI for single-case verification. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
|
/bot run |
|
PR_Github #50762 [ run ] triggered by Bot. Commit: |
|
PR_Github #50762 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #50917 [ run ] triggered by Bot. Commit: |
The DSA indexer pipeline (dsa.py) and CUDA indexer_topk_decode op both
only read top-K indices from the kernel output; the value buffer is
caller-allocated scratch that's never consumed. Add a kernel-level
return_output_values switch so the cuTe DSL kernel can elide all
STG.value stores when the caller doesn't need them.
Kernel (gvr_topk_decode.py):
- GvrTopKKernel gains return_output_values: bool = True.
- All 9 STG.value sites + the output_values_row slice are gated under
cutlass.const_expr(self.return_output_values), letting cute.compile
eliminate the dead writes when False.
Op (cute_dsl_custom_ops.py):
- CuteDSLGvrTopKDecodeRunner adds return_output_values to the compile
cache key + _compile signature; forward() hardcodes False, drops the
output_values arg, and passes None for the value-output slot at
launch (mirrors the optional-fake-tensor pattern at
CuteDSLTopKDecodeMultiCTARunner._compile).
- trtllm::cute_dsl_gvr_topk_decode op signature drops output_values;
mutates_args is now ("output_indices",), aligning with CUDA's
indexer_topk_decode which also only exposes indices.
Tests:
- tests/unittest/.../test_cute_dsl_gvr_topk_decode.py drops the
output_values buffer alloc + op kwarg (all 144 cases pass with the
sort+allclose strict check).
- tests/scripts/.../run_gvr_topk.py wrapper exposes
return_output_values as a knob so the standalone driver can still
capture written values; the _compile cache + cute.compile
out_values_fake placeholder are conditional on the flag.
SASS verification at bf16 K=1024 N=8K BS=384 confirms 88 STG.E.U16
writes are eliminated (kernel cubin -6KB, total SASS -416 lines).
On B200 SXM5 + synth_data, v5 (return_output_values=False) gives a
median 1.2% latency improvement over v4 with sp<1 configs nearly
halved (52 -> 27).
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
* Modernize type hints: replace ``typing.Tuple[...]`` with the built-in ``tuple[...]`` (Python 3.10+) and drop the ``from typing import Tuple`` import in both ``_make_inputs`` / ``_tie_aware_correct`` helpers. * Fix the ``pre_idx[..., 0]`` argmax invariant for next_n > 1: argmax must come from the kernel's effective scan range ``[0, N - next_n + 1)``, not full ``[0, N)``. With the prior full-N argmax, an index landing in the ``[N_eff, N)`` tail could violate the CUDA-side ``preIdxCount == topK`` dispatch precondition (kernel still produced correct top-K because pre_idx is only a Phase-1 hint, but the test was technically exercising the kernel under invalid input). Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
The kernel derives batch_size as logits.shape[0] / next_n implicitly,
then sizes pre_idx / seq_lens / output_indices accordingly. If the
divisibility breaks, the failure modes are either an OOB write or a
ZeroDivisionError raised from deep inside the JIT-compiled kernel —
neither is actionable for callers. Add an upfront check in the op
body so the contract violation surfaces with a clear message.
Other invariants (top_k in {512,1024,2048}, compress_ratio in {1,4},
logits dtype) are already enforced by GvrTopKKernel.__init__ via
GvrParams.get and the dtype switch.
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
* Drop the single-use N_cols local in CuteDSLGvrTopKDecodeRunner.forward; fold it directly into the N_dec ternary for less noise. * Add the input shape / dtype / knob signature to the info_once dedup key. Without the signature the first call's log message hid every subsequent shape from production diagnostics; now each new shape emits a single log line on its first run. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Both _tie_aware_check (unittest) and _tie_aware_correct (run_gvr_topk
standalone driver) previously assumed the reference scan range was
``N - next_n + ofs + 1`` per row. That matches the kernel for cr=1
and for cr>=2 with next_n in {1, 2}, but breaks for cr>=2 with
next_n>=3 because floor-division by cr makes per-row N_eff vary
within a group in ways the simple closed form can't express.
Switch both reference helpers to mirror the kernel's exact formula:
actual_kv_len = seq_lens[row // next_n] - next_n + (row % next_n) + 1
N_eff = actual_kv_len // compress_ratio # cr=1 is identity
This requires the helpers to take ``seq_lens`` (and compress_ratio for
the standalone driver) so the reference can compute per-row N_eff
exactly as the kernel does. With this, any (next_n, cr) combo is
testable without the floor-division mismatch.
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
|
PR_Github #50917 [ run ] completed with state
|
|
/bot run |
|
PR_Github #50978 [ run ] triggered by Bot. Commit: |
longcheng-nv
left a comment
There was a problem hiding this comment.
Review — comment (approve in principle; please address the points below)
Thanks @limin2021 for the cuTe DSL port. This PR is well-scoped — single
operator, no production wiring, in-kernel auto-heuristic confined to the new
kernel itself (i.e. doesn't touch any cross-algorithm dispatcher). The
structural design is clean and I'm happy to see this land once the items
below are addressed.
I'm leaving this as COMMENT rather than REQUEST_CHANGES because the
substance of the kernel + tests is good — the asks below are mostly about
PR-body evidence and a few small footguns.
✅ What's right
Functional correctness — strong test design.
test_cute_dsl_gvr_topk_decode.py covers a 144-cell parametric sweep
(dtype × top_k × N × next_n × batch_size × cr) and uses a tie-aware
multi-row reference check (out-of-range / duplicates / n_below /
sort+allclose vs torch.topk). The per-row N_eff formula mirrors the
kernel's exact contract (seq_lens - next_n + ofs + 1) // cr, including
next_n>=3 + cr>=2 floor-division cases. This is significantly more
careful than many CuTe DSL kernel PRs that rely on torch.allclose on
sorted values alone.
Code quality — knobs and rationale are documented.
The forward() auto-heuristic in cute_dsl_custom_ops.py is layered
with concrete rationale per knob:
num_threads_per_block(T) — 1024 iffnum_rows ≤ num_sms AND N_dec ≥ 65K
(or 131K for graph-capture half-prec)use_256bit_load(V) — fp32 ANDN_dec ≥ 16K; half-prec excluded
because cvt-to-fp32 doubles fragment reg footprint (5-11% regression
on K=512/1024)enable_warp_parallel_reduce— only at T=1024; T=512 costs ~2ppmin_blocks_per_mp— 3-tier with dtype-conditional ordering (fp32
prioritizes LDG-ILP via mb=2 cap=64; half-prec prioritizes cvt-ILP)
The max_seq_len graph-capture hint is a real correctness fix —
without it, graph captured at small N would replay with T=512 even
when N=256K at replay time. Easy to miss; good catch.
Algorithm — GvrParams table justifies its choices.
18 entries (3 dtypes × 3 K × 2 cr), with explicit per-cell kFTarget
values. The docstring distinguishes cr=1 (V3.2) vs cr=4 (V4) tuning
origins and cites evidence: "cross-prompt swe-bench shows 1.5-2.2×
P2-iter reduction vs V3.2's kFTarget=384/2560".
Architecture — auto-heuristic is INSIDE the new kernel.
Unlike a cross-algorithm dispatcher, this heuristic only chooses between
knob combinations for THIS kernel. When gvr_topk_decode is re-optimized
later, the heuristic updates with it — no main-level constant to rot.
This is the right abstraction layer for tuning of this kind.
⚠️ Asks before merge
1. Replace the perf PNG with a numerical table
The PR body has one performance image but no numerical detail.
Reviewers cannot validate, future readers cannot reproduce. Please
publish a markdown table covering at least:
| dtype | K | N | BS | next_n | cute-dsl μs | CUDA prod GVR μs | speedup |
|---|---|---|---|---|---|---|---|
| fp32 | 2048 | 64K | 1 | 1 | ? | ? | ? |
| fp32 | 2048 | 64K | 32 | 1 | ? | ? | ? |
| fp32 | 512 | 4K | 1 | 1 | ? | ? | ? |
| bf16 | 512 | 64K | 32 | 2 | ? | ? | ? |
| ... |
(Your _fmin_f32_inline docstring already implies there's a measured
gap — "~8-10 us of the cuTe vs prod GVR gap at fp32 K=2048 BS=1" —
make the rest of those measurements visible.)
Also include 2-3 knob-decision validation rows: e.g., one row
showing T=1024 beats T=512 at the heuristic's switch point N=65K
(num_rows ≤ num_sms), and one showing V=256-bit beats V=128-bit at
fp32 N=16K. That justifies the auto-heuristic's specific boundaries.
2. Add an alignment runtime check when use_256bit_load=True
cute_dsl_custom_ops.py justifies the 32-byte alignment as:
"PyTorch CUDA allocations are 256-byte aligned; Phase 2/3 offsets
are multiples of vec_w * elem_bytes = 32 bytes."
This holds for torch.empty(...) outputs but breaks for views into
larger allocations with non-aligned offsets. If a caller eventually
wires this op to take a view (likely once it lands in the V4 indexer
pipeline), misalignment will trigger silent corruption or a fault.
Suggested addition to forward():
if use_256bit_load:
assert logits.data_ptr() % 32 == 0, (
f"256-bit vec load requires 32B-aligned logits.data_ptr(), "
f"got {logits.data_ptr()} % 32 = {logits.data_ptr() % 32}"
)Cheap, catches a real footgun.
3. Clarify return_output_values policy
The op-level forward hardcodes return_output_values = False with the
comment "DSA indexer pipeline only consumes indices, mirroring CUDA's
indexer_topk_decode". But the kernel supports True. Two options:
- If False is permanent for the trtllm op: remove the True branch
from the kernel (or sentinel-tag it with# TODO(...): remove when production needs values) - If True will be re-enabled later: document the eventual caller in
a TODO so the dead-looking code path is intentional
Currently this reads ambiguously to a future reader.
4. Add 1-2 tests exercising realistic preIdx hit-rate
_make_inputs() builds pre_idx[:, 0] = argmax(logits) and
pre_idx[:, 1:] = arange(1, top_k). That's a worst-case "only slot 0
is meaningful" scenario, exercising kernel robustness to junk pre_idx.
But it does not exercise the realistic case where
|preIdx ∩ topK| / top_k is 0.3-0.8 (matching production
V3.2 / V4 captures). The kernel's whole purpose — the "Guess" phase —
saves work proportional to hit-rate; the current tests don't have any
cell where Guess actually carries weight.
Suggested cell:
# Realistic preIdx: ~50% of top_k slots actually match torch.topk
ref_topk = logits.topk(top_k, dim=-1).indices
keep_mask = torch.rand(ref_topk.shape, device=device) < 0.5
random_fill = torch.randint(0, N, ref_topk.shape, device=device, dtype=torch.int32)
pre_idx = torch.where(keep_mask, ref_topk.int(), random_fill)This catches "Guess phase short-circuits and returns wrong result"
bugs that the current argmax-only setup cannot.
Minor (optional follow-up)
- The
_fmin_f32_inlinePTX-asm workaround (cute DSL hascute.arch.fmax
but notcute.arch.fmin) — please file an upstream cute-dsl issue
and link it in the comment, so future maintainers know when the
workaround can be removed GvrParamsis an 18-entry table; future K additions (256? 4096?)
need 6 new rows. Worth a follow-up to consider whetherkFTarget
could be computed (e.g.kFTarget = Kfor cr=4) rather than
tabulated — not blocking for this PRtests/scripts/cute_dsl_kernels/top_k/run_gvr_topk.pyis 500 lines
and not in CI per its docstring — please add a brief top-of-file
comment explaining why it's standalone (no trtllm env dep) and how
it relates to the CI'd unittest
Happy to re-review and approve as soon as 1-4 are addressed.
|
PR_Github #50978 [ run ] completed with state |
Description
Add cute dsl gvr top-k decode kernel.
(1) port cuda gvr kernel to cute dsl. Thx for Long's help, who gives the 1st version porting code.
(2) add some extra optimizations, e.g., unroll, 256bits vectorization, num_threads tune, to further improve the perf.
Test Coverage
Performance
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.