Skip to content

[None][feat] Multi-K (512/1024/2048) and Multi-dtype (fp32/bf16/fp16) GVR Top-K#13948

Merged
lfr-0531 merged 19 commits into
NVIDIA:mainfrom
longcheng-nv:feat/gvr-topk-v123
May 12, 2026
Merged

[None][feat] Multi-K (512/1024/2048) and Multi-dtype (fp32/bf16/fp16) GVR Top-K#13948
lfr-0531 merged 19 commits into
NVIDIA:mainfrom
longcheng-nv:feat/gvr-topk-v123

Conversation

@longcheng-nv
Copy link
Copy Markdown
Collaborator

@longcheng-nv longcheng-nv commented May 9, 2026

Summary

Follow-up to #13477 (Scheme X dispatcher). Extends the GVR (Guess-Verify-Refine) Top-K micro-kernel to:

  • Multi-K: 512 / 1024 / 2048 (was: 2048 only)
  • Multi-dtype: fp32 / bf16 / fp16 logits (was: fp32 only)

via a GvrParams<T, K> trait class with 9 specializations. K=2048 fp32 production path is byte-identical SASS-wise (V2e LOCK preserved across all 3 production K=2048 kernel variants).

Commit Breakdown (16 commits)

  1. feat: Multi-dtype GVR Top-K kernel: fp32/bf16/fp16
  2. feat: Multi-dtype GVR Top-K dispatcher: bf16/fp16 path
  3. feat: Multi-K (512/1024/2048) GVR Top-K via GvrParams trait class
  4. fix: Forward TopK template arg into blockFusedSnapIter helpers
  5. feat: Relax GVR Top-K caller gate to K ∈ {512, 1024, 2048}
  6. feat: Per-(T, K) kNumBins tuning saves 5-9 % on K=512/1024 cells
  7. feat: Defer dtype P3 conversion → −6 to −13 % on bf16/fp16 cells
  8. feat: kFTarget retune for V4 M=K (K=512/1024 cells)
  9. feat: Align standalone gvrTopKKernel launch_bounds with production
  10. chore: Strip internal experiment-ID lineage from GVR Top-K comments
  11. chore: Drop dead thresholdPos param + remove internal opt-registry tags
  12. refactor: Hoist Scheme X bounds + collapse heuristicTopKDecode launcher
  13. chore: Strip residual Opt-M / v0 references from selection-pass comment
  14. feat: Symmetric Radix / Insertion fallback for indexer Top-K bf16/fp16 path
  15. fix: GVR Top-K P4 snap-iter convergence + drop test tolerance widening (round-3 review fix)
  16. chore: GVR Top-K drop dead blockFusedSnapIterDtype + clarify kFTarget docstring (round-3 review fix)

Key Files (6 files / +1621 / −355)

File Change
cpp/tensorrt_llm/kernels/heuristic_topk.cuh GvrParams<T,K> trait class; 9-way template specialization; per-(T,K) kNumBins / kFTarget tuning; deferred fp32 P3 conversion
cpp/tensorrt_llm/kernels/heuristicTopKDecode.cu 3 launcher overloads (fp32/bf16/fp16); GvrParams-driven kernel dispatch
cpp/tensorrt_llm/kernels/heuristicTopKDecode.h Multi-dtype API surface
cpp/tensorrt_llm/kernels/indexerTopK.cu Multi-dtype dispatcher; symmetric Insertion / Radix fallback for bf16/fp16
cpp/tensorrt_llm/kernels/IndexerTopK.h dtype-templated invokeIndexerTopKDecode
cpp/tensorrt_llm/thop/IndexerTopKOp.cpp Python op auto-routes by tensor.scalar_type()

API

No public API change. The Python op torch.ops.trtllm.indexer_topk_decode
now auto-routes by logits.scalar_type() to the matching dispatcher;
bf16/fp16 paths require pre_idx and heuristic_scratch (matching
dtype). K is selected via the index_topk argument as before; values
outside {512, 1024, 2048} fall back to Radix / Insertion.

Test plan

  • Unit tests (pytest tests/unittest/_torch/thop/parallel/test_indexer_topk.py -k test_indexer_topk_decode):
    • test_indexer_topk_decode (5-arg / no pre_idx) — fp32 K∈{128, 2048} Radix/Insertion fallback covered (32 cases).
    • test_indexer_topk_decode_dist (7-arg / GVR Heuristic, parametrized in this PR over index_topk × dtype × num_tokens × batch_size × next_n × dist × success_ratio = 2268 cases) — all 9 new GVR (T, K) specializations covered. num_tokens=16384 cells hit GVR (1134); num_tokens=8192 cells route through Radix/Insertion fallback via Scheme X dispatcher (kSeqSmall=12288 gate). Post round-3 fix (commit 15): all 2268 cells pass under the strict default tolerance (1e-5) — flaky markers and the bin_tolerance widening are removed; the former 0.09 % K=512 / K=1024 intermittent at numRows ≈ 384 is fixed by the P4 snap-iter convergence guarantee.
  • Q9 ablation suite on B200 sm_100 — multi-K × multi-dtype × multi-BS perf characterization with bit-equivalence to torch.topk set (9 combos × 9 SWE-Bench-64K layers × 12 BS cells).
  • V2e LOCK SASS byte-identity verified (production K=2048 path).
  • CUDA Graph capture / replay safe via the warmup helper introduced in [None][perf] Scheme X L2-aware dispatcher and PDL launchers for sparse-attention GVR Top-K #13477 (no new framework hooks; static caches in dispatcher are dtype-independent).

V2e LOCK SASS byte-identity (production K=2048)

Production kernel md5 baseline / md5 post-PR
heuristicTopKMultiRowKernel<2048> fp32 a9f5c6b8 / a9f5c6b8
heuristicTopKMultiRowKernelDtype<__half, 2048> 9caa920c / 9caa920c
heuristicTopKMultiRowKernelDtype<__nv_bfloat16, 2048> aede08cc / aede08cc

K=2048 production SASS is byte-identical with pre-PR baseline for all
three dtypes. New K=512 / K=1024 specializations have stable SASS produced
fresh by the GvrParams<T,K> template.

Performance Results (B200 sm_100, SWE-Bench-64K decode logits)

BS=1 cross-layer pooled (Q9d 04b — 9 layers × 9 GVR combos vs Radix fp32 K=2048 baseline, production path)

Routed through torch.ops.trtllm.indexer_topk_decode (Scheme X v1.2 dispatcher
heuristicTopKMultiRowKernel{,Dtype}<TopK> with preIdxOffset=+1, V3.2
decode semantics). 1 827 timed rows / variant (9 layers × stride-10 over rows
1..2024), 5 reps + 3 warmup, 128 MiB L2 flush before every launch.

Variant Latency (µs) Speedup vs Radix
Radix fp32 K=2048 (baseline) 50.98 1.00×
K=512 fp32 20.16 2.53×
K=512 bf16 17.15 2.97×
K=512 fp16 18.02 2.83×
K=1024 fp32 22.85 2.23×
K=1024 bf16 19.26 2.65×
K=1024 fp16 18.82 2.71×
K=2048 fp32 26.27 1.94×
K=2048 bf16 21.15 2.41×
K=2048 fp16 21.25 2.40×

GVR wins all 9 combos × 9 layers. K↓ and dtype↓ are both monotonic
(smaller K = smaller smem residency / higher occupancy; smaller dtype =
smaller HBM-load + smem footprint). Cross-validates the BS=1 column of
Table 2 (Q9e 05) within ±10 % wall — both tables are now production path.

BS scaling — pooled across 9 layers, last row replicated (Q9e 05, production path)

Speedup vs Radix fp32 K=2048 baseline (50.1 µs at BS=1 → 93.5 µs at BS=400):

Variant BS=1 BS=8 BS=32 BS=128 BS=200 BS=256 BS=400
K=2048 fp32 1.82× 1.85× 1.82× 1.71× 1.74× 1.46× 1.28×
K=2048 bf16 2.22× 2.28× 2.28× 2.20× 2.31× 2.08× 2.09×
K=2048 fp16 2.13× 2.21× 2.19× 2.13× 2.22× 2.07× 2.05×
K=1024 bf16 2.30× 2.37× 2.33× 2.30× 2.38× 2.24× 2.24×
K=512 bf16 2.62× 2.70× 2.67× 2.54× 2.55× 2.34× 2.29×

bf16 / fp16 maintain 2.0 – 2.3× speedup across the full BS range
(1 .. 400). fp32 fades from 1.85× to 1.28× as BS grows past 200 (smem
residency tightens). Half-precision dtypes scale gracefully because
their lower smem footprint preserves occupancy beyond 1 wave.

Reproducibility

  • BS=1 (Q9d 04b, production path): single-GPU nsys, ~3 min wall, 91 350 NVTX-tagged launches.
  • BS scaling (Q9e): single-GPU nsys, ~3 min wall, 5 400 NVTX-tagged
    launches (9 GVR combos × 12 BS × 9 layers × 5 reps).
  • All variants use 128 MiB L2 flush before every launch (cold-cache).

Round-3 review fixes (commit 15)

Addresses lfr-0531 and mingyangHao review comments on the 0.09 %
intermittent at K=512 / K=1024:

  • Kernel (cpp/tensorrt_llm/kernels/heuristic_topk.cuh): P4 snap
    loop's snap_limit = cand_count/4 was insufficient for cells where
    the initial bin-edge threshold sits far from the true K-th value. Bug:
    on cgt >= kK non-convergence, Pass 1 emitted K elements in scan
    order from the cgt > kK strictly-greater set, missing some true
    top-K members. Fix: tighten snap_limit = cand_count, which is the
    provably-sufficient bound (each snap iter strictly decreases cgt or
    strictly increases cge by ≥ 1, so worst-case convergence takes
    cand_count − kK + 1 iters). Common path still breaks early at iter
    1–3, so steady-state perf is unchanged.
  • Test (tests/unittest/_torch/thop/parallel/test_indexer_topk.py):
    remove pytest.mark.flaky(reruns=2) markers on K=512 / K=1024; drop
    the full_range/256 bin-quantized tolerance and the unsound
    |mean| * eps * 4 bf16/fp16 branch (bf16 → fp32 promotion inside the
    kernel is bit-exact and order-preserving, so no eps slack is
    warranted); restore compare_top_k_results default
    tolerance = 1e-5.

Additional cleanup in commit 4880bb6 (chore, no live SASS impact):

  • Removed dead blockFusedSnapIterDtype helper (-64 lines) and stale doc paragraphs — the dtype P4 snap path uses the fp32 blockFusedSnapIter<TopK> directly after commit 7's deferred-conversion optimization stored smem keys[] as fp32.
  • Expanded the kFTarget docstring to call out that the value is the secant solver's soft steering target (not the convergence condition), and that kFTarget < kK is intentional for small K — addresses the line-215 review question.

Local validation on B200 sm_100 with the fix + strict default
tolerance:

Scope Cells Result Wall
K=512 + K=1024 (all dist / dtype / BS / next_n / ratio) 1512 1512/1512 137 s
K=2048 (same axes) 756 756/756 70 s
Stress: 5× focused corner (logistic m0.47, BS=128, 5×36 180/180 ~5 min
next_n=3 → numRows=384)

0 failures across 2484 invocations under default strict tolerance.
The former 0.09 % intermittent corner is now stable.

Risks

  • Smem ceiling at K=2048 + fp32 + BS=200..256: the K=512 fp32 path
    shows an anomalous BS=200→256 latency spike (29 → 77 µs).
    Production deployment is bounded by Scheme X dispatcher (kBsLarge
    derived from L2 capacity), so the cliff is not reached in
    serving traffic — but flagged for follow-up perf work.
  • bf16 / fp16 path is decode-only: bf16 / fp16 prefill has not been
    added in this PR. Prefill remains fp32-only.

🤖 Generated with Claude Code

Summary by CodeRabbit

Release Notes

  • New Features
    • Added support for bfloat16 and float16 data types in TopK decoding operations (previously fp32-only)
    • Extended topK parameter support to values of 512, 1024, and 2048
    • Introduced utility function to determine GVR heuristic optimization eligibility

Review Change Stack

longcheng-nv and others added 14 commits May 9, 2026 16:23
Add a templated dtype path (gvrTopKJobDtype / gvrTopKKernelDtype) that
shares the original GVR algorithm but accepts bf16 or fp16 inputs.
fp32 path (gvrTopKJob / gvrTopKKernel and the multi-row caller in
heuristicTopKDecode.cu) is byte-untouched — option A from the design
doc — so the production fp32 kernel is guaranteed zero-regression.

Architecture:
- GvrDtypeTraits<float | __nv_bfloat16 | __half>: encapsulates cvt
  intrinsics (to_fp32 / from_fp32) and vector-load width
  (VEC_W=4 for fp32, VEC_W=8 for bf16/fp16).
- KernelSmem -> KernelSmemTpl<SmemKey> template + alias
  (using KernelSmem = KernelSmemTpl<float>) so existing call sites
  remain source/ABI-equivalent.
- New helpers blockCountGEDtype<InputT> (8-wide vector load) and
  blockFusedSnapIterDtype<SmemKey> (Tier-2 SmemKey reads with
  Trait::to_fp32 up-cast).
- New gvrTopKJobDtype<InputT> mirrors gvrTopKJob with five
  trait-driven substitution sites: P1 preIdx load, P2 vector
  blockCountGE, P3 collect, P4 histogram bin index, P4 emit.
- launchHeuristicTopK<T> routes T=float to the original kernel and
  T=__nv_bfloat16/__half to gvrTopKKernelDtype<T> via if constexpr;
  smem size is computed per-trait so dynamic smem opt-in matches.
- Three explicit instantiations: <float, int>, <__nv_bfloat16, int>,
  <__half, int>.

Tier-1 fp32 invariants preserved across all dtypes:
  threshold, val_lo/hi, range1, inv1, all accumulators, histogram
  bin index — kept in fp32. Only Tier-2 containers (input HBM,
  smem keys, outputValues) follow InputT.

ptxas (sm_100) post-change:
  gvrTopKKernel (fp32):                    REG=64, smem~59KB
  gvrTopKKernelDtype<__nv_bfloat16>:       REG=61, smem~47KB
  gvrTopKKernelDtype<__half>:              REG=61, smem~47KB
fp32 instantiation's resource line is character-identical to baseline.

Single-row micro-perf (B200, K=2048, nsys + 128MB L2 flush, 20 reps,
SWE-Bench-style temporal preIdx) shows 8-14 % bf16/fp16 speedup over
fp32 at N >= 32K. Multi-row Stage B is expected to widen the gap
toward the design's 1.3-1.5x prediction (HBM-bound regime).

Stage A acceptance evidence (REPORT_stageA.md in the standalone
ablation_study tree):
  - SASS gate: fp32 byte-identical
  - Algorithm correctness: mean overlap dtype-invariant (89.05-89.19%)
    on SWE-Bench-64K 9 layers x 200 rows
  - ptxas resources: no spills, three instantiations within budget

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
Stage B of 4d_multi_dtype_unified — wire the multi-dtype GVR Top-K
kernel (heuristic_topk.cuh, landed in commit 5992714) into the
TRT-LLM indexer Top-K dispatcher and torch op.

Architecture:
- heuristicTopKDecode.cu: add heuristicTopKMultiRowKernelDtype<InputT>
  (multi-row outer kernel mirroring heuristicTopKMultiRowKernel for
  bf16 / fp16 inputs) plus launchHeuristicTopKDecodeDtype<InputT> +
  __nv_bfloat16 / __half overload entry points.
- IndexerTopK.h / indexerTopK.cu: invokeIndexerTopKDecodeDtype<InputT>
  GVR-Heuristic-only dispatcher with bytesPerElem=2 kBsL2 doubling;
  bf16/fp16 path has no Radix fallback (production radix kernel is
  fp32-only), TLLM_CHECK aborts when GVR conditions not met.
- IndexerTopKOp.cpp: dtype switch (Float / BFloat16 / Half) over the
  torch op input; scratch dtype must match logits dtype.

fp32 path is byte-untouched (option A): heuristicTopKMultiRowKernel,
gvrTopKJob, gvrTopKKernel, and the fp32 invokeIndexerTopKDecode entry
all unchanged. Verified via Stage B.B7 ptxas — fp32 multi-row kernel
REG=40 / SHARED=1024 / STACK=16 byte-identical to baseline, bf16/fp16
multi-row kernels at REG=40 / SHARED=1024 / STACK=8 with no spill.

Validation:
- Multi-row correctness (B5): fp32/bf16/fp16 x BS={1,4,16} x 3 layers,
  864 cells, mean GVR-vs-torch.topk overlap dtype-invariant
  (L0 ~98%, L22 ~91%, L60 ~77%).
- Multi-row perf (B6, nsys + 128MB L2 flush): BS=1 1.20-1.23x,
  BS>=4 1.29-1.33x (matches DESIGN.md sec 6 prediction).
- Full BS=1 SWE-Bench sweep (PreC-T1, 9 layers x 2024 rows x 3 reps):
  GVR fp32 vs Radix 1.90x median, GVR bf16 vs Radix 2.23x,
  GVR bf16 vs GVR fp32 1.178x (Amdahl f~0.30 ceiling matches).

Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
Add compile-time (InputT, TopK) dispatch for the GVR Heuristic Top-K
kernel via the new GvrParams<InputT, int TopK> trait class:

  - 9 specializations covering {fp32, bf16, fp16} x {512, 1024, 2048}
    with per-(T, K) (kFTarget, kC, kNumBins) tuned per the SWE-Bench-64K
    9-combo sim sweep (D_param_sweep_9combos REPORT.md). Primary
    template undefined - any unsupported combo triggers compile-time
    error rather than runtime fall-through.
  - Generic KernelSmemTplK<SmemKey, CCap, NumBinsT> parameterizes the
    smem layout. KernelSmemTpl<SmemKey> becomes an alias preserving
    the K=2048 default layout (legacy KernelSmem alias unchanged).
  - blockCountGE / blockFusedSnapIter (and Dtype mirrors) gain
    <int TopK = TOP_K, typename SmemT = ...> template parameters with
    defaults matching legacy callers - K=2048 callsites compile
    unchanged.
  - gvrTopKJob<TopK>, gvrTopKJobDtype<InputT, TopK>,
    gvrTopKKernel<TopK>, gvrTopKKernelDtype<InputT, TopK> are
    templated; bodies use constexpr aliases (kK / kCC / kBins /
    kFTarget) from Params.
  - launchHeuristicTopK<T> and launchHeuristicTopKDecode (fp32 +
    Dtype) dispatch on runtime topK in {512, 1024, 2048} via C++20
    generic-lambda switch. Each (T, K) gets its own kernel function
    pointer, smem size, and cudaFuncSetAttribute call.
  - Multi-row kernels (heuristicTopKMultiRowKernel<TopK>,
    heuristicTopKMultiRowKernelDtype<InputT, TopK>) get 9 explicit
    instantiations in the anonymous namespace of heuristicTopKDecode.cu
    so the runtime switch can take their address.

K=2048 fp32 specialization preserves V2e production identity
(kFTarget=3072, kC=6144) - SASS for gvrTopKKernel<2048> is 1761
instructions byte-identical to the pre-refactor baseline. K=512/1024
fp32 + all bf16/fp16 use the sim-sweep-recommended kC=5120 with
K-specific kFTarget (K=K, 3K, 4K). Reg pressure unchanged at 64
(fp32) / 62 (bf16/fp16) per thread - all 9 kernels reg-bound 2 CTA/SM
on B200 sm_100.

NOTE: indexerTopK.cu kHeuristicTopK gating still selects only
topK==2048 for the Heuristic dispatch path. Relaxing the predicate
to topK in {512, 1024, 2048} is deferred to V4 indexer integration
(separate caller-side change).

Validated by smoke compile (nvcc 13.1.115 -arch=sm_100 -std=c++20,
9 kernel entries with expected (kFTarget, kC) literals folded as
ptxas immediates).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
blockFusedSnapIter and blockFusedSnapIterDtype use TopK as a template
parameter (with TOP_K=2048 default for legacy callers). The previous
multi-K commit's gvrTopKJob<TopK> and gvrTopKJobDtype<InputT, TopK>
called these helpers WITHOUT forwarding TopK, so the helpers picked up
the default TOP_K=2048 and snapped to the wrong K-th boundary for
K!=2048.

Symptom (caught by Step 3a real-machine correctness sweep on
SWE-Bench-64K, 1791 rows x 9 layers per combo via the JIT extension):

  K=2048 fp32  -> 100.00% pass (default == intended - coincidentally OK)
  K=512  fp32  ->  26.08% pass (326 fails out of 441)
  K=1024 bf16  ->  40.36% pass (263 fails out of 441)

Fix: pass <TopK> (single-dtype path) and <SmemKey, TopK> (dtype path)
explicitly. Two-line change.

After fix:
  - 9-combo correctness 16,119 / 16,119 = 100.00% on SWE-Bench-64K
    (9 layers x 199 rows x 9 (K, dtype) combos).
  - K=2048 fp32 SASS remains byte-identical to V2e baseline: 1761
    instructions, 0 diff lines after stripping address/encoding
    comments. The default-K instantiation is unchanged.

Note for future templated-kernel refactors: this trap is invisible to
smoke-compile and PTX literal-grep checks - both default and explicit
forms compile cleanly and fold the same constant for K=2048. Only
end-to-end correctness on K!=default surfaces it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
Open the indexer-TopK decode dispatcher to admit GVR-Heuristic for
topK == 512 and topK == 1024 in addition to the existing topK == 2048.
The kernel/launcher (heuristicTopKDecode.cu) already has runtime
topK switch + GvrParams<T, K> specializations for all three K values
since 1bb1839; this commit extends only the caller-side gate so the
new K paths actually dispatch from invokeIndexerTopKDecode.

Two call sites in indexerTopK.cu are updated symmetrically:

  - fp32 dispatcher (invokeIndexerTopKDecode for float):
      topK == kHeuristicTopK     →  topK ∈ {512, 1024, 2048}
      preIdxCount == kHeuristicSize  →  preIdxCount == topK

  - bf16/fp16 dispatcher (invokeIndexerTopKDecodeDtype<InputT>):
      same gate change + TLLM_CHECK_WITH_INFO message updated to
      reference the new K-set and runtime-derived preIdxCount.

Behavioral guarantees:
  - K=2048 path is byte-identical to the previous gate (set
    membership reduces to the old equality when topK=2048 and
    preIdxCount=2048), so all Scheme X v1.2 production traffic is
    unaffected.
  - kSeqSmall / kBsLarge / effectiveSplitWorkThreshold are reused
    unchanged. Per-K threshold tuning can be layered later; the
    current values are conservative for K=512/1024 and only push more
    rows toward the radix fallback (no regression risk).
  - kHeuristicTopK / kHeuristicSize public constants in
    heuristicTopKDecode.h are kept (still 2048) for ABI/header
    compatibility; they are no longer read by indexerTopK.cu and can
    be deprecated separately.

Build verified: kernels_src compiles clean (only pre-existing C++20
lambda-template-parameter warnings, unrelated to this change).

Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
…/1024 cells

Phase-4 histogram resolution (NUM_BINS=2048 default in V2e) was over-resolved
for K=512/1024 cells. A 4-config sweep (48,600 reps in
01_micro_kernel_bs1/phase_timing_v2_multik_multidtype/REPORT_kbins_sweep.md)
showed kNumBins ≈ K is the sweet spot — saving 5-9% total per cell with no
correctness or smem-cap risk.

Changed 6 of 9 GvrParams<T, K>::kNumBins values:

  fp32  K=512  : NUM_BINS=2048 → 1024  (-7.0% total)
  fp32  K=1024 : NUM_BINS=2048 → 1024  (-5.9% total)
  fp32  K=2048 : NUM_BINS=2048 → 2048  (V2e LOCK, byte-identity preserved)
  bf16  K=512  : NUM_BINS=2048 →  512  (-8.2%, compromise rule)
  bf16  K=1024 : NUM_BINS=2048 →  512  (-7.8%, K/2 captures full saving)
  bf16  K=2048 : NUM_BINS=2048 → 2048  (no change, K=2048 sweep ~0%)
  fp16  K=512  : NUM_BINS=2048 →  512  (-5.7% total)
  fp16  K=1024 : NUM_BINS=2048 → 1024  (-6.1% total)
  fp16  K=2048 : NUM_BINS=2048 → 2048  (no change)

Why it works:
  - histogram_clear shrinks linearly with NumBins (cleared via NumBins/512
    iterations of zero-init).
  - bin_search via V2e OPT6 parallel 2-step is log(NumBins).
  - atomicAdd contention rises modestly: K=512 cell goes from 5120/2048=2.5
    cands/bin to 5120/512=10/bin — still well below the contention knee
    (~10-20 cands/bin where atomic serialization dominates).
  - snap iteration count is K-dependent, NOT NumBins-dependent, so smaller
    NumBins doesn't slow snap.

Why this is safe to deploy upstream:
  - K=2048 fp32 path's kNumBins is unchanged → SASS byte-identical to V2e
    production baseline → Scheme X v1.2 4-report verification NOT triggered
    (per Q9d Option B).
  - All 9 (T, K) cells were validated on the 9-combo set-eq tests in
    Step 3a (16,119 rows × 9 combos = 100% match) before this change. The
    kNumBins change does not affect the set of selected indices — only the
    histogram resolution during snap convergence.
  - Compile clean: kernels_src + tensorrt_llm + th_common rebuilt 09:46 UTC.

Side effect: K=512 / K=1024 dtype paths now use ~4 KB less smem (hist drops
from 8 KB to 2-4 KB). At K=512 bf16 with kC=5120, total smem drops from
~40 KB to ~36 KB — may unlock 5 CTA/SM at multi-row BS≥148, compounding
with the Q9e BS-scaling occupancy gain.

Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
…p16 cells

The bf16/fp16 path's Phase 3 ballot-free collect was doing per-element
Trait::from_fp32(val) write conversions inside the hot smem-write loop.
Q1b phase data showed dtype P3 was 1.0-2.3 µs slower than fp32 P3 across
K (despite reading same N logits) — the unpack8 + per-element from_fp32
in the 8-wide inner loop was 2× the per-ldg work of fp32's 4-wide path.

Solution: keep keys[] in smem as float (deferred conversion), convert
only at the K output writebacks (3 exit paths in gvrTopKJobDtype). Costs
~10 KB extra smem at K=2048 dtype (5120 keys × +2B); still fits 4 CTA/SM.
Switches to the fp32 helper blockFusedSnapIter<TopK> since the smem
layout now matches.

Wall verify (Q9d-equivalent: 9 layers × 200 rows × 5 reps under nsys,
P0.5 also active):
  Cell           Q9d original    Post-P0.5+P0    ∆
  K=512  fp32     26.91           25.28           -6.1%
  K=512  bf16     25.63           22.21           -13.4%
  K=512  fp16     24.54           21.50           -12.4%
  K=1024 fp32     30.69           29.41           -4.2%
  K=1024 bf16     28.80           25.41           -11.8%
  K=1024 fp16     27.81           24.58           -11.6%
  K=2048 fp32     33.92           33.92           0     (V2e LOCK preserved)
  K=2048 bf16     29.95           28.06           -6.3%
  K=2048 fp16     29.02           27.30           -5.9%
  Radix           52.96           52.99           0     (sanity check)

K=2048 fp32 byte-identity preserved → V2e Scheme X v1.2 4-report
verification window NOT triggered. fp32 path's gvrTopKJob is unchanged;
only the dtype path's gvrTopKJobDtype has the new fp32 keys[] layout.

Files changed:
  - cpp/tensorrt_llm/kernels/heuristic_topk.cuh (10 sites)
  - cpp/tensorrt_llm/kernels/heuristicTopKDecode.cu (2 sites)

Compile clean: kernels_src + tensorrt_llm + th_common rebuilt 10:11 UTC.
Q9 set-eq tests pending — re-run before merging upstream.

Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
Under V4 production where preIdx is built with M=K (not legacy M=2048),
P1 pmean shifts toward the right tail of the previous row's top-K
distribution. P2 secant therefore overshoots when kFTarget is set
relative to the V3.2-era M=2048 baseline, costing extra iterations
on K=512/1024 cells. The Sprint 2 sweep (`10_multi_cta_v1/M_eq_K_redesign/`)
identified K-proportional sweet spots: 0.75K for K=512 and 2.5K for K=1024,
applied uniformly across fp32/bf16/fp16.

Changes (6 of 9 GvrParams<T, K> specializations; K=2048 untouched):

  GvrParams<float, 512>           kFTarget   512 -> 384
  GvrParams<float, 1024>          kFTarget  3072 -> 2560
  GvrParams<__nv_bfloat16, 512>   kFTarget   512 -> 384
  GvrParams<__nv_bfloat16, 1024>  kFTarget  3072 -> 2560
  GvrParams<__half, 512>          kFTarget   512 -> 384
  GvrParams<__half, 1024>         kFTarget  3072 -> 2560

  GvrParams<*, 2048>              UNCHANGED (V2e LOCK preserved)

SWE-Bench-64K rebench (9 layers, 1827 cells, Q9d harness):

  gvr_K512  fp32  25.34 -> 24.58 us  (-3.03 %)
  gvr_K512  bf16  22.21 -> 21.25 us  (-4.32 %)
  gvr_K512  fp16  21.63 -> 20.38 us  (-5.77 %)  [best gain]
  gvr_K1024 fp32  29.54 -> 28.74 us  (-2.71 %)
  gvr_K1024 bf16  25.47 -> 24.80 us  (-2.64 %)
  gvr_K1024 fp16  24.64 -> 23.90 us  (-2.99 %)

  gvr_K2048 (all dtypes)          byte-identical SASS (V2e LOCK)

All 9 layers improve on every re-tuned cell; no per-layer regression.

Why this is safe to deploy upstream:
  - K=2048 fp32 hot path SASS verified byte-identical at md5 level
    (6/6 V2e LOCK kernels in libth_common.so:
      heuristicTopKMultiRowKernel<2048>             a9f5c6b8f26c..
      gvrTopKKernel<2048>                           e9ca7cdb3e46..
      heuristicTopKMultiRowKernelDtype<bf16, 2048>  aede08cc084e..
      heuristicTopKMultiRowKernelDtype<half, 2048>  9caa920c4368..
      gvrTopKKernelDtype<bf16, 2048>                af8ba491cba5..
      gvrTopKKernelDtype<half, 2048>                797315781bce..
    pre-patch and post-patch dumps identical.).
  - Scheme X v1.2 4-report regression NOT triggered (heuristic_topk.cuh-only
    change; dispatcher cpp/tensorrt_llm/kernels/indexerTopK.cu unchanged;
    K=2048 dispatcher hot path SASS-invariant).
  - Single-shot Scheme X smoke (L0 x BS={1,64,256,432}) confirms speedup
    within +/- 4 % of 2026-04-23 baseline at BS<432; Radix wall within
    +/- 2 % across all 4 BS.
  - Q9d set-equality 9-combo correctness validated in Sprint 2 prior to
    deployment.

Build: kernels_src + tensorrt_llm + th_common rebuilt 2026-05-08 06:02 UTC.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
…tion

The standalone gvrTopKKernel<TopK> + gvrTopKKernelDtype<InputT, TopK>
wrappers were declared with __launch_bounds__(BLOCK_SIZE, 1), giving
nvcc a minBlocksPerSM=1 hint. The production multi-row entry
heuristicTopKMultiRowKernel{,Dtype} (heuristicTopKDecode.cu:50, 107)
uses the single-arg form __launch_bounds__(BLOCK_SIZE) with no minimum
hint. Same gvrTopKJob{,Dtype} device function in both, but the launch
bounds difference produced divergent register allocations:

  Standalone (before):                Production multi-row:
    (float, 2048):  REG=64              REG=40
    (float, 1024):  REG=63              REG=61
    (float,  512):  REG=63              REG=40
    (bf16,  2048):  REG=61              REG=44
    (bf16,  1024):  REG=61              REG=40
    (bf16,   512):  REG=61              REG=40
    (half,  2048):  REG=61              REG=44
    (half,  1024):  REG=59              REG=54
    (half,   512):  REG=61              REG=40

This made the standalone wrapper run at 50 % theoretical occupancy
across all 9 specs (REG-bound 2 CTA/SM), while the production multi-row
path ran at 75-100 % theoretical for 5 of 9 specs. Local JIT extension
benchmarks (cuda_ext/topk_ext.cu via launchHeuristicTopK) and any
ablation comparing standalone wrappers therefore reflected occupancy
that did not match production behavior.

Removing the trailing `, 1` argument applies the same nvcc heuristic to
both entries:

  Standalone (after):
    (float, 2048):  REG=40              -> 3 CTA/SM (75 % theor)
    (float, 1024):  REG=40 (stack=16)   -> 3 CTA/SM
    (float,  512):  REG=40              -> 4 CTA/SM (100 %)
    (bf16,  2048):  REG=45              -> 2 CTA/SM (50 %)
    (bf16,  1024):  REG=40 (stack=8)    -> 3 CTA/SM
    (bf16,   512):  REG=32 (stack=8)    -> 4 CTA/SM (100 %)
    (half,  2048):  REG=45              -> 2 CTA/SM
    (half,  1024):  REG=40 (stack=8)    -> 3 CTA/SM
    (half,   512):  REG=32 (stack=8)    -> 4 CTA/SM (100 %)

Standalone now matches production reg counts within Δ <= 5; for K=512
specs and K=1024 fp16/bf16 standalone is even leaner than production
because the single-row body has no rowIdx outer loop or per-row preIdx
indexing.

Why this is safe to land:

  - Production heuristicTopKMultiRowKernel{,Dtype} is a different
    __global__ in heuristicTopKDecode.cu with its own __launch_bounds__,
    which this commit does not touch. SASS for the V4 production K=2048
    fp32 hot path is byte-identical (V2e LOCK preserved on production
    multi-row entry).

  - The standalone entry SASS does change. V2e LOCK byte-identity on
    the standalone wrappers is intentionally broken; gvrTopKKernel*
    is only invoked by the local JIT extension and the standalone
    ablation harnesses, never by TRT-LLM production runtime.

  - Q9d set-equality 9-combo correctness is unaffected: this commit
    changes only the kernel launch envelope (registers, CTA scheduling),
    not algorithm flow. Re-validation will follow on the post-alignment
    re-run of 09_precision_ablation/04_multik_multidtype_perf_bs1.

Build: kernels_src + tensorrt_llm + th_common rebuilt 2026-05-08 07:39 UTC.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
…ments

Remove development-history annotations from kernel source comments that
would not be meaningful to upstream readers:

- Internal experiment IDs (4d.D, 4d_multi_dtype_unified, Q1b-P0,
  Q10b-S2, Q10b-S3 retune annotations)
- Date stamps (2026-05-06, 2026-05-07, 2026-05-08)
- References to local ablation directories
  (04_kernel_optimizations/4d_multi_dtype_unified/,
   10_multi_cta_v1/M_eq_K_redesign/,
   ablation_study/gvr_phase_timing/...)
- Aspirational planning leftovers (\"Step 2 must verify with cuobjdump\",
  \"Step 3 must run Scheme X v1.2 4-report regression to confirm\")
- Stale Sprint-2 plan that contradicted the current code (claimed
  GvrParams<float, 2048>::kFTarget = 4096 but actual code is 3072)
- \"Backwards-compat alias\" / \"byte-equivalent signature\" framing for
  KernelSmemTpl / KernelSmem aliases — these are convenience aliases
  for the K=2048 default, not back-compat shims.

The Multi-K trait class (GvrParams<T, K>) gains a longer mechanism-level
explanation of why kFTarget, kC, and kNumBins vary across (T, K) cells:
preIdx-distribution shift under V3.2 decode semantics for kFTarget,
register/smem occupancy for kC, and atomic-contention-vs-Phase-4-setup
trade-off for kNumBins. fp32 K=2048 SASS byte-identity preservation is
called out at the relevant specialization.

No code logic changes — comment-only cleanup. ptxas output for all 9
GvrParams<T, K> specializations is identical to pre-cleanup.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
…stry tags

Two small follow-ups to commit 1d1d938:

1. Remove the unused `thresholdPos` parameter from `gvrTopKKernel`,
   `gvrTopKKernelDtype`, `launchHeuristicTopK`, and the
   `cudaLaunchKernelEx` call site. The parameter was carried as a
   placeholder for an older threshold-hint API that was never wired into
   the kernel body — `gvrTopKJob` / `gvrTopKJobDtype` neither receive
   it nor reference it. Removing it shrinks the public launcher signature
   from 9 args to 8 and matches the multi-row entry's signature shape.

   SASS impact (verified with cuobjdump on libth_common.so):
   - **Production multi-row K=2048 specs (3/3): byte-identical.**
     `heuristicTopKMultiRowKernel<2048>`,
     `heuristicTopKMultiRowKernelDtype<__half, 2048>`,
     `heuristicTopKMultiRowKernelDtype<__nv_bfloat16, 2048>` keep their
     pre-cleanup md5s — V2e LOCK on the production hot path is preserved.
   - **Standalone wrapper K=2048 specs (3/3): mangled signature changes
     by one parameter** (`...Pii` → `...Pi`), so md5s differ. These
     wrappers are only used by JIT consumers (e.g. ablation-side
     PyTorch extensions); production never calls them.

2. Strip leftover `OPT[N]` references from comments
   (`OPT4: Warp-Level Reduction Primitives`, `OPT5: when done==1...`,
   `OPT6: Parallel K-th bin search`, `OPT7: cache per-thread count...`).
   These were tags into an internal optimization registry; the same
   text remains as plain mechanism comments without the registry code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
… launcher

Two pure refactors with zero kernel SASS impact:

1. **`getSchemeXBounds()` helper (indexerTopK.cu)**: the fp32 and
   bf16/fp16 dispatchers each carried independent `std::call_once`
   caches for SM count, L2 capacity, and the `kSeqSmall` threshold
   (5 once_flags total, ~80 LOC of boilerplate duplication). Hoist
   into a single anonymous-namespace helper that takes
   `bytesPerElem` so kBsL2 still adapts per-dtype. Reduces once_flag
   count from 5 to 2 (helper + debug-flag).

2. **`launchHeuristicTopKDecodeImpl<InputT>` template
   (heuristicTopKDecode.cu)**: the fp32 launcher and
   `launchHeuristicTopKDecodeDtype<InputT>` carried near-identical
   lambda bodies (cudaFuncSetAttribute / cudaLaunchKernelEx scaffolding)
   differing only in (a) vector-load alignment requirement and (b) which
   kernel function pointer to take. Collapse into one templated impl
   in anonymous namespace with `if constexpr` for the kernel
   function-pointer selection; the three `launchHeuristicTopKDecode`
   overloads become thin forwarders.

   Side effect: the `switch (topK)` now has a `default:` arm that
   throws on unreachable input (defensive belt over the
   `TLLM_CHECK_WITH_INFO` guard).

Also strips two missed `// 4d:` include tags in indexerTopK.cu that
the prior chore-cleanup commits did not catch.

Net diff: −40 LOC across two files. No kernel binary changes — V2e
SASS LOCK on K=2048 fp32 production hot path is preserved (kernel
templates are not touched).

Verified by building `make th_common -j8` (clean link).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
…s comment

The two-pass selection block at gvrTopKJob's Phase-4 had a comment
referencing internal experiment names ("Opt-M fix" / "v0 interleaved
version") that the prior chore-cleanup commits did not catch — the
audit regex enumerated `OPT[0-9]+` and `Q[0-9]b-P[0-9]+` patterns but
not `Opt-M` (Opt-M was one of 12 already-closed exploratory branches
in the project's internal registry).

Replace with a mechanism-only description that explains *why* the two
passes are sequential (deterministic selection under K-th-rank ties)
without referencing the internal v0/v1/v2 history.

Same comment block in `gvrTopKJobDtype` was already mechanism-only
(no Opt-M reference), so this is a single 7-line edit in the fp32
path.

No code logic changes, no SASS impact.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
…6/fp16 path

Templates topKPerRowDecode / topKPerRowJob / topKPerRowPrefill /
processHistogramStep on InputT (default float). HBM reads cast to
float at load sites; the histogram + sort steps remain dtype-agnostic
since they operate on float keys in shared memory. fp32 SASS is
byte-identical to the pre-template baseline (template defaults).

Replaces the bf16/fp16 dispatcher's hard abort with the same fallback
chain the fp32 dispatcher uses:
  numColumns < kSortingAlgorithmThreshold      -> insertion sort
  kSortingAlgorithmThreshold <= N < splitWork  -> radix sort
  N >= splitWork                                -> unsupported (TLLM_CHECK)
The split-work tier requires float aux buffers the bf16/fp16 entry
does not expose; callers in that regime must use the fp32 entry.

Exposes canIndexerTopKDecodeUseGvr(numRows, numColumns, topK,
bytesPerElem) so callers can introspect the dispatcher decision
before allocating preIdx tensors or heuristicScratch buffers.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
@longcheng-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 9, 2026

📝 Walkthrough

Walkthrough

This PR extends TensorRT-LLM's top-K decoding kernels to support bfloat16 and float16 inputs alongside existing float32 support. The changes introduce compile-time dtype specialization via traits, refactor the GVR algorithm to accept compile-time TopK values (512, 1024, 2048), and extend public APIs and dispatchers across IndexerTopK and HeuristicTopK subsystems.

Changes

Multi-dtype Top-K Kernel Extension

Layer / File(s) Summary
Public API
cpp/tensorrt_llm/kernels/IndexerTopK.h, cpp/tensorrt_llm/kernels/heuristicTopKDecode.h
Adds bf16/fp16 overload declarations for invokeIndexerTopKDecode and launchHeuristicTopKDecode, plus canIndexerTopKDecodeUseGvr(...) predicate.
Dtype Traits & Shared Memory
cpp/tensorrt_llm/kernels/heuristic_topk.cuh
Introduces GvrDtypeTraits<T> specializations (float, bf16, fp16), GvrParams<T, TopK> compile-time constants per dtype/TopK pair, and templated KernelSmemTplK layout matching selected parameters.
Block Helpers
cpp/tensorrt_llm/kernels/heuristic_topk.cuh
Templatizes blockCountGE, blockFusedSnapIter, and dtype-specific variants (blockCountGEDtype, blockFusedSnapIterDtype) to use K-dependent traits and shared-memory layouts.
GVR Algorithm
cpp/tensorrt_llm/kernels/heuristic_topk.cuh
Refactors gvrTopKJob<TopK> and adds gvrTopKJobDtype<InputT, TopK> implementations; updates Phase 2–4 logic to use compile-time kCC/kBins/kK instead of fixed constants; deterministic partition with two-pass emission.
Global Kernels & Instantiations
cpp/tensorrt_llm/kernels/heuristic_topk.cuh, cpp/tensorrt_llm/kernels/heuristicTopKDecode.cu
Defines gvrTopKKernel<TopK> and gvrTopKKernelDtype<InputT, TopK> with extern shared memory; explicit instantiations for 6 (dtype × TopK) combinations.
HeuristicTopK Dispatcher
cpp/tensorrt_llm/kernels/heuristicTopKDecode.cu, cpp/tensorrt_llm/kernels/heuristic_topk.cuh
Implements launchHeuristicTopK and launchHeuristicTopKDecodeImpl<InputT> runtime dispatchers; validates topK against {512, 1024, 2048}; computes dtype-specific shared-memory size and selects kernel via switch(topK).
IndexerTopK Kernel Templating
cpp/tensorrt_llm/kernels/indexerTopK.cu
Extends processHistogramStep, processBins, topKPerRowJob, topKPerRowDecode, and topKPerRowPrefill to accept InputT template parameter; converts bf16/fp16 to float at load sites.
IndexerTopK Dispatcher
cpp/tensorrt_llm/kernels/indexerTopK.cu
Refactors Scheme X logic via new SchemeXBounds/getSchemeXBounds(...) helper; updates fp32 dispatch to use cached bounds; adds bf16/fp16 overloads with split-work restriction; implements canIndexerTopKDecodeUseGvr(...) predicate.
Torch Integration
cpp/tensorrt_llm/thop/IndexerTopKOp.cpp
Discovers logits dtype, validates heuristic_scratch dtype matching, dispatches to appropriate invokeIndexerTopKDecode overload with reinterpret_casted bf16/half pointers and dtype-matched scratch buffers.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 12.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main changes: extending GVR Top-K to support multiple K values (512/1024/2048) and multiple data types (fp32/bf16/fp16).
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The PR description is comprehensive and well-structured, covering summary, commit breakdown, key files, API changes, test plan with detailed test coverage, performance results with tables, round-3 review fixes, and identified risks.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
cpp/tensorrt_llm/kernels/heuristic_topk.cuh (2)

1720-1726: ⚡ Quick win

Use getEnvEnablePDL() for consistency.

heuristicTopKDecode.cu (line 222) uses tensorrt_llm::common::getEnvEnablePDL() while this launcher manually parses TRTLLM_ENABLE_PDL. Using the common utility ensures consistent behavior and avoids duplicating env-parsing logic.

♻️ Suggested fix

Add the include at the top of the file:

`#include` "tensorrt_llm/common/envUtils.h"

Then replace the manual parsing:

-    // Honor the standard TRTLLM_ENABLE_PDL env var (default on; set "0" to
-    // disable).
-    bool enablePDL = true;
-    if (char const* env = std::getenv("TRTLLM_ENABLE_PDL"))
-    {
-        if (env[0] == '0' && env[1] == '\0')
-            enablePDL = false;
-    }
+    bool const enablePDL = tensorrt_llm::common::getEnvEnablePDL();
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/tensorrt_llm/kernels/heuristic_topk.cuh` around lines 1720 - 1726,
Replace the manual getenv parsing that sets the local bool enablePDL with a call
to the shared utility tensorrt_llm::common::getEnvEnablePDL(); add the include
"tensorrt_llm/common/envUtils.h" to the top of the file, remove the block that
directly reads TRTLLM_ENABLE_PDL and instead initialize enablePDL by calling
tensorrt_llm::common::getEnvEnablePDL() to keep behavior consistent with
heuristicTopKDecode.cu.

1684-1686: 💤 Low value

Remove unused SmemKey type alias.

SmemKey is defined from the trait but never used — SmemT correctly uses float for keys as documented in the optimization note at lines 1173-1176. The unused alias adds confusion.

♻️ Suggested fix
 template <typename InputT, int TopK = TOP_K>
 __global__ void __launch_bounds__(BLOCK_SIZE)
     gvrTopKKernelDtype(InputT const* __restrict__ input, int const N, int const* __restrict__ preIdx, int const M,
         int const topK, InputT* __restrict__ outputValues, int* __restrict__ outputIndices)
 {
-    using SmemKey = typename GvrDtypeTraits<InputT>::SmemKey;
-    using SmemT
-        = KernelSmemTplK<float, GvrParams<InputT, TopK>::kC, GvrParams<InputT, TopK>::kNumBins>; // dtype keys fp32
+    // dtype path uses fp32 smem keys (deferred convert to output).
+    using SmemT = KernelSmemTplK<float, GvrParams<InputT, TopK>::kC, GvrParams<InputT, TopK>::kNumBins>;
     extern __shared__ unsigned char smem_raw[];
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/tensorrt_llm/kernels/heuristic_topk.cuh` around lines 1684 - 1686, Remove
the unused type alias SmemKey introduced from GvrDtypeTraits<InputT>; it is
never referenced and causes confusion—delete the line "using SmemKey = typename
GvrDtypeTraits<InputT>::SmemKey;" and keep the existing SmemT definition that
intentionally uses float keys (KernelSmemTplK<float, GvrParams<InputT,
TopK>::kC, GvrParams<InputT, TopK>::kNumBins>), leaving GvrDtypeTraits, InputT
and TopK untouched.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@cpp/tensorrt_llm/kernels/heuristic_topk.cuh`:
- Around line 1720-1726: Replace the manual getenv parsing that sets the local
bool enablePDL with a call to the shared utility
tensorrt_llm::common::getEnvEnablePDL(); add the include
"tensorrt_llm/common/envUtils.h" to the top of the file, remove the block that
directly reads TRTLLM_ENABLE_PDL and instead initialize enablePDL by calling
tensorrt_llm::common::getEnvEnablePDL() to keep behavior consistent with
heuristicTopKDecode.cu.
- Around line 1684-1686: Remove the unused type alias SmemKey introduced from
GvrDtypeTraits<InputT>; it is never referenced and causes confusion—delete the
line "using SmemKey = typename GvrDtypeTraits<InputT>::SmemKey;" and keep the
existing SmemT definition that intentionally uses float keys
(KernelSmemTplK<float, GvrParams<InputT, TopK>::kC, GvrParams<InputT,
TopK>::kNumBins>), leaving GvrDtypeTraits, InputT and TopK untouched.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 71e22bbe-2801-4613-8946-0f31f85d685a

📥 Commits

Reviewing files that changed from the base of the PR and between b26c000 and 17b5227.

📒 Files selected for processing (6)
  • cpp/tensorrt_llm/kernels/IndexerTopK.h
  • cpp/tensorrt_llm/kernels/heuristicTopKDecode.cu
  • cpp/tensorrt_llm/kernels/heuristicTopKDecode.h
  • cpp/tensorrt_llm/kernels/heuristic_topk.cuh
  • cpp/tensorrt_llm/kernels/indexerTopK.cu
  • cpp/tensorrt_llm/thop/IndexerTopKOp.cpp

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47533 [ run ] triggered by Bot. Commit: 17b5227 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47533 [ run ] completed with state FAILURE. Commit: 17b5227
/LLM/main/L0_MergeRequest_PR pipeline #37450 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@longcheng-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47544 [ run ] triggered by Bot. Commit: 17b5227 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47544 [ run ] completed with state FAILURE. Commit: 17b5227
/LLM/main/L0_MergeRequest_PR pipeline #37460 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

…C++20 templated lambda

CI build #47533 (PR NVIDIA#13948) failed on Build-x86_64 / Build-SBSA with 9
errors per arch in cudafe1.stub.c:

  error: explicit specialization of '__wrapper__device_stub_gvrTopKKernel<512>'
         after instantiation
  heuristic_topk.cuh:1148:20: implicit instantiation first required here

(also for K={1024,2048} on gvrTopKKernel and for gvrTopKKernelDtype<{bf16,
fp16}, {512,1024,2048}>.) Local build on sm_100f / nvcc 13.1 passed cleanly;
sm_89 / sm_120f cudafe1 stubgen tightens [temp.expl.spec]/6 ordering and
rejects the implicit-then-explicit sequence.

Two changes in cpp/tensorrt_llm/kernels/heuristic_topk.cuh:

1. Add 9 explicit kernel instantiations (gvrTopKKernel<{512,1024,2048}>
   plus gvrTopKKernelDtype<{__nv_bfloat16,__half}, {512,1024,2048}>) right
   after the gvrTopKKernelDtype definition. Forces nvcc to emit each host
   wrapper stub at a fixed point ahead of any address-take in
   launchHeuristicTopK. Mirrors the proven pattern at
   heuristicTopKDecode.cu:155-171 for the multi-row kernels (which is why
   those didn't fail). Header is included by exactly one TU so no ODR
   concern.

2. Replace the C++20 templated lambda `[&]<int TopK>() {...}` in
   launchHeuristicTopK with detail::launchHeuristicTopKImpl<T, TopK>
   function template + switch dispatch. Removes the C++20-extension
   warning NVIDIA#3288-D and eliminates the templated-lambda + capture +
   __global__-address-take quirk pattern that triggers cudafe1's
   stubgen ordering bug. Kernel body, GvrParams traits, __launch_bounds__,
   PDL attr, and cudaLaunchKernelEx call are all preserved — kernel SASS
   is unchanged for all 9 (T, K) instantiations.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
@longcheng-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47548 [ run ] triggered by Bot. Commit: e5ad15e Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47548 [ run ] completed with state SUCCESS. Commit: e5ad15e
/LLM/main/L0_MergeRequest_PR pipeline #37463 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@longcheng-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47580 [ run ] triggered by Bot. Commit: e5ad15e Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47652 [ run ] triggered by Bot. Commit: 607349e Link to invocation

…#13477 review follow-up)

PR NVIDIA#13477 reviewers (hyukn, yuxianq) requested on 2026-04-30 that
`warmup_heuristic_topk_decode` and its module-level idempotency cache
live next to their sole caller in
`tensorrt_llm/_torch/attention_backend/sparse/dsa.py` rather than in
`tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py` (which is the wrong
abstraction layer — it hosts `torch.library` op registrations, not
caller-specific warmup helpers). The relocation was acknowledged at
merge time but deferred to a follow-up; this commit closes that loop.

Mechanical move of:
  - `_HEURISTIC_TOPK_WARMUP_DONE` (Set) + `_HEURISTIC_TOPK_WARMUP_LOCK`
    (threading.Lock) idempotency state
  - `def warmup_heuristic_topk_decode(...)` (52 lines, including
    docstring; behavior unchanged)
  - `import threading` and `Set` from `typing` (dropped from
    cpp_custom_ops.py since no other code in that file uses them)

Caller in `Indexer.setup_module()` (dsa.py:1306) is updated from a
conditional cross-module import + `cpp_custom_ops.warmup_heuristic_topk_decode(...)`
to a direct local call. No other call sites exist in the tree.

This is a no-op for behavior: the function body, idempotency keying
(device, top_k, hint_size, num_cols), and CUDA Graph capture ordering
are byte-identical to the merged version from PR NVIDIA#13477.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
@longcheng-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@longcheng-nv longcheng-nv requested review from a team as code owners May 11, 2026 03:40
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47653 [ run ] triggered by Bot. Commit: 0c9f705 Link to invocation

Comment thread tests/unittest/_torch/thop/parallel/test_indexer_topk.py Outdated
Comment thread cpp/tensorrt_llm/kernels/heuristic_topk.cuh
Comment thread cpp/tensorrt_llm/kernels/heuristic_topk.cuh Outdated
Comment thread tests/unittest/_torch/thop/parallel/test_indexer_topk.py Outdated
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47653 [ run ] completed with state FAILURE. Commit: 0c9f705
/LLM/main/L0_MergeRequest_PR pipeline #37557 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

Comment thread tests/unittest/_torch/thop/parallel/test_indexer_topk.py Outdated
longcheng-nv and others added 2 commits May 11, 2026 14:45
… tolerance widening

PR NVIDIA#13948 round-3 reviewers (mingyangHao, lfr-0531) flagged two coupled
issues on the new `test_indexer_topk_decode_dist` parametrization:

  (a) K=512 / K=1024 cells marked `pytest.mark.flaky(reruns=2)` to absorb
      a ~0.09 % intermittent failure rate.
  (b) The corresponding `bin_tolerance` was widened to
      `full_range/256` plus an `|mean| * eps * 4` bf16/fp16 branch,
      raising the risk of silently masking real wrong-answer regressions.

Reviewers correctly pointed out that GVR Top-K is an *exact* algorithm
(same-dtype `logits.topk` is the spec, bf16 -> fp32 promotion inside
the kernel is bit-exact and order-preserving, so the K-th cutoff is
identical in both comparison spaces). Any sorted-value-multiset gap is
a real kernel bug, not algorithmic noise. This patch fixes the bug and
removes the workaround.

Kernel fix (cpp/tensorrt_llm/kernels/heuristic_topk.cuh)
--------------------------------------------------------

P4 snap loop's prior bound `snap_limit = cand_count > 128 ? cand_count/4 : 32`
was insufficient for cells where the initial bin-edge threshold sits
far from the true K-th value. Falling through with `cgt >= kK` then
caused Pass 1 to emit K elements in candidate-buffer scan order from
the `cgt > kK` strictly-greater set, missing some true top-K members
whose buffer-position happened to be later than the cap.

Proof of new bound's sufficiency: each snap iter monotonically advances
the threshold to the next distinct value:

  - if `cgt >= K`:  raise thr to `min(v > thr)` -> cgt strictly drops by
                    `count(v == new_thr) >= 1`.
  - if `cge <  K`:  lower thr to `max(v < thr)` -> cge strictly rises by
                    `count(v == new_thr) >= 1`.
  - else (cgt < K <= cge): converged, break.

Worst case from initial (cgt = cand_count, cge = 0) to convergence
takes at most `cand_count - kK + 1` iters, so `snap_limit = cand_count`
is the tightest provably-sufficient bound. Common-path cells still
break early at iter 1-3 (the bin-edge initial thr is typically within
one bin width of the K-th value), so steady-state perf is unchanged;
only the former failure-path cells now take a handful of extra (sub-us)
snap iters before converging.

Applied to both the fp32 path (`gvrTopKJob`) and the dtype path
(`gvrTopKJobDtype`).

Test fix (tests/unittest/_torch/thop/parallel/test_indexer_topk.py)
-------------------------------------------------------------------

  - Remove `pytest.mark.flaky(reruns=2)` markers on K=512 / K=1024.
  - Drop the `full_range/256` bin-quantized tolerance and the
    `|mean| * eps * 4` bf16/fp16 branch (unsound: bf16 quantization
    scales with the value magnitude near the K-th boundary, not with
    the distribution mean; furthermore, bf16->fp32 promotion inside
    the kernel introduces zero noise that needs eps slack).
  - Restore `compare_top_k_results` default `tolerance = 1e-5`.

Local validation
----------------

Full `test_indexer_topk_decode_dist` sweep on B200 with the kernel
fix + strict default tolerance:

  | scope                                           | cells | result      | wall  |
  |-------------------------------------------------|-------|-------------|-------|
  | K=512 + K=1024 (all dist/dtype/BS/next_n/ratio) |  1512 | 1512/1512   | 137 s |
  | K=2048 (same axes)                              |   756 |  756/756    |  70 s |
  | Stress: 5x focused corner (logistic m0.47,      |  5x36 |  180/180    |  ~5 m |
  |   BS=128, next_n=3 -> numRows=384)              |       |             |       |

0 failures across 2484 invocations (no `flaky` markers, default 1e-5
tolerance). The former 0.09 % intermittent corner is now stable.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
…kFTarget docstring

Round-3 review follow-up (mingyangHao on `heuristic_topk.cuh`).

heuristic_topk.cuh:553 (mingyangHao): `blockFusedSnapIterDtype` is defined
but never called. After commit 7 (this PR) deferred the bf16/fp16 ->
fp32 P3 conversion so that smem `keys[]` are stored fp32 on all dtype
paths, the Phase-4 snap iteration reuses the fp32
`blockFusedSnapIter<TopK>` directly. The dtype-specialized variant was
left over from the pre-deferred-conversion design. Remove the unused
~64-line function plus its companion doc paragraph (lines 508-510) and
the stale `blockFusedSnapIter: blockFusedSnapIterDtype<SmemKey>` line in
`gvrTopKJobDtype`'s docstring (line 1178).

heuristic_topk.cuh:215 (mingyangHao, "Why for 512 topk, target is 384
which is smaller?"): expand the `kFTarget` docstring to call out that
the value is the secant solver's **soft steering target**, not the
convergence condition (which remains `count in [kK, kCC]`, asserted by
the `done = 1` check in `gvrTopKJob`'s P1+P2 scope). A `kFTarget < kK`
is intentional for small K: preIdx-seeded P1 lands the initial
threshold near the right tail of the prev-row top-K, so biasing the
target below kK pulls the next secant interpolation downward more
aggressively, reaching the legal `[kK, kCC]` band in fewer steps. The
concrete multipliers (0.75K / 2.5K / 1.5K) were tuned empirically over
V4 M=K cells in commit 8 of this PR.

No code-level changes other than dead-code deletion + docstring
expansion — all live kernel SASS unchanged.

Local validation: full `test_indexer_topk_decode_dist` sweep
re-executed, 2268/2268 cells pass under strict default 1e-5 tolerance
(209 s wall on B200 sm_100).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
@longcheng-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47764 [ run ] triggered by Bot. Commit: 4880bb6 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47764 [ run ] completed with state SUCCESS. Commit: 4880bb6
/LLM/main/L0_MergeRequest_PR pipeline #37656 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@longcheng-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47798 [ run ] triggered by Bot. Commit: 4880bb6 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47798 [ run ] completed with state SUCCESS. Commit: 4880bb6
/LLM/main/L0_MergeRequest_PR pipeline #37689 completed with status: 'SUCCESS'

CI Report

Link to invocation

@longcheng-nv
Copy link
Copy Markdown
Collaborator Author

Review Status Summary

All reviewer feedback has been addressed in commit 4880bb6. CI is fully green on this HEAD — two consecutive /bot run cycles (#47764 and #47798) both completed with SUCCESS, all 7 GitHub check-runs ✅, blossom-ci ✅.

Addressed threads

Reviewer Thread Resolution
@mingyangHao heuristic_topk.cuh:227"Why for 512 topk, target is 384?" Expanded kFTarget docstring with derivation rationale (4880bb6)
@mingyangHao heuristic_topk.cuh"function defined but never referenced" Removed dead code: −64 lines body, −3 lines companion doc (4880bb6)
@mingyangHao heuristic_topk.cuh:553 — unroll pragma on 16-iter loops Verified -O3 nvcc already unrolls; SASS byte-identical with/without pragma — no change needed
@mingyangHao test_indexer_topk.py"will be zero when mean=0, dedicated design?" Fixed unsound bf16/fp16 quantization-noise formula — now scales with value magnitude near the K-th boundary (4880bb6)
@lfr-0531 test_indexer_topk.py"still accuracy issues at K=512/1024?" Confirmed a real correctness bug (not tie-noise); GVR Top-K is now exact-match against same-dtype logits.topk
@lfr-0531 (same thread) — flaky(reruns=2) masking Removed both flaky markers in d9be9d6; tests now pass deterministically

Status

  • Mergeable: true (state: blocked — awaiting review approval)
  • All 6 inline review threads have a response from the author; none re-opened
  • Latest CI run on HEAD 4880bb600e: ✅ SUCCESS (2026-05-12 00:23 UTC)

@juney-nvidia could you take a look when you have a moment? @mingyangHao @lfr-0531 if my responses satisfactorily resolve your threads, please mark them resolved or leave any follow-up. Happy to address further comments. Thanks!

Copy link
Copy Markdown
Collaborator

@yuxianq yuxianq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approve the warmup_heuristic_topk_decode change in cpp_custom_ops.py.

Copy link
Copy Markdown
Collaborator

@mingyangHao mingyangHao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@longcheng-nv longcheng-nv removed the request for review from juney-nvidia May 12, 2026 05:13
@lfr-0531 lfr-0531 merged commit de190a0 into NVIDIA:main May 12, 2026
8 checks passed
yufeiwu-nv pushed a commit to yufeiwu-nv/TensorRT-LLM that referenced this pull request May 19, 2026
… GVR Top-K (NVIDIA#13948)

Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants