Skip to content

[None][perf] Optimize DSv4 FP8 o_a_proj path#14446

Open
lfr-0531 wants to merge 8 commits into
NVIDIA:feat/deepseek_v4from
lfr-0531:user/fanrongl/fix-dsv4-cutedsl-api
Open

[None][perf] Optimize DSv4 FP8 o_a_proj path#14446
lfr-0531 wants to merge 8 commits into
NVIDIA:feat/deepseek_v4from
lfr-0531:user/fanrongl/fix-dsv4-cutedsl-api

Conversation

@lfr-0531
Copy link
Copy Markdown
Collaborator

@coderabbitai summary

Description

This PR pulls the original DeepSeek-V4 FP8 o_a_proj changes from #14254 onto the latest feat/deepseek_v4 branch and fixes the CuTe DSL API incompatibility seen in CI.

Changes included:

  • Keep DSv4 o_a_proj in native FP8 on SM100 and decouple it from use_cute_dsl_blockscaling_bmm.
  • Add the fused inverse-RoPE + FP8 quant path feeding cute_dsl_fp8_bmm_blackwell directly.
  • Keep the optimized Triton fused inverse-RoPE quant implementation as the default path and align its scale-buffer padding with the BMM consumer.
  • Update Blackwell CuTe DSL kernels from the removed cute.arch.ProxyKind / SharedSpace API to cute.arch.fence_view_async_shared().

Test Coverage

  • pre-commit run --files <changed files>
  • python -m py_compile $(rg --files tensorrt_llm/_torch/cute_dsl_kernels/blackwell -g '*.py')
  • B300 local minimal validation with current source: compile and run cute_dsl_fp8_bmm_blackwell; numerical diff vs BF16 reference was ~6.8e-4 (< 1e-3).

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

lishicheng1996-nv and others added 7 commits May 22, 2026 06:53
DSv4's o_a_proj BMM in `MLA._deepseek_v4_o_proj` is the only BMM that
chains fused_inv_rope_fp8_quant_vllm_port -> cute_dsl_fp8_bmm_blackwell.
Today both the init-time decision to keep o_a_proj in FP8 e4m3 (vs.
allocating an o_a_proj_dequant bf16 buffer) and the runtime decision to
take the fused inv-RoPE + cute-dsl FP8 BMM path are gated on
`self.use_cute_dsl_blockscaling_bmm`. That coupling is incidental:
- Only DSv4 has o_a_proj (DSv3 / V3 doesn't have this projection).
- The flag is meant to choose between cute_dsl and the bf16-fallback for
  the K/V absorption BMMs (`k_b_proj_trans`, `v_b_proj`). DSv4's o_a_proj
  doesn't have a non-cute-dsl alternative that's worth using on SM100:
  without the fused inv-RoPE + FP8 quant + cute-dsl BMM chain, the
  o_a_proj path falls back to bf16 dequant + torch.bmm, which throws
  away the FP8 weight savings and adds an extra quant op.

This PR removes `self.use_cute_dsl_blockscaling_bmm` from both gates so
DSv4 unconditionally uses the FP8-native o_a_proj + fused inv-RoPE +
cute-dsl FP8 BMM chain on SM100. The flag still controls K/V absorption
BMM kernel selection for both DSv3 and DSv4 - that remains a separate
concern and is unaffected by this change.

Scope:
- DSv4 + SM100 + FP8 block-scales: o_a_proj is now always FP8 native,
  always uses fused inv-RoPE + cute_dsl_fp8_bmm_blackwell.
- DSv4 + SM100 with use_cute_dsl_blockscaling_bmm=True (today): same
  behavior; no perf change.
- DSv4 + SM100 with use_cute_dsl_blockscaling_bmm=False (the default):
  the o_a_proj path switches from bf16-fallback to the fused FP8 chain.
- DSv3 + SM100: untouched (still allocates o_a_proj_dequant; the
  is_deepseek_v4 gate skips the FP8-native branch).
- All non-SM100 paths: untouched.

Signed-off-by: Shicheng Li <shicli@nvidia.com>
…t + Triton V2.7 fallback

Promote the fused inv-RoPE + FP8 quant kernel to CuTe DSL by default on SM100,
with the existing Triton kernel staying as the fallback via
`TLLM_DISABLE_CUTE_DSL_FUSED_INV_ROPE=1`.

Optimizations applied to the kernel chain (DSv4-Pro DEP8 M=8192 microbench
on GB200, baseline = `mla_rope_inplace + fp8_batched_quantize_1x128_permute102`):

 1. Triton V2.7: BLOCK_TOKENS_M dispatch (1/8/16/32 by M-threshold) + two
    separate `@triton.jit` functions for BTM=1 vs BTM>1 (Triton recompiles
    differently for the loop-body structure even with constexpr BTM=1, so
    splitting avoids a small-M regression that single-kernel V2.x showed).
    Per-block work tuned for memory-bound coalescing: num_warps=1 at BTM=1,
    num_warps=2 + num_stages=2 at BTM>=8.

 2. CuTe DSL V1.5 (new): mirrors Triton's BTM dispatch, plus register-
    prefetched pipelining of the next token's LDG.E.128 while consuming the
    current — matches Triton's `num_stages=2` schedule. 1 warp / CTA with
    per-128 absmax via warp shuffle-bfly. Skips the heavyweight TMA-bulk
    pipeline machinery (PipelineTmaAsync, warp-specialised
    producer/consumer) because per-CTA tile (16 KB at BTM=16) fits in
    registers without smem staging.

Fair-comparison microbench (GB200, lyris0105, n_groups=8, heads_per_group=16,
head_dim=512, quant_group=128, DeepGEMM-style bench: 500-iter warmup,
single-elapsed_time / N-iter mean):

 M     baseline    Triton V2.7   CuTe DSL V1.5   Triton vs base   CuTe vs base
 64    70.6 us     91.4 us       74.1 us         +29% slower      +5% slower
 256   74.5 us     74.7 us       66.1 us         tied             -11% faster
 1024  205.0 us    74.6 us       65.7 us         -64% faster      -68% faster
 2048  379.3 us    107.0 us      107.0 us        -72% faster      -72% faster
 4096  752.8 us    206.5 us      206.4 us        -73% faster      -73% faster
 8192  1496.1 us   395.8 us      395.7 us        -74% faster      -74% faster

At M >= 2048 both Triton V2.7 and CuTe DSL V1.5 saturate B200 HBM at
~4.11 TB/s (51% of 8 TB/s peak); CuTe DSL is 5-19% faster than Triton at
M <= 1024 thanks to BTM-prefetch.

Files:

 * tensorrt_llm/_torch/cute_dsl_kernels/blackwell/fused_inv_rope_fp8_quant.py (new)
   - Sm100FusedInvRopeFp8QuantKernel: BTM-templated kernel + register prefetch
   - `_fused_inv_rope_fp8_quant_impl_cute_dsl`: host wrapper with JIT cache
     keyed on (shape constants, BTM, dtype)
 * tensorrt_llm/_torch/custom_ops/triton_fused_inv_rope_fp8_quant.py
   - Triton V0 -> V2.7 (BLOCK_TOKENS_M dispatch + two-kernel split + nw=2)
   - Dispatch in `_fused_inv_rope_fp8_quant_impl` now prefers CuTe DSL by
     default; falls back to Triton when the cutlass DSL stack is absent or
     `TLLM_DISABLE_CUTE_DSL_FUSED_INV_ROPE=1` is set
 * tests/unittest/_torch/custom_ops/test_fused_inv_rope_fp8_quant.py
   - Cover both the default (CuTe DSL) backend and the Triton fallback at
     num_tokens in {3, 64, 257, 512, 1024, 2048, 8192}; both pass within
     ~6e-4 rel diff vs the legacy two-kernel reference path

Signed-off-by: Shicheng Li <shicli@nvidia.com>
… default; CuTe DSL opt-in

Follow-up to the previous commit. A 4-allocation median microbench on B200
(GB200, gb200 partition, lyris0035 / 0068 / 0136 / 0217 / 0053 / 0105 / 0179
nodes across all runs) shows the CuTe DSL backend and the optimized Triton
kernel are tied at every M:

  M     unoptimized   optimized triton   cute dsl   opt triton/cute dsl
  64       68.7              61.7           60.6    within ~3 us (noise)
  256      61.1              58.7           57.8    within ~1 us (noise)
  1024     74.2              64.5           64.6    tied (delta 0.1 us)
  2048     142.0             105.3          105.3   tied (delta < 0.1 us)
  4096     279.8             204.2          204.6   tied (delta 0.4 us)
  8192     554.4             391.7          391.7   tied (delta < 0.1 us)

Earlier "CuTe DSL is 12-19% faster at small M" finding was a single-allocation
artifact -- at M <= 1024 the per-chip clock variance dominates over any real
kernel-quality difference. With 4 independent allocations the medians coincide.

Action: make optimized Triton the default backend; keep CuTe DSL in tree as
an opt-in alternative via `TLLM_USE_CUTE_DSL_FUSED_INV_ROPE=1`. CuTe DSL
remains useful as the scaffold future TMA / warp-spec V2 will build on; the
register-prefetch BTM>1 path already matches Triton, so the V2 framework
investment can pivot to TMA-bulk + warp-spec for compute-bound shapes.

Changes (all to existing files; no new files):

 * tensorrt_llm/_torch/custom_ops/triton_fused_inv_rope_fp8_quant.py
   - Default backend is Triton (the path inlined in this file).
   - `_fused_inv_rope_fp8_quant_impl` selects CuTe DSL only when
     `TLLM_USE_CUTE_DSL_FUSED_INV_ROPE=1` is set (was: default-on with
     opt-out `TLLM_DISABLE_CUTE_DSL_FUSED_INV_ROPE=1`).
 * tensorrt_llm/_torch/cute_dsl_kernels/blackwell/fused_inv_rope_fp8_quant.py
   - Docstring header updated to reflect the opt-in env var name.
 * tests/unittest/_torch/custom_ops/test_fused_inv_rope_fp8_quant.py
   - Default test covers optimized Triton.
   - Opt-in test renamed to `test_fused_inv_rope_fp8_quant_neox_cute_dsl`
     and sets `TLLM_USE_CUTE_DSL_FUSED_INV_ROPE=1`.

Signed-off-by: Shicheng Li <shicli@nvidia.com>
The CuTe DSL backend added in commit 3815456 was never functional on
cluster. It used a 2-dot relative import (`from ..cute_dsl_utils import
IS_CUTLASS_DSL_AVAILABLE`) from `cute_dsl_kernels/blackwell/...` which
resolves to `cute_dsl_kernels.cute_dsl_utils` — a path that does not
exist. The dispatch's bare `except Exception:` silently swallowed the
ImportError and set `_cute_dsl_backend = None`, so every invocation
fell through to the Triton kernel.

After fixing the import path to 3 dots (`from ...cute_dsl_utils`),
the V1.5 kernel itself failed `cute.compile` on every M:

  - BTM=1 (M < 1024): opaque "ICE IR Verification Failed" inside the
    cutlass-DSL -> MLIR lowering.
  - BTM>1 (M >= 1024): "`pid_token` is None prior to this `for`, and
    update to Int32 inside of this `for` is not supported." The
    kernel's `pid_token = base_token + cutlass.Int32(m_in_block)` is
    assigned inside `cutlass.range(...)` without being initialised
    before it.

Neither failure mode is a 10-minute fix. The BTM>1 case needs the
kernel restructured to predeclare `pid_token` at function scope; the
BTM=1 ICE needs a minimal repro to escalate to the DSL team.

The optimized Triton kernel is the only functional backend either
way. Dropping CuTe DSL keeps the PR honest:

 * tensorrt_llm/_torch/custom_ops/triton_fused_inv_rope_fp8_quant.py
   - Remove the CuTe DSL lazy-import block and the env-gated dispatch
     in `_fused_inv_rope_fp8_quant_impl`.
   - Remove the now-unused `import os`.

 * tensorrt_llm/_torch/cute_dsl_kernels/blackwell/fused_inv_rope_fp8_quant.py
   - Deleted.

 * tests/unittest/_torch/custom_ops/test_fused_inv_rope_fp8_quant.py
   - Remove the `test_fused_inv_rope_fp8_quant_neox_cute_dsl` test
     and the `backend_env` plumbing; only the default (Triton) path
     remains.

Microbench numbers in the PR description are also updated — the
previous "CuTe DSL" column was always the Triton kernel running on
both sides (because of the silent fallback above), so the column
provided no information and the "tied within 0.05 us" claim was
trivially true (kernel vs itself). Description now compares
Unoptimized vs Optimized Triton directly.

A future PR can re-add a CuTe DSL backend once the V1.5 kernel
actually compiles.

Signed-off-by: Shicheng Li <shicli@nvidia.com>
…th BMM consumer

The optimized multi-token-per-block (BTM > 4) path padded the scale buffer's
m-dim to `pad_up(M, max(BTM, 4))`, but the downstream
`cute_dsl_fp8_bmm_blackwell` hard-codes `sf_m = pad_up(m, 4)` for its scale
m-stride. When `M` was not a multiple of `BTM` (e.g. M=1500 with BTM=16:
producer stride 1504 vs consumer assumed stride 1500), the BMM read scales
from drifting wrong physical offsets, silently corrupting dequant — dropping
DSv4-Pro DEP8 gsm8k strict-match by 13.3 pts (96.59 -> 83.32) and
flexible-extract by 6.8 pts (96.51 -> 89.69). TEP8 was also affected
(~1 pt). Microbenches at powers-of-2 M missed it because those M are already
multiples of BTM, so the producer/consumer strides happened to coincide.

Fix: keep the scale buffer padded to `pad_up(num_tokens, 4)` regardless of
BTM, matching the consumer's contract. The mblock kernel's grid may now
overshoot `scale_buf_m` when `pad_up(M, 4) % BTM != 0`; a new `scale_buf_m`
constexpr lets the inner loop bail cleanly on overshoot rows. Perf impact
is < 2% at all production M (within run-to-run noise on a single allocation).

Post-fix gsm8k (1319 questions, 5-shot):
  TEP8: flex 96.29 / strict 96.36  (was 95.22 / 95.22)
  DEP8: flex 96.44 / strict 96.51  (was 89.69 / 83.32, matches bf16 baseline)

Signed-off-by: Shicheng Li <shicli@nvidia.com>
…gnment

PR NVIDIA#14254 decoupled DSv4's FP8-native o_a_proj path from
use_cute_dsl_blockscaling_bmm: on SM100, DSv4 unconditionally takes the
fused inv-RoPE + FP8 quant + cute-dsl BMM chain and never falls back to
bf16 dequant. As a side effect, MLA.__init__ no longer allocates
self.o_a_proj_dequant for DSv4 (it stays None).

The test was setting `mla.o_a_proj_dequant.data = o_a_proj_bf16` for the
fp8 case, which now hits `AttributeError: 'NoneType' object has no
attribute 'data'`. That assignment was a no-op: the reference
computation reads `o_a_proj_bf16` directly (line 245), not via the
mla buffer. Drop the line and leave a comment explaining the new
invariant for future readers.

Signed-off-by: Shicheng Li <shicli@nvidia.com>
Signed-off-by: Fanrong Li <lfr-0531@users.noreply.github.com>
@lfr-0531 lfr-0531 requested review from a team as code owners May 22, 2026 07:11
@lfr-0531 lfr-0531 requested review from HuiGao-NV, PerkzZheng and hyukn and removed request for a team May 22, 2026 07:11
@lfr-0531 lfr-0531 requested review from lishicheng1996-nv and mingyangHao and removed request for HuiGao-NV and PerkzZheng May 22, 2026 07:12
@lfr-0531
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #49881 [ run ] triggered by Bot. Commit: cfe6f04 Link to invocation

@lfr-0531 lfr-0531 requested a review from xxi-nv May 22, 2026 07:20
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #49881 [ run ] completed with state SUCCESS. Commit: cfe6f04
/LLM/main/L0_MergeRequest_PR pipeline #39465 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants