[None][perf] Optimize DSv4 FP8 o_a_proj path#14446
Open
lfr-0531 wants to merge 8 commits into
Open
Conversation
DSv4's o_a_proj BMM in `MLA._deepseek_v4_o_proj` is the only BMM that chains fused_inv_rope_fp8_quant_vllm_port -> cute_dsl_fp8_bmm_blackwell. Today both the init-time decision to keep o_a_proj in FP8 e4m3 (vs. allocating an o_a_proj_dequant bf16 buffer) and the runtime decision to take the fused inv-RoPE + cute-dsl FP8 BMM path are gated on `self.use_cute_dsl_blockscaling_bmm`. That coupling is incidental: - Only DSv4 has o_a_proj (DSv3 / V3 doesn't have this projection). - The flag is meant to choose between cute_dsl and the bf16-fallback for the K/V absorption BMMs (`k_b_proj_trans`, `v_b_proj`). DSv4's o_a_proj doesn't have a non-cute-dsl alternative that's worth using on SM100: without the fused inv-RoPE + FP8 quant + cute-dsl BMM chain, the o_a_proj path falls back to bf16 dequant + torch.bmm, which throws away the FP8 weight savings and adds an extra quant op. This PR removes `self.use_cute_dsl_blockscaling_bmm` from both gates so DSv4 unconditionally uses the FP8-native o_a_proj + fused inv-RoPE + cute-dsl FP8 BMM chain on SM100. The flag still controls K/V absorption BMM kernel selection for both DSv3 and DSv4 - that remains a separate concern and is unaffected by this change. Scope: - DSv4 + SM100 + FP8 block-scales: o_a_proj is now always FP8 native, always uses fused inv-RoPE + cute_dsl_fp8_bmm_blackwell. - DSv4 + SM100 with use_cute_dsl_blockscaling_bmm=True (today): same behavior; no perf change. - DSv4 + SM100 with use_cute_dsl_blockscaling_bmm=False (the default): the o_a_proj path switches from bf16-fallback to the fused FP8 chain. - DSv3 + SM100: untouched (still allocates o_a_proj_dequant; the is_deepseek_v4 gate skips the FP8-native branch). - All non-SM100 paths: untouched. Signed-off-by: Shicheng Li <shicli@nvidia.com>
…t + Triton V2.7 fallback
Promote the fused inv-RoPE + FP8 quant kernel to CuTe DSL by default on SM100,
with the existing Triton kernel staying as the fallback via
`TLLM_DISABLE_CUTE_DSL_FUSED_INV_ROPE=1`.
Optimizations applied to the kernel chain (DSv4-Pro DEP8 M=8192 microbench
on GB200, baseline = `mla_rope_inplace + fp8_batched_quantize_1x128_permute102`):
1. Triton V2.7: BLOCK_TOKENS_M dispatch (1/8/16/32 by M-threshold) + two
separate `@triton.jit` functions for BTM=1 vs BTM>1 (Triton recompiles
differently for the loop-body structure even with constexpr BTM=1, so
splitting avoids a small-M regression that single-kernel V2.x showed).
Per-block work tuned for memory-bound coalescing: num_warps=1 at BTM=1,
num_warps=2 + num_stages=2 at BTM>=8.
2. CuTe DSL V1.5 (new): mirrors Triton's BTM dispatch, plus register-
prefetched pipelining of the next token's LDG.E.128 while consuming the
current — matches Triton's `num_stages=2` schedule. 1 warp / CTA with
per-128 absmax via warp shuffle-bfly. Skips the heavyweight TMA-bulk
pipeline machinery (PipelineTmaAsync, warp-specialised
producer/consumer) because per-CTA tile (16 KB at BTM=16) fits in
registers without smem staging.
Fair-comparison microbench (GB200, lyris0105, n_groups=8, heads_per_group=16,
head_dim=512, quant_group=128, DeepGEMM-style bench: 500-iter warmup,
single-elapsed_time / N-iter mean):
M baseline Triton V2.7 CuTe DSL V1.5 Triton vs base CuTe vs base
64 70.6 us 91.4 us 74.1 us +29% slower +5% slower
256 74.5 us 74.7 us 66.1 us tied -11% faster
1024 205.0 us 74.6 us 65.7 us -64% faster -68% faster
2048 379.3 us 107.0 us 107.0 us -72% faster -72% faster
4096 752.8 us 206.5 us 206.4 us -73% faster -73% faster
8192 1496.1 us 395.8 us 395.7 us -74% faster -74% faster
At M >= 2048 both Triton V2.7 and CuTe DSL V1.5 saturate B200 HBM at
~4.11 TB/s (51% of 8 TB/s peak); CuTe DSL is 5-19% faster than Triton at
M <= 1024 thanks to BTM-prefetch.
Files:
* tensorrt_llm/_torch/cute_dsl_kernels/blackwell/fused_inv_rope_fp8_quant.py (new)
- Sm100FusedInvRopeFp8QuantKernel: BTM-templated kernel + register prefetch
- `_fused_inv_rope_fp8_quant_impl_cute_dsl`: host wrapper with JIT cache
keyed on (shape constants, BTM, dtype)
* tensorrt_llm/_torch/custom_ops/triton_fused_inv_rope_fp8_quant.py
- Triton V0 -> V2.7 (BLOCK_TOKENS_M dispatch + two-kernel split + nw=2)
- Dispatch in `_fused_inv_rope_fp8_quant_impl` now prefers CuTe DSL by
default; falls back to Triton when the cutlass DSL stack is absent or
`TLLM_DISABLE_CUTE_DSL_FUSED_INV_ROPE=1` is set
* tests/unittest/_torch/custom_ops/test_fused_inv_rope_fp8_quant.py
- Cover both the default (CuTe DSL) backend and the Triton fallback at
num_tokens in {3, 64, 257, 512, 1024, 2048, 8192}; both pass within
~6e-4 rel diff vs the legacy two-kernel reference path
Signed-off-by: Shicheng Li <shicli@nvidia.com>
… default; CuTe DSL opt-in
Follow-up to the previous commit. A 4-allocation median microbench on B200
(GB200, gb200 partition, lyris0035 / 0068 / 0136 / 0217 / 0053 / 0105 / 0179
nodes across all runs) shows the CuTe DSL backend and the optimized Triton
kernel are tied at every M:
M unoptimized optimized triton cute dsl opt triton/cute dsl
64 68.7 61.7 60.6 within ~3 us (noise)
256 61.1 58.7 57.8 within ~1 us (noise)
1024 74.2 64.5 64.6 tied (delta 0.1 us)
2048 142.0 105.3 105.3 tied (delta < 0.1 us)
4096 279.8 204.2 204.6 tied (delta 0.4 us)
8192 554.4 391.7 391.7 tied (delta < 0.1 us)
Earlier "CuTe DSL is 12-19% faster at small M" finding was a single-allocation
artifact -- at M <= 1024 the per-chip clock variance dominates over any real
kernel-quality difference. With 4 independent allocations the medians coincide.
Action: make optimized Triton the default backend; keep CuTe DSL in tree as
an opt-in alternative via `TLLM_USE_CUTE_DSL_FUSED_INV_ROPE=1`. CuTe DSL
remains useful as the scaffold future TMA / warp-spec V2 will build on; the
register-prefetch BTM>1 path already matches Triton, so the V2 framework
investment can pivot to TMA-bulk + warp-spec for compute-bound shapes.
Changes (all to existing files; no new files):
* tensorrt_llm/_torch/custom_ops/triton_fused_inv_rope_fp8_quant.py
- Default backend is Triton (the path inlined in this file).
- `_fused_inv_rope_fp8_quant_impl` selects CuTe DSL only when
`TLLM_USE_CUTE_DSL_FUSED_INV_ROPE=1` is set (was: default-on with
opt-out `TLLM_DISABLE_CUTE_DSL_FUSED_INV_ROPE=1`).
* tensorrt_llm/_torch/cute_dsl_kernels/blackwell/fused_inv_rope_fp8_quant.py
- Docstring header updated to reflect the opt-in env var name.
* tests/unittest/_torch/custom_ops/test_fused_inv_rope_fp8_quant.py
- Default test covers optimized Triton.
- Opt-in test renamed to `test_fused_inv_rope_fp8_quant_neox_cute_dsl`
and sets `TLLM_USE_CUTE_DSL_FUSED_INV_ROPE=1`.
Signed-off-by: Shicheng Li <shicli@nvidia.com>
The CuTe DSL backend added in commit 3815456 was never functional on cluster. It used a 2-dot relative import (`from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE`) from `cute_dsl_kernels/blackwell/...` which resolves to `cute_dsl_kernels.cute_dsl_utils` — a path that does not exist. The dispatch's bare `except Exception:` silently swallowed the ImportError and set `_cute_dsl_backend = None`, so every invocation fell through to the Triton kernel. After fixing the import path to 3 dots (`from ...cute_dsl_utils`), the V1.5 kernel itself failed `cute.compile` on every M: - BTM=1 (M < 1024): opaque "ICE IR Verification Failed" inside the cutlass-DSL -> MLIR lowering. - BTM>1 (M >= 1024): "`pid_token` is None prior to this `for`, and update to Int32 inside of this `for` is not supported." The kernel's `pid_token = base_token + cutlass.Int32(m_in_block)` is assigned inside `cutlass.range(...)` without being initialised before it. Neither failure mode is a 10-minute fix. The BTM>1 case needs the kernel restructured to predeclare `pid_token` at function scope; the BTM=1 ICE needs a minimal repro to escalate to the DSL team. The optimized Triton kernel is the only functional backend either way. Dropping CuTe DSL keeps the PR honest: * tensorrt_llm/_torch/custom_ops/triton_fused_inv_rope_fp8_quant.py - Remove the CuTe DSL lazy-import block and the env-gated dispatch in `_fused_inv_rope_fp8_quant_impl`. - Remove the now-unused `import os`. * tensorrt_llm/_torch/cute_dsl_kernels/blackwell/fused_inv_rope_fp8_quant.py - Deleted. * tests/unittest/_torch/custom_ops/test_fused_inv_rope_fp8_quant.py - Remove the `test_fused_inv_rope_fp8_quant_neox_cute_dsl` test and the `backend_env` plumbing; only the default (Triton) path remains. Microbench numbers in the PR description are also updated — the previous "CuTe DSL" column was always the Triton kernel running on both sides (because of the silent fallback above), so the column provided no information and the "tied within 0.05 us" claim was trivially true (kernel vs itself). Description now compares Unoptimized vs Optimized Triton directly. A future PR can re-add a CuTe DSL backend once the V1.5 kernel actually compiles. Signed-off-by: Shicheng Li <shicli@nvidia.com>
…th BMM consumer The optimized multi-token-per-block (BTM > 4) path padded the scale buffer's m-dim to `pad_up(M, max(BTM, 4))`, but the downstream `cute_dsl_fp8_bmm_blackwell` hard-codes `sf_m = pad_up(m, 4)` for its scale m-stride. When `M` was not a multiple of `BTM` (e.g. M=1500 with BTM=16: producer stride 1504 vs consumer assumed stride 1500), the BMM read scales from drifting wrong physical offsets, silently corrupting dequant — dropping DSv4-Pro DEP8 gsm8k strict-match by 13.3 pts (96.59 -> 83.32) and flexible-extract by 6.8 pts (96.51 -> 89.69). TEP8 was also affected (~1 pt). Microbenches at powers-of-2 M missed it because those M are already multiples of BTM, so the producer/consumer strides happened to coincide. Fix: keep the scale buffer padded to `pad_up(num_tokens, 4)` regardless of BTM, matching the consumer's contract. The mblock kernel's grid may now overshoot `scale_buf_m` when `pad_up(M, 4) % BTM != 0`; a new `scale_buf_m` constexpr lets the inner loop bail cleanly on overshoot rows. Perf impact is < 2% at all production M (within run-to-run noise on a single allocation). Post-fix gsm8k (1319 questions, 5-shot): TEP8: flex 96.29 / strict 96.36 (was 95.22 / 95.22) DEP8: flex 96.44 / strict 96.51 (was 89.69 / 83.32, matches bf16 baseline) Signed-off-by: Shicheng Li <shicli@nvidia.com>
…gnment PR NVIDIA#14254 decoupled DSv4's FP8-native o_a_proj path from use_cute_dsl_blockscaling_bmm: on SM100, DSv4 unconditionally takes the fused inv-RoPE + FP8 quant + cute-dsl BMM chain and never falls back to bf16 dequant. As a side effect, MLA.__init__ no longer allocates self.o_a_proj_dequant for DSv4 (it stays None). The test was setting `mla.o_a_proj_dequant.data = o_a_proj_bf16` for the fp8 case, which now hits `AttributeError: 'NoneType' object has no attribute 'data'`. That assignment was a no-op: the reference computation reads `o_a_proj_bf16` directly (line 245), not via the mla buffer. Drop the line and leave a comment explaining the new invariant for future readers. Signed-off-by: Shicheng Li <shicli@nvidia.com>
Signed-off-by: Fanrong Li <lfr-0531@users.noreply.github.com>
Collaborator
Author
|
/bot run --disable-fail-fast |
Collaborator
|
PR_Github #49881 [ run ] triggered by Bot. Commit: |
Collaborator
|
PR_Github #49881 [ run ] completed with state
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
@coderabbitai summary
Description
This PR pulls the original DeepSeek-V4 FP8 o_a_proj changes from #14254 onto the latest
feat/deepseek_v4branch and fixes the CuTe DSL API incompatibility seen in CI.Changes included:
o_a_projin native FP8 on SM100 and decouple it fromuse_cute_dsl_blockscaling_bmm.cute_dsl_fp8_bmm_blackwelldirectly.cute.arch.ProxyKind/SharedSpaceAPI tocute.arch.fence_view_async_shared().Test Coverage
pre-commit run --files <changed files>python -m py_compile $(rg --files tensorrt_llm/_torch/cute_dsl_kernels/blackwell -g '*.py')cute_dsl_fp8_bmm_blackwell; numerical diff vs BF16 reference was~6.8e-4(< 1e-3).PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.