[None][perf] DSv4 o_a_proj: enable fused inv-RoPE+FP8 BMM by default + optimize kernel#14254
Conversation
DSv4's o_a_proj BMM in `MLA._deepseek_v4_o_proj` is the only BMM that chains fused_inv_rope_fp8_quant_vllm_port -> cute_dsl_fp8_bmm_blackwell. Today both the init-time decision to keep o_a_proj in FP8 e4m3 (vs. allocating an o_a_proj_dequant bf16 buffer) and the runtime decision to take the fused inv-RoPE + cute-dsl FP8 BMM path are gated on `self.use_cute_dsl_blockscaling_bmm`. That coupling is incidental: - Only DSv4 has o_a_proj (DSv3 / V3 doesn't have this projection). - The flag is meant to choose between cute_dsl and the bf16-fallback for the K/V absorption BMMs (`k_b_proj_trans`, `v_b_proj`). DSv4's o_a_proj doesn't have a non-cute-dsl alternative that's worth using on SM100: without the fused inv-RoPE + FP8 quant + cute-dsl BMM chain, the o_a_proj path falls back to bf16 dequant + torch.bmm, which throws away the FP8 weight savings and adds an extra quant op. This PR removes `self.use_cute_dsl_blockscaling_bmm` from both gates so DSv4 unconditionally uses the FP8-native o_a_proj + fused inv-RoPE + cute-dsl FP8 BMM chain on SM100. The flag still controls K/V absorption BMM kernel selection for both DSv3 and DSv4 - that remains a separate concern and is unaffected by this change. Scope: - DSv4 + SM100 + FP8 block-scales: o_a_proj is now always FP8 native, always uses fused inv-RoPE + cute_dsl_fp8_bmm_blackwell. - DSv4 + SM100 with use_cute_dsl_blockscaling_bmm=True (today): same behavior; no perf change. - DSv4 + SM100 with use_cute_dsl_blockscaling_bmm=False (the default): the o_a_proj path switches from bf16-fallback to the fused FP8 chain. - DSv3 + SM100: untouched (still allocates o_a_proj_dequant; the is_deepseek_v4 gate skips the FP8-native branch). - All non-SM100 paths: untouched. Signed-off-by: Shicheng Li <shicli@nvidia.com>
…t + Triton V2.7 fallback
Promote the fused inv-RoPE + FP8 quant kernel to CuTe DSL by default on SM100,
with the existing Triton kernel staying as the fallback via
`TLLM_DISABLE_CUTE_DSL_FUSED_INV_ROPE=1`.
Optimizations applied to the kernel chain (DSv4-Pro DEP8 M=8192 microbench
on GB200, baseline = `mla_rope_inplace + fp8_batched_quantize_1x128_permute102`):
1. Triton V2.7: BLOCK_TOKENS_M dispatch (1/8/16/32 by M-threshold) + two
separate `@triton.jit` functions for BTM=1 vs BTM>1 (Triton recompiles
differently for the loop-body structure even with constexpr BTM=1, so
splitting avoids a small-M regression that single-kernel V2.x showed).
Per-block work tuned for memory-bound coalescing: num_warps=1 at BTM=1,
num_warps=2 + num_stages=2 at BTM>=8.
2. CuTe DSL V1.5 (new): mirrors Triton's BTM dispatch, plus register-
prefetched pipelining of the next token's LDG.E.128 while consuming the
current — matches Triton's `num_stages=2` schedule. 1 warp / CTA with
per-128 absmax via warp shuffle-bfly. Skips the heavyweight TMA-bulk
pipeline machinery (PipelineTmaAsync, warp-specialised
producer/consumer) because per-CTA tile (16 KB at BTM=16) fits in
registers without smem staging.
Fair-comparison microbench (GB200, lyris0105, n_groups=8, heads_per_group=16,
head_dim=512, quant_group=128, DeepGEMM-style bench: 500-iter warmup,
single-elapsed_time / N-iter mean):
M baseline Triton V2.7 CuTe DSL V1.5 Triton vs base CuTe vs base
64 70.6 us 91.4 us 74.1 us +29% slower +5% slower
256 74.5 us 74.7 us 66.1 us tied -11% faster
1024 205.0 us 74.6 us 65.7 us -64% faster -68% faster
2048 379.3 us 107.0 us 107.0 us -72% faster -72% faster
4096 752.8 us 206.5 us 206.4 us -73% faster -73% faster
8192 1496.1 us 395.8 us 395.7 us -74% faster -74% faster
At M >= 2048 both Triton V2.7 and CuTe DSL V1.5 saturate B200 HBM at
~4.11 TB/s (51% of 8 TB/s peak); CuTe DSL is 5-19% faster than Triton at
M <= 1024 thanks to BTM-prefetch.
Files:
* tensorrt_llm/_torch/cute_dsl_kernels/blackwell/fused_inv_rope_fp8_quant.py (new)
- Sm100FusedInvRopeFp8QuantKernel: BTM-templated kernel + register prefetch
- `_fused_inv_rope_fp8_quant_impl_cute_dsl`: host wrapper with JIT cache
keyed on (shape constants, BTM, dtype)
* tensorrt_llm/_torch/custom_ops/triton_fused_inv_rope_fp8_quant.py
- Triton V0 -> V2.7 (BLOCK_TOKENS_M dispatch + two-kernel split + nw=2)
- Dispatch in `_fused_inv_rope_fp8_quant_impl` now prefers CuTe DSL by
default; falls back to Triton when the cutlass DSL stack is absent or
`TLLM_DISABLE_CUTE_DSL_FUSED_INV_ROPE=1` is set
* tests/unittest/_torch/custom_ops/test_fused_inv_rope_fp8_quant.py
- Cover both the default (CuTe DSL) backend and the Triton fallback at
num_tokens in {3, 64, 257, 512, 1024, 2048, 8192}; both pass within
~6e-4 rel diff vs the legacy two-kernel reference path
Signed-off-by: Shicheng Li <shicli@nvidia.com>
… default; CuTe DSL opt-in
Follow-up to the previous commit. A 4-allocation median microbench on B200
(GB200, gb200 partition, lyris0035 / 0068 / 0136 / 0217 / 0053 / 0105 / 0179
nodes across all runs) shows the CuTe DSL backend and the optimized Triton
kernel are tied at every M:
M unoptimized optimized triton cute dsl opt triton/cute dsl
64 68.7 61.7 60.6 within ~3 us (noise)
256 61.1 58.7 57.8 within ~1 us (noise)
1024 74.2 64.5 64.6 tied (delta 0.1 us)
2048 142.0 105.3 105.3 tied (delta < 0.1 us)
4096 279.8 204.2 204.6 tied (delta 0.4 us)
8192 554.4 391.7 391.7 tied (delta < 0.1 us)
Earlier "CuTe DSL is 12-19% faster at small M" finding was a single-allocation
artifact -- at M <= 1024 the per-chip clock variance dominates over any real
kernel-quality difference. With 4 independent allocations the medians coincide.
Action: make optimized Triton the default backend; keep CuTe DSL in tree as
an opt-in alternative via `TLLM_USE_CUTE_DSL_FUSED_INV_ROPE=1`. CuTe DSL
remains useful as the scaffold future TMA / warp-spec V2 will build on; the
register-prefetch BTM>1 path already matches Triton, so the V2 framework
investment can pivot to TMA-bulk + warp-spec for compute-bound shapes.
Changes (all to existing files; no new files):
* tensorrt_llm/_torch/custom_ops/triton_fused_inv_rope_fp8_quant.py
- Default backend is Triton (the path inlined in this file).
- `_fused_inv_rope_fp8_quant_impl` selects CuTe DSL only when
`TLLM_USE_CUTE_DSL_FUSED_INV_ROPE=1` is set (was: default-on with
opt-out `TLLM_DISABLE_CUTE_DSL_FUSED_INV_ROPE=1`).
* tensorrt_llm/_torch/cute_dsl_kernels/blackwell/fused_inv_rope_fp8_quant.py
- Docstring header updated to reflect the opt-in env var name.
* tests/unittest/_torch/custom_ops/test_fused_inv_rope_fp8_quant.py
- Default test covers optimized Triton.
- Opt-in test renamed to `test_fused_inv_rope_fp8_quant_neox_cute_dsl`
and sets `TLLM_USE_CUTE_DSL_FUSED_INV_ROPE=1`.
Signed-off-by: Shicheng Li <shicli@nvidia.com>
The CuTe DSL backend added in commit 3815456 was never functional on cluster. It used a 2-dot relative import (`from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE`) from `cute_dsl_kernels/blackwell/...` which resolves to `cute_dsl_kernels.cute_dsl_utils` — a path that does not exist. The dispatch's bare `except Exception:` silently swallowed the ImportError and set `_cute_dsl_backend = None`, so every invocation fell through to the Triton kernel. After fixing the import path to 3 dots (`from ...cute_dsl_utils`), the V1.5 kernel itself failed `cute.compile` on every M: - BTM=1 (M < 1024): opaque "ICE IR Verification Failed" inside the cutlass-DSL -> MLIR lowering. - BTM>1 (M >= 1024): "`pid_token` is None prior to this `for`, and update to Int32 inside of this `for` is not supported." The kernel's `pid_token = base_token + cutlass.Int32(m_in_block)` is assigned inside `cutlass.range(...)` without being initialised before it. Neither failure mode is a 10-minute fix. The BTM>1 case needs the kernel restructured to predeclare `pid_token` at function scope; the BTM=1 ICE needs a minimal repro to escalate to the DSL team. The optimized Triton kernel is the only functional backend either way. Dropping CuTe DSL keeps the PR honest: * tensorrt_llm/_torch/custom_ops/triton_fused_inv_rope_fp8_quant.py - Remove the CuTe DSL lazy-import block and the env-gated dispatch in `_fused_inv_rope_fp8_quant_impl`. - Remove the now-unused `import os`. * tensorrt_llm/_torch/cute_dsl_kernels/blackwell/fused_inv_rope_fp8_quant.py - Deleted. * tests/unittest/_torch/custom_ops/test_fused_inv_rope_fp8_quant.py - Remove the `test_fused_inv_rope_fp8_quant_neox_cute_dsl` test and the `backend_env` plumbing; only the default (Triton) path remains. Microbench numbers in the PR description are also updated — the previous "CuTe DSL" column was always the Triton kernel running on both sides (because of the silent fallback above), so the column provided no information and the "tied within 0.05 us" claim was trivially true (kernel vs itself). Description now compares Unoptimized vs Optimized Triton directly. A future PR can re-add a CuTe DSL backend once the V1.5 kernel actually compiles. Signed-off-by: Shicheng Li <shicli@nvidia.com>
|
/bot run --add-multi-gpu-test |
|
PR_Github #49297 [ run ] triggered by Bot. Commit: |
|
PR_Github #49297 [ run ] completed with state
|
…th BMM consumer The optimized multi-token-per-block (BTM > 4) path padded the scale buffer's m-dim to `pad_up(M, max(BTM, 4))`, but the downstream `cute_dsl_fp8_bmm_blackwell` hard-codes `sf_m = pad_up(m, 4)` for its scale m-stride. When `M` was not a multiple of `BTM` (e.g. M=1500 with BTM=16: producer stride 1504 vs consumer assumed stride 1500), the BMM read scales from drifting wrong physical offsets, silently corrupting dequant — dropping DSv4-Pro DEP8 gsm8k strict-match by 13.3 pts (96.59 -> 83.32) and flexible-extract by 6.8 pts (96.51 -> 89.69). TEP8 was also affected (~1 pt). Microbenches at powers-of-2 M missed it because those M are already multiples of BTM, so the producer/consumer strides happened to coincide. Fix: keep the scale buffer padded to `pad_up(num_tokens, 4)` regardless of BTM, matching the consumer's contract. The mblock kernel's grid may now overshoot `scale_buf_m` when `pad_up(M, 4) % BTM != 0`; a new `scale_buf_m` constexpr lets the inner loop bail cleanly on overshoot rows. Perf impact is < 2% at all production M (within run-to-run noise on a single allocation). Post-fix gsm8k (1319 questions, 5-shot): TEP8: flex 96.29 / strict 96.36 (was 95.22 / 95.22) DEP8: flex 96.44 / strict 96.51 (was 89.69 / 83.32, matches bf16 baseline) Signed-off-by: Shicheng Li <shicli@nvidia.com>
|
/bot run --add-multi-gpu-test |
|
PR_Github #49326 [ run ] triggered by Bot. Commit: |
|
PR_Github #49326 [ run ] completed with state
|
mingyangHao
left a comment
There was a problem hiding this comment.
This is a good implementation, I think we can merge this ATM.
|
/bot run --add-multi-gpu-test |
…gnment PR NVIDIA#14254 decoupled DSv4's FP8-native o_a_proj path from use_cute_dsl_blockscaling_bmm: on SM100, DSv4 unconditionally takes the fused inv-RoPE + FP8 quant + cute-dsl BMM chain and never falls back to bf16 dequant. As a side effect, MLA.__init__ no longer allocates self.o_a_proj_dequant for DSv4 (it stays None). The test was setting `mla.o_a_proj_dequant.data = o_a_proj_bf16` for the fp8 case, which now hits `AttributeError: 'NoneType' object has no attribute 'data'`. That assignment was a no-op: the reference computation reads `o_a_proj_bf16` directly (line 245), not via the mla buffer. Drop the line and leave a comment explaining the new invariant for future readers. Signed-off-by: Shicheng Li <shicli@nvidia.com>
|
PR_Github #49412 [ run ] triggered by Bot. Commit: |
|
/bot run --add-multi-gpu-test |
|
PR_Github #49414 [ run ] triggered by Bot. Commit: |
|
PR_Github #49412 [ run ] completed with state |
|
PR_Github #49414 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #49830 [ run ] triggered by Bot. Commit: |
|
PR_Github #49830 [ run ] completed with state
|

Description
This PR does two things on DSv4's
o_a_projMLA path.1. Enable fused inv-RoPE + FP8 BMM for
o_a_projby defaultRemoves
self.use_cute_dsl_blockscaling_bmmfrom the two gates inMLA._deepseek_v4_o_proj(init-time FP8 weight retention + runtime kernel selection) so DSv4 unconditionally uses the FP8-nativeo_a_proj+ fused inv-RoPE + cute-dsl FP8 BMM chain on SM100. The flag still controls K/V absorption BMM kernel selection (k_b_proj_trans,v_b_proj) for both DSv3 and DSv4 — that is a separate kernel-choice concern and is unaffected here.2. Optimize the fused inv-RoPE + FP8 quant kernel
Making the fused chain the default means the kernel now runs on every DSv4 deployment. The original (vLLM-port Triton) was tuned for a small-M GEN-side regime and underperformed at the CTX-side M (≥ 2048). This PR refreshes the kernel with the changes below:
@triton.jitfunctions, one each for BTM=1 and BTM>1 — Triton recompiles differently with an outer loop scaffold even at constexpr BTM=1, so physically splitting the two regimes avoids a small-M regression a single-kernel variant showed.num_warps/num_stagesretune —nw=1, ns=1at BTM=1;nw=2, ns=2at BTM ≥ 8. Right band for memory-bound block-scale quant kernels:nw=4either breaks per-128 coalescing (at BTM=1) or starves SM occupancy via register pressure (at BTM=16).tl.range(0, BTM, num_stages=2)software-pipelines the next iter's load with the current iter's compute and store.Microbench results
Measured on GB200 (B200, HBM3e, ~8 TB/s peak per GPU), DSv4-Pro DEP8 shape (
n_groups=8, heads_per_group=16, head_dim=512 (nope=448+rope=64), quant_group=128). DeepGEMM-style timing (500-iter warmup, singlecuda.Event-pair elapsed_time / N-iter mean, L2 flush pre-warmup, large fp32 matmul to warm host launch path). Median across 4 independent same-allocation legs on 4 different gb200 nodes to suppress per-chip clock-state variance.Both legs run the same
fused_inv_rope_fp8_quant_vllm_portop; only the BTM dispatch differs. Unoptimized = the original vLLM-port single-config BTM=1 kernel (the pre-PR baseline). Optimized = the BTM-dispatched + retuned Triton kernel from §2.At M ≥ 2048 (the CTX-side regime for DSv4-Pro ISL=8192 workloads) the optimization yields a structural 26-29 % win over the unoptimized vLLM-port. The optimized Triton kernel saturates B200 HBM at ~4.11 TB/s = 51 % of the 8 TB/s peak. The unoptimized BTM=1 kernel sits at ~2.9 TB/s = 37 % of peak — the gap closed by the BTM dispatch + retune is reproducible across all 4 allocations.
At M = 1024 (BTM=8 in the optimized dispatch) the optimization yields a 13 % win — also reproducible across allocations.
At M ≤ 256 both Unoptimized and Optimized dispatch the BTM=1 kernel — same code path, so any difference is measurement noise. The cross-allocation spread at this regime (kernel wall ≈ 60-90 µs) is ±15 µs.
ncu --set fullon the optimized kernel at M=8192 confirms the structural ceiling: 86.5 % Compute SM throughput / 46.0 % memory throughput / 31.0 % DRAM throughput, 60.1 % achieved occupancy at the 62.5 % theoretical max (register-limited),Warp Cycles Per Issued Instruction = 10.98. The kernel is compute-bound on per-warp instruction-issue with no occupancy headroom to hide load latency; further gains beyond this point would need either a kernel-structure change (e.g. TMA-bulk + warp-spec) or a host-side improvement.End-to-end accuracy (gsm8k)
Goal: confirm the PR's "FP8-native o_a_proj + fused inv-RoPE + cute-dsl FP8 BMM" chain — which this PR makes the unconditional default for DSv4 on SM100 — produces the same gsm8k accuracy as the pre-PR bf16-fallback path that ran whenever
use_cute_dsl_blockscaling_bmm=False(the default).Setup. DSv4-Pro on GB200, TP8 EP8,
max_batch=16,kv_cache_free_gpu_mem_fraction=0.7, greedy viatrtllm-eval gsm8k --apply_chat_template, 1319 questions, 5-shot. Each row is one full eval run.A/B comparison (DEP8, attention DP on — the path most sensitive to the scale-buf padding bug fixed on this branch):
mla_rope_inplace→ bf16torch.bmm(with bf16-dequanto_a_proj)fused_inv_rope_fp8_quant_vllm_port→cute_dsl_fp8_bmm_blackwell(FP8o_a_proj)Δ is within run-to-run noise (stderr ≈ 0.5 on each score). DEP8 sees no accuracy regression from switching to the FP8 chain.
TEP8 sanity check (attention DP off) — only the PR-on number measured since this config has never had the failure mode that triggered the DEP8 investigation:
fused_inv_rope_fp8_quant_vllm_port→cute_dsl_fp8_bmm_blackwell(FP8o_a_proj)Comparable to DEP8's PR-on row; matches expected DSv4-Pro gsm8k performance.
Test Coverage
tests/unittest/_torch/custom_ops/test_fused_inv_rope_fp8_quant.pycovers the fused op atnum_tokens ∈ {3, 64, 257, 512, 1024, 2048, 8192}. The op matches the legacy 2-kernel reference (mla_rope_inplace + fp8_batched_quantize_1x128_permute102) within ~6e-4 relative diff in dequantized BF16 space — well inside the 1 % bound the test asserts.tests/unittest/_torch/attention/sparse/deepseek_v4/test_deepseek_v4_o_proj.pyexercises the DSv4o_a_projBMM on SM100. With this PR, that test now hits the FP8-native + fused chain by default (no need to manually setuse_cute_dsl_blockscaling_bmm=True).o_a_projpath activates without setting the flag and produces the expected fused-chain device-time profile.PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.