[None][feat] Multi-K (512/1024/2048) and Multi-dtype (fp32/bf16/fp16) GVR Top-K#13948
Conversation
Add a templated dtype path (gvrTopKJobDtype / gvrTopKKernelDtype) that
shares the original GVR algorithm but accepts bf16 or fp16 inputs.
fp32 path (gvrTopKJob / gvrTopKKernel and the multi-row caller in
heuristicTopKDecode.cu) is byte-untouched — option A from the design
doc — so the production fp32 kernel is guaranteed zero-regression.
Architecture:
- GvrDtypeTraits<float | __nv_bfloat16 | __half>: encapsulates cvt
intrinsics (to_fp32 / from_fp32) and vector-load width
(VEC_W=4 for fp32, VEC_W=8 for bf16/fp16).
- KernelSmem -> KernelSmemTpl<SmemKey> template + alias
(using KernelSmem = KernelSmemTpl<float>) so existing call sites
remain source/ABI-equivalent.
- New helpers blockCountGEDtype<InputT> (8-wide vector load) and
blockFusedSnapIterDtype<SmemKey> (Tier-2 SmemKey reads with
Trait::to_fp32 up-cast).
- New gvrTopKJobDtype<InputT> mirrors gvrTopKJob with five
trait-driven substitution sites: P1 preIdx load, P2 vector
blockCountGE, P3 collect, P4 histogram bin index, P4 emit.
- launchHeuristicTopK<T> routes T=float to the original kernel and
T=__nv_bfloat16/__half to gvrTopKKernelDtype<T> via if constexpr;
smem size is computed per-trait so dynamic smem opt-in matches.
- Three explicit instantiations: <float, int>, <__nv_bfloat16, int>,
<__half, int>.
Tier-1 fp32 invariants preserved across all dtypes:
threshold, val_lo/hi, range1, inv1, all accumulators, histogram
bin index — kept in fp32. Only Tier-2 containers (input HBM,
smem keys, outputValues) follow InputT.
ptxas (sm_100) post-change:
gvrTopKKernel (fp32): REG=64, smem~59KB
gvrTopKKernelDtype<__nv_bfloat16>: REG=61, smem~47KB
gvrTopKKernelDtype<__half>: REG=61, smem~47KB
fp32 instantiation's resource line is character-identical to baseline.
Single-row micro-perf (B200, K=2048, nsys + 128MB L2 flush, 20 reps,
SWE-Bench-style temporal preIdx) shows 8-14 % bf16/fp16 speedup over
fp32 at N >= 32K. Multi-row Stage B is expected to widen the gap
toward the design's 1.3-1.5x prediction (HBM-bound regime).
Stage A acceptance evidence (REPORT_stageA.md in the standalone
ablation_study tree):
- SASS gate: fp32 byte-identical
- Algorithm correctness: mean overlap dtype-invariant (89.05-89.19%)
on SWE-Bench-64K 9 layers x 200 rows
- ptxas resources: no spills, three instantiations within budget
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
Stage B of 4d_multi_dtype_unified — wire the multi-dtype GVR Top-K kernel (heuristic_topk.cuh, landed in commit 5992714) into the TRT-LLM indexer Top-K dispatcher and torch op. Architecture: - heuristicTopKDecode.cu: add heuristicTopKMultiRowKernelDtype<InputT> (multi-row outer kernel mirroring heuristicTopKMultiRowKernel for bf16 / fp16 inputs) plus launchHeuristicTopKDecodeDtype<InputT> + __nv_bfloat16 / __half overload entry points. - IndexerTopK.h / indexerTopK.cu: invokeIndexerTopKDecodeDtype<InputT> GVR-Heuristic-only dispatcher with bytesPerElem=2 kBsL2 doubling; bf16/fp16 path has no Radix fallback (production radix kernel is fp32-only), TLLM_CHECK aborts when GVR conditions not met. - IndexerTopKOp.cpp: dtype switch (Float / BFloat16 / Half) over the torch op input; scratch dtype must match logits dtype. fp32 path is byte-untouched (option A): heuristicTopKMultiRowKernel, gvrTopKJob, gvrTopKKernel, and the fp32 invokeIndexerTopKDecode entry all unchanged. Verified via Stage B.B7 ptxas — fp32 multi-row kernel REG=40 / SHARED=1024 / STACK=16 byte-identical to baseline, bf16/fp16 multi-row kernels at REG=40 / SHARED=1024 / STACK=8 with no spill. Validation: - Multi-row correctness (B5): fp32/bf16/fp16 x BS={1,4,16} x 3 layers, 864 cells, mean GVR-vs-torch.topk overlap dtype-invariant (L0 ~98%, L22 ~91%, L60 ~77%). - Multi-row perf (B6, nsys + 128MB L2 flush): BS=1 1.20-1.23x, BS>=4 1.29-1.33x (matches DESIGN.md sec 6 prediction). - Full BS=1 SWE-Bench sweep (PreC-T1, 9 layers x 2024 rows x 3 reps): GVR fp32 vs Radix 1.90x median, GVR bf16 vs Radix 2.23x, GVR bf16 vs GVR fp32 1.178x (Amdahl f~0.30 ceiling matches). Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
Add compile-time (InputT, TopK) dispatch for the GVR Heuristic Top-K
kernel via the new GvrParams<InputT, int TopK> trait class:
- 9 specializations covering {fp32, bf16, fp16} x {512, 1024, 2048}
with per-(T, K) (kFTarget, kC, kNumBins) tuned per the SWE-Bench-64K
9-combo sim sweep (D_param_sweep_9combos REPORT.md). Primary
template undefined - any unsupported combo triggers compile-time
error rather than runtime fall-through.
- Generic KernelSmemTplK<SmemKey, CCap, NumBinsT> parameterizes the
smem layout. KernelSmemTpl<SmemKey> becomes an alias preserving
the K=2048 default layout (legacy KernelSmem alias unchanged).
- blockCountGE / blockFusedSnapIter (and Dtype mirrors) gain
<int TopK = TOP_K, typename SmemT = ...> template parameters with
defaults matching legacy callers - K=2048 callsites compile
unchanged.
- gvrTopKJob<TopK>, gvrTopKJobDtype<InputT, TopK>,
gvrTopKKernel<TopK>, gvrTopKKernelDtype<InputT, TopK> are
templated; bodies use constexpr aliases (kK / kCC / kBins /
kFTarget) from Params.
- launchHeuristicTopK<T> and launchHeuristicTopKDecode (fp32 +
Dtype) dispatch on runtime topK in {512, 1024, 2048} via C++20
generic-lambda switch. Each (T, K) gets its own kernel function
pointer, smem size, and cudaFuncSetAttribute call.
- Multi-row kernels (heuristicTopKMultiRowKernel<TopK>,
heuristicTopKMultiRowKernelDtype<InputT, TopK>) get 9 explicit
instantiations in the anonymous namespace of heuristicTopKDecode.cu
so the runtime switch can take their address.
K=2048 fp32 specialization preserves V2e production identity
(kFTarget=3072, kC=6144) - SASS for gvrTopKKernel<2048> is 1761
instructions byte-identical to the pre-refactor baseline. K=512/1024
fp32 + all bf16/fp16 use the sim-sweep-recommended kC=5120 with
K-specific kFTarget (K=K, 3K, 4K). Reg pressure unchanged at 64
(fp32) / 62 (bf16/fp16) per thread - all 9 kernels reg-bound 2 CTA/SM
on B200 sm_100.
NOTE: indexerTopK.cu kHeuristicTopK gating still selects only
topK==2048 for the Heuristic dispatch path. Relaxing the predicate
to topK in {512, 1024, 2048} is deferred to V4 indexer integration
(separate caller-side change).
Validated by smoke compile (nvcc 13.1.115 -arch=sm_100 -std=c++20,
9 kernel entries with expected (kFTarget, kC) literals folded as
ptxas immediates).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
blockFusedSnapIter and blockFusedSnapIterDtype use TopK as a template
parameter (with TOP_K=2048 default for legacy callers). The previous
multi-K commit's gvrTopKJob<TopK> and gvrTopKJobDtype<InputT, TopK>
called these helpers WITHOUT forwarding TopK, so the helpers picked up
the default TOP_K=2048 and snapped to the wrong K-th boundary for
K!=2048.
Symptom (caught by Step 3a real-machine correctness sweep on
SWE-Bench-64K, 1791 rows x 9 layers per combo via the JIT extension):
K=2048 fp32 -> 100.00% pass (default == intended - coincidentally OK)
K=512 fp32 -> 26.08% pass (326 fails out of 441)
K=1024 bf16 -> 40.36% pass (263 fails out of 441)
Fix: pass <TopK> (single-dtype path) and <SmemKey, TopK> (dtype path)
explicitly. Two-line change.
After fix:
- 9-combo correctness 16,119 / 16,119 = 100.00% on SWE-Bench-64K
(9 layers x 199 rows x 9 (K, dtype) combos).
- K=2048 fp32 SASS remains byte-identical to V2e baseline: 1761
instructions, 0 diff lines after stripping address/encoding
comments. The default-K instantiation is unchanged.
Note for future templated-kernel refactors: this trap is invisible to
smoke-compile and PTX literal-grep checks - both default and explicit
forms compile cleanly and fold the same constant for K=2048. Only
end-to-end correctness on K!=default surfaces it.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
Open the indexer-TopK decode dispatcher to admit GVR-Heuristic for topK == 512 and topK == 1024 in addition to the existing topK == 2048. The kernel/launcher (heuristicTopKDecode.cu) already has runtime topK switch + GvrParams<T, K> specializations for all three K values since 1bb1839; this commit extends only the caller-side gate so the new K paths actually dispatch from invokeIndexerTopKDecode. Two call sites in indexerTopK.cu are updated symmetrically: - fp32 dispatcher (invokeIndexerTopKDecode for float): topK == kHeuristicTopK → topK ∈ {512, 1024, 2048} preIdxCount == kHeuristicSize → preIdxCount == topK - bf16/fp16 dispatcher (invokeIndexerTopKDecodeDtype<InputT>): same gate change + TLLM_CHECK_WITH_INFO message updated to reference the new K-set and runtime-derived preIdxCount. Behavioral guarantees: - K=2048 path is byte-identical to the previous gate (set membership reduces to the old equality when topK=2048 and preIdxCount=2048), so all Scheme X v1.2 production traffic is unaffected. - kSeqSmall / kBsLarge / effectiveSplitWorkThreshold are reused unchanged. Per-K threshold tuning can be layered later; the current values are conservative for K=512/1024 and only push more rows toward the radix fallback (no regression risk). - kHeuristicTopK / kHeuristicSize public constants in heuristicTopKDecode.h are kept (still 2048) for ABI/header compatibility; they are no longer read by indexerTopK.cu and can be deprecated separately. Build verified: kernels_src compiles clean (only pre-existing C++20 lambda-template-parameter warnings, unrelated to this change). Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
…/1024 cells
Phase-4 histogram resolution (NUM_BINS=2048 default in V2e) was over-resolved
for K=512/1024 cells. A 4-config sweep (48,600 reps in
01_micro_kernel_bs1/phase_timing_v2_multik_multidtype/REPORT_kbins_sweep.md)
showed kNumBins ≈ K is the sweet spot — saving 5-9% total per cell with no
correctness or smem-cap risk.
Changed 6 of 9 GvrParams<T, K>::kNumBins values:
fp32 K=512 : NUM_BINS=2048 → 1024 (-7.0% total)
fp32 K=1024 : NUM_BINS=2048 → 1024 (-5.9% total)
fp32 K=2048 : NUM_BINS=2048 → 2048 (V2e LOCK, byte-identity preserved)
bf16 K=512 : NUM_BINS=2048 → 512 (-8.2%, compromise rule)
bf16 K=1024 : NUM_BINS=2048 → 512 (-7.8%, K/2 captures full saving)
bf16 K=2048 : NUM_BINS=2048 → 2048 (no change, K=2048 sweep ~0%)
fp16 K=512 : NUM_BINS=2048 → 512 (-5.7% total)
fp16 K=1024 : NUM_BINS=2048 → 1024 (-6.1% total)
fp16 K=2048 : NUM_BINS=2048 → 2048 (no change)
Why it works:
- histogram_clear shrinks linearly with NumBins (cleared via NumBins/512
iterations of zero-init).
- bin_search via V2e OPT6 parallel 2-step is log(NumBins).
- atomicAdd contention rises modestly: K=512 cell goes from 5120/2048=2.5
cands/bin to 5120/512=10/bin — still well below the contention knee
(~10-20 cands/bin where atomic serialization dominates).
- snap iteration count is K-dependent, NOT NumBins-dependent, so smaller
NumBins doesn't slow snap.
Why this is safe to deploy upstream:
- K=2048 fp32 path's kNumBins is unchanged → SASS byte-identical to V2e
production baseline → Scheme X v1.2 4-report verification NOT triggered
(per Q9d Option B).
- All 9 (T, K) cells were validated on the 9-combo set-eq tests in
Step 3a (16,119 rows × 9 combos = 100% match) before this change. The
kNumBins change does not affect the set of selected indices — only the
histogram resolution during snap convergence.
- Compile clean: kernels_src + tensorrt_llm + th_common rebuilt 09:46 UTC.
Side effect: K=512 / K=1024 dtype paths now use ~4 KB less smem (hist drops
from 8 KB to 2-4 KB). At K=512 bf16 with kC=5120, total smem drops from
~40 KB to ~36 KB — may unlock 5 CTA/SM at multi-row BS≥148, compounding
with the Q9e BS-scaling occupancy gain.
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
…p16 cells The bf16/fp16 path's Phase 3 ballot-free collect was doing per-element Trait::from_fp32(val) write conversions inside the hot smem-write loop. Q1b phase data showed dtype P3 was 1.0-2.3 µs slower than fp32 P3 across K (despite reading same N logits) — the unpack8 + per-element from_fp32 in the 8-wide inner loop was 2× the per-ldg work of fp32's 4-wide path. Solution: keep keys[] in smem as float (deferred conversion), convert only at the K output writebacks (3 exit paths in gvrTopKJobDtype). Costs ~10 KB extra smem at K=2048 dtype (5120 keys × +2B); still fits 4 CTA/SM. Switches to the fp32 helper blockFusedSnapIter<TopK> since the smem layout now matches. Wall verify (Q9d-equivalent: 9 layers × 200 rows × 5 reps under nsys, P0.5 also active): Cell Q9d original Post-P0.5+P0 ∆ K=512 fp32 26.91 25.28 -6.1% K=512 bf16 25.63 22.21 -13.4% K=512 fp16 24.54 21.50 -12.4% K=1024 fp32 30.69 29.41 -4.2% K=1024 bf16 28.80 25.41 -11.8% K=1024 fp16 27.81 24.58 -11.6% K=2048 fp32 33.92 33.92 0 (V2e LOCK preserved) K=2048 bf16 29.95 28.06 -6.3% K=2048 fp16 29.02 27.30 -5.9% Radix 52.96 52.99 0 (sanity check) K=2048 fp32 byte-identity preserved → V2e Scheme X v1.2 4-report verification window NOT triggered. fp32 path's gvrTopKJob is unchanged; only the dtype path's gvrTopKJobDtype has the new fp32 keys[] layout. Files changed: - cpp/tensorrt_llm/kernels/heuristic_topk.cuh (10 sites) - cpp/tensorrt_llm/kernels/heuristicTopKDecode.cu (2 sites) Compile clean: kernels_src + tensorrt_llm + th_common rebuilt 10:11 UTC. Q9 set-eq tests pending — re-run before merging upstream. Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
Under V4 production where preIdx is built with M=K (not legacy M=2048),
P1 pmean shifts toward the right tail of the previous row's top-K
distribution. P2 secant therefore overshoots when kFTarget is set
relative to the V3.2-era M=2048 baseline, costing extra iterations
on K=512/1024 cells. The Sprint 2 sweep (`10_multi_cta_v1/M_eq_K_redesign/`)
identified K-proportional sweet spots: 0.75K for K=512 and 2.5K for K=1024,
applied uniformly across fp32/bf16/fp16.
Changes (6 of 9 GvrParams<T, K> specializations; K=2048 untouched):
GvrParams<float, 512> kFTarget 512 -> 384
GvrParams<float, 1024> kFTarget 3072 -> 2560
GvrParams<__nv_bfloat16, 512> kFTarget 512 -> 384
GvrParams<__nv_bfloat16, 1024> kFTarget 3072 -> 2560
GvrParams<__half, 512> kFTarget 512 -> 384
GvrParams<__half, 1024> kFTarget 3072 -> 2560
GvrParams<*, 2048> UNCHANGED (V2e LOCK preserved)
SWE-Bench-64K rebench (9 layers, 1827 cells, Q9d harness):
gvr_K512 fp32 25.34 -> 24.58 us (-3.03 %)
gvr_K512 bf16 22.21 -> 21.25 us (-4.32 %)
gvr_K512 fp16 21.63 -> 20.38 us (-5.77 %) [best gain]
gvr_K1024 fp32 29.54 -> 28.74 us (-2.71 %)
gvr_K1024 bf16 25.47 -> 24.80 us (-2.64 %)
gvr_K1024 fp16 24.64 -> 23.90 us (-2.99 %)
gvr_K2048 (all dtypes) byte-identical SASS (V2e LOCK)
All 9 layers improve on every re-tuned cell; no per-layer regression.
Why this is safe to deploy upstream:
- K=2048 fp32 hot path SASS verified byte-identical at md5 level
(6/6 V2e LOCK kernels in libth_common.so:
heuristicTopKMultiRowKernel<2048> a9f5c6b8f26c..
gvrTopKKernel<2048> e9ca7cdb3e46..
heuristicTopKMultiRowKernelDtype<bf16, 2048> aede08cc084e..
heuristicTopKMultiRowKernelDtype<half, 2048> 9caa920c4368..
gvrTopKKernelDtype<bf16, 2048> af8ba491cba5..
gvrTopKKernelDtype<half, 2048> 797315781bce..
pre-patch and post-patch dumps identical.).
- Scheme X v1.2 4-report regression NOT triggered (heuristic_topk.cuh-only
change; dispatcher cpp/tensorrt_llm/kernels/indexerTopK.cu unchanged;
K=2048 dispatcher hot path SASS-invariant).
- Single-shot Scheme X smoke (L0 x BS={1,64,256,432}) confirms speedup
within +/- 4 % of 2026-04-23 baseline at BS<432; Radix wall within
+/- 2 % across all 4 BS.
- Q9d set-equality 9-combo correctness validated in Sprint 2 prior to
deployment.
Build: kernels_src + tensorrt_llm + th_common rebuilt 2026-05-08 06:02 UTC.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
…tion
The standalone gvrTopKKernel<TopK> + gvrTopKKernelDtype<InputT, TopK>
wrappers were declared with __launch_bounds__(BLOCK_SIZE, 1), giving
nvcc a minBlocksPerSM=1 hint. The production multi-row entry
heuristicTopKMultiRowKernel{,Dtype} (heuristicTopKDecode.cu:50, 107)
uses the single-arg form __launch_bounds__(BLOCK_SIZE) with no minimum
hint. Same gvrTopKJob{,Dtype} device function in both, but the launch
bounds difference produced divergent register allocations:
Standalone (before): Production multi-row:
(float, 2048): REG=64 REG=40
(float, 1024): REG=63 REG=61
(float, 512): REG=63 REG=40
(bf16, 2048): REG=61 REG=44
(bf16, 1024): REG=61 REG=40
(bf16, 512): REG=61 REG=40
(half, 2048): REG=61 REG=44
(half, 1024): REG=59 REG=54
(half, 512): REG=61 REG=40
This made the standalone wrapper run at 50 % theoretical occupancy
across all 9 specs (REG-bound 2 CTA/SM), while the production multi-row
path ran at 75-100 % theoretical for 5 of 9 specs. Local JIT extension
benchmarks (cuda_ext/topk_ext.cu via launchHeuristicTopK) and any
ablation comparing standalone wrappers therefore reflected occupancy
that did not match production behavior.
Removing the trailing `, 1` argument applies the same nvcc heuristic to
both entries:
Standalone (after):
(float, 2048): REG=40 -> 3 CTA/SM (75 % theor)
(float, 1024): REG=40 (stack=16) -> 3 CTA/SM
(float, 512): REG=40 -> 4 CTA/SM (100 %)
(bf16, 2048): REG=45 -> 2 CTA/SM (50 %)
(bf16, 1024): REG=40 (stack=8) -> 3 CTA/SM
(bf16, 512): REG=32 (stack=8) -> 4 CTA/SM (100 %)
(half, 2048): REG=45 -> 2 CTA/SM
(half, 1024): REG=40 (stack=8) -> 3 CTA/SM
(half, 512): REG=32 (stack=8) -> 4 CTA/SM (100 %)
Standalone now matches production reg counts within Δ <= 5; for K=512
specs and K=1024 fp16/bf16 standalone is even leaner than production
because the single-row body has no rowIdx outer loop or per-row preIdx
indexing.
Why this is safe to land:
- Production heuristicTopKMultiRowKernel{,Dtype} is a different
__global__ in heuristicTopKDecode.cu with its own __launch_bounds__,
which this commit does not touch. SASS for the V4 production K=2048
fp32 hot path is byte-identical (V2e LOCK preserved on production
multi-row entry).
- The standalone entry SASS does change. V2e LOCK byte-identity on
the standalone wrappers is intentionally broken; gvrTopKKernel*
is only invoked by the local JIT extension and the standalone
ablation harnesses, never by TRT-LLM production runtime.
- Q9d set-equality 9-combo correctness is unaffected: this commit
changes only the kernel launch envelope (registers, CTA scheduling),
not algorithm flow. Re-validation will follow on the post-alignment
re-run of 09_precision_ablation/04_multik_multidtype_perf_bs1.
Build: kernels_src + tensorrt_llm + th_common rebuilt 2026-05-08 07:39 UTC.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
…ments Remove development-history annotations from kernel source comments that would not be meaningful to upstream readers: - Internal experiment IDs (4d.D, 4d_multi_dtype_unified, Q1b-P0, Q10b-S2, Q10b-S3 retune annotations) - Date stamps (2026-05-06, 2026-05-07, 2026-05-08) - References to local ablation directories (04_kernel_optimizations/4d_multi_dtype_unified/, 10_multi_cta_v1/M_eq_K_redesign/, ablation_study/gvr_phase_timing/...) - Aspirational planning leftovers (\"Step 2 must verify with cuobjdump\", \"Step 3 must run Scheme X v1.2 4-report regression to confirm\") - Stale Sprint-2 plan that contradicted the current code (claimed GvrParams<float, 2048>::kFTarget = 4096 but actual code is 3072) - \"Backwards-compat alias\" / \"byte-equivalent signature\" framing for KernelSmemTpl / KernelSmem aliases — these are convenience aliases for the K=2048 default, not back-compat shims. The Multi-K trait class (GvrParams<T, K>) gains a longer mechanism-level explanation of why kFTarget, kC, and kNumBins vary across (T, K) cells: preIdx-distribution shift under V3.2 decode semantics for kFTarget, register/smem occupancy for kC, and atomic-contention-vs-Phase-4-setup trade-off for kNumBins. fp32 K=2048 SASS byte-identity preservation is called out at the relevant specialization. No code logic changes — comment-only cleanup. ptxas output for all 9 GvrParams<T, K> specializations is identical to pre-cleanup. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
…stry tags Two small follow-ups to commit 1d1d938: 1. Remove the unused `thresholdPos` parameter from `gvrTopKKernel`, `gvrTopKKernelDtype`, `launchHeuristicTopK`, and the `cudaLaunchKernelEx` call site. The parameter was carried as a placeholder for an older threshold-hint API that was never wired into the kernel body — `gvrTopKJob` / `gvrTopKJobDtype` neither receive it nor reference it. Removing it shrinks the public launcher signature from 9 args to 8 and matches the multi-row entry's signature shape. SASS impact (verified with cuobjdump on libth_common.so): - **Production multi-row K=2048 specs (3/3): byte-identical.** `heuristicTopKMultiRowKernel<2048>`, `heuristicTopKMultiRowKernelDtype<__half, 2048>`, `heuristicTopKMultiRowKernelDtype<__nv_bfloat16, 2048>` keep their pre-cleanup md5s — V2e LOCK on the production hot path is preserved. - **Standalone wrapper K=2048 specs (3/3): mangled signature changes by one parameter** (`...Pii` → `...Pi`), so md5s differ. These wrappers are only used by JIT consumers (e.g. ablation-side PyTorch extensions); production never calls them. 2. Strip leftover `OPT[N]` references from comments (`OPT4: Warp-Level Reduction Primitives`, `OPT5: when done==1...`, `OPT6: Parallel K-th bin search`, `OPT7: cache per-thread count...`). These were tags into an internal optimization registry; the same text remains as plain mechanism comments without the registry code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
… launcher Two pure refactors with zero kernel SASS impact: 1. **`getSchemeXBounds()` helper (indexerTopK.cu)**: the fp32 and bf16/fp16 dispatchers each carried independent `std::call_once` caches for SM count, L2 capacity, and the `kSeqSmall` threshold (5 once_flags total, ~80 LOC of boilerplate duplication). Hoist into a single anonymous-namespace helper that takes `bytesPerElem` so kBsL2 still adapts per-dtype. Reduces once_flag count from 5 to 2 (helper + debug-flag). 2. **`launchHeuristicTopKDecodeImpl<InputT>` template (heuristicTopKDecode.cu)**: the fp32 launcher and `launchHeuristicTopKDecodeDtype<InputT>` carried near-identical lambda bodies (cudaFuncSetAttribute / cudaLaunchKernelEx scaffolding) differing only in (a) vector-load alignment requirement and (b) which kernel function pointer to take. Collapse into one templated impl in anonymous namespace with `if constexpr` for the kernel function-pointer selection; the three `launchHeuristicTopKDecode` overloads become thin forwarders. Side effect: the `switch (topK)` now has a `default:` arm that throws on unreachable input (defensive belt over the `TLLM_CHECK_WITH_INFO` guard). Also strips two missed `// 4d:` include tags in indexerTopK.cu that the prior chore-cleanup commits did not catch. Net diff: −40 LOC across two files. No kernel binary changes — V2e SASS LOCK on K=2048 fp32 production hot path is preserved (kernel templates are not touched). Verified by building `make th_common -j8` (clean link). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
…s comment
The two-pass selection block at gvrTopKJob's Phase-4 had a comment
referencing internal experiment names ("Opt-M fix" / "v0 interleaved
version") that the prior chore-cleanup commits did not catch — the
audit regex enumerated `OPT[0-9]+` and `Q[0-9]b-P[0-9]+` patterns but
not `Opt-M` (Opt-M was one of 12 already-closed exploratory branches
in the project's internal registry).
Replace with a mechanism-only description that explains *why* the two
passes are sequential (deterministic selection under K-th-rank ties)
without referencing the internal v0/v1/v2 history.
Same comment block in `gvrTopKJobDtype` was already mechanism-only
(no Opt-M reference), so this is a single 7-line edit in the fp32
path.
No code logic changes, no SASS impact.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
…6/fp16 path Templates topKPerRowDecode / topKPerRowJob / topKPerRowPrefill / processHistogramStep on InputT (default float). HBM reads cast to float at load sites; the histogram + sort steps remain dtype-agnostic since they operate on float keys in shared memory. fp32 SASS is byte-identical to the pre-template baseline (template defaults). Replaces the bf16/fp16 dispatcher's hard abort with the same fallback chain the fp32 dispatcher uses: numColumns < kSortingAlgorithmThreshold -> insertion sort kSortingAlgorithmThreshold <= N < splitWork -> radix sort N >= splitWork -> unsupported (TLLM_CHECK) The split-work tier requires float aux buffers the bf16/fp16 entry does not expose; callers in that regime must use the fp32 entry. Exposes canIndexerTopKDecodeUseGvr(numRows, numColumns, topK, bytesPerElem) so callers can introspect the dispatcher decision before allocating preIdx tensors or heuristicScratch buffers. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
|
/bot run --disable-fail-fast |
📝 WalkthroughWalkthroughThis PR extends TensorRT-LLM's top-K decoding kernels to support bfloat16 and float16 inputs alongside existing float32 support. The changes introduce compile-time dtype specialization via traits, refactor the GVR algorithm to accept compile-time TopK values (512, 1024, 2048), and extend public APIs and dispatchers across IndexerTopK and HeuristicTopK subsystems. ChangesMulti-dtype Top-K Kernel Extension
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
cpp/tensorrt_llm/kernels/heuristic_topk.cuh (2)
1720-1726: ⚡ Quick winUse
getEnvEnablePDL()for consistency.
heuristicTopKDecode.cu(line 222) usestensorrt_llm::common::getEnvEnablePDL()while this launcher manually parsesTRTLLM_ENABLE_PDL. Using the common utility ensures consistent behavior and avoids duplicating env-parsing logic.♻️ Suggested fix
Add the include at the top of the file:
`#include` "tensorrt_llm/common/envUtils.h"Then replace the manual parsing:
- // Honor the standard TRTLLM_ENABLE_PDL env var (default on; set "0" to - // disable). - bool enablePDL = true; - if (char const* env = std::getenv("TRTLLM_ENABLE_PDL")) - { - if (env[0] == '0' && env[1] == '\0') - enablePDL = false; - } + bool const enablePDL = tensorrt_llm::common::getEnvEnablePDL();🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tensorrt_llm/kernels/heuristic_topk.cuh` around lines 1720 - 1726, Replace the manual getenv parsing that sets the local bool enablePDL with a call to the shared utility tensorrt_llm::common::getEnvEnablePDL(); add the include "tensorrt_llm/common/envUtils.h" to the top of the file, remove the block that directly reads TRTLLM_ENABLE_PDL and instead initialize enablePDL by calling tensorrt_llm::common::getEnvEnablePDL() to keep behavior consistent with heuristicTopKDecode.cu.
1684-1686: 💤 Low valueRemove unused
SmemKeytype alias.
SmemKeyis defined from the trait but never used —SmemTcorrectly usesfloatfor keys as documented in the optimization note at lines 1173-1176. The unused alias adds confusion.♻️ Suggested fix
template <typename InputT, int TopK = TOP_K> __global__ void __launch_bounds__(BLOCK_SIZE) gvrTopKKernelDtype(InputT const* __restrict__ input, int const N, int const* __restrict__ preIdx, int const M, int const topK, InputT* __restrict__ outputValues, int* __restrict__ outputIndices) { - using SmemKey = typename GvrDtypeTraits<InputT>::SmemKey; - using SmemT - = KernelSmemTplK<float, GvrParams<InputT, TopK>::kC, GvrParams<InputT, TopK>::kNumBins>; // dtype keys fp32 + // dtype path uses fp32 smem keys (deferred convert to output). + using SmemT = KernelSmemTplK<float, GvrParams<InputT, TopK>::kC, GvrParams<InputT, TopK>::kNumBins>; extern __shared__ unsigned char smem_raw[];🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tensorrt_llm/kernels/heuristic_topk.cuh` around lines 1684 - 1686, Remove the unused type alias SmemKey introduced from GvrDtypeTraits<InputT>; it is never referenced and causes confusion—delete the line "using SmemKey = typename GvrDtypeTraits<InputT>::SmemKey;" and keep the existing SmemT definition that intentionally uses float keys (KernelSmemTplK<float, GvrParams<InputT, TopK>::kC, GvrParams<InputT, TopK>::kNumBins>), leaving GvrDtypeTraits, InputT and TopK untouched.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@cpp/tensorrt_llm/kernels/heuristic_topk.cuh`:
- Around line 1720-1726: Replace the manual getenv parsing that sets the local
bool enablePDL with a call to the shared utility
tensorrt_llm::common::getEnvEnablePDL(); add the include
"tensorrt_llm/common/envUtils.h" to the top of the file, remove the block that
directly reads TRTLLM_ENABLE_PDL and instead initialize enablePDL by calling
tensorrt_llm::common::getEnvEnablePDL() to keep behavior consistent with
heuristicTopKDecode.cu.
- Around line 1684-1686: Remove the unused type alias SmemKey introduced from
GvrDtypeTraits<InputT>; it is never referenced and causes confusion—delete the
line "using SmemKey = typename GvrDtypeTraits<InputT>::SmemKey;" and keep the
existing SmemT definition that intentionally uses float keys
(KernelSmemTplK<float, GvrParams<InputT, TopK>::kC, GvrParams<InputT,
TopK>::kNumBins>), leaving GvrDtypeTraits, InputT and TopK untouched.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 71e22bbe-2801-4613-8946-0f31f85d685a
📒 Files selected for processing (6)
cpp/tensorrt_llm/kernels/IndexerTopK.hcpp/tensorrt_llm/kernels/heuristicTopKDecode.cucpp/tensorrt_llm/kernels/heuristicTopKDecode.hcpp/tensorrt_llm/kernels/heuristic_topk.cuhcpp/tensorrt_llm/kernels/indexerTopK.cucpp/tensorrt_llm/thop/IndexerTopKOp.cpp
|
PR_Github #47533 [ run ] triggered by Bot. Commit: |
|
PR_Github #47533 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #47544 [ run ] triggered by Bot. Commit: |
|
PR_Github #47544 [ run ] completed with state
|
…C++20 templated lambda CI build #47533 (PR NVIDIA#13948) failed on Build-x86_64 / Build-SBSA with 9 errors per arch in cudafe1.stub.c: error: explicit specialization of '__wrapper__device_stub_gvrTopKKernel<512>' after instantiation heuristic_topk.cuh:1148:20: implicit instantiation first required here (also for K={1024,2048} on gvrTopKKernel and for gvrTopKKernelDtype<{bf16, fp16}, {512,1024,2048}>.) Local build on sm_100f / nvcc 13.1 passed cleanly; sm_89 / sm_120f cudafe1 stubgen tightens [temp.expl.spec]/6 ordering and rejects the implicit-then-explicit sequence. Two changes in cpp/tensorrt_llm/kernels/heuristic_topk.cuh: 1. Add 9 explicit kernel instantiations (gvrTopKKernel<{512,1024,2048}> plus gvrTopKKernelDtype<{__nv_bfloat16,__half}, {512,1024,2048}>) right after the gvrTopKKernelDtype definition. Forces nvcc to emit each host wrapper stub at a fixed point ahead of any address-take in launchHeuristicTopK. Mirrors the proven pattern at heuristicTopKDecode.cu:155-171 for the multi-row kernels (which is why those didn't fail). Header is included by exactly one TU so no ODR concern. 2. Replace the C++20 templated lambda `[&]<int TopK>() {...}` in launchHeuristicTopK with detail::launchHeuristicTopKImpl<T, TopK> function template + switch dispatch. Removes the C++20-extension warning NVIDIA#3288-D and eliminates the templated-lambda + capture + __global__-address-take quirk pattern that triggers cudafe1's stubgen ordering bug. Kernel body, GvrParams traits, __launch_bounds__, PDL attr, and cudaLaunchKernelEx call are all preserved — kernel SASS is unchanged for all 9 (T, K) instantiations. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
PR_Github #47548 [ run ] triggered by Bot. Commit: |
|
PR_Github #47548 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #47580 [ run ] triggered by Bot. Commit: |
|
PR_Github #47652 [ run ] triggered by Bot. Commit: |
…#13477 review follow-up) PR NVIDIA#13477 reviewers (hyukn, yuxianq) requested on 2026-04-30 that `warmup_heuristic_topk_decode` and its module-level idempotency cache live next to their sole caller in `tensorrt_llm/_torch/attention_backend/sparse/dsa.py` rather than in `tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py` (which is the wrong abstraction layer — it hosts `torch.library` op registrations, not caller-specific warmup helpers). The relocation was acknowledged at merge time but deferred to a follow-up; this commit closes that loop. Mechanical move of: - `_HEURISTIC_TOPK_WARMUP_DONE` (Set) + `_HEURISTIC_TOPK_WARMUP_LOCK` (threading.Lock) idempotency state - `def warmup_heuristic_topk_decode(...)` (52 lines, including docstring; behavior unchanged) - `import threading` and `Set` from `typing` (dropped from cpp_custom_ops.py since no other code in that file uses them) Caller in `Indexer.setup_module()` (dsa.py:1306) is updated from a conditional cross-module import + `cpp_custom_ops.warmup_heuristic_topk_decode(...)` to a direct local call. No other call sites exist in the tree. This is a no-op for behavior: the function body, idempotency keying (device, top_k, hint_size, num_cols), and CUDA Graph capture ordering are byte-identical to the merged version from PR NVIDIA#13477. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
PR_Github #47653 [ run ] triggered by Bot. Commit: |
|
PR_Github #47653 [ run ] completed with state
|
… tolerance widening PR NVIDIA#13948 round-3 reviewers (mingyangHao, lfr-0531) flagged two coupled issues on the new `test_indexer_topk_decode_dist` parametrization: (a) K=512 / K=1024 cells marked `pytest.mark.flaky(reruns=2)` to absorb a ~0.09 % intermittent failure rate. (b) The corresponding `bin_tolerance` was widened to `full_range/256` plus an `|mean| * eps * 4` bf16/fp16 branch, raising the risk of silently masking real wrong-answer regressions. Reviewers correctly pointed out that GVR Top-K is an *exact* algorithm (same-dtype `logits.topk` is the spec, bf16 -> fp32 promotion inside the kernel is bit-exact and order-preserving, so the K-th cutoff is identical in both comparison spaces). Any sorted-value-multiset gap is a real kernel bug, not algorithmic noise. This patch fixes the bug and removes the workaround. Kernel fix (cpp/tensorrt_llm/kernels/heuristic_topk.cuh) -------------------------------------------------------- P4 snap loop's prior bound `snap_limit = cand_count > 128 ? cand_count/4 : 32` was insufficient for cells where the initial bin-edge threshold sits far from the true K-th value. Falling through with `cgt >= kK` then caused Pass 1 to emit K elements in candidate-buffer scan order from the `cgt > kK` strictly-greater set, missing some true top-K members whose buffer-position happened to be later than the cap. Proof of new bound's sufficiency: each snap iter monotonically advances the threshold to the next distinct value: - if `cgt >= K`: raise thr to `min(v > thr)` -> cgt strictly drops by `count(v == new_thr) >= 1`. - if `cge < K`: lower thr to `max(v < thr)` -> cge strictly rises by `count(v == new_thr) >= 1`. - else (cgt < K <= cge): converged, break. Worst case from initial (cgt = cand_count, cge = 0) to convergence takes at most `cand_count - kK + 1` iters, so `snap_limit = cand_count` is the tightest provably-sufficient bound. Common-path cells still break early at iter 1-3 (the bin-edge initial thr is typically within one bin width of the K-th value), so steady-state perf is unchanged; only the former failure-path cells now take a handful of extra (sub-us) snap iters before converging. Applied to both the fp32 path (`gvrTopKJob`) and the dtype path (`gvrTopKJobDtype`). Test fix (tests/unittest/_torch/thop/parallel/test_indexer_topk.py) ------------------------------------------------------------------- - Remove `pytest.mark.flaky(reruns=2)` markers on K=512 / K=1024. - Drop the `full_range/256` bin-quantized tolerance and the `|mean| * eps * 4` bf16/fp16 branch (unsound: bf16 quantization scales with the value magnitude near the K-th boundary, not with the distribution mean; furthermore, bf16->fp32 promotion inside the kernel introduces zero noise that needs eps slack). - Restore `compare_top_k_results` default `tolerance = 1e-5`. Local validation ---------------- Full `test_indexer_topk_decode_dist` sweep on B200 with the kernel fix + strict default tolerance: | scope | cells | result | wall | |-------------------------------------------------|-------|-------------|-------| | K=512 + K=1024 (all dist/dtype/BS/next_n/ratio) | 1512 | 1512/1512 | 137 s | | K=2048 (same axes) | 756 | 756/756 | 70 s | | Stress: 5x focused corner (logistic m0.47, | 5x36 | 180/180 | ~5 m | | BS=128, next_n=3 -> numRows=384) | | | | 0 failures across 2484 invocations (no `flaky` markers, default 1e-5 tolerance). The former 0.09 % intermittent corner is now stable. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
…kFTarget docstring Round-3 review follow-up (mingyangHao on `heuristic_topk.cuh`). heuristic_topk.cuh:553 (mingyangHao): `blockFusedSnapIterDtype` is defined but never called. After commit 7 (this PR) deferred the bf16/fp16 -> fp32 P3 conversion so that smem `keys[]` are stored fp32 on all dtype paths, the Phase-4 snap iteration reuses the fp32 `blockFusedSnapIter<TopK>` directly. The dtype-specialized variant was left over from the pre-deferred-conversion design. Remove the unused ~64-line function plus its companion doc paragraph (lines 508-510) and the stale `blockFusedSnapIter: blockFusedSnapIterDtype<SmemKey>` line in `gvrTopKJobDtype`'s docstring (line 1178). heuristic_topk.cuh:215 (mingyangHao, "Why for 512 topk, target is 384 which is smaller?"): expand the `kFTarget` docstring to call out that the value is the secant solver's **soft steering target**, not the convergence condition (which remains `count in [kK, kCC]`, asserted by the `done = 1` check in `gvrTopKJob`'s P1+P2 scope). A `kFTarget < kK` is intentional for small K: preIdx-seeded P1 lands the initial threshold near the right tail of the prev-row top-K, so biasing the target below kK pulls the next secant interpolation downward more aggressively, reaching the legal `[kK, kCC]` band in fewer steps. The concrete multipliers (0.75K / 2.5K / 1.5K) were tuned empirically over V4 M=K cells in commit 8 of this PR. No code-level changes other than dead-code deletion + docstring expansion — all live kernel SASS unchanged. Local validation: full `test_indexer_topk_decode_dist` sweep re-executed, 2268/2268 cells pass under strict default 1e-5 tolerance (209 s wall on B200 sm_100). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
PR_Github #47764 [ run ] triggered by Bot. Commit: |
|
PR_Github #47764 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #47798 [ run ] triggered by Bot. Commit: |
|
PR_Github #47798 [ run ] completed with state |
Review Status SummaryAll reviewer feedback has been addressed in commit 4880bb6. CI is fully green on this HEAD — two consecutive Addressed threads
Status
@juney-nvidia could you take a look when you have a moment? @mingyangHao @lfr-0531 if my responses satisfactorily resolve your threads, please mark them resolved or leave any follow-up. Happy to address further comments. Thanks! |
yuxianq
left a comment
There was a problem hiding this comment.
Approve the warmup_heuristic_topk_decode change in cpp_custom_ops.py.
… GVR Top-K (NVIDIA#13948) Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Summary
Follow-up to #13477 (Scheme X dispatcher). Extends the GVR (Guess-Verify-Refine) Top-K micro-kernel to:
via a
GvrParams<T, K>trait class with 9 specializations. K=2048 fp32 production path is byte-identical SASS-wise (V2e LOCK preserved across all 3 production K=2048 kernel variants).Commit Breakdown (16 commits)
GvrParamstrait classTopKtemplate arg intoblockFusedSnapIterhelperskNumBinstuning saves 5-9 % on K=512/1024 cellskFTargetretune for V4 M=K (K=512/1024 cells)gvrTopKKernellaunch_boundswith productionthresholdPosparam + remove internal opt-registry tagsheuristicTopKDecodelauncherblockFusedSnapIterDtype+ clarifykFTargetdocstring (round-3 review fix)Key Files (6 files / +1621 / −355)
cpp/tensorrt_llm/kernels/heuristic_topk.cuhGvrParams<T,K>trait class; 9-way template specialization; per-(T,K)kNumBins/kFTargettuning; deferred fp32 P3 conversioncpp/tensorrt_llm/kernels/heuristicTopKDecode.cuGvrParams-driven kernel dispatchcpp/tensorrt_llm/kernels/heuristicTopKDecode.hcpp/tensorrt_llm/kernels/indexerTopK.cucpp/tensorrt_llm/kernels/IndexerTopK.hinvokeIndexerTopKDecodecpp/tensorrt_llm/thop/IndexerTopKOp.cpptensor.scalar_type()API
No public API change. The Python op
torch.ops.trtllm.indexer_topk_decodenow auto-routes by
logits.scalar_type()to the matching dispatcher;bf16/fp16 paths require
pre_idxandheuristic_scratch(matchingdtype). K is selected via the
index_topkargument as before; valuesoutside {512, 1024, 2048} fall back to Radix / Insertion.
Test plan
pytest tests/unittest/_torch/thop/parallel/test_indexer_topk.py -k test_indexer_topk_decode):test_indexer_topk_decode(5-arg / nopre_idx) — fp32 K∈{128, 2048} Radix/Insertion fallback covered (32 cases).test_indexer_topk_decode_dist(7-arg / GVR Heuristic, parametrized in this PR overindex_topk × dtype × num_tokens × batch_size × next_n × dist × success_ratio= 2268 cases) — all 9 new GVR(T, K)specializations covered.num_tokens=16384cells hit GVR (1134);num_tokens=8192cells route through Radix/Insertion fallback via Scheme X dispatcher (kSeqSmall=12288gate). Post round-3 fix (commit 15): all 2268 cells pass under the strict default tolerance (1e-5) —flakymarkers and thebin_tolerancewidening are removed; the former 0.09 % K=512 / K=1024 intermittent atnumRows ≈ 384is fixed by the P4 snap-iter convergence guarantee.torch.topkset (9 combos × 9 SWE-Bench-64K layers × 12 BS cells).V2e LOCK SASS byte-identity (production K=2048)
heuristicTopKMultiRowKernel<2048>fp32a9f5c6b8/a9f5c6b8✅heuristicTopKMultiRowKernelDtype<__half, 2048>9caa920c/9caa920c✅heuristicTopKMultiRowKernelDtype<__nv_bfloat16, 2048>aede08cc/aede08cc✅K=2048 production SASS is byte-identical with pre-PR baseline for all
three dtypes. New K=512 / K=1024 specializations have stable SASS produced
fresh by the
GvrParams<T,K>template.Performance Results (B200 sm_100, SWE-Bench-64K decode logits)
BS=1 cross-layer pooled (Q9d 04b — 9 layers × 9 GVR combos vs Radix fp32 K=2048 baseline, production path)
Routed through
torch.ops.trtllm.indexer_topk_decode(Scheme X v1.2 dispatcher→
heuristicTopKMultiRowKernel{,Dtype}<TopK>withpreIdxOffset=+1, V3.2decode semantics). 1 827 timed rows / variant (9 layers × stride-10 over rows
1..2024), 5 reps + 3 warmup, 128 MiB L2 flush before every launch.
GVR wins all 9 combos × 9 layers. K↓ and dtype↓ are both monotonic
(smaller K = smaller smem residency / higher occupancy; smaller dtype =
smaller HBM-load + smem footprint). Cross-validates the BS=1 column of
Table 2 (Q9e 05) within ±10 % wall — both tables are now production path.
BS scaling — pooled across 9 layers, last row replicated (Q9e 05, production path)
Speedup vs Radix fp32 K=2048 baseline (50.1 µs at BS=1 → 93.5 µs at BS=400):
bf16 / fp16 maintain 2.0 – 2.3× speedup across the full BS range
(1 .. 400). fp32 fades from 1.85× to 1.28× as BS grows past 200 (smem
residency tightens). Half-precision dtypes scale gracefully because
their lower smem footprint preserves occupancy beyond 1 wave.
Reproducibility
launches (9 GVR combos × 12 BS × 9 layers × 5 reps).
Round-3 review fixes (commit 15)
Addresses lfr-0531 and mingyangHao review comments on the 0.09 %
intermittent at K=512 / K=1024:
cpp/tensorrt_llm/kernels/heuristic_topk.cuh): P4 snaploop's
snap_limit = cand_count/4was insufficient for cells wherethe initial bin-edge threshold sits far from the true K-th value. Bug:
on
cgt >= kKnon-convergence, Pass 1 emitted K elements in scanorder from the
cgt > kKstrictly-greater set, missing some truetop-K members. Fix: tighten
snap_limit = cand_count, which is theprovably-sufficient bound (each snap iter strictly decreases
cgtorstrictly increases
cgeby ≥ 1, so worst-case convergence takescand_count − kK + 1iters). Common path still breaks early at iter1–3, so steady-state perf is unchanged.
tests/unittest/_torch/thop/parallel/test_indexer_topk.py):remove
pytest.mark.flaky(reruns=2)markers on K=512 / K=1024; dropthe
full_range/256bin-quantized tolerance and the unsound|mean| * eps * 4bf16/fp16 branch (bf16 → fp32 promotion inside thekernel is bit-exact and order-preserving, so no eps slack is
warranted); restore
compare_top_k_resultsdefaulttolerance = 1e-5.Additional cleanup in commit 4880bb6 (chore, no live SASS impact):
blockFusedSnapIterDtypehelper (-64 lines) and stale doc paragraphs — the dtype P4 snap path uses the fp32blockFusedSnapIter<TopK>directly after commit 7's deferred-conversion optimization stored smemkeys[]as fp32.kFTargetdocstring to call out that the value is the secant solver's soft steering target (not the convergence condition), and thatkFTarget < kKis intentional for small K — addresses the line-215 review question.Local validation on B200 sm_100 with the fix + strict default
tolerance:
0 failures across 2484 invocations under default strict tolerance.
The former 0.09 % intermittent corner is now stable.
Risks
shows an anomalous BS=200→256 latency spike (29 → 77 µs).
Production deployment is bounded by Scheme X dispatcher (
kBsLargederived from L2 capacity), so the cliff is not reached in
serving traffic — but flagged for follow-up perf work.
added in this PR. Prefill remains fp32-only.
🤖 Generated with Claude Code
Summary by CodeRabbit
Release Notes