diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 34f2ffd9a55..1874a9fc030 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -691,6 +691,10 @@ std::vector NoauxTc(paddle::Tensor& scores, bool renormalize, float routed_scaling_factor); +std::vector FusedCastSigmoidBias(const paddle::Tensor& input, + const paddle::Tensor& bias, + std::string cast_type); + std::vector NoauxTcRedundant( paddle::Tensor& scores, paddle::Tensor& scores_with_bias, @@ -1699,6 +1703,13 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("noaux_tc", &NoauxTc, "noaux_tc for Deepseekv3 MoE compute"); + m.def("fused_cast_sigmoid_bias", + &FusedCastSigmoidBias, + "Fused cast+sigmoid+bias for MoE gating scores", + py::arg("input"), + py::arg("bias"), + py::arg("cast_type") = std::string("float32")); + m.def("noaux_tc_redundant", &NoauxTcRedundant, "noaux_tc_redundant for MoE compute"); diff --git a/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu b/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu new file mode 100644 index 00000000000..f25084076c4 --- /dev/null +++ b/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu @@ -0,0 +1,206 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" + +// Fused kernel: cast(input, cast_type) -> sigmoid -> scores, scores + bias -> +// scores_with_bias +// +// For each element (token i, expert j): +// scores[i][j] = OutT(sigmoid(float(input[i][j]))) +// scores_with_bias[i][j] = OutT(sigmoid(float(input[i][j])) + bias[j]) +// +// Input: input [num_tokens, num_experts] bf16/fp16/fp32 +// bias [num_experts] or [1, num_experts] fp32 +// Output: scores [num_tokens, num_experts] cast_type (fp32/fp16/bf16) +// scores_with_bias [num_tokens, num_experts] cast_type (fp32/fp16/bf16) +// +// Precision guarantee: +// All intermediate computations (cast, sigmoid, bias addition) are performed +// in float32, regardless of input/output types. The cast to OutT only happens +// at the final store. This matches the reference implementation: +// gate_fp32 = gate_out.cast("float32") +// scores_fp32 = sigmoid(gate_fp32) +// scores_with_bias_fp32 = scores_fp32 + bias // bias is always float32 +// scores = scores_fp32.cast(cast_type) +// scores_with_bias = scores_with_bias_fp32.cast(cast_type) +// +// When cast_type is "float32", the fused kernel is numerically identical to +// the reference. For fp16/bf16 output, the only precision loss comes from +// the final static_cast, equivalent to .cast() in the reference path. +// +// Note: bias is intentionally kept as float32 (not converted to OutT) to +// ensure the addition s + bias[j] is always computed in full float32 +// precision before the final downcast. + +template +__global__ void fused_cast_sigmoid_bias_kernel( + const InT* __restrict__ input, + const float* __restrict__ bias, + OutT* __restrict__ scores, + OutT* __restrict__ scores_with_bias, + const int num_experts) { + const int64_t token_idx = blockIdx.x; + const int64_t offset = token_idx * num_experts; + + for (int j = threadIdx.x; j < num_experts; j += blockDim.x) { + // All intermediate computation in float32 for precision + float val = static_cast(input[offset + j]); + float s = 1.0f / (1.0f + expf(-val)); + // s (float32) + bias[j] (float32) -> float32 addition, then downcast + scores[offset + j] = static_cast(s); + scores_with_bias[offset + j] = static_cast(s + bias[j]); + } +} + +// Vectorized version for better memory throughput +template +__global__ void fused_cast_sigmoid_bias_vec_kernel( + const InT* __restrict__ input, + const float* __restrict__ bias, // kept as float32 for full-precision add + OutT* __restrict__ scores, + OutT* __restrict__ scores_with_bias, + const int num_experts) { + const int64_t token_idx = blockIdx.x; + const int64_t offset = token_idx * num_experts; + + using in_vec_t = AlignedVector; + using out_vec_t = AlignedVector; + using bias_vec_t = AlignedVector; // float32 bias vectors + + const int vec_count = num_experts / kVecSize; + for (int idx = threadIdx.x; idx < vec_count; idx += blockDim.x) { + const int base = idx * kVecSize; + in_vec_t in_vec; + bias_vec_t bias_vec; + Load(input + offset + base, &in_vec); + Load(bias + base, &bias_vec); + + out_vec_t s_vec, sb_vec; +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + // All intermediate computation in float32 for precision + float val = static_cast(in_vec[i]); + float s = 1.0f / (1.0f + expf(-val)); + // s (float32) + bias_vec[i] (float32) -> float32 addition, then downcast + s_vec[i] = static_cast(s); + sb_vec[i] = static_cast(s + bias_vec[i]); + } + + Store(s_vec, scores + offset + base); + Store(sb_vec, scores_with_bias + offset + base); + } + + // Handle remaining elements (same float32 precision guarantee) + const int remaining_start = vec_count * kVecSize; + for (int j = remaining_start + threadIdx.x; j < num_experts; + j += blockDim.x) { + float val = static_cast(input[offset + j]); + float s = 1.0f / (1.0f + expf(-val)); + scores[offset + j] = static_cast(s); + scores_with_bias[offset + j] = static_cast(s + bias[j]); + } +} + +static paddle::DataType ParseCastType(const std::string& cast_type) { + if (cast_type == "float32") return paddle::DataType::FLOAT32; + if (cast_type == "float16") return paddle::DataType::FLOAT16; + if (cast_type == "bfloat16") return paddle::DataType::BFLOAT16; + PD_THROW("Unsupported cast_type: " + cast_type + + ". Only float32, float16, bfloat16 are supported."); +} + +std::vector FusedCastSigmoidBias(const paddle::Tensor& input, + const paddle::Tensor& bias, + std::string cast_type) { + auto input_shape = input.shape(); + PD_CHECK(input_shape.size() == 2, + "input must be 2D [num_tokens, num_experts]"); + auto bias_shape = bias.shape(); + // Support both [num_experts] and [1, num_experts] bias shapes + PD_CHECK( + bias_shape.size() == 1 || (bias_shape.size() == 2 && bias_shape[0] == 1), + "bias must be 1D [num_experts] or 2D [1, num_experts]"); + + int64_t num_tokens = input_shape[0]; + int64_t num_experts = input_shape[1]; + int64_t bias_numel = (bias_shape.size() == 1) ? bias_shape[0] : bias_shape[1]; + PD_CHECK(bias_numel == num_experts, "bias size must match num_experts"); + PD_CHECK(bias.dtype() == paddle::DataType::FLOAT32, + "bias must be float32, got ", + bias.dtype()); + + auto place = input.place(); + auto stream = input.stream(); + auto out_dtype = ParseCastType(cast_type); + + auto scores = paddle::empty({num_tokens, num_experts}, out_dtype, place); + auto scores_with_bias = + paddle::empty({num_tokens, num_experts}, out_dtype, place); + + if (num_tokens == 0) { + return {scores, scores_with_bias}; + } + + dim3 grid(num_tokens); + int block_size = std::min(static_cast(1024), num_experts); + // Round up to warp size + block_size = ((block_size + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + dim3 block(block_size); + + DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), in_scalar_t, { + DISPATCH_FLOAT_FP6_DTYPE(out_dtype, out_scalar_t, { + constexpr int kVecSize = 16 / sizeof(in_scalar_t); + if (num_experts % kVecSize == 0 && num_experts >= kVecSize) { + fused_cast_sigmoid_bias_vec_kernel + <<>>(input.data(), + bias.data(), + scores.data(), + scores_with_bias.data(), + num_experts); + } else { + fused_cast_sigmoid_bias_kernel + <<>>(input.data(), + bias.data(), + scores.data(), + scores_with_bias.data(), + num_experts); + } + }); + }); + + return {scores, scores_with_bias}; +} + +std::vector FusedCastSigmoidBiasInferDtype( + const paddle::DataType& input_dtype, + const paddle::DataType& bias_dtype, + std::string cast_type) { + auto out_dtype = ParseCastType(cast_type); + return {out_dtype, out_dtype}; +} + +std::vector> FusedCastSigmoidBiasInferShape( + const std::vector& input_shape, + const std::vector& bias_shape) { + return {input_shape, input_shape}; +} + +PD_BUILD_STATIC_OP(fused_cast_sigmoid_bias) + .Inputs({"input", "bias"}) + .Outputs({"scores", "scores_with_bias"}) + .Attrs({"cast_type: std::string"}) + .SetKernelFn(PD_KERNEL(FusedCastSigmoidBias)) + .SetInferShapeFn(PD_INFER_SHAPE(FusedCastSigmoidBiasInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(FusedCastSigmoidBiasInferDtype)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 4339bdec028..7b1bda32510 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -331,6 +331,7 @@ def find_end_files(directory, end_str): "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc_redundant.cu", + "gpu_ops/fused_cast_sigmoid_bias.cu", "gpu_ops/custom_all_reduce/all_reduce.cu", "gpu_ops/merge_prefill_decode_output.cu", "gpu_ops/limit_thinking_content_length.cu", @@ -686,6 +687,7 @@ def find_end_files(directory, end_str): "gpu_ops/recover_decode_task.cu", "gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc_redundant.cu", + "gpu_ops/fused_cast_sigmoid_bias.cu", "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/text_image_gather_scatter.cu", "gpu_ops/text_image_index_out.cu", diff --git a/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py b/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py new file mode 100644 index 00000000000..44d7e54ae88 --- /dev/null +++ b/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py @@ -0,0 +1,73 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle + +_FUSED_CAST_SIGMOID_BIAS_IMPORT_ERROR = None + +try: + from fastdeploy.model_executor.ops.gpu import ( + fused_cast_sigmoid_bias as _fused_cast_sigmoid_bias_cuda, + ) +except ImportError as e: + _fused_cast_sigmoid_bias_cuda = None + _FUSED_CAST_SIGMOID_BIAS_IMPORT_ERROR = e + + +def is_available() -> bool: + """Return whether the fused GPU custom op is available.""" + return _fused_cast_sigmoid_bias_cuda is not None + + +def fused_cast_sigmoid_bias( + gate_out: paddle.Tensor, + e_score_correction_bias: paddle.Tensor, + cast_type: str = "float32", +) -> tuple: + """ + Fused operation: cast gate_out to the specified type, apply sigmoid, and add bias. + + This function fuses the following three separate operations: + 1. gate_out = gate_out.cast(cast_type) + 2. scores = sigmoid(gate_out) + 3. scores_with_bias = scores + e_score_correction_bias + + Args: + gate_out: [num_tokens, num_experts], bf16/fp16/fp32 dtype - raw gate output + e_score_correction_bias: [num_experts], fp32 dtype - correction bias + cast_type: output dtype string, supports "float32", "float16", "bfloat16" + + Returns: + scores: [num_tokens, num_experts], cast_type dtype - result of sigmoid(gate_out) + scores_with_bias: [num_tokens, num_experts], cast_type dtype - scores with bias added + + Precision: + All intermediate computations (cast, sigmoid, bias addition) are performed + in float32 precision; conversion to cast_type happens only at the final store. + When cast_type is "float32", the result is bit-exact with the following + reference implementation: + gate_fp32 = gate_out.cast("float32") + scores = sigmoid(gate_fp32) + scores_with_bias = scores + bias + When cast_type is "float16"/"bfloat16", the only precision loss comes from + the final type conversion, equivalent to calling .cast(cast_type) after + computing in float32. + """ + if _fused_cast_sigmoid_bias_cuda is None: + raise ImportError( + "fused_cast_sigmoid_bias is not available. " "Please ensure the GPU custom ops are compiled." + ) from _FUSED_CAST_SIGMOID_BIAS_IMPORT_ERROR + return _fused_cast_sigmoid_bias_cuda(gate_out, e_score_correction_bias, cast_type) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 760a23734ab..e9e71e9e930 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -346,13 +346,13 @@ def apply_tp( Paddle Cutlass compute Fused MoE. """ gate_out = gate(x) - gate_out = gate_out.cast("float32") - - if fc1_latent_proj is not None: - x = fc1_latent_proj(x) - if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16": if layer.topk_method == "noaux_tc": + use_fused = not fastdeploy.envs.FD_ENABLE_RL and current_platform.is_cuda() and not fc1_latent_proj + if not use_fused: + gate_out = gate_out.cast("float32") + if fc1_latent_proj is not None: + x = fc1_latent_proj(x) gate_out, topk_weights, topk_idx = get_moe_scores( gate_out, layer.n_group, @@ -361,8 +361,12 @@ def apply_tp( layer.routed_scaling_factor, layer.gate_correction_bias, getattr(layer, "renormalize", True), + use_fused_cast=use_fused, ) else: + gate_out = gate_out.cast("float32") + if fc1_latent_proj is not None: + x = fc1_latent_proj(x) topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, layer.gate_correction_bias, @@ -415,6 +419,11 @@ def apply_tp( return fused_moe_out if layer.topk_method == "noaux_tc": + use_fused = not fastdeploy.envs.FD_ENABLE_RL and current_platform.is_cuda() and not fc1_latent_proj + if not use_fused: + gate_out = gate_out.cast("float32") + if fc1_latent_proj is not None: + x = fc1_latent_proj(x) gate_out, topk_weights, topk_idx = get_moe_scores( gate_out, layer.n_group, @@ -424,6 +433,7 @@ def apply_tp( layer.gate_correction_bias, getattr(layer, "renormalize", True), topk_reduce_func=getattr(layer, "topk_reduce_func", None), + use_fused_cast=use_fused, ) ( @@ -448,6 +458,9 @@ def apply_tp( topk_only_mode=True, ) else: + gate_out = gate_out.cast("float32") + if fc1_latent_proj is not None: + x = fc1_latent_proj(x) ( permute_input, token_nums_per_expert, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index cdaa66678fd..22d9588fb67 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -761,9 +761,11 @@ def apply_tp( below is TP compute method. """ gate_out = gate(x) - gate_out = gate_out.cast("float32") if layer.topk_method == "noaux_tc": + use_fused = not fastdeploy.envs.FD_ENABLE_RL and current_platform.is_cuda() + if not use_fused: + gate_out = gate_out.cast("float32") _, topk_weights, topk_ids = fastdeploy.model_executor.layers.moe.moe.get_moe_scores( gate_out, layer.n_group, @@ -773,8 +775,10 @@ def apply_tp( layer.gate_correction_bias, getattr(layer, "renormalize", True), topk_reduce_func=getattr(layer, "topk_reduce_func", None), + use_fused_cast=use_fused, ) else: + gate_out = gate_out.cast("float32") topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, layer.gate_correction_bias, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 20e84449fae..ed85b02547f 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -29,6 +29,7 @@ set_weight_attrs, weight_fully_copied, ) +from fastdeploy.platforms import current_platform from fastdeploy.utils import ceil_div, register_custom_python_op from ..quantization.quant_base import QuantMethodBase @@ -302,7 +303,6 @@ def apply( if token_num == 0: return paddle.zeros([token_num, layer.hidden_size], dtype=x.dtype) gate_out = gate(x) - gate_out = gate_out.cast("float32") top_k = layer.top_k num_local_experts = layer.num_local_experts top_k = layer.top_k @@ -310,6 +310,9 @@ def apply( hidden_size = layer.hidden_size if layer.topk_method == "noaux_tc": + use_fused = not fastdeploy.envs.FD_ENABLE_RL and current_platform.is_cuda() + if not use_fused: + gate_out = gate_out.cast("float32") gate_out, topk_weights, topk_ids = get_moe_scores( gate_out, layer.n_group, @@ -318,8 +321,10 @@ def apply( layer.routed_scaling_factor, layer.gate_correction_bias, getattr(layer, "renormalize", True), + use_fused_cast=use_fused, ) else: + gate_out = gate_out.cast("float32") topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, layer.gate_correction_bias, diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 4aa23f67931..2c158478a1d 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -39,8 +39,14 @@ from fastdeploy.model_executor.ops.gpu import noaux_tc, noaux_tc_redundant except: logger.warning("import noaux_tc Failed!") + import numpy as np +if current_platform.is_cuda(): + from fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias import ( + fused_cast_sigmoid_bias, + ) + def get_moe_method(layer=None): """ @@ -91,13 +97,17 @@ def get_moe_scores( tokens_per_expert_stats_list: paddle.Tensor = None, redundant_ep_rank_num_plus_one: int = 1, topk_reduce_func: Callable = lambda x: x.sum(axis=-1, keepdim=True) + 1e-20, + use_fused_cast: bool = False, ) -> paddle.Tensor: """ compute moe scores using e_score_correction_bias. """ - scores = paddle.nn.functional.sigmoid(gating_output) assert e_score_correction_bias is not None, "e_score_correction_bias is none!" - scores_with_bias = scores + e_score_correction_bias + if use_fused_cast and current_platform.is_cuda(): + scores, scores_with_bias = fused_cast_sigmoid_bias(gating_output, e_score_correction_bias, cast_type="float32") + else: + scores = paddle.nn.functional.sigmoid(gating_output) + scores_with_bias = scores + e_score_correction_bias if envs.FD_USE_PHI_MOE_TOPK: # calculate renormalize and routed_scaling_factor value outside the noaux_tc diff --git a/tests/layers/test_deepgemm_fused_moe.py b/tests/layers/test_deepgemm_fused_moe.py index 5381ee866a3..66910544756 100644 --- a/tests/layers/test_deepgemm_fused_moe.py +++ b/tests/layers/test_deepgemm_fused_moe.py @@ -205,6 +205,27 @@ def hook(topk_ids): assert "topk_ids" in captured assert list(out.shape) == [NUM_TOKENS, HIDDEN_SIZE] + @requires_deepgemm + def test_apply_tp_noaux_tc_with_use_fused_false(self): + """noaux_tc path with FD_ENABLE_RL=True: triggers use_fused=False and gate_out.cast('float32').""" + layer = DummyLayer() + layer.topk_method = "noaux_tc" + gate = DummyGate(layer.num_local_experts) + method = _make_method() + + x = paddle.randn([NUM_TOKENS, HIDDEN_SIZE], dtype="bfloat16") + + import fastdeploy.envs as fd_envs + + original_fd_enable_rl = fd_envs.FD_ENABLE_RL + fd_envs.FD_ENABLE_RL = True + + try: + out = method.apply(layer, x, gate) + assert list(out.shape) == [NUM_TOKENS, HIDDEN_SIZE] + finally: + fd_envs.FD_ENABLE_RL = original_fd_enable_rl + @requires_deepgemm def test_apply_tp_aux_path(self): """Non-noaux_tc: moe_topk_select → fp8_quant_blockwise → moe_permute → deepgemm → moe_unpermute.""" diff --git a/tests/layers/test_fused_cast_sigmoid_bias.py b/tests/layers/test_fused_cast_sigmoid_bias.py new file mode 100644 index 00000000000..21bfb0901fd --- /dev/null +++ b/tests/layers/test_fused_cast_sigmoid_bias.py @@ -0,0 +1,497 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import importlib +import os +import sys +from unittest import mock + +import paddle +import paddle.nn.functional as F +import pytest + +from fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias import ( + fused_cast_sigmoid_bias, + is_available, +) + +DTYPE_MAP = { + "float16": paddle.float16, + "bfloat16": paddle.bfloat16, + "float32": paddle.float32, +} + + +def _ensure_gpu_test_environment(): + """Ensure GPU runtime and required custom ops are available for this test module.""" + if not paddle.is_compiled_with_cuda(): + pytest.skip( + "fused_cast_sigmoid_bias requires CUDA-enabled Paddle.", + allow_module_level=True, + ) + paddle.set_device("gpu") + + +_ensure_gpu_test_environment() + + +def reference_cast_sigmoid_bias(gate_out, bias, cast_type="float32"): + """Reference implementation: compute in fp32, cast output to cast_type.""" + gate_fp32 = gate_out.cast("float32") + scores_fp32 = F.sigmoid(gate_fp32) + scores_with_bias_fp32 = scores_fp32 + bias + scores = scores_fp32.cast(cast_type) + scores_with_bias = scores_with_bias_fp32.cast(cast_type) + return scores, scores_with_bias + + +def test_functionality(): + """Test basic functionality: correct shapes and dtypes (default cast_type=float32).""" + print("=" * 60) + print("Test 1: Functionality (default cast_type=float32)") + print("=" * 60) + + for dtype_name in ["float16", "bfloat16", "float32"]: + for num_tokens in [1, 7, 128, 1024]: + for num_experts in [8, 64, 128, 256]: + gate_out = paddle.randn([num_tokens, num_experts], dtype=dtype_name) + bias = paddle.randn([num_experts], dtype="float32") + + scores, scores_with_bias = fused_cast_sigmoid_bias(gate_out, bias) + + assert scores.shape == [ + num_tokens, + num_experts, + ], f"scores shape mismatch: {scores.shape} vs {[num_tokens, num_experts]}" + assert scores_with_bias.shape == [ + num_tokens, + num_experts, + ], f"scores_with_bias shape mismatch: {scores_with_bias.shape}" + assert scores.dtype == paddle.float32, f"scores dtype mismatch: {scores.dtype}" + assert ( + scores_with_bias.dtype == paddle.float32 + ), f"scores_with_bias dtype mismatch: {scores_with_bias.dtype}" + + # Sigmoid output should be in [0, 1] + assert bool(paddle.all(scores >= 0.0).item()) and bool( + paddle.all(scores <= 1.0).item() + ), "scores out of [0,1] range" + print(f" [PASS] dtype={dtype_name}") + + print(" All functionality tests passed.\n") + + +def test_functionality_cast_types(): + """Test functionality with different cast_type values.""" + print("=" * 60) + print("Test 1b: Functionality with different cast_type") + print("=" * 60) + + for input_dtype in ["float16", "bfloat16", "float32"]: + for cast_type in ["float16", "bfloat16", "float32"]: + expected_paddle_dtype = DTYPE_MAP[cast_type] + for num_tokens in [1, 64, 256]: + for num_experts in [8, 64, 256]: + gate_out = paddle.randn([num_tokens, num_experts], dtype=input_dtype) + bias = paddle.randn([num_experts], dtype="float32") + + scores, scores_with_bias = fused_cast_sigmoid_bias(gate_out, bias, cast_type) + + assert scores.shape == [num_tokens, num_experts], f"scores shape mismatch: {scores.shape}" + assert scores_with_bias.shape == [ + num_tokens, + num_experts, + ], f"scores_with_bias shape mismatch: {scores_with_bias.shape}" + assert ( + scores.dtype == expected_paddle_dtype + ), f"scores dtype mismatch: got {scores.dtype}, expected {expected_paddle_dtype}" + assert ( + scores_with_bias.dtype == expected_paddle_dtype + ), f"scores_with_bias dtype mismatch: got {scores_with_bias.dtype}, expected {expected_paddle_dtype}" + + print(f" [PASS] input_dtype={input_dtype}, cast_type={cast_type}") + + print(" All cast_type functionality tests passed.\n") + + +def test_accuracy(): + """Test numerical accuracy against reference implementation (default cast_type=float32).""" + print("=" * 60) + print("Test 2: Accuracy (default cast_type=float32)") + print("=" * 60) + + test_cases = [ + ("float16", 1, 8), + ("float16", 128, 256), + ("float16", 1024, 256), + ("bfloat16", 1, 8), + ("bfloat16", 128, 256), + ("bfloat16", 1024, 256), + ("float32", 1, 8), + ("float32", 128, 256), + ("float32", 1024, 256), + ] + + for dtype_name, num_tokens, num_experts in test_cases: + gate_out = paddle.randn([num_tokens, num_experts], dtype=dtype_name) + bias = paddle.randn([num_experts], dtype="float32") + + # Fused kernel + fused_scores, fused_scores_with_bias = fused_cast_sigmoid_bias(gate_out, bias) + + # Reference + ref_scores, ref_scores_with_bias = reference_cast_sigmoid_bias(gate_out, bias) + + # Compare + scores_diff = paddle.abs(fused_scores - ref_scores).max().item() + scores_bias_diff = paddle.abs(fused_scores_with_bias - ref_scores_with_bias).max().item() + + atol = 1e-6 if dtype_name == "float32" else 1e-3 + passed = scores_diff < atol and scores_bias_diff < atol + + status = "PASS" if passed else "FAIL" + print( + f" [{status}] dtype={dtype_name}, tokens={num_tokens}, experts={num_experts} | " + f"scores_max_diff={scores_diff:.2e}, scores_with_bias_max_diff={scores_bias_diff:.2e}" + ) + + if not passed: + raise AssertionError( + f"Accuracy test failed for dtype={dtype_name}, tokens={num_tokens}, experts={num_experts}. " + f"scores_diff={scores_diff}, scores_bias_diff={scores_bias_diff}, atol={atol}" + ) + + print(" All accuracy tests passed.\n") + + +def test_accuracy_cast_types(): + """Test numerical accuracy with different cast_type values.""" + print("=" * 60) + print("Test 2b: Accuracy with different cast_type") + print("=" * 60) + + # (input_dtype, cast_type, num_tokens, num_experts) + test_cases = [ + # cast to float32 (original behavior) + ("float16", "float32", 128, 256), + ("bfloat16", "float32", 128, 256), + ("float32", "float32", 128, 256), + # cast to float16 + ("float16", "float16", 128, 256), + ("bfloat16", "float16", 128, 256), + ("float32", "float16", 128, 256), + # cast to bfloat16 + ("float16", "bfloat16", 128, 256), + ("bfloat16", "bfloat16", 128, 256), + ("float32", "bfloat16", 128, 256), + # different shapes + ("bfloat16", "float16", 1, 8), + ("bfloat16", "float16", 1024, 256), + ("float16", "bfloat16", 1, 8), + ("float16", "bfloat16", 1024, 256), + ] + + for input_dtype, cast_type, num_tokens, num_experts in test_cases: + gate_out = paddle.randn([num_tokens, num_experts], dtype=input_dtype) + bias = paddle.randn([num_experts], dtype="float32") + + # Fused kernel + fused_scores, fused_scores_with_bias = fused_cast_sigmoid_bias(gate_out, bias, cast_type) + + # Reference + ref_scores, ref_scores_with_bias = reference_cast_sigmoid_bias(gate_out, bias, cast_type) + + # Compare in float32 for stable diff computation + scores_diff = paddle.abs(fused_scores.cast("float32") - ref_scores.cast("float32")).max().item() + scores_bias_diff = ( + paddle.abs(fused_scores_with_bias.cast("float32") - ref_scores_with_bias.cast("float32")).max().item() + ) + + # Tolerance depends on cast_type precision + if cast_type == "float32": + atol = 1e-6 + elif cast_type == "bfloat16": + atol = 1e-2 # bfloat16 has fewer mantissa bits + else: # float16 + atol = 1e-3 + + passed = scores_diff < atol and scores_bias_diff < atol + + status = "PASS" if passed else "FAIL" + print( + f" [{status}] input={input_dtype}, cast_type={cast_type}, " + f"tokens={num_tokens}, experts={num_experts} | " + f"scores_diff={scores_diff:.2e}, bias_diff={scores_bias_diff:.2e}" + ) + + if not passed: + raise AssertionError( + f"Accuracy test failed for input={input_dtype}, cast_type={cast_type}, " + f"tokens={num_tokens}, experts={num_experts}. " + f"scores_diff={scores_diff}, bias_diff={scores_bias_diff}, atol={atol}" + ) + + print(" All cast_type accuracy tests passed.\n") + + +def test_accuracy_extreme_values(): + """Test accuracy with extreme input values.""" + print("=" * 60) + print("Test 3: Accuracy with extreme values") + print("=" * 60) + + num_tokens, num_experts = 64, 256 + + for dtype_name in ["float16", "bfloat16"]: + # Large positive values -> sigmoid ~ 1.0 + gate_out = paddle.full([num_tokens, num_experts], 10.0, dtype=dtype_name) + bias = paddle.zeros([num_experts], dtype="float32") + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias) + diff = paddle.abs(fused_scores - ref_scores).max().item() + print(f" [{'PASS' if diff < 1e-5 else 'FAIL'}] dtype={dtype_name}, large positive: max_diff={diff:.2e}") + + # Large negative values -> sigmoid ~ 0.0 + gate_out = paddle.full([num_tokens, num_experts], -10.0, dtype=dtype_name) + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias) + diff = paddle.abs(fused_scores - ref_scores).max().item() + print(f" [{'PASS' if diff < 1e-5 else 'FAIL'}] dtype={dtype_name}, large negative: max_diff={diff:.2e}") + + # Zero values -> sigmoid = 0.5 + gate_out = paddle.zeros([num_tokens, num_experts], dtype=dtype_name) + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias) + diff = paddle.abs(fused_scores - ref_scores).max().item() + assert diff < 1e-6, f"Zero input test failed: diff={diff}" + print(f" [PASS] dtype={dtype_name}, zeros: max_diff={diff:.2e}") + + print(" All extreme value tests passed.\n") + + +def test_accuracy_extreme_values_cast_types(): + """Test accuracy with extreme values across different cast_type values.""" + print("=" * 60) + print("Test 3b: Accuracy with extreme values + different cast_type") + print("=" * 60) + + num_tokens, num_experts = 64, 256 + + for input_dtype in ["float16", "bfloat16"]: + for cast_type in ["float16", "bfloat16", "float32"]: + bias = paddle.zeros([num_experts], dtype="float32") + + # Large positive + gate_out = paddle.full([num_tokens, num_experts], 10.0, dtype=input_dtype) + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias, cast_type) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias, cast_type) + diff = paddle.abs(fused_scores.cast("float32") - ref_scores.cast("float32")).max().item() + atol = 1e-2 if cast_type == "bfloat16" else 1e-5 + status = "PASS" if diff < atol else "FAIL" + print(f" [{status}] input={input_dtype}, cast={cast_type}, " f"large positive: diff={diff:.2e}") + + # Zero values + gate_out = paddle.zeros([num_tokens, num_experts], dtype=input_dtype) + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias, cast_type) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias, cast_type) + diff = paddle.abs(fused_scores.cast("float32") - ref_scores.cast("float32")).max().item() + atol = 1e-2 if cast_type == "bfloat16" else 1e-5 + assert diff < atol, f"Zero input test failed: input={input_dtype}, cast={cast_type}, diff={diff}" + print(f" [PASS] input={input_dtype}, cast={cast_type}, " f"zeros: diff={diff:.2e}") + + print(" All extreme value cast_type tests passed.\n") + + +@pytest.mark.skipif( + os.getenv("RUN_PERFORMANCE_TESTS") != "1", + reason="Performance benchmark is disabled by default. Set RUN_PERFORMANCE_TESTS=1 to enable.", +) +def test_performance(): + """Benchmark fused kernel vs reference implementation using CUDA events.""" + print("=" * 60) + print("Test 4: Performance (CUDA event timing)") + print("=" * 60) + + configs = [ + ("bfloat16", 1, 256), # single token decode + ("bfloat16", 8, 256), # small batch decode + ("bfloat16", 64, 256), # medium batch + ("bfloat16", 256, 256), # typical DeepSeek-V3 config + ("bfloat16", 1024, 256), # large prefill + ("bfloat16", 4096, 256), # very large prefill + ] + + warmup_iters = 100 + bench_iters = 500 + + for dtype_name, num_tokens, num_experts in configs: + gate_out = paddle.randn([num_tokens, num_experts], dtype=dtype_name) + bias = paddle.randn([num_experts], dtype="float32") + + # Warmup fused + for _ in range(warmup_iters): + fused_cast_sigmoid_bias(gate_out, bias) + paddle.device.synchronize() + + # Benchmark fused with CUDA events + start_event = paddle.device.cuda.Event(enable_timing=True) + end_event = paddle.device.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(bench_iters): + fused_cast_sigmoid_bias(gate_out, bias) + end_event.record() + paddle.device.synchronize() + fused_time = start_event.elapsed_time(end_event) / bench_iters * 1e3 # us + + # Warmup reference + for _ in range(warmup_iters): + reference_cast_sigmoid_bias(gate_out, bias) + paddle.device.synchronize() + + # Benchmark reference with CUDA events + start_event = paddle.device.cuda.Event(enable_timing=True) + end_event = paddle.device.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(bench_iters): + reference_cast_sigmoid_bias(gate_out, bias) + end_event.record() + paddle.device.synchronize() + ref_time = start_event.elapsed_time(end_event) / bench_iters * 1e3 # us + + speedup = ref_time / fused_time if fused_time > 0 else float("inf") + print( + f" tokens={num_tokens:5d}, experts={num_experts:3d} | " + f"ref={ref_time:8.1f}us, fused={fused_time:8.1f}us, speedup={speedup:.2f}x" + ) + + print() + print(" Note: The CUDA custom op fuses cast+sigmoid+bias into a single kernel,") + print(" eliminating 2 intermediate tensors and reducing kernel launches from 3 to 1.") + print(" Expected speedup: ~3x over the reference 3-op implementation.") + print(" Performance benchmark complete.\n") + + +def test_is_available(): + """Test is_available() function returns True when GPU ops are available.""" + print("=" * 60) + print("Test: is_available()") + print("=" * 60) + + # In normal GPU test environment, is_available should return True + result = is_available() + assert isinstance(result, bool), f"is_available() should return bool, got {type(result)}" + assert result is True, f"is_available() should return True when GPU ops are compiled, got {result}" + print(f" [PASS] is_available() returned {result}") + print(" is_available() test passed.\n") + + +def test_import_error(): + """Test that ImportError is raised when GPU ops are not available.""" + print("=" * 60) + print("Test 5: Import error handling") + print("=" * 60) + + module_name = "fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias" + gpu_ops_module = "fastdeploy.model_executor.ops.gpu" + + # Save original module references + original_module = sys.modules.pop(module_name, None) + original_gpu_ops = sys.modules.get(gpu_ops_module) + + try: + # Mock the GPU ops module to raise ImportError on import + with mock.patch.dict(sys.modules, {gpu_ops_module: None}): + # Re-import the module so it picks up the mocked (missing) GPU ops + reloaded = importlib.import_module(module_name) + importlib.reload(reloaded) + + # The module should load successfully, but calling the function + # should raise ImportError because the cuda op is unavailable. + dummy_gate = paddle.randn([1, 8], dtype="float32") + dummy_bias = paddle.randn([8], dtype="float32") + try: + reloaded.fused_cast_sigmoid_bias(dummy_gate, dummy_bias) + raise AssertionError("Expected ImportError was not raised") + except ImportError as e: + assert "fused_cast_sigmoid_bias is not available" in str(e), f"Unexpected error message: {e}" + print(f" [PASS] ImportError raised with correct message: {e}") + finally: + # Restore original modules + sys.modules.pop(module_name, None) + if original_module is not None: + sys.modules[module_name] = original_module + if original_gpu_ops is not None: + sys.modules[gpu_ops_module] = original_gpu_ops + + print(" Import error handling test passed.\n") + + +def test_is_available_when_ops_unavailable(): + """Test is_available() returns False when GPU ops are not available.""" + print("=" * 60) + print("Test: is_available() when ops unavailable") + print("=" * 60) + + module_name = "fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias" + gpu_ops_module = "fastdeploy.model_executor.ops.gpu" + + # Save original module references + original_module = sys.modules.pop(module_name, None) + original_gpu_ops = sys.modules.get(gpu_ops_module) + + try: + # Mock the GPU ops module to raise ImportError on import + with mock.patch.dict(sys.modules, {gpu_ops_module: None}): + # Re-import the module so it picks up the mocked (missing) GPU ops + reloaded = importlib.import_module(module_name) + importlib.reload(reloaded) + + # is_available should return False when ops are not available + result = reloaded.is_available() + assert isinstance(result, bool), f"is_available() should return bool, got {type(result)}" + assert result is False, f"is_available() should return False when GPU ops are unavailable, got {result}" + print(f" [PASS] is_available() returned {result} when ops unavailable") + finally: + # Restore original modules + sys.modules.pop(module_name, None) + if original_module is not None: + sys.modules[module_name] = original_module + if original_gpu_ops is not None: + sys.modules[gpu_ops_module] = original_gpu_ops + + print(" is_available() when ops unavailable test passed.\n") + + +if __name__ == "__main__": + print("Running fused_cast_sigmoid_bias tests...\n") + + test_is_available() + test_functionality() + test_functionality_cast_types() + test_accuracy() + test_accuracy_cast_types() + test_accuracy_extreme_values() + test_accuracy_extreme_values_cast_types() + test_import_error() + test_is_available_when_ops_unavailable() + if os.getenv("RUN_PERFORMANCE_TESTS") == "1": + test_performance() + else: + print("Skipping performance benchmark. Set RUN_PERFORMANCE_TESTS=1 to enable.\n") + + print("=" * 60) + print("All tests passed!") + print("=" * 60) diff --git a/tests/layers/test_fused_moe_cutlass_backend.py b/tests/layers/test_fused_moe_cutlass_backend.py index 60a5d2ce313..f3854acaf65 100644 --- a/tests/layers/test_fused_moe_cutlass_backend.py +++ b/tests/layers/test_fused_moe_cutlass_backend.py @@ -57,7 +57,7 @@ def name(self): class DummyFDConfig: def __init__(self, load_choices="default_v1"): self.model_config = types.SimpleNamespace(model="dummy", prefix_layer_name="prefix") - self.load_config = types.SimpleNamespace(load_choices=load_choices) + self.load_config = types.SimpleNamespace(load_choices=load_choices, dynamic_load_weight=False) class DummyLayer(paddle.nn.Layer): @@ -394,7 +394,15 @@ def combine(self, ffn_out, topk_idx, topk_weights, handle, quant_group_size=-1): def test_apply_tp_with_dispatch_and_reduce(self, monkeypatch): def fake_get_moe_scores( - gate_out, n_group, topk_group, top_k, routed_scaling_factor, bias, renormalize, topk_reduce_func=None + gate_out, + n_group, + topk_group, + top_k, + routed_scaling_factor, + bias, + renormalize, + topk_reduce_func=None, + use_fused_cast=False, ): return gate_out, paddle.to_tensor([[0.6, 0.4]]), paddle.to_tensor([[0, 1]]) @@ -831,6 +839,74 @@ def spy_permute(*args, **kwargs): assert not paddle.isnan(out).any(), "output contains NaN" assert not paddle.isinf(out).any(), "output contains Inf" + def test_apply_tp_noaux_tc_with_use_fused_false(self, monkeypatch): + fc1_called = {"count": 0} + + class FC1Proj(paddle.nn.Layer): + def forward(self, x): + fc1_called["count"] += 1 + return x * 2 + + fc1_latent_proj = FC1Proj() + + def fake_get_moe_scores( + gate_out, + n_group, + topk_group, + top_k, + routed_scaling_factor, + bias, + renormalize, + topk_reduce_func=None, + use_fused_cast=False, + ): + return gate_out, paddle.to_tensor([[0.6, 0.4]]), paddle.to_tensor([[0, 1]]) + + def fake_dispatch(*args, **kwargs): + return ( + paddle.ones([1, 2]), + paddle.to_tensor([1, 0]), + paddle.to_tensor([0]), + paddle.to_tensor([[0.6, 0.4]]), + paddle.to_tensor([[0, 1]]), + paddle.to_tensor([0]), + None, + None, + ) + + def fake_reduce(*args, **kwargs): + return paddle.ones([1, 2]) * 5 + + def fake_compute_ffn(*args, **kwargs): + return paddle.ones([1, 2]) * 2 + + monkeypatch.setattr(backend, "get_moe_scores", fake_get_moe_scores, raising=False) + monkeypatch.setattr(backend, "moe_expert_dispatch", fake_dispatch, raising=False) + monkeypatch.setattr(backend, "moe_expert_reduce", fake_reduce, raising=False) + + # Mock compute_ffn on the class to avoid real GPU op data type issues + monkeypatch.setattr(backend.CutlassMoEMethod, "compute_ffn", fake_compute_ffn) + + # Set FD_ENABLE_RL=True to trigger use_fused = False + monkeypatch.setattr(backend.fastdeploy.envs, "FD_ENABLE_RL", True) + + layer = DummyLayer(with_bias=False) + layer.topk_method = "noaux_tc" + # Add necessary attributes for compute_ffn access + layer.up_gate_proj_weight = paddle.zeros([2, 2 * 1], dtype="float16") + layer.down_proj_weight = paddle.zeros([2, 2], dtype="float16") + layer.activation = "silu" + + method = backend.CutlassMoEMethod(None) + + x = paddle.ones([1, 2]) + gate = paddle.nn.Identity() + + method.apply(layer, x, gate, fc1_latent_proj=fc1_latent_proj) + + # Verify fc1_latent_proj was called (line 354/425-426 was executed) + assert fc1_called["count"] > 0, "fc1_latent_proj should have been called" + @requires_cuda def test_apply_ep_prefill_moe_permute_real_ops(self, monkeypatch): """FD_USE_PHI_MOE_PERMUTE=True + w16a16: EP prefill uses real moe_permute / @@ -970,7 +1046,15 @@ def forward(self, x): fc2_latent_proj = FC2Proj() def fake_get_moe_scores( - gate_out, n_group, topk_group, top_k, routed_scaling_factor, bias, renormalize, topk_reduce_func=None + gate_out, + n_group, + topk_group, + top_k, + routed_scaling_factor, + bias, + renormalize, + topk_reduce_func=None, + use_fused_cast=False, ): return gate_out, paddle.to_tensor([[0.6, 0.4]]), paddle.to_tensor([[0, 1]]) diff --git a/tests/layers/test_fused_moe_triton_backend.py b/tests/layers/test_fused_moe_triton_backend.py index dbd00c10ce8..3d7b3673275 100644 --- a/tests/layers/test_fused_moe_triton_backend.py +++ b/tests/layers/test_fused_moe_triton_backend.py @@ -703,3 +703,92 @@ def fake_transform_scale_ue8m0(sf, mn, weight_block_size=None): # Verify the quant_weight_ue8m0 branch was executed assert len(quant_calls) > 0, "quant_weight_ue8m0 should have been called" assert len(transform_calls) > 0, "transform_scale_ue8m0 should have been called" + + def test_triton_weight_only_apply_noaux_tc_with_fd_enable_rl(self, fake_ops, monkeypatch): + quant_config = DummyQuantConfig(is_checkpoint_bf16=False) + layer = DummyLayer(quant_config) + layer.topk_method = "noaux_tc" + method = backend.TritonWeightOnlyMoEMethod(quant_config) + method.create_weights(layer, model_format="torch") + + layer._up_weights = [ + paddle.arange(layer.hidden_size * layer.moe_intermediate_size * 2, dtype="float32").reshape( + [layer.hidden_size, layer.moe_intermediate_size * 2] + ) + for _ in range(layer.num_local_experts) + ] + layer._down_weights = [ + paddle.arange(layer.moe_intermediate_size * layer.hidden_size, dtype="float32").reshape( + [layer.moe_intermediate_size, layer.hidden_size] + ) + for _ in range(layer.num_local_experts) + ] + method.process_loaded_weights(layer, state_dict={}) + + kernel = DummyKernel() + monkeypatch.setattr(backend, "fused_moe_kernel_paddle", kernel, raising=False) + + # Set FD_ENABLE_RL=True to trigger use_fused = False at line 313 + # This should trigger gate_out.cast('float32') at line 315 + monkeypatch.setattr(backend.fastdeploy.envs, "FD_ENABLE_RL", True) + + x = paddle.randn([1, layer.hidden_size], dtype="float32") + gate = DummyGate(layer.num_local_experts) + + captured = {} + + def hook(topk_ids): + captured["topk_ids"] = topk_ids + + _ = method.apply(layer, x, gate, topk_ids_hookfunc=hook) + assert "topk_ids" in captured + + def test_triton_weight_only_apply_noaux_tc_with_non_cuda(self, fake_ops, monkeypatch): + quant_config = DummyQuantConfig(is_checkpoint_bf16=False) + layer = DummyLayer(quant_config) + # Ensure topk_method is "noaux_tc" to enter the target branch + layer.topk_method = "noaux_tc" + method = backend.TritonWeightOnlyMoEMethod(quant_config) + method.create_weights(layer, model_format="torch") + + layer._up_weights = [ + paddle.arange(layer.hidden_size * layer.moe_intermediate_size * 2, dtype="float32").reshape( + [layer.hidden_size, layer.moe_intermediate_size * 2] + ) + for _ in range(layer.num_local_experts) + ] + layer._down_weights = [ + paddle.arange(layer.moe_intermediate_size * layer.hidden_size, dtype="float32").reshape( + [layer.moe_intermediate_size, layer.hidden_size] + ) + for _ in range(layer.num_local_experts) + ] + method.process_loaded_weights(layer, state_dict={}) + + kernel = DummyKernel() + monkeypatch.setattr(backend, "fused_moe_kernel_paddle", kernel, raising=False) + + # Mock current_platform.is_cuda() to return False to trigger use_fused = False at line 313 + # This should trigger gate_out.cast("float32") at line 315 + monkeypatch.setattr(backend, "current_platform", types.SimpleNamespace(is_cuda=lambda: False)) + + x = paddle.randn([2, layer.hidden_size], dtype="float32") + gate = DummyGate(layer.num_local_experts) + + def fake_get_moe_scores(*args, **kwargs): + gate_out = args[0] + token_num = gate_out.shape[0] + top_k = args[3] + topk_ids = paddle.zeros([token_num, top_k], dtype="int64") + topk_weights = paddle.ones([token_num, top_k], dtype="float32") + return gate_out, topk_weights, topk_ids + + monkeypatch.setattr(backend, "get_moe_scores", fake_get_moe_scores) + + captured = {} + + def hook(topk_ids): + captured["topk_ids"] = topk_ids + + _ = method.apply(layer, x, gate, topk_ids_hookfunc=hook) + assert "topk_ids" in captured