diff --git a/custom_ops/gpu_ops/helper.h b/custom_ops/gpu_ops/helper.h index 2f276174bd..eaa5e3f09c 100644 --- a/custom_ops/gpu_ops/helper.h +++ b/custom_ops/gpu_ops/helper.h @@ -14,6 +14,8 @@ #pragma once +#include + #ifndef PADDLE_WITH_COREX #include "glog/logging.h" #endif diff --git a/custom_ops/metax_ops/fused_moe.cu b/custom_ops/metax_ops/fused_moe.cu new file mode 100644 index 0000000000..c3f2169d4f --- /dev/null +++ b/custom_ops/metax_ops/fused_moe.cu @@ -0,0 +1,181 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#pragma once + +#include "helper.h" +#include "mc_fused_moe_helper.h" +#include "fused_moe_op.h" + +__global__ void compute_total_rows_before_expert_kernel( + int* sorted_experts, + const int64_t sorted_experts_len, + const int64_t num_experts, + int32_t* total_rows_before_expert) { + const int expert = blockIdx.x * blockDim.x + threadIdx.x; + if (expert >= num_experts) return; + + total_rows_before_expert[expert] = + find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert); +} + +void compute_total_rows_before_expert(int* sorted_indices, + const int64_t total_indices, + const int64_t num_experts, + int32_t* total_rows_before_expert, + cudaStream_t stream) { + const int threads = std::min(int64_t(1024), num_experts); + const int blocks = (num_experts + threads - 1) / threads; + + compute_total_rows_before_expert_kernel<<>>( + sorted_indices, total_indices, num_experts, total_rows_before_expert); +} + +template +void FusedMoeKernel(const paddle::Tensor& input, + const paddle::Tensor& gate_weight, + const paddle::Tensor& ffn1_weight, + const paddle::optional& ffn1_scale, + const paddle::optional& ffn1_bias, + const paddle::Tensor& ffn2_weight, + const paddle::optional& ffn2_scale, + const paddle::optional& ffn2_bias, + const std::string& quant_method, + const int moe_topk, + const bool group_moe, + const bool norm_topk_prob, + paddle::Tensor* output) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + + auto* output_data = output->data(); + + auto moe_compute = McMoeHelper(quant_method); + + moe_compute.computeFFN( + &input, + &gate_weight, + &ffn1_weight, + ffn1_scale ? ffn1_scale.get_ptr() : nullptr, + ffn1_bias ? ffn1_bias.get_ptr() : nullptr, + &ffn2_weight, + ffn2_scale ? ffn2_scale.get_ptr() : nullptr, + ffn2_bias ? ffn2_bias.get_ptr() : nullptr, + nullptr, + moe_topk, + group_moe, + norm_topk_prob, + 1.0, // ComputeFFN + "ffn", + output); +} + + +std::vector FusedExpertMoe( + const paddle::Tensor& input, + const paddle::Tensor& gate_weight, + const paddle::Tensor& ffn1_weight, + const paddle::Tensor& ffn2_weight, + const paddle::optional& ffn1_bias, + const paddle::optional& ffn1_scale, + const paddle::optional& ffn2_bias, + const paddle::optional& ffn2_scale, + const std::string& quant_method, + const int moe_topk, + const bool norm_topk_prob, + const bool group_moe) { + const auto input_type = input.dtype(); + auto output = paddle::empty_like(input); + + switch (input_type) { + case paddle::DataType::BFLOAT16: + FusedMoeKernel(input, + gate_weight, + ffn1_weight, + ffn1_scale, + ffn1_bias, + ffn2_weight, + ffn2_scale, + ffn2_bias, + quant_method, + moe_topk, + group_moe, + norm_topk_prob, + &output); + break; + // case paddle::DataType::FLOAT16: + // FusedMoeKernel(input, + // gate_weight, + // ffn1_weight, + // ffn1_scale, + // ffn1_bias, + // ffn2_weight, + // ffn2_scale, + // ffn2_bias, + // quant_method, + // moe_topk, + // group_moe, + // norm_topk_prob, + // &output); + // break; + default: + PD_THROW("Only support bf16 for FusedMoeKernel"); + } + return {output}; +} + +std::vector> FusedExpertMoeInferShape( + const std::vector& input_shape, + const std::vector& gate_weight_shape, + const std::vector& ffn1_weight_shape, + const std::vector& ffn2_weight_shape, + const paddle::optional>& ffn1_bias_shape, + const paddle::optional>& ffn1_scale_shape, + const paddle::optional>& ffn2_bias_shape, + const paddle::optional>& ffn2_scale_shape) { + return {input_shape}; +} + +std::vector FusedExpertMoeInferDtype( + const paddle::DataType& input_dtype, + const paddle::DataType& gate_weight_dtype, + const paddle::DataType& ffn1_weight_dtype, + const paddle::DataType& ffn2_weight_dtype, + const paddle::optional& ffn1_bias_dtype, + const paddle::optional& ffn1_scale_dtype, + const paddle::optional& ffn2_bias_dtype, + const paddle::optional& ffn2_scale_dtype) { + return {input_dtype}; +} + + +PD_BUILD_OP(fused_expert_moe) + .Inputs({"input", + "gate_weight", + "ffn1_weight", + "ffn2_weight", + paddle::Optional("ffn1_bias"), + paddle::Optional("ffn1_scale"), + paddle::Optional("ffn2_bias"), + paddle::Optional("ffn2_scale")}) + .Outputs({"output"}) + .Attrs({"quant_method:std::string", + "moe_topk:int", + "norm_topk_prob:bool", + "group_moe:bool"}) + .SetKernelFn(PD_KERNEL(FusedExpertMoe)) + .SetInferShapeFn(PD_INFER_SHAPE(FusedExpertMoeInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(FusedExpertMoeInferDtype)); diff --git a/custom_ops/metax_ops/fused_moe_helper.h b/custom_ops/metax_ops/fused_moe_helper.h new file mode 100644 index 0000000000..67c616ce4f --- /dev/null +++ b/custom_ops/metax_ops/fused_moe_helper.h @@ -0,0 +1,53 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h" +#include "fused_moe_op.h" + +using namespace phi; + +template +__global__ void moe_token_type_ids_kernel(T *gating_output, + const int *moe_token_type_ids_out, + const int num_rows, + const int num_experts, + const int k) { + const int moe_token_index = blockIdx.x * blockDim.x + threadIdx.x; + + if (moe_token_index >= num_rows) { + return; + } + + gating_output[moe_token_index * 2] = + gating_output[moe_token_index * 2] + + (moe_token_type_ids_out[moe_token_index]) * -1e10; + gating_output[moe_token_index * 2 + 1] = + gating_output[moe_token_index * 2 + 1] + + (1 - moe_token_type_ids_out[moe_token_index]) * -1e10; +} + +template +void moe_token_type_ids_kernelLauncher(T *gating_output, + const int *moe_token_type_ids_out, + const int num_rows, + const int num_experts, + const int k, + cudaStream_t stream) { + const int blocks = num_rows * k / 512 + 1; + const int threads = 512; + moe_token_type_ids_kernel<<>>( + gating_output, moe_token_type_ids_out, num_rows, num_experts, k); +} diff --git a/custom_ops/metax_ops/fused_moe_imp_op.h b/custom_ops/metax_ops/fused_moe_imp_op.h new file mode 100644 index 0000000000..547b4cacc0 --- /dev/null +++ b/custom_ops/metax_ops/fused_moe_imp_op.h @@ -0,0 +1,123 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include "cub/cub.cuh" + +static const float HALF_FLT_MAX = 65504.F; +static const float HALF_FLT_MIN = -65504.F; +static inline size_t AlignTo16(const size_t& input) { + static constexpr int ALIGNMENT = 16; + return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); +} + +class CubKeyValueSorter { + public: + CubKeyValueSorter() : num_experts_(0), num_bits_(sizeof(int) * 8) {} + + explicit CubKeyValueSorter(const int num_experts) + : num_experts_(num_experts), + num_bits_(static_cast(log2(num_experts)) + 1) {} + + void update_num_experts(const int num_experts) { + num_experts_ = num_experts; + num_bits_ = static_cast(log2(num_experts)) + 1; + } + + size_t getWorkspaceSize(const size_t num_key_value_pairs, + bool descending = false) { + num_key_value_pairs_ = num_key_value_pairs; + size_t required_storage = 0; + int* null_int = nullptr; + if (descending) { + cub::DeviceRadixSort::SortPairsDescending(NULL, + required_storage, + null_int, + null_int, + null_int, + null_int, + num_key_value_pairs, + 0, + 32); + } else { + cub::DeviceRadixSort::SortPairs(NULL, + required_storage, + null_int, + null_int, + null_int, + null_int, + num_key_value_pairs, + 0, + num_bits_); + } + return required_storage; + } + + template + void run(void* workspace, + const size_t workspace_size, + const KeyT* keys_in, + KeyT* keys_out, + const int* values_in, + int* values_out, + const size_t num_key_value_pairs, + bool descending, + cudaStream_t stream) { + size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs); + size_t actual_ws_size = workspace_size; + + if (expected_ws_size > workspace_size) { + std::stringstream err_ss; + err_ss << "[Error][CubKeyValueSorter::run]\n"; + err_ss << "Error. The allocated workspace is too small to run this " + "problem.\n"; + err_ss << "Expected workspace size of at least " << expected_ws_size + << " but got problem size " << workspace_size << "\n"; + throw std::runtime_error(err_ss.str()); + } + if (descending) { + cub::DeviceRadixSort::SortPairsDescending(workspace, + actual_ws_size, + keys_in, + keys_out, + values_in, + values_out, + num_key_value_pairs, + 0, + 32, + stream); + } else { + cub::DeviceRadixSort::SortPairs(workspace, + actual_ws_size, + keys_in, + keys_out, + values_in, + values_out, + num_key_value_pairs, + 0, + num_bits_, + stream); + } + } + + private: + size_t num_key_value_pairs_; + int num_experts_; + int num_bits_; +}; diff --git a/custom_ops/metax_ops/fused_moe_op.h b/custom_ops/metax_ops/fused_moe_op.h new file mode 100644 index 0000000000..b53df12bf1 --- /dev/null +++ b/custom_ops/metax_ops/fused_moe_op.h @@ -0,0 +1,990 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include "fused_moe_imp_op.h" +#include "fused_moe_helper.h" +#include "mctlass/numeric_conversion.h" // BUILD_MARK +// Ignore mctlass warnings about type punning +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic ignored "-Wunused-function" + +// #include "paddle/phi/backends/gpu/gpu_info.h" +#pragma GCC diagnostic pop + +#include "helper.h" + +#define WARP_SIZE 32 + +struct GpuLaunchConfig { + dim3 block_per_grid; + dim3 thread_per_block; +}; + +inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) { + int blocks_x = cols; + int blocks_y = 1; + int blocks_z = 1; + if (blocks_x > 1024) { + blocks_y = 256; + blocks_x = (blocks_x + blocks_y - 1) / blocks_y; + } + + GpuLaunchConfig config; + config.block_per_grid.x = blocks_x; + config.block_per_grid.y = blocks_y; + config.block_per_grid.z = blocks_z; + return config; +} + +// ====================== Softmax things =============================== +// We have our own implementation of softmax here so we can support transposing +// the output in the softmax kernel when we extend this module to support +// expert-choice routing. +template +__launch_bounds__(TPB) __global__ + void group_moe_softmax(const T* input, + T* output, + T* softmax_max_prob, + const int64_t num_cols, + const int64_t softmax_num_rows) { + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + __shared__ float normalizing_factor; + __shared__ float float_max; + __shared__ float max_out; + + int globalIdx = blockIdx.x + blockIdx.y * gridDim.x; + if (globalIdx >= softmax_num_rows) { + return; + } + const int64_t thread_row_offset = globalIdx * num_cols; + + cub::Sum sum; + float threadData(-FLT_MAX); + + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData = max(static_cast(input[idx]), threadData); + } + + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + if (threadIdx.x == 0) { + float_max = maxElem; + } + __syncthreads(); + + threadData = 0; + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData += exp((static_cast(input[idx]) - float_max)); + } + + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); + + if (threadIdx.x == 0) { + normalizing_factor = 1.f / Z; + } + __syncthreads(); + + threadData = 0; + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + const float val = + exp((static_cast(input[idx]) - float_max)) * normalizing_factor; + output[idx] = T(val); + threadData = max(static_cast(T(val)), threadData); + } + + const float maxOut = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + if (threadIdx.x == 0) { + // group max probs + max_out = 1.f / maxOut; + softmax_max_prob[globalIdx] = T(max_out); + } + __syncthreads(); + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + // group softmax normalization + output[idx] = output[idx] * static_cast(max_out); + } +} + +template +__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, + T* output, + int* indices, + int* source_rows, + T* softmax_max_prob, + const int64_t num_experts, + const int64_t k, + const int64_t num_rows) { + using cub_kvp = cub::KeyValuePair; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + const int block_row = blockIdx.x + blockIdx.y * gridDim.x; + if (block_row >= num_rows) { + return; + } + + const bool should_process_row = true; + const int thread_read_offset = block_row * num_experts; + + for (int k_idx = 0; k_idx < k; ++k_idx) { + thread_kvp.key = 0; + thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs_after_softmax[idx]; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) { + const int prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = + BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) { + const int idx = k * block_row + k_idx; + // restore normalized probes + output[idx] = result_kvp.value / T(softmax_max_prob[idx]); + indices[idx] = should_process_row ? result_kvp.key : num_experts; + source_rows[idx] = k_idx * num_rows + block_row; + } + __syncthreads(); + } +} + +template +__launch_bounds__(TPB) __global__ void moe_softmax(const T* input, + T* output, + const int64_t num_cols, + const int64_t num_rows) { + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + __shared__ float normalizing_factor; + __shared__ float float_max; + + int globalIdx = blockIdx.x + blockIdx.y * gridDim.x; + if (globalIdx >= num_rows) { + return; + } + const int64_t thread_row_offset = globalIdx * num_cols; + + cub::Sum sum; + float threadData(-FLT_MAX); + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData = max(static_cast(input[idx]), threadData); + } + + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + if (threadIdx.x == 0) { + float_max = maxElem; + } + __syncthreads(); + + threadData = 0; + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData += exp((static_cast(input[idx]) - float_max)); + } + + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); + + if (threadIdx.x == 0) { + normalizing_factor = 1.f / Z; + } + __syncthreads(); + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + const float val = + exp((static_cast(input[idx]) - float_max)) * normalizing_factor; + output[idx] = T(val); + } +} + +template +__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, + T* output, + int* indices, + int* source_rows, + const int64_t num_experts, + const int64_t k, + const int64_t num_rows) { + using cub_kvp = cub::KeyValuePair; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + const int block_row = blockIdx.x + blockIdx.y * gridDim.x; + if (block_row >= num_rows) { + return; + } + + const bool should_process_row = true; + const int thread_read_offset = block_row * num_experts; + + for (int k_idx = 0; k_idx < k; ++k_idx) { + thread_kvp.key = 0; + thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs_after_softmax[idx]; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) { + const int prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = + BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) { + const int idx = k * block_row + k_idx; + output[idx] = result_kvp.value; + indices[idx] = should_process_row ? result_kvp.key : num_experts; + source_rows[idx] = k_idx * num_rows + block_row; + } + __syncthreads(); + } +} + +// ====================== TopK softmax things =============================== + +/* + A Top-K gating softmax written to exploit when the number of experts in the + MoE layers are a small power of 2. This allows us to cleanly share the rows + among the threads in a single warp and eliminate communication between warps + (so no need to use shared mem). + + It fuses the softmax, max and argmax into a single kernel. + + Limitations: + 1) This implementation is intended for when the number of experts is a small + power of 2. 2) This implementation assumes k is small, but will work for any + k. +*/ + +template +__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ + void topk_gating_softmax(const T* input, + T* output, + const int64_t num_rows, + int* indices, + int* source_rows, + const int64_t k) { + // We begin by enforcing compile time assertions and setting up compile time + // constants. + static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); + static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), + "NUM_EXPERTS must be power of 2"); + static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), + "BYTES_PER_LDG must be power of 2"); + static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); + + // Number of bytes each thread pulls in per load + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); + static constexpr int ELTS_PER_ROW = NUM_EXPERTS; + static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; + static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; + + // Restrictions based on previous section. + static_assert( + VPT % ELTS_PER_LDG == 0, + "The elements per thread must be a multiple of the elements per ldg"); + static_assert(WARP_SIZE % THREADS_PER_ROW == 0, + "The threads per row must cleanly divide the threads per warp"); + static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), + "THREADS_PER_ROW must be power of 2"); + static_assert(THREADS_PER_ROW <= WARP_SIZE, + "THREADS_PER_ROW can be at most warp size"); + + // We have NUM_EXPERTS elements per row. We specialize for small #experts + static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; + static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; + static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; + + // Restrictions for previous section. + static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, + "The elts per row must cleanly divide the total elt per warp"); + + // ===================== From this point, we finally start computing run-time + // variables. ======================== + + // Compute CTA and warp rows. We pack multiple rows into a single warp, and a + // block contains WARPS_PER_CTA warps. This, each block processes a chunk of + // rows. We start by computing the start row for each block. + const int cta_base_row = blockIdx.x * ROWS_PER_CTA; + + // Now, using the base row per thread block, we compute the base row per warp. + const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; + + // The threads in a warp are split into sub-groups that will work on a row. + // We compute row offset for each thread sub-group + const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; + const int thread_row = warp_base_row + thread_row_in_warp; + + // Threads with indices out of bounds should early exit here. + if (thread_row >= num_rows) return; + const bool should_process_row = true; + + // We finally start setting up the read pointers for each thread. First, each + // thread jumps to the start of the row it will read. + const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW; + + // Now, we compute the group each thread belong to in order to determine the + // first column to start loads. + const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; + const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; + const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + + // Determine the pointer type to use to read in the data depending on the + // BYTES_PER_LDG template param. In theory, this can support all powers of 2 + // up to 16. + using AccessType = mctlass::AlignedArray; + + // Finally, we pull in the data from global mem + mctlass::Array row_chunk_input; + AccessType* row_chunk_vec_ptr = + reinterpret_cast(&row_chunk_input); + const AccessType* vec_thread_read_ptr = + reinterpret_cast(thread_read_ptr); +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + } + + using ComputeType = float; + using Converter = mctlass::NumericArrayConverter; + Converter compute_type_converter; + mctlass::Array row_chunk = + compute_type_converter(row_chunk_input); + + // First, we perform a max reduce within the thread. We can do the max in fp16 + // safely (I think) and just convert to float afterwards for the exp + sum + // reduction. + ComputeType thread_max = row_chunk[0]; +#pragma unroll + for (int ii = 1; ii < VPT; ++ii) { + thread_max = max(thread_max, row_chunk[ii]); + } + +// Now, we find the max within the thread group and distribute among the +// threads. We use a butterfly reduce. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + thread_max = + max(thread_max, + __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); + } + + // From this point, thread max in all the threads have the max within the row. + // Now, we subtract the max from each element in the thread and take the exp. + // We also compute the thread local sum. + float row_sum = 0; +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) { + row_chunk[ii] = expf(row_chunk[ii] - thread_max); + row_sum += row_chunk[ii]; + } + +// Now, we perform the sum reduce within each thread group. Similar to the max +// reduce, we use a bufferfly pattern. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); + } + + // From this point, all threads have the max and the sum for their rows in the + // thread_max and thread_sum variables respectively. Finally, we can scale the + // rows for the softmax. Technically, for top-k gating we don't need to + // compute the entire softmax row. We can likely look at the maxes and only + // compute for the top-k values in the row. However, this kernel will likely + // not be a bottle neck and it seems better to closer match torch and find the + // argmax after computing the softmax. + const float reciprocal_row_sum = 1.f / row_sum; + +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) { + row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; + } + + // Now, softmax_res contains the softmax of the row chunk. Now, I want to find + // the topk elements in each row, along with the max index.​ + int start_col = first_elt_read_by_thread; + static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + + for (int k_idx = 0; k_idx < k; ++k_idx) { + // First, each thread does the local argmax + float max_val = row_chunk[0]; + int expert = start_col; +#pragma unroll + for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; + ++ldg, col += COLS_PER_GROUP_LDG) { +#pragma unroll + for (int ii = 0; ii < ELTS_PER_LDG; ++ii) { + float val = row_chunk[ldg * ELTS_PER_LDG + ii]; + + // No check on the experts here since columns with the smallest index + // are processed first and only updated if > (not >=) + if (val > max_val) { + max_val = val; + expert = col + ii; + } + } + } + +// Now, we perform the argmax reduce. We use the butterfly pattern so threads +// reach consensus about the max. This will be useful for K > 1 so that the +// threads can agree on "who" had the max value. That thread can then blank out +// their max with -inf and the warp can run more iterations... +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + float other_max = + __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); + int other_expert = + __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); + + // We want lower indices to "win" in every thread so we break ties this + // way + if (other_max > max_val || + (other_max == max_val && other_expert < expert)) { + max_val = other_max; + expert = other_expert; + } + } + + // Write the max for this k iteration to global memory. + if (thread_group_idx == 0) { + // The lead thread from each sub-group will write out the final results to + // global memory. (This will be a single) thread per row of the + // input/output matrices. + const int idx = k * thread_row + k_idx; + output[idx] = T(max_val); + indices[idx] = should_process_row ? expert : NUM_EXPERTS; + source_rows[idx] = k_idx * num_rows + thread_row; + } + + // Finally, we clear the value in the thread with the current max if there + // is another iteration to run. + if (k_idx + 1 < k) { + const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; + const int thread_to_clear_in_group = + (expert / ELTS_PER_LDG) % THREADS_PER_ROW; + + // Only the thread in the group which produced the max will reset the + // "winning" value to -inf. + if (thread_group_idx == thread_to_clear_in_group) { + const int offset_for_expert = expert % ELTS_PER_LDG; + // Safe to set to any negative value since row_chunk values must be + // between 0 and 1. + row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = + ComputeType(-10000.f); + } + } + } +} + +namespace detail { +// Constructs some constants needed to partition the work across threads at +// compile time. +template +struct TopkConstants { + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); + static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || + EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, + ""); + static constexpr int VECs_PER_THREAD = + std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; + static constexpr int THREADS_PER_ROW = EXPERTS / VPT; + static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; +}; +} // namespace detail + +template +void topk_gating_softmax_launcher_helper(const T* input, + T* output, + int* indices, + int* source_row, + const int64_t num_rows, + const int64_t num_experts, + const int64_t k, + cudaStream_t stream) { + static constexpr uint64_t MAX_BYTES_PER_LDG = 16; + static constexpr int BYTES_PER_LDG = + std::min(MAX_BYTES_PER_LDG, sizeof(T) * EXPERTS); + using Constants = detail::TopkConstants; + static constexpr int VPT = Constants::VPT; + static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; + const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; + + dim3 block_dim(WARP_SIZE, WARPS_PER_TB); + topk_gating_softmax + <<>>( + input, output, num_rows, indices, source_row, k); +} + +template +void topk_gating_softmax_kernelLauncher(const T* input, + T* output, + T* softmax, + int* indices, + int* source_row, + T* softmax_max_prob, + const int64_t num_rows, + const int64_t num_experts, + const int64_t k, + const bool group_moe, + cudaStream_t stream, + const bool topk_only_mode = false) { + if (topk_only_mode) { + static constexpr int TPB = 256; + const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows); + moe_top_k<<>>( + input, output, indices, source_row, num_experts, k, num_rows); + return; + } + static constexpr int WARPS_PER_TB = 4; + + #define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \ + case N: { \ + topk_gating_softmax_launcher_helper( \ + input, output, indices, source_row, num_rows, num_experts, k, stream); \ + break; \ + } + switch (num_experts) { + LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2) + LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4) + LAUNCH_TOPK_GATING_SOFTMAX_HELPER(8) + LAUNCH_TOPK_GATING_SOFTMAX_HELPER(16) + LAUNCH_TOPK_GATING_SOFTMAX_HELPER(32) + LAUNCH_TOPK_GATING_SOFTMAX_HELPER(64) + LAUNCH_TOPK_GATING_SOFTMAX_HELPER(128) + LAUNCH_TOPK_GATING_SOFTMAX_HELPER(256) + + default: { + static constexpr int TPB = 256; + if (group_moe) { + const int group_experts = num_experts / k; + const int softmax_num_rows = num_rows * k; + const auto config_softmax = Get1DBlocksAnd2DGridsMoe(softmax_num_rows); + group_moe_softmax + <<>>( + input, + softmax, + softmax_max_prob, + group_experts, + softmax_num_rows); + const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows); + moe_top_k + <<>>(softmax, + output, + indices, + source_row, + softmax_max_prob, + num_experts, + k, + num_rows); + } else { + const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows); + moe_softmax<<>>( + input, softmax, num_experts, num_rows); + moe_top_k + <<>>(softmax, + output, + indices, + source_row, + num_experts, + k, + num_rows); + } + } + } +} + +// ========================== Permutation things +// ======================================= + +// Duplicated and permutes rows for MoE. In addition, reverse the permutation +// map to help with finalizing routing. + +// "expanded_x_row" simply means that the number of values is num_rows x k. It +// is "expanded" since we will have to duplicate some rows in the input matrix +// to match the dimensions. Duplicates will always get routed to separate +// experts in the end. + +// Note that the expanded_dest_row_to_expanded_source_row map referred to here +// has indices in the range (0, k*rows_in_input - 1). However, it is set up so +// that index 0, rows_in_input, 2*rows_in_input ... (k-1)*rows_in_input all map +// to row 0 in the original matrix. Thus, to know where to read in the source +// matrix, we simply take the modulus of the expanded index. + +template +__global__ void initialize_moe_routing_kernel( + const T* unpermuted_input, + T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + const int64_t num_rows, + const int64_t active_rows, + const int64_t cols, + const int64_t num_rows_k) { + using LoadT = AlignedVector; + LoadT src_vec; + + // Reverse permutation map. + // I do this so that later, we can use the source -> dest map to do the k-way + // reduction and unpermuting. I need the reverse map for that reduction to + // allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 + // thread block will be responsible for all k summations. + const int expanded_dest_row = blockIdx.x + blockIdx.y * gridDim.x; + if (expanded_dest_row >= num_rows_k) return; + const int expanded_source_row = + expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + if (threadIdx.x == 0) { + expanded_source_row_to_expanded_dest_row[expanded_source_row] = + expanded_dest_row; + } + + if ((blockIdx.x + blockIdx.y * gridDim.x) < active_rows) { + // Duplicate and permute rows + const int source_row = expanded_source_row % num_rows; + + const T* source_row_ptr = unpermuted_input + source_row * cols; + T* dest_row_ptr = permuted_output + expanded_dest_row * cols; + + for (int tid = threadIdx.x * VecSize; tid < cols; + tid += blockDim.x * VecSize) { + // dest_row_ptr[tid] = source_row_ptr[tid]; + Load(&source_row_ptr[tid], &src_vec); + Store(src_vec, &dest_row_ptr[tid]); + } + } +} + +template +void initialize_moe_routing_kernelLauncher( + const T* unpermuted_input, + T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + const int64_t num_rows, + const int64_t active_rows, + const int64_t cols, + const int64_t k, + cudaStream_t stream) { + const int threads = std::min(cols, int64_t(1024)); + constexpr int max_pack_size = 16 / sizeof(T); + const auto config_initialize = Get1DBlocksAnd2DGridsMoe(num_rows * k); + if (cols % max_pack_size == 0) { + initialize_moe_routing_kernel + <<>>( + unpermuted_input, + permuted_output, + expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, + num_rows, + k * active_rows, + cols, + num_rows * k); + } else { + initialize_moe_routing_kernel + <<>>( + unpermuted_input, + permuted_output, + expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, + num_rows, + k * active_rows, + cols, + num_rows * k); + } +} + +// ============================== Infer GEMM sizes +// ================================= +__device__ inline int find_total_elts_leq_target(int* sorted_indices, + const int64_t arr_length, + const int64_t target) { + int64_t low = 0, high = arr_length - 1, target_location = -1; + while (low <= high) { + int64_t mid = (low + high) / 2; + + if (sorted_indices[mid] > target) { + high = mid - 1; + } else { + low = mid + 1; + target_location = mid; + } + } + return target_location + 1; +} + +void compute_total_rows_before_expert(int* sorted_indices, + const int64_t total_indices, + const int64_t num_experts, + int32_t* total_rows_before_expert, + cudaStream_t stream); + +// Final kernel to unpermute and scale +// This kernel unpermutes the original data, does the k-way reduction and +// performs the final skip connection. +template +__global__ void finalize_moe_routing_kernel( + const T* expanded_permuted_rows, + T* reduced_unpermuted_output, + const T* bias, + const float* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, + const int64_t cols, + const int64_t k, + const int64_t compute_bias, + const bool norm_topk_prob, + const float routed_scaling_factor, + const int64_t num_rows) { + const int original_row = blockIdx.x + blockIdx.y * gridDim.x; + // const int original_row = blockIdx.x; + // const int num_rows = gridDim.x; + if (original_row >= num_rows) return; + T* reduced_row_ptr = reduced_unpermuted_output + original_row * cols; + + for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { + T thread_output{0.f}; + float row_rescale{0.f}; + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int expanded_original_row = original_row + k_idx * num_rows; + const int expanded_permuted_row = + expanded_source_row_to_expanded_dest_row[expanded_original_row]; + + const int64_t k_offset = original_row * k + k_idx; + const float row_scale = scales[k_offset]; + row_rescale = row_rescale + row_scale; + + const T* expanded_permuted_rows_row_ptr = + expanded_permuted_rows + expanded_permuted_row * cols; + + const int expert_idx = expert_for_source_row[k_offset]; + const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr; + const T bias_value = bias_ptr ? bias_ptr[tid] : T{0.f}; + + thread_output = + static_cast(thread_output) + + row_scale * static_cast( + expanded_permuted_rows_row_ptr[tid] + + bias_value * + static_cast(static_cast(compute_bias))); + } + + thread_output = static_cast(thread_output) / + (norm_topk_prob ? row_rescale : 1.0f) * + routed_scaling_factor; + reduced_row_ptr[tid] = thread_output; + } +} + +template +void finalize_moe_routing_kernelLauncher( + const T* expanded_permuted_rows, + T* reduced_unpermuted_output, + const T* bias, + const float* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, + const int64_t num_rows, + const int64_t cols, + const int64_t k, + const int64_t compute_bias, + const bool norm_topk_prob, + const float routed_scaling_factor, + cudaStream_t stream) { + const int threads = std::min(cols, int64_t(1024)); + const auto config_final = Get1DBlocksAnd2DGridsMoe(num_rows); + + finalize_moe_routing_kernel + <<>>( + expanded_permuted_rows, + reduced_unpermuted_output, + bias, + scales, + expanded_source_row_to_expanded_dest_row, + expert_for_source_row, + cols, + k, + compute_bias, + norm_topk_prob, + routed_scaling_factor, + num_rows); +} + +// ========================= TopK Softmax specializations +// =========================== +template void topk_gating_softmax_kernelLauncher(const float*, + float*, + float*, + int*, + int*, + float*, + const int64_t, + const int64_t, + const int64_t, + const bool, + cudaStream_t, + const bool); +template void topk_gating_softmax_kernelLauncher(const half*, + half*, + half*, + int*, + int*, + half*, + const int64_t, + const int64_t, + const int64_t, + const bool, + cudaStream_t, + const bool); +#ifdef PADDLE_CUDA_BF16 +template void topk_gating_softmax_kernelLauncher(const __nv_bfloat16*, + __nv_bfloat16*, + __nv_bfloat16*, + int*, + int*, + __nv_bfloat16*, + const int64_t, + const int64_t, + const int64_t, + const bool, + cudaStream_t, + const bool); +#endif +// ===================== Specializations for init routing +// ========================= +template void initialize_moe_routing_kernelLauncher(const float*, + float*, + const int*, + int*, + const int64_t, + const int64_t, + const int64_t, + const int64_t, + cudaStream_t); +template void initialize_moe_routing_kernelLauncher(const half*, + half*, + const int*, + int*, + const int64_t, + const int64_t, + const int64_t, + const int64_t, + cudaStream_t); +#ifdef PADDLE_CUDA_BF16 +template void initialize_moe_routing_kernelLauncher(const __nv_bfloat16*, + __nv_bfloat16*, + const int*, + int*, + const int64_t, + const int64_t, + const int64_t, + const int64_t, + cudaStream_t); +#endif +// ==================== Specializations for final routing +// =================================== +template void finalize_moe_routing_kernelLauncher(const float*, + float*, + const float*, + const float*, + const int*, + const int*, + const int64_t, + const int64_t, + const int64_t, + const int64_t, + const bool, + const float, + cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const half*, + half*, + const half*, + const float*, + const int*, + const int*, + const int64_t, + const int64_t, + const int64_t, + const int64_t, + const bool, + const float, + cudaStream_t); +#ifdef PADDLE_CUDA_BF16 +template void finalize_moe_routing_kernelLauncher(const __nv_bfloat16*, + __nv_bfloat16*, + const __nv_bfloat16*, + const float*, + const int*, + const int*, + const int64_t, + const int64_t, + const int64_t, + const int64_t, + const bool, + const float, + cudaStream_t); +#endif diff --git a/custom_ops/metax_ops/mc_fused_moe_helper.h b/custom_ops/metax_ops/mc_fused_moe_helper.h new file mode 100644 index 0000000000..e235ec5e52 --- /dev/null +++ b/custom_ops/metax_ops/mc_fused_moe_helper.h @@ -0,0 +1,417 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "mctlass/numeric_conversion.h" +#include "mctlassEx/mctlassEx.h" +#include "fused_moe_helper.h" + + +template +void mc_grouped_gemm_basic_kernel( + const ElementA* ptrA, + mctlassExOrder_t majorA, + const ElementB* ptrB, + mctlassExOrder_t majorB, + const ElementA* ptrScale, + const ElementA* ptrBias, + ElementC* ptrC, + mctlassExOrder_t majorC, + const int *ptrSegInd, + int numExperts, + int m, // expanded_active_expert_rows + int n, // inter_dim + int k, // hidden_size + mcStream_t stream) { + mctlassExHandle_t handle; + mctlassExHandleCreate(&handle); + + int* ptrMNumTilesInd; + mcMallocAsync((void**)&ptrMNumTilesInd, sizeof(int) * numExperts, stream); + + mctlassExMatrixLayout_t matLayoutA; + mctlassExMatrixLayout_t matLayoutB; + mctlassExMatrixLayout_t matLayoutC; + + // mat A: (m, k) + mctlassExMatrixLayoutCreate(&matLayoutA, mctlassExDataType::MCTLASS_EX_BF16, m, k, k); + mctlassExMatrixLayoutSetAttribute(matLayoutA, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER, + &majorA, sizeof(mctlassExOrder_t)); + // mat B: (num_experts, n, k) + mctlassExMatrixLayoutCreate(&matLayoutB, mctlassExDataType::MCTLASS_EX_INT8, k, n, k); + mctlassExMatrixLayoutSetAttribute(matLayoutB, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER, + &majorB, sizeof(mctlassExOrder_t)); + mctlassExMatrixLayoutSetAttribute(matLayoutB, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT, + &numExperts, sizeof(int)); + // mat C: (m, n) + mctlassExMatrixLayoutCreate(&matLayoutC, mctlassExDataType::MCTLASS_EX_BF16, m, n, n); + mctlassExMatrixLayoutSetAttribute(matLayoutC, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER, + &majorC, sizeof(mctlassExOrder_t)); + // bias: (num_experts, n) + // scale: (num, n) + + mctlassExDesc_t mctlass_desc; + mctlassExCreateDesc(&mctlass_desc); + mctlassExDataType input_type = mctlassExDataType::MCTLASS_EX_BF16; + mctlassExDataType scale_type = mctlassExDataType::MCTLASS_EX_INT8; + mctlassExDataType compute_type = mctlassExDataType::MCTLASS_EX_FP32; + mctlassExEpilogueType epilogue_type = mctlassExEpilogueType::MCTLASS_EX_GEMM_DEFAULT; + if (ptrBias) { + epilogue_type = mctlassExEpilogueType::MCTLASS_EX_GEMM_BIAS_PERGROUP; + } + // set scale + mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_B_SCALE_POINTER, + &ptrScale, sizeof(ptrScale)); + mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_B_SCALE_TYPE, + &scale_type, sizeof(mctlassExDataType)); + // set bias + if (ptrBias) { + mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_BIAS_POINTER, + &ptrBias, sizeof(ptrBias)); + } + // set coumpute type + mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_COMPUTE_TYPE, + &compute_type, sizeof(mctlassExDataType)); + // set epilogue type + mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_EPILOGUE_TYPE, + &epilogue_type, sizeof(mctlassExEpilogueType)); + + const mctlassExContiguousGroupedGemmAlgo_t algo = mctlassExContiguousGroupedGemmAlgo_t::MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_SEGPTR; + int blocksizeM = mctlassExContiguousGroupedGemmGetBlocksizeM(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo); + mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo, ptrSegInd, ptrMNumTilesInd, numExperts, blocksizeM); + + mctlassExContiguousGroupedGemmBasic(handle, mctlass_desc, + ptrA, matLayoutA, + ptrB, matLayoutB, + ptrC, matLayoutC, + ptrSegInd, nullptr, ptrMNumTilesInd, + &algo, nullptr, 0, stream); + + mctlassExHandleDestroy(handle); + mctlassExMatrixLayoutDestroy(matLayoutA); + mctlassExMatrixLayoutDestroy(matLayoutB); + mctlassExMatrixLayoutDestroy(matLayoutC); + mctlassExDestroyDesc(mctlass_desc); + mcFreeAsync(ptrMNumTilesInd, stream); +} + +template +class McMoeHelper { + public: + McMoeHelper(const std::string gemm_method): gemm_method_(gemm_method) {} + + // -------- getWorkspaceSize -------- // + template + size_t getWorkspaceSize(const int64_t num_rows, + const int64_t hidden_size, + const int64_t inter_size, + const int64_t num_experts, + const int64_t k) { + const size_t buf_size = AlignTo16(k * num_rows * hidden_size); + const size_t interbuf_size = AlignTo16(k * num_rows * inter_size); + const size_t padded_experts = AlignTo16(num_experts); + const size_t num_moe_inputs = AlignTo16(k * num_rows); + // softmax output, permuted_rows and permuted_experts have moved to outside + // of moe kernel, allocate them in Encoder or Decoder before invoking + // FfnLayer forward. + size_t total_ws_bytes = + 5 * num_moe_inputs * + sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ + total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data + total_ws_bytes += + padded_experts * sizeof(int32_t); // Hold total_rows_before_expert_ + + const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT); + const size_t sorter_ws_size_bytes = + AlignTo16(sorter_.getWorkspaceSize(num_rows)); + sorter_.update_num_experts(num_experts); + + int64_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result; + if (sorter_ws_size_bytes > bytes_for_fc1_result) { + int64_t remaining_bytes = + AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result); + bytes_for_intermediate_and_sorting += remaining_bytes; + } + + total_ws_bytes += + bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub + // sorting workspace + + int64_t num_softmax_outs = 0; + const bool is_pow_2 = + (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + if (!is_pow_2 || num_experts > 256) { + num_softmax_outs = AlignTo16(num_rows * num_experts); + } + + total_ws_bytes += num_softmax_outs * sizeof(float); + + return total_ws_bytes; + } + + void computeFFN(const paddle::Tensor *input, + const paddle::Tensor *gate_weight, + const paddle::Tensor *ffn1_weight, + const paddle::Tensor *ffn1_scale, + const paddle::Tensor *ffn1_bias, + const paddle::Tensor *ffn2_weight, + const paddle::Tensor *ffn2_scale, + const paddle::Tensor *ffn2_bias, + const paddle::Tensor *moe_token_type_ids, + const int moe_topk, + const bool group_moe, + const bool norm_topk_prob, + const float routed_scaling_factor, + const std::string moe_type, + paddle::Tensor *output) { + auto *input_activations = input->data(); + auto *gating_weights = gate_weight->data(); + const T *fc1_expert_biases = ffn1_bias ? ffn1_bias->data() : nullptr; + const T *fc2_expert_biases = ffn2_bias ? ffn2_bias->data() : nullptr; + + auto *output_ = output->data(); + auto stream = input->stream(); + auto place = input->place(); + auto input_type = input->dtype(); + + auto input_dims = input->dims(); + auto ffn1_dims = ffn1_weight->dims(); + int64_t token_num = 0; + if (input_dims.size() == 3) { + token_num = input_dims[0] * input_dims[1]; + } else { + token_num = input_dims[0]; + } + const int64_t num_rows = token_num; + + const int64_t hidden_size = ffn1_dims[2]; + int64_t inter_dim = 0; + if (moe_type == "qkv") { + inter_dim = ffn1_dims[2] * ffn1_dims[3] * ffn1_dims[4]; + } else { + inter_dim = ffn1_dims[1]; + } + + // if (gemm_method == "weight_only_int4") { + // inter_dim = inter_dim * 2; + // } + + const int64_t inter_size = inter_dim; + const int64_t num_experts = ffn1_dims[0]; + const int64_t k = moe_topk; + + + int64_t bytes = + getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts, k); + + // Pointers + int *expert_for_source_row; + int *source_rows_; + int *permuted_rows_; + int *permuted_experts_; + int *expanded_source_row_to_expanded_dest_row; + + T *permuted_data_; + int32_t *total_rows_before_expert_; + T *fc1_result_; + float *softmax_out_; + + paddle::Tensor ws_ptr_tensor = + GetEmptyTensor({bytes}, paddle::DataType::INT8, place); + int8_t *ws_ptr = ws_ptr_tensor.data(); + + const int64_t buf_size = AlignTo16(k * num_rows * hidden_size); + const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size); + const int64_t padded_experts = AlignTo16(num_experts); + const int64_t num_moe_inputs = AlignTo16(k * num_rows); + + expert_for_source_row = reinterpret_cast(ws_ptr); + source_rows_ = expert_for_source_row + num_moe_inputs; + permuted_rows_ = source_rows_ + num_moe_inputs; + permuted_experts_ = permuted_rows_ + num_moe_inputs; + expanded_source_row_to_expanded_dest_row = + permuted_experts_ + num_moe_inputs; + permuted_data_ = reinterpret_cast( + expanded_source_row_to_expanded_dest_row + num_moe_inputs); + total_rows_before_expert_ = + reinterpret_cast(permuted_data_ + buf_size); + fc1_result_ = + reinterpret_cast(total_rows_before_expert_ + padded_experts); + + const bool is_pow_2 = + (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + if (!is_pow_2 || num_experts > 256) { + softmax_out_ = reinterpret_cast(fc1_result_ + interbuf_size); + } else { + softmax_out_ = nullptr; + } + + paddle::Tensor expert_scales_float_tensor = + GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place); + float *expert_scales_float = expert_scales_float_tensor.data(); + + float *softmax_max_prob = nullptr; + if (group_moe) { + paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor( + {num_rows, moe_topk}, paddle::DataType::FLOAT32, place); + // (TODO: check fill success ?) + paddle::experimental::fill(softmax_max_prob_tensor, 0.f); + softmax_max_prob = softmax_max_prob_tensor.data(); + } + + paddle::Tensor fc1_out_tensor = + GetEmptyTensor({num_rows * k, inter_size}, input_type, place); + T *fc1_out = fc1_out_tensor.data(); + + auto input_cast_tensor = + paddle::experimental::cast(*input, paddle::DataType::FLOAT32); + auto gate_tensor = + paddle::experimental::matmul(input_cast_tensor, *gate_weight); + float *gating_output = gate_tensor.data(); + + if (moe_token_type_ids) { + auto *moe_token_type_ids_out = moe_token_type_ids->data(); + moe_token_type_ids_kernelLauncher(gating_output, + moe_token_type_ids_out, + num_rows, + num_experts, + k, + stream); + } + + topk_gating_softmax_kernelLauncher(gating_output, + expert_scales_float, + softmax_out_, + expert_for_source_row, + source_rows_, + softmax_max_prob, + num_rows, + num_experts, + k, + group_moe, + stream); + + const int64_t sorter_ws_size_bytes = + AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows))); + + sorter_.run(fc1_result_, + sorter_ws_size_bytes, + expert_for_source_row, + permuted_experts_, + source_rows_, + permuted_rows_, + k * num_rows, + false, + stream); + + initialize_moe_routing_kernelLauncher( + input_activations, + permuted_data_, + permuted_rows_, + expanded_source_row_to_expanded_dest_row, + num_rows, + num_rows, + hidden_size, + k, + stream); + + const int64_t expanded_active_expert_rows = k * num_rows; + + compute_total_rows_before_expert(permuted_experts_, + expanded_active_expert_rows, + num_experts, + total_rows_before_expert_, + stream); + + mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ROWMAJOR_ORDER; + mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_COLUMNMAJOR_ORDER; + + mc_grouped_gemm_basic_kernel( + reinterpret_cast(permuted_data_), + row_major, + reinterpret_cast(ffn1_weight->data()), + column_major, + reinterpret_cast(ffn1_scale->data()), + reinterpret_cast(fc1_expert_biases), + reinterpret_cast(fc1_out), + row_major, + total_rows_before_expert_, + num_experts, + expanded_active_expert_rows, + inter_size, + hidden_size, + stream); + + if (moe_type == "ffn") { + auto act_out_tensor = + paddle::experimental::swiglu(fc1_out_tensor, nullptr); + auto act_out = act_out_tensor.data(); + + paddle::Tensor fc2_output_tensor = + GetEmptyTensor({k * num_rows, hidden_size}, input_type, place); + T *fc2_result = fc2_output_tensor.data(); + + mc_grouped_gemm_basic_kernel( + reinterpret_cast(act_out), + row_major, + reinterpret_cast(ffn2_weight->data()), + column_major, + reinterpret_cast(ffn2_scale->data()), + nullptr, + reinterpret_cast(fc2_result), + row_major, + total_rows_before_expert_, + num_experts, + expanded_active_expert_rows, + hidden_size, + inter_size / 2, + stream); + + finalize_moe_routing_kernelLauncher( + fc2_result, + output_, + fc2_expert_biases, + reinterpret_cast(expert_scales_float), + expanded_source_row_to_expanded_dest_row, + expert_for_source_row, + num_rows, + hidden_size, + k, + static_cast(1), + norm_topk_prob, + routed_scaling_factor, + stream); + } else { + finalize_moe_routing_kernelLauncher( + // fc2_result, + fc1_out, + output_, + fc1_expert_biases, // fc2_expert_biases, + reinterpret_cast(expert_scales_float), + expanded_source_row_to_expanded_dest_row, + expert_for_source_row, + num_rows, + inter_size, + k, + static_cast(0), + norm_topk_prob, + routed_scaling_factor, + stream); + } + } + +private: + std::string gemm_method_; + CubKeyValueSorter sorter_; +}; diff --git a/custom_ops/metax_ops/moe_dispatch.cu b/custom_ops/metax_ops/moe_dispatch.cu new file mode 100644 index 0000000000..e855666e00 --- /dev/null +++ b/custom_ops/metax_ops/moe_dispatch.cu @@ -0,0 +1,274 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic ignored "-Wunused-function" +#pragma once + +#include "fused_moe_helper.h" +#include "fused_moe_op.h" +#pragma GCC diagnostic pop + +#include "helper.h" + + +template +void MoeDispatchKernel(const paddle::Tensor& input, + const paddle::Tensor& gating_output, + const int moe_topk, + const bool group_moe, + const bool topk_only_mode, + const int num_rows, + const int hidden_size, + const int expert_num, + paddle::Tensor* permute_input, + paddle::Tensor* tokens_expert_prefix_sum, + paddle::Tensor* permute_indices_per_token, + paddle::Tensor* top_k_weight, + paddle::Tensor* top_k_indices) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + + auto stream = input.stream(); + auto place = input.place(); + + if (group_moe) { + // Check if expert_num is divisible by moe_topk, else throw an error + PADDLE_ENFORCE_EQ(expert_num % moe_topk, + 0, + common::errors::InvalidArgument( + "The number of experts (expert_num) " + "must be divisible by moe_topk. " + "Got expert_num = %d and moe_topk = %d.", + expert_num, + moe_topk)); + } + + const int num_moe_inputs = AlignTo16(num_rows * moe_topk); + const int bytes = num_moe_inputs * sizeof(int); + + CubKeyValueSorter sorter_; + sorter_.update_num_experts(expert_num); + + const int sorter_ws_size_bytes = + AlignTo16(sorter_.getWorkspaceSize(moe_topk * num_rows)); + const int sort_tmp_in_out_size = num_moe_inputs * 2 * sizeof(int); + + paddle::Tensor ws_ptr_tensor = + GetEmptyTensor({bytes + sorter_ws_size_bytes + sort_tmp_in_out_size}, + paddle::DataType::INT8, + place); + + int8_t* ws_ptr = ws_ptr_tensor.data(); + int* source_rows_ = reinterpret_cast(ws_ptr); + int8_t* sorter_ws_ptr = reinterpret_cast(ws_ptr + bytes); + int* permuted_experts_ = + reinterpret_cast(sorter_ws_ptr + sorter_ws_size_bytes); + int* permuted_rows_ = permuted_experts_ + num_moe_inputs; + + int* expert_for_source_row = top_k_indices->data(); + + float* softmax_max_prob = nullptr; + if (group_moe) { + paddle::Tensor softmax_max_prob_tensor = + GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place); + paddle::experimental::fill(softmax_max_prob_tensor, 0.f); + softmax_max_prob = softmax_max_prob_tensor.data(); + } + + float* softmax_out_; + + const bool is_pow_2 = + (expert_num != 0) && ((expert_num & (expert_num - 1)) == 0); + + paddle::Tensor softmax_buffer; + + if (!is_pow_2 || expert_num > 256 || group_moe) { + softmax_buffer = GetEmptyTensor( + {num_rows * expert_num}, paddle::DataType::FLOAT32, place); + softmax_out_ = softmax_buffer.data(); + } else { + softmax_out_ = nullptr; + } + + topk_gating_softmax_kernelLauncher(gating_output.data(), + top_k_weight->data(), + softmax_out_, + expert_for_source_row, + source_rows_, + softmax_max_prob, + num_rows, + expert_num, + moe_topk, + group_moe, + stream, + topk_only_mode); + + sorter_.run(reinterpret_cast(sorter_ws_ptr), + sorter_ws_size_bytes, + expert_for_source_row, + permuted_experts_, + source_rows_, + permuted_rows_, + moe_topk * num_rows, + false, + stream); + + + initialize_moe_routing_kernelLauncher( + input.data(), + permute_input->data(), + permuted_rows_, + permute_indices_per_token->data(), + num_rows, + num_rows, + hidden_size, + moe_topk, + stream); + + + compute_total_rows_before_expert( + permuted_experts_, + moe_topk * num_rows, + expert_num, + tokens_expert_prefix_sum->data(), + stream); +} + + +std::vector MoeExpertDispatch( + const paddle::Tensor& input, + const paddle::Tensor& gating_output, + const int moe_topk, + const bool group_moe, + const bool topk_only_mode) { + const auto input_type = input.dtype(); + auto place = input.place(); + int token_rows = 0; + auto input_dims = input.dims(); + auto gating_dims = gating_output.dims(); + const int expert_num = gating_dims[gating_dims.size() - 1]; + + if (input_dims.size() == 3) { + token_rows = input_dims[0] * input_dims[1]; + } else { + token_rows = input_dims[0]; + } + const int num_rows = token_rows; + const int hidden_size = input.dims()[input_dims.size() - 1]; + + auto permute_input = + GetEmptyTensor({moe_topk * num_rows, hidden_size}, input_type, place); + // correspond to the weighted coefficients of the results from each expert. + auto top_k_weight = + GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place); + auto top_k_indices = + GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::INT32, place); + + auto tokens_expert_prefix_sum = + GetEmptyTensor({expert_num}, paddle::DataType::INT32, place); + auto permute_indices_per_token = + GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place); + + + switch (input_type) { + case paddle::DataType::BFLOAT16: + MoeDispatchKernel(input, + gating_output, + moe_topk, + group_moe, + topk_only_mode, + num_rows, + hidden_size, + expert_num, + &permute_input, + &tokens_expert_prefix_sum, + &permute_indices_per_token, + &top_k_weight, + &top_k_indices); + break; + // case paddle::DataType::FLOAT16: + // MoeDispatchKernel(input, + // gating_output, + // moe_topk, + // group_moe, + // topk_only_mode, + // num_rows, + // hidden_size, + // expert_num, + // &permute_input, + // &tokens_expert_prefix_sum, + // &permute_indices_per_token, + // &top_k_weight, + // &top_k_indices); + // break; + default: + PD_THROW("Only support bf16 for MoeDispatchKernel"); + } + return {permute_input, + tokens_expert_prefix_sum, + permute_indices_per_token, + top_k_weight, + top_k_indices}; +} + + +std::vector> MoeExpertDispatchInferShape( + const std::vector& input_shape, + const std::vector& gating_output_shape, + const int moe_topk) { + int token_rows = -1; + + if (input_shape.size() == 3) { + token_rows = input_shape[0] * input_shape[1]; + } else { + token_rows = input_shape[0]; + } + const int expert_num = gating_output_shape[gating_output_shape.size() - 1]; + const int num_rows = token_rows; + const int hidden_size = input_shape[input_shape.size() - 1]; + + return {{moe_topk * num_rows, hidden_size}, + {expert_num}, + {moe_topk, num_rows}, + {num_rows, moe_topk}, + {num_rows, moe_topk}}; +} + +std::vector MoeExpertDispatchInferDtype( + const paddle::DataType& input_dtype, + const paddle::DataType& gating_output_dtype, + const int moe_topk) { + return {input_dtype, + paddle::DataType::INT64, + paddle::DataType::INT32, + paddle::DataType::FLOAT32, + paddle::DataType::INT32}; +} + + +PD_BUILD_OP(moe_expert_dispatch) + .Inputs({"input", "gating_output"}) + .Outputs({"permute_input", + "tokens_expert_prefix_sum", + "permute_indices_per_token", + "top_k_weight", + "top_k_indices"}) + .Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"}) + .SetKernelFn(PD_KERNEL(MoeExpertDispatch)) + .SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype)); diff --git a/custom_ops/metax_ops/moe_ffn.cu b/custom_ops/metax_ops/moe_ffn.cu new file mode 100644 index 0000000000..8f11912060 --- /dev/null +++ b/custom_ops/metax_ops/moe_ffn.cu @@ -0,0 +1,173 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#pragma once +#include "mc_fused_moe_helper.h" +#include "helper.h" + +template +void McMoeFFNKernel(const paddle::Tensor& permute_input, + const paddle::Tensor& tokens_expert_prefix_sum, + const paddle::Tensor& ffn1_weight, + const paddle::Tensor& ffn2_weight, + const paddle::optional& ffn1_bias, + const paddle::optional& ffn1_scale, + const paddle::optional& ffn2_scale, + const std::string& quant_method, + paddle::Tensor ffn_out) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + + auto ffn_out_ptr = ffn_out.data(); + auto permuted_input_ptr = permute_input.data(); + auto place = permute_input.place(); + auto input_type = permute_input.dtype(); + auto stream = permute_input.stream(); + + const int expanded_active_expert_rows = permute_input.dims()[0]; // permute_input.dims(): m, k + const int num_experts = ffn1_weight.dims()[0]; // batchsize + const int hidden_size = ffn1_weight.dims()[2]; // n + int inter_dim = ffn1_weight.dims()[1]; // k + + const int64_t inter_size = inter_dim; // since weight_only_int_8 + paddle::Tensor fc1_out_tensor = GetEmptyTensor( + {expanded_active_expert_rows, inter_size}, input_type, place); + auto fc1_out_ptr = fc1_out_tensor.data(); + + mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ROWMAJOR_ORDER; + mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_COLUMNMAJOR_ORDER; + + // ffn1 + auto fc1_expert_biases = + ffn1_bias + ? const_cast(ffn1_bias.get_ptr())->data() + : nullptr; + auto fc1_expert_scales = const_cast(ffn1_scale.get_ptr())->data(); + mc_grouped_gemm_basic_kernel( + reinterpret_cast(permuted_input_ptr), + row_major, + reinterpret_cast(ffn1_weight.data()), + column_major, + reinterpret_cast(fc1_expert_scales), + reinterpret_cast(fc1_expert_biases), + reinterpret_cast(fc1_out_ptr), + row_major, + tokens_expert_prefix_sum.data(), + num_experts, + expanded_active_expert_rows, + inter_dim, + hidden_size, + stream); + + // swiglu + auto act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr); + auto act_out = act_out_tensor.data(); + + auto fc2_expert_scales = const_cast(ffn2_scale.get_ptr())->data(); + mc_grouped_gemm_basic_kernel( + reinterpret_cast(act_out), + row_major, + reinterpret_cast(ffn2_weight.data()), + column_major, + reinterpret_cast(fc2_expert_scales), + nullptr, + reinterpret_cast(ffn_out_ptr), + row_major, + tokens_expert_prefix_sum.data(), + num_experts, + expanded_active_expert_rows, + hidden_size, + inter_dim / 2, + stream); +} + +std::vector MoeExpertFFN( + const paddle::Tensor& permute_input, + const paddle::Tensor& tokens_expert_prefix_sum, + const paddle::Tensor& ffn1_weight, + const paddle::Tensor& ffn2_weight, + const paddle::optional& ffn1_bias, + const paddle::optional& ffn1_scale, + const paddle::optional& ffn2_scale, + const std::string& quant_method) { + assert(quant_method == "weight_only_int8"); + const auto input_type = permute_input.dtype(); + auto ffn_out = paddle::empty_like(permute_input); + + switch (input_type) { + case paddle::DataType::BFLOAT16: + McMoeFFNKernel(permute_input, + tokens_expert_prefix_sum, + ffn1_weight, + ffn2_weight, + ffn1_bias, + ffn1_scale, + ffn2_scale, + quant_method, + ffn_out); + break; + // case paddle::DataType::FLOAT16: + // MoeFFNKernel(permute_input, + // tokens_expert_prefix_sum, + // ffn1_weight, + // ffn2_weight, + // ffn1_bias, + // ffn1_scale, + // ffn2_scale, + // quant_method, + // ffn_out); + // break; + default: + PD_THROW("Only support bf16 for MoeExpertFFN"); + } + return {ffn_out}; +} + +std::vector> MoeExpertFFNInferShape( + const std::vector& permute_input_shape, + const std::vector& tokens_expert_prefix_sum_shape, + const std::vector& ffn1_weight_shape, + const std::vector& ffn2_weight_shape, + const paddle::optional>& ffn1_bias_shape, + const paddle::optional>& ffn1_scale_shape, + const paddle::optional>& ffn2_scale_shape) { + return {permute_input_shape}; +} + +std::vector MoeExpertFFNInferDtype( + const paddle::DataType& permute_input_dtype, + const paddle::DataType& tokens_expert_prefix_sum_dtype, + const paddle::DataType& ffn1_weight_dtype, + const paddle::DataType& ffn2_weight_dtype, + const paddle::optional& ffn1_bias_dtype, + const paddle::optional& ffn1_scale_dtype, + const paddle::optional& ffn2_scale_dtype) { + return {permute_input_dtype}; +} + +PD_BUILD_OP(moe_expert_ffn) + .Inputs({"permute_input", + "tokens_expert_prefix_sum", + "ffn1_weight", + "ffn2_weight", + paddle::Optional("ffn1_bias"), + paddle::Optional("ffn1_scale"), + paddle::Optional("ffn2_scale")}) + .Outputs({"output_tensor"}) + .Attrs({"quant_method:std::string"}) + .SetKernelFn(PD_KERNEL(MoeExpertFFN)) + .SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype)); diff --git a/custom_ops/metax_ops/moe_reduce.cu b/custom_ops/metax_ops/moe_reduce.cu new file mode 100644 index 0000000000..be9e84ce77 --- /dev/null +++ b/custom_ops/metax_ops/moe_reduce.cu @@ -0,0 +1,143 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#pragma once + +#include "helper.h" +#include "fused_moe_helper.h" +#include "fused_moe_op.h" + +template +void MoeReduceKernel(const paddle::Tensor& ffn_out, + const paddle::Tensor& top_k_weight, + const paddle::Tensor& permute_indices_per_token, + const paddle::Tensor& top_k_indices, + const paddle::optional& ffn2_bias, + const bool norm_topk_prob, + const float routed_scaling_factor, + const int num_rows, + const int hidden_size, + const int topk, + paddle::Tensor* output) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + auto stream = ffn_out.stream(); + + finalize_moe_routing_kernelLauncher( + ffn_out.data(), + output->data(), + ffn2_bias ? ffn2_bias->data() : nullptr, + top_k_weight.data(), + permute_indices_per_token.data(), + top_k_indices.data(), + num_rows, + hidden_size, + topk, + static_cast(1), + norm_topk_prob, + routed_scaling_factor, + stream); +} + + +std::vector MoeExpertReduce( + const paddle::Tensor& ffn_out, + const paddle::Tensor& top_k_weight, + const paddle::Tensor& permute_indices_per_token, + const paddle::Tensor& top_k_indices, + const paddle::optional& ffn2_bias, + const bool norm_topk_prob, + const float routed_scaling_factor) { + const auto input_type = ffn_out.dtype(); + auto place = ffn_out.place(); + + const int topk = top_k_indices.dims()[1]; + const int num_rows = ffn_out.dims()[0] / topk; + const int hidden_size = ffn_out.dims()[1]; + + auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place); + + // Avoids ‘invalid configuration argument’ when we launch the kernel. + if (ffn_out.dims()[0] == 0) return {output}; + + switch (input_type) { + case paddle::DataType::BFLOAT16: + MoeReduceKernel(ffn_out, + top_k_weight, + permute_indices_per_token, + top_k_indices, + ffn2_bias, + norm_topk_prob, + routed_scaling_factor, + num_rows, + hidden_size, + topk, + &output); + break; + // case paddle::DataType::FLOAT16: + // MoeReduceKernel(ffn_out, + // top_k_weight, + // permute_indices_per_token, + // top_k_indices, + // ffn2_bias, + // norm_topk_prob, + // routed_scaling_factor, + // num_rows, + // hidden_size, + // topk, + // &output); + // break; + default: + PD_THROW("Only support bf16 for MoeDispatchKernel"); + } + return {output}; +} + + +std::vector> MoeExpertReduceInferShape( + const std::vector& ffn_out_shape, + const std::vector& top_k_weight_shape, + const std::vector& permute_indices_per_token_shape, + const std::vector& top_k_indices_shape, + const paddle::optional>& ffn2_bias_shape) { + const int topk = top_k_indices_shape[1]; + std::vector fused_moe_out_shape = {ffn_out_shape[0] / topk, + ffn_out_shape[1]}; + + return {fused_moe_out_shape}; +} + +std::vector MoeExpertReduceInferDtype( + const paddle::DataType& ffn_out_dtype, + const paddle::DataType& top_k_weight_dtype, + const paddle::DataType& permute_indices_per_token_dtype, + const paddle::DataType& top_k_indices_dtype, + const paddle::optional& ffn2_bias_dtype) { + return {ffn_out_dtype}; +} + + +PD_BUILD_OP(moe_expert_reduce) + .Inputs({"ffn_out", + "top_k_weight", + "permute_indices_per_token", + "top_k_indices", + paddle::Optional("ffn2_bias")}) + .Outputs({"output"}) + .Attrs({"norm_topk_prob:bool", "routed_scaling_factor:float"}) + .SetKernelFn(PD_KERNEL(MoeExpertReduce)) + .SetInferShapeFn(PD_INFER_SHAPE(MoeExpertReduceInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertReduceInferDtype)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 3ca8c3c3f3..bb3eeeed91 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -595,6 +595,10 @@ def find_end_files(directory, end_str): "gpu_ops/moe/tritonmoe_preprocess.cu", "gpu_ops/moe/moe_topk_select.cu", "gpu_ops/recover_decode_task.cu", + "metax_ops/moe_dispatch.cu", + "metax_ops/moe_ffn.cu", + "metax_ops/moe_reduce.cu", + "metax_ops/fused_moe.cu", ] sources += find_end_files("gpu_ops/speculate_decoding", ".cu") @@ -615,7 +619,7 @@ def find_end_files(directory, end_str): ], }, library_dirs=[os.path.join(maca_path, "lib")], - extra_link_args=["-lruntime_cu"], + extra_link_args=["-lruntime_cu", "-lmctlassEx"], include_dirs=[ os.path.join(maca_path, "include"), os.path.join(maca_path, "include/mcr"), diff --git a/docs/get_started/installation/metax_gpu.md b/docs/get_started/installation/metax_gpu.md index a6c71e58d7..eb4ea84c94 100644 --- a/docs/get_started/installation/metax_gpu.md +++ b/docs/get_started/installation/metax_gpu.md @@ -19,8 +19,8 @@ docker login --username=cr_temp_user --password=eyJpbnN0YW5jZUlkIjoiY3JpLXpxYTIz ## 2. paddlepaddle and custom device installation ```shell -1)pip install paddlepaddle==3.0.0.dev20250729 -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/ -2)pip install paddle-metax-gpu==3.0.0.dev20250807 -i https://www.paddlepaddle.org.cn/packages/nightly/maca/ +1)pip install paddlepaddle==3.0.0.dev20250825 -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/ +2)pip install paddle-metax-gpu==3.0.0.dev20250826 -i https://www.paddlepaddle.org.cn/packages/nightly/maca/ ``` ## 3. Build Wheel from Source @@ -47,6 +47,8 @@ from fastdeploy.model_executor.ops.gpu import beam_search_softmax If the above code executes successfully, the environment is ready. ## 5. Demo + +```python from fastdeploy import LLM, SamplingParams prompts = [ @@ -68,7 +70,9 @@ for output in outputs: print(prompt) print(generated_text) print("-" * 50) +``` +``` Output: INFO 2025-08-18 10:54:18,455 416822 engine.py[line:202] Waiting worker processes ready... Loading Weights: 100%|█████████████████████████████████████████████████████████████████████████| 100/100 [03:33<00:00, 2.14s/it] @@ -81,3 +85,4 @@ Generated 1 outputs Hello. My name is Alice and I'm here to help you. What can I do for you today? Hello Alice! I'm trying to organize a small party +``` diff --git a/fastdeploy/model_executor/layers/backends/metax/__init__.py b/fastdeploy/model_executor/layers/backends/metax/__init__.py index 365e50e8b6..568c7d9972 100644 --- a/fastdeploy/model_executor/layers/backends/metax/__init__.py +++ b/fastdeploy/model_executor/layers/backends/metax/__init__.py @@ -13,9 +13,11 @@ # limitations under the License. from .attention.flash_attn_backend import FlashAttentionBackend +from .moe.fused_moe_cutlass_metax_backend import MetaxCutlassWeightOnlyMoEMethod from .moe.fused_moe_triton_metax_backend import MetaxTritonWeightOnlyMoEMethod __all__ = [ "FlashAttentionBackend", "MetaxTritonWeightOnlyMoEMethod", + "MetaxCutlassWeightOnlyMoEMethod", ] diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/flash_attention_interface.py b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attention_interface.py index f7520d2382..c1480170e6 100644 --- a/fastdeploy/model_executor/layers/backends/metax/attention/flash_attention_interface.py +++ b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attention_interface.py @@ -1,3 +1,17 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from typing import Optional, Tuple, Union diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py index a4993d165b..8b673d23f4 100644 --- a/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py @@ -1,4 +1,3 @@ -""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" from __future__ import annotations @@ -261,27 +259,9 @@ def get_splited_qkv( forward_meta: ForwardMeta, cu_seqlens_q: paddle.Tensor, batch_ids=None, - is_decode=False, ): - q_end = self.num_heads * self.head_dim - k_end = q_end + self.kv_num_heads * self.head_dim - v_end = k_end + self.kv_num_heads * self.head_dim - assert v_end == qkv.shape[-1], f"Shape mismatch: {v_end} vs {qkv.shape[-1]}" - assert qkv.shape[0] == cu_seqlens_q[-1], f"Shape mismatch: {qkv.shape[0]} vs {cu_seqlens_q[-1]}" - - if batch_ids is None: - batch_ids = list(range(forward_meta.seq_lens_this_time.shape[0])) - - q = qkv[..., 0:q_end] - k = qkv[..., q_end:k_end] - v = qkv[..., k_end:v_end] - - q = q.view([-1, self.num_heads, self.head_dim]) - k = k.view([-1, self.kv_num_heads, self.head_dim]) - v = v.view([-1, self.kv_num_heads, self.head_dim]) - - if is_decode: - return q, k, v + qkv = qkv.view([-1, self.num_heads + self.kv_num_heads * 2, self.head_dim]) + q, k, v = qkv.split(num_or_sections=[self.num_heads, self.kv_num_heads, self.kv_num_heads], axis=-2) for idx in range(len(cu_seqlens_q) - 1): batch_idx = batch_ids[idx] @@ -375,41 +355,6 @@ def update_kv_cache( cache_start += self.block_size tensor_start = tensor_end - def merge_output(self, prefill_out, decode_out, forward_meta: ForwardMeta): - assert not (prefill_out is None and decode_out is None), "prefill and decode output cannot both be None" - if prefill_out is None: - return decode_out - elif decode_out is None: - return prefill_out - else: - prefill_out = prefill_out - decode_out = decode_out - - merged_output = [] - prefill_tensor_start = 0 - decode_tensor_start = 0 - for seq_lens_this_time in forward_meta.seq_lens_this_time: - if seq_lens_this_time == 0: - continue - if seq_lens_this_time > 1: - tensor_end = prefill_tensor_start + seq_lens_this_time - merged_output.append(prefill_out[prefill_tensor_start:tensor_end, :, :]) - prefill_tensor_start = tensor_end - else: - assert seq_lens_this_time == 1 - tensor_end = decode_tensor_start + seq_lens_this_time - merged_output.append(decode_out[decode_tensor_start:tensor_end, :, :]) - decode_tensor_start = tensor_end - - assert ( - prefill_tensor_start == prefill_out.shape[0] - ), f"prefill merged unfinished: {prefill_tensor_start} vs {prefill_out.shape[0]}" - assert ( - decode_tensor_start == decode_out.shape[0] - ), f"decode merged unfinished: {decode_tensor_start} vs {decode_out.shape[0]}" - merged_output = paddle.concat(merged_output, axis=0) - return merged_output - def forward_prefill(self, prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta: ForwardMeta): prefill_q, prefill_k, prefill_v = self.get_splited_qkv( @@ -438,23 +383,17 @@ def forward_prefill(self, prefill_qkv, layer_id, k_cache_id, v_cache_id, forward return prefill_out def forward_decode(self, decode_qkv, k_cache_id, v_cache_id, forward_meta: ForwardMeta): - cache_k = forward_meta.caches[k_cache_id] - cache_v = forward_meta.caches[v_cache_id] - cu_seq_lens = list(range(self.decode_len + 1)) - - q, k, v = self.get_splited_qkv(decode_qkv, forward_meta, cu_seq_lens, self.batch_ids_decode, is_decode=True) - decoder_q = q.view([self.decode_len, 1, self.num_heads, self.head_dim]) - decoder_k_ = k.view([self.decode_len, 1, self.kv_num_heads, self.head_dim]) - decoder_v_ = v.view([self.decode_len, 1, self.kv_num_heads, self.head_dim]) + qkv = decode_qkv.view([-1, 1, self.num_heads + self.kv_num_heads * 2, self.head_dim]) + q, k, v = qkv.split(num_or_sections=[self.num_heads, self.kv_num_heads, self.kv_num_heads], axis=-2) decode_out = flash_attn_kvcache_func( - decoder_q, - cache_k, - cache_v, + q, + forward_meta.caches[k_cache_id], + forward_meta.caches[v_cache_id], self.seq_lens_dec, self.block_table_dec, - decoder_k_, - decoder_v_, + k, + v, rotary_cos=forward_meta.rotary_embs[0, 0, :, 0, :].astype("bfloat16"), rotary_sin=forward_meta.rotary_embs[1, 0, :, 0, :].astype("bfloat16"), causal=self.causal, diff --git a/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_cutlass_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_cutlass_metax_backend.py new file mode 100644 index 0000000000..19b2ba8f8d --- /dev/null +++ b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_cutlass_metax_backend.py @@ -0,0 +1,370 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import nn +from paddle.nn.quant import weight_quantize + +import fastdeploy +from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce +from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase +from fastdeploy.model_executor.layers.utils import get_tensor +from fastdeploy.model_executor.ops.gpu import fused_expert_moe +from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs + + +class MetaxCutlassMoEMethod(MoEMethodBase): + """ + Use Cutlass Group Gemm to compute Fused MoE. + This method is the oldest way to compute MoE in Paddle. + """ + + def process_loaded_weights(self, layer: nn.Layer, state_dict): + up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = ( + layer.extract_moe_ffn_weights(state_dict) + ) + stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0) + stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0) + + layer.up_gate_proj_weight.set_value(stacked_up_gate_proj_weights) + layer.down_proj_weight.set_value(stacked_down_proj_weights) + + def compute_ffn( + self, + layer: nn.Layer, + permute_input: paddle.Tensor, + token_nums_per_expert: paddle.Tensor, + expert_idx_per_token: paddle.Tensor, + used_in_ep_low_latency: bool = False, + estimate_total_token_nums: int = -1, + ): + """ + Paddle Cutlass compute Fused MoE. + """ + return fastdeploy.model_executor.ops.gpu.moe_expert_ffn( + permute_input, + token_nums_per_expert, + getattr(layer, self.added_weight_attrs[0]), + getattr(layer, self.added_weight_attrs[1]), + None, + (layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None), + (layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None), + "weight_only_int8", + ) + + def apply_ep_prefill( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate: nn.Layer, + ) -> paddle.Tensor: + """ + Apply the EP prefill method. + """ + raise NotImplementedError + + def apply_ep_decode( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate: nn.Layer, + ) -> paddle.Tensor: + """ + Apply the EP decoder method. + """ + raise NotImplementedError + + def apply_tp( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate: nn.Layer, + ) -> paddle.Tensor: + """ + Paddle Cutlass compute Fused MoE. + """ + + fused_moe_out = fused_expert_moe( + x, + gate.weight, + getattr(layer, self.added_weight_attrs[0]), + getattr(layer, self.added_weight_attrs[1]), + None, + (layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None), + None, + (layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None), + "weight_only_int8", + layer.top_k, + True, + False, + ) + if layer.reduce_results and layer.tp_size > 1: + tensor_model_parallel_all_reduce(fused_moe_out) + + return fused_moe_out + + +class MetaxCutlassWeightOnlyMoEMethod(MetaxCutlassMoEMethod): + """ + weight only for moe + """ + + def __init__(self, quant_config=None): + """ + weight only for moe + """ + super().__init__(quant_config) + # print(f"[DEBUG] quant_config: {quant_config}") + self.quant_config = quant_config + self.moe_quant_type = self.quant_config.algo + self.pack_num = 1 + + def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False): + """ + Paddle cutlass process prequanted weights. + """ + up_gate_proj_expert_weight_key = layer.weight_key_map.get("up_gate_proj_expert_weight_key", None) + down_proj_expert_weight_key = layer.weight_key_map.get("down_proj_expert_weight_key", None) + up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None) + down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None) + + up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight( + state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, is_rearrange + ) + # self.check(layer, up_gate_proj_weights, down_proj_weights) + up_gate_proj_weight_scale = [] + down_proj_weight_scale = [] + for expert_idx in logical_expert_ids: + up_gate_proj_weight_scale.append( + get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx))) + ) + down_proj_weight_scale.append( + get_tensor(state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx))) + ) + + up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0) + down_proj_weight = paddle.stack(down_proj_weights, axis=0) + up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0) + down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0) + + name_tensor_map = { + "up_gate_proj_weight": up_gate_proj_weight, + "down_proj_weight": down_proj_weight, + "up_gate_proj_weight_scale": up_gate_proj_weight_scale, + "down_proj_weight_scale": down_proj_weight_scale, + } + for name, tensor in name_tensor_map.items(): + getattr(layer, name).set_value(tensor) + + def create_weights(self, layer: nn.Layer, **extra_weight_attrs): + """ + Paddle cutlass create weight process. + """ + self.default_dtype = layer._helper.get_default_dtype() + if self.moe_quant_type == "weight_only_int4": + self.up_gate_proj_weight_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size, + layer.hidden_size, + ] + else: + self.up_gate_proj_weight_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size * 2, + layer.hidden_size, + ] + if self.moe_quant_type == "weight_only_int4": + self.down_proj_weight_shape = [ + layer.num_local_experts, + layer.hidden_size // 2, + layer.moe_intermediate_size, + ] + else: + self.down_proj_weight_shape = [ + layer.num_local_experts, + layer.hidden_size, + layer.moe_intermediate_size, + ] + self.up_gate_proj_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2] + self.down_proj_scale_shape = [layer.num_local_experts, layer.hidden_size] + + if layer.fd_config.load_config.load_choices == "default_v1": + layer.up_gate_proj_weight = layer.create_parameter( + shape=[layer.num_experts, layer.hidden_size, layer.moe_intermediate_size * 2], + dtype=layer.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ) + + layer.down_proj_weight = layer.create_parameter( + shape=[layer.num_experts, layer.moe_intermediate_size, layer.hidden_size], + dtype=layer.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ) + + set_weight_attrs( + layer.up_gate_proj_weight, + { + **extra_weight_attrs, + "tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True), + }, + ) + set_weight_attrs( + layer.down_proj_weight, + { + **extra_weight_attrs, + "tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False), + }, + ) + else: + self.weight_dtype = "int8" + + up_gate_proj_weight_name = self.added_weight_attrs[0] + down_proj_weight_name = self.added_weight_attrs[1] + up_gate_proj_scale_name = self.added_scale_attrs[0] + down_proj_scale_name = self.added_scale_attrs[1] + + setattr( + layer, + up_gate_proj_weight_name, + layer.create_parameter( + shape=self.up_gate_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_weight_name, + layer.create_parameter( + shape=self.down_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # weight_scale + setattr( + layer, + up_gate_proj_scale_name, + layer.create_parameter( + shape=self.up_gate_proj_scale_shape, + dtype=self.default_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_scale_name, + layer.create_parameter( + shape=self.down_proj_scale_shape, + dtype=self.default_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + moe_extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}} + set_weight_attrs(layer.up_gate_proj_weight, moe_extra_weight_attrs) + set_weight_attrs(layer.down_proj_weight, moe_extra_weight_attrs) + scale_extra_weight_attrs = { + **extra_weight_attrs, + "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "up": 0, "down": None}, + } + set_weight_attrs(layer.up_gate_proj_weight_scale, scale_extra_weight_attrs) + set_weight_attrs(layer.down_proj_weight_scale, scale_extra_weight_attrs) + + def process_weights_after_loading(self, layer): + """ """ + if not layer.fd_config.load_config.load_choices == "default_v1": + return + weight_id_map = {"gate_up": 0, "down": 1} + if ( + hasattr(layer.up_gate_proj_weight, "tensor_track") + and layer.up_gate_proj_weight.tensor_track is not None + and layer.up_gate_proj_weight.tensor_track.is_fully_copied() + ): + weight_type = "gate_up" + else: + weight_type = "down" + + # 1.init shape and type + # weight + weight_name = self.added_weight_attrs[weight_id_map[weight_type]] + unquantized_weight_name = weight_name.replace("quant_weight", "weight") + weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape + weight_dtype = "int8" + # scale + scale_name = self.added_scale_attrs[weight_id_map[weight_type]] + scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape + scale_dtype = self.default_dtype + + # 2.crate tmp tensor + + weight = paddle.empty(weight_shape, dtype=weight_dtype) + scale = paddle.empty(scale_shape, dtype=scale_dtype) + + # 3.quantize weight + + for expert_id in range(layer.num_experts): + weight[expert_id], scale[expert_id] = weight_quantize( + getattr(layer, unquantized_weight_name)[expert_id], algo=self.moe_quant_type, arch=80, group_size=-1 + ) + + free_tensor(getattr(layer, unquantized_weight_name)) + + # create weight + setattr( + layer, + weight_name, + layer.create_parameter( + shape=weight_shape, + dtype=weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # create scale + setattr( + layer, + scale_name, + layer.create_parameter( + shape=scale_shape, + dtype=scale_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + getattr(layer, weight_name).copy_(weight, False) + getattr(layer, scale_name).copy_(scale, False) + + def process_loaded_weights(self, layer: nn.Layer, state_dict): + """ + Paddle cutlass load weight process. + """ + up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict) + self.check(layer, up_gate_proj_weights, down_proj_weights) + for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]): + weight_name = self.added_weight_attrs[idx] + scale_name = self.added_scale_attrs[idx] + + weight_list = [] + weight_scale_list = [] + for i in range(layer.num_local_experts): + quant_weight, scale = weight_quantize( + weight_tensor[i], algo=self.moe_quant_type, arch=80, group_size=-1 + ) + quant_weight = paddle.transpose(quant_weight, [1, 0]) + weight_list.append(quant_weight) + weight_scale_list.append(scale) + quanted_weight = paddle.stack(weight_list, axis=0) + getattr(layer, weight_name).set_value(quanted_weight) + + quanted_weight_scale = paddle.stack(weight_scale_list, axis=0) + getattr(layer, scale_name).set_value(quanted_weight_scale) diff --git a/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py index 907ddff65f..4e4e867899 100644 --- a/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py @@ -1,4 +1,3 @@ -""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" import paddle from paddle import nn import fastdeploy +from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess from fastdeploy.utils import ceil_div @@ -153,7 +152,6 @@ def apply( Triton compute Fused MoE. """ token_num = x.shape[0] - top_k = layer.top_k num_local_experts = layer.num_local_experts top_k = layer.top_k moe_intermediate_size = layer.moe_intermediate_size @@ -172,21 +170,12 @@ def apply( dtype=x.dtype, ) - if self.quant_config is not None: - config = { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - } - else: - config = { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - } - + config = { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + } sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess( topk_ids, num_local_experts, config["BLOCK_SIZE_M"] ) @@ -292,4 +281,6 @@ def apply( down_proj_out.reshape_([token_num, top_k, hidden_size]) out = down_proj_out.sum(axis=1) + if layer.tp_size > 1: + tensor_model_parallel_all_reduce(out) return out diff --git a/fastdeploy/model_executor/layers/backends/metax/moe/triton_moe_kernels.py b/fastdeploy/model_executor/layers/backends/metax/moe/triton_moe_kernels.py index e859e7ce45..a359330c55 100644 --- a/fastdeploy/model_executor/layers/backends/metax/moe/triton_moe_kernels.py +++ b/fastdeploy/model_executor/layers/backends/metax/moe/triton_moe_kernels.py @@ -1,5 +1,4 @@ -""" -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" import triton import triton.language as tl diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 58c87cf339..661031f513 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -50,12 +50,7 @@ def get_moe_method(): from fastdeploy.model_executor.layers.backends import GCUFusedMoeMethod return GCUFusedMoeMethod(None) - elif current_platform.is_maca(): - from fastdeploy.model_executor.layers.backends import ( - MetaxTritonWeightOnlyMoEMethod, - ) - return MetaxTritonWeightOnlyMoEMethod(None) elif current_platform.is_intel_hpu(): from fastdeploy.model_executor.layers.backends import HpuMoEMethod diff --git a/fastdeploy/model_executor/layers/quantization/weight_only.py b/fastdeploy/model_executor/layers/quantization/weight_only.py index 070d0fbf41..e2ea626612 100644 --- a/fastdeploy/model_executor/layers/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/quantization/weight_only.py @@ -122,10 +122,18 @@ def get_quant_method(self, layer) -> Optional[QuantMethodBase]: elif current_platform.is_maca(): if isinstance(layer, FusedMoE): from fastdeploy.model_executor.layers.backends import ( + MetaxCutlassWeightOnlyMoEMethod, MetaxTritonWeightOnlyMoEMethod, ) - return MetaxTritonWeightOnlyMoEMethod(self) + if layer.use_method == "cutlass": + + return MetaxCutlassWeightOnlyMoEMethod(self) + elif layer.use_method == "triton": + + return MetaxTritonWeightOnlyMoEMethod(self) + else: + raise ValueError(f"Unsupported MOE backend {layer.use_method}") else: return GPUWeightOnlyLinearMethod(self) diff --git a/requirements_metaxgpu.txt b/requirements_metaxgpu.txt index 7aa310fa23..26f6de0954 100644 --- a/requirements_metaxgpu.txt +++ b/requirements_metaxgpu.txt @@ -8,9 +8,9 @@ aiozmq openai>=1.93.0 tqdm pynvml -uvicorn +uvicorn==0.29.0 fastapi -paddleformers +paddleformers>=0.2 redis etcd3 httpx @@ -30,11 +30,12 @@ use-triton-in-paddle crcmod fastsafetensors==0.1.14 msgpack +modelscope opentelemetry-api>=1.24.0 opentelemetry-sdk>=1.24.0 opentelemetry-instrumentation-redis opentelemetry-instrumentation-mysql -opentelemetry-distro  +opentelemetry-distro opentelemetry-exporter-otlp opentelemetry-instrumentation-fastapi partial_json_parser