[None][fix] Plumb swiglu_limit through DeepGEMM and TRTLLMGen FP8 fused MoE#13767
Conversation
|
/bot run |
|
PR_Github #46816 [ run ] triggered by Bot. Commit: |
c81e578 to
f2ecdbf
Compare
|
PR_Github #46816 [ run ] completed with state
|
ac32c9c to
e07f48b
Compare
|
/bot run |
|
PR_Github #46889 [ run ] triggered by Bot. Commit: |
|
PR_Github #46889 [ run ] completed with state
|
…sed MoE
Forward an optional per-expert swiglu_limit through both fused-MoE FP8
paths so DeepSeek-V4-Flash-Base (FP8 block-scale on Blackwell) actually
applies its config-declared gate/up clamp on routed experts, matching
the swiglu_torch reference. Existing callers that pass no limit are
unaffected: the Triton kernel guards on HAS_SWIGLU_LIMIT, and the CUDA
kernels guard on a null swigluLimitPtr.
DeepGEMM Triton path (used by WIDEEP and DEEPGEMM moe_backends):
- silu_and_mul_masked_post_quant_fwd accepts an optional fp32 [g]
tensor; the kernel applies gate.clamp(max=limit) and
up.clamp(-limit, limit) before silu/mul.
- WideEPMoE / DeepGemmFusedMoE __init__ accept swiglu_limit and
propagate it to the underlying op via self.swiglu_limit.
- create_moe.py allow-list includes WideEPMoE / DeepGemmFusedMoE.
- DeepseekV4MoE supports_swiglu_limit set extends to those classes.
TRTLLMGen FP8 path (run_fp8_block_scale_moe):
- C++ binding accepts an optional gemm1_clamp_limit tensor of shape
[local_num_experts]; setOpsData forwards it to
activation::Data::swigluLimitPtr.
- Both activationKernel and activationDeepSeekKernel apply the clamp
after dequantization, before silu/mul. FP8 path treats the limit
as uniform across experts (reads index [0]); the per-expert tensor
shape is preserved for API symmetry with the NVFP4 path.
- fp8_block_scale_moe_runner custom op grows the kwarg; the autotuner
input list shifts replacement indices accordingly.
- moe_op_backend TRTLLM impl forwards the new kwarg; Flashinfer impl
raises NotImplementedError until its wrapper exposes the param.
- _check_configs is split: bias/alpha/beta still gate to NVFP4/MXFP4
(they need the fused-GEMM activation cubins); swiglu_limit also
accepts FP8 block-scale via the separate-activation kernel.
Empirical V4-Flash-Base GSM8K (8x B200, TP=8 EP=8, lm-eval 5-shot):
WIDEEP no clamp (silently dropped): 91.17
WIDEEP -> DeepGEMM op, clamp on: 92.23 (+1.06 abs, ~1.4 sigma)
TRTLLMGen FP8, clamp off: 92.23
TRTLLMGen FP8, clamp on (fixed kernel): 92.23 (no shift; FC1 rarely
trips +-10 in this path)
Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
Parametrize TestDeepSeekV4FlashBase::test_auto_dtype on moe_backend (WIDEEP, TRTLLM) and switch the 4xB300 pre_merge entry from the WIDEEP-hardcoded form to the TRTLLM variant. TRTLLMGen FP8 is the user-facing default on Blackwell (model_config.py::resolve_moe_backend) and ~7% faster per step than WIDEEP in our V4-Flash-Base GSM8K runs; holding the WIDEEP variant out of CI for now (still selectable manually). Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
Drop the per-element fp32 round-trip and per-CTA / per-program global
load that the FP8 swiglu_limit plumbing introduced. swiglu_limit is
uniform across experts on the FP8 paths (V4-Flash-Base config), so the
per-expert tensor was redundant; lifting it to a scalar value lets the
clamp run in native dtype and gets baked into the kernel.
DeepGEMM Triton kernel (silu_and_mul_masked_post_quant_fwd):
- Drop the bf16 -> fp32 -> clamp -> bf16 round-trip on `up`. Clamp
uses tl.cast(SWIGLU_LIMIT, input dtype); for V4 (limit ~7) this is
bf16-exact, so semantics are preserved.
- Replace the swiglu_limit_ptr argument with SWIGLU_LIMIT: tl.constexpr
(Python float baked into the JIT). Removes one global load per
program and lets the limit constant-fold.
- Wrapper now takes Optional[float] instead of Optional[Tensor].
TRTLLMGen FP8 separate-activation kernels (DevKernel.cu):
- Replace activation::Data::swigluLimitPtr with scalar swigluLimit +
hasSwigluLimit. Eliminates the per-CTA fp32 global load.
- Plumb a scalar gemm1_clamp_limit_value + has_gemm1_clamp_limit_value
through MoERunnerArgs. The pre-existing gemm1_clamp_limit pointer
is kept for NVFP4 / MXFP4 fused-activation cubins (which genuinely
consume per-expert limits via fc31_alpha rescaling).
- fp8BlockScaleMoe.cpp binding takes Optional<double>.
Autotuner (trtllm_gen_custom_ops.py):
- Drop gemm1_clamp_limit from the FP8 runner's input_tensors_for_tuner.
The limit doesn't influence tactic validity, so it was just
fragmenting the cache key. Pass through the runner constructor
instead.
MoE plumbing:
- Add swiglu_limit_scalar to MoE base + thread through create_moe,
DeepseekV4MoE, ConfigurableMoE, CutlassFusedMoE / DeepGemmFusedMoE /
WideEPMoE / TRTLLMGenFusedMoE constructors. FP8 paths read
self.swiglu_limit_scalar; NVFP4 paths still use self.swiglu_limit
(per-expert tensor) unchanged.
Validation (V4-Flash-Base, 8x B200, TP=8 EP=8, lm-eval gsm8k 5-shot):
WIDEEP backend (DeepGEMM Triton path, clamp load-bearing):
92.57 +/- 0.72 (vs reference 92.23 with clamp on; 91.17 with
clamp silently dropped). Clamp is correctly
applied through the optimized Triton kernel.
TRTLLM backend (TRTLLMGen FP8 separate-activation path):
90.90 +/- 0.79 (vs reference 92.23; within 95% CI. Clamp is a
no-op semantically in this path per the parent
commit's measurements.)
Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
DeepSeek-V4-Flash-Base is not yet on the CI shared model store (/scratch.trt_llm_data/llm-models/DeepSeek-V4-Flash-Base does not exist on B300 CI workers). When the path is missing, _ModelWrapper.__post_init__ leaves model as a string, is_local_model returns False, and LLM(...) falls into the HF download branch where snapshot_download rejects the slash-bearing path with HFValidationError. Comment out the test until the model is uploaded; leave a TODO pointing at the missing path. Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
…per fix TestDeepSeekV4Flash::test_nvfp4_4gpus_static_eplb[moe_backend=TRTLLM] fails on B300 with AttributeError: 'str' object has no attribute 'value' out of cuda/bindings/driver.pyx. The throw originates from kv_cache_manager_v2/_exceptions.py:49 calling drv.cuGetErrorString on a plain Python string instead of a CUresult enum, so the worker's actual init CUDA error is unreadable. Skip until that wrapper is fixed and the underlying NVFP4 + EPLB init failure can be diagnosed. Unrelated to the swiglu_limit FP8 path; surfaces independently in the NVFP4 fused-activation cubin path. Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
e07f48b to
865c8b6
Compare
|
/bot run |
|
PR_Github #46924 [ run ] triggered by Bot. Commit: |
|
PR_Github #46924 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #46973 [ run ] triggered by Bot. Commit: |
|
PR_Github #46973 [ run ] completed with state |
…ed MoE (NVIDIA#13767) Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> (cherry picked from commit 1a52b72) Signed-off-by: Yuhang He <58161490+heyuhhh@users.noreply.github.com> Signed-off-by: Fanrong Li <lfr-0531@users.noreply.github.com> (cherry picked from commit 7a9b0ca) Signed-off-by: Fanrong Li <lfr-0531@users.noreply.github.com>
Summary
Plumb optional per-expert
swiglu_limitthrough both Blackwell FP8 fused-MoE paths so DeepSeek-V4-Flash-Base's config-declared clamp is actually applied to routed experts. No-clamp callers are byte-identical (TritonHAS_SWIGLU_LIMITconstexpr; CUDA nullswigluLimitPtr).Files changed
DeepGEMM Triton path:
tensorrt_llm/quantization/utils/fp8_utils.pytensorrt_llm/_torch/modules/fused_moe/ops/moe_op_deepgemm.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.pytensorrt_llm/_torch/modules/fused_moe/create_moe.pytensorrt_llm/_torch/models/modeling_deepseekv4.pyTRTLLMGen FP8 path:
cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.hcpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.cucpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cucpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpptensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.pytensorrt_llm/_torch/modules/fused_moe/moe_op_backend.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.pyAccuracy
V4-Flash-Base FP8, 8x B200, TP=8 EP=8, lm-eval gsm8k 5-shot, 1319 samples: