Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions custom_ops/gpu_ops/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#pragma once

#include <cuda_fp8.h>

#ifndef PADDLE_WITH_COREX
#include "glog/logging.h"
#endif
Expand Down
181 changes: 181 additions & 0 deletions custom_ops/metax_ops/fused_moe.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.


#pragma once

#include "helper.h"
#include "mc_fused_moe_helper.h"
#include "fused_moe_op.h"

__global__ void compute_total_rows_before_expert_kernel(
int* sorted_experts,
const int64_t sorted_experts_len,
const int64_t num_experts,
int32_t* total_rows_before_expert) {
const int expert = blockIdx.x * blockDim.x + threadIdx.x;
if (expert >= num_experts) return;

total_rows_before_expert[expert] =
find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert);
}

void compute_total_rows_before_expert(int* sorted_indices,
const int64_t total_indices,
const int64_t num_experts,
int32_t* total_rows_before_expert,
cudaStream_t stream) {
const int threads = std::min(int64_t(1024), num_experts);
const int blocks = (num_experts + threads - 1) / threads;

compute_total_rows_before_expert_kernel<<<blocks, threads, 0, stream>>>(
sorted_indices, total_indices, num_experts, total_rows_before_expert);
}

template <paddle::DataType T, typename ElementA, typename ElementB, typename ElementC>
void FusedMoeKernel(const paddle::Tensor& input,
const paddle::Tensor& gate_weight,
const paddle::Tensor& ffn1_weight,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const std::string& quant_method,
const int moe_topk,
const bool group_moe,
const bool norm_topk_prob,
paddle::Tensor* output) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;

auto* output_data = output->data<data_t>();

auto moe_compute = McMoeHelper<data_t, ElementA, ElementB, ElementC>(quant_method);

moe_compute.computeFFN(
&input,
&gate_weight,
&ffn1_weight,
ffn1_scale ? ffn1_scale.get_ptr() : nullptr,
ffn1_bias ? ffn1_bias.get_ptr() : nullptr,
&ffn2_weight,
ffn2_scale ? ffn2_scale.get_ptr() : nullptr,
ffn2_bias ? ffn2_bias.get_ptr() : nullptr,
nullptr,
moe_topk,
group_moe,
norm_topk_prob,
1.0, // ComputeFFN
"ffn",
output);
}


std::vector<paddle::Tensor> FusedExpertMoe(
const paddle::Tensor& input,
const paddle::Tensor& gate_weight,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const std::string& quant_method,
const int moe_topk,
const bool norm_topk_prob,
const bool group_moe) {
const auto input_type = input.dtype();
auto output = paddle::empty_like(input);

switch (input_type) {
case paddle::DataType::BFLOAT16:
FusedMoeKernel<paddle::DataType::BFLOAT16, maca_bfloat16, int8_t, maca_bfloat16>(input,
gate_weight,
ffn1_weight,
ffn1_scale,
ffn1_bias,
ffn2_weight,
ffn2_scale,
ffn2_bias,
quant_method,
moe_topk,
group_moe,
norm_topk_prob,
&output);
break;
// case paddle::DataType::FLOAT16:
// FusedMoeKernel<paddle::DataType::FLOAT16>(input,
// gate_weight,
// ffn1_weight,
// ffn1_scale,
// ffn1_bias,
// ffn2_weight,
// ffn2_scale,
// ffn2_bias,
// quant_method,
// moe_topk,
// group_moe,
// norm_topk_prob,
// &output);
// break;
default:
PD_THROW("Only support bf16 for FusedMoeKernel");
}
return {output};
}

std::vector<std::vector<int64_t>> FusedExpertMoeInferShape(
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& gate_weight_shape,
const std::vector<int64_t>& ffn1_weight_shape,
const std::vector<int64_t>& ffn2_weight_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape) {
return {input_shape};
}

std::vector<paddle::DataType> FusedExpertMoeInferDtype(
const paddle::DataType& input_dtype,
const paddle::DataType& gate_weight_dtype,
const paddle::DataType& ffn1_weight_dtype,
const paddle::DataType& ffn2_weight_dtype,
const paddle::optional<paddle::DataType>& ffn1_bias_dtype,
const paddle::optional<paddle::DataType>& ffn1_scale_dtype,
const paddle::optional<paddle::DataType>& ffn2_bias_dtype,
const paddle::optional<paddle::DataType>& ffn2_scale_dtype) {
return {input_dtype};
}


PD_BUILD_OP(fused_expert_moe)
.Inputs({"input",
"gate_weight",
"ffn1_weight",
"ffn2_weight",
paddle::Optional("ffn1_bias"),
paddle::Optional("ffn1_scale"),
paddle::Optional("ffn2_bias"),
paddle::Optional("ffn2_scale")})
.Outputs({"output"})
.Attrs({"quant_method:std::string",
"moe_topk:int",
"norm_topk_prob:bool",
"group_moe:bool"})
.SetKernelFn(PD_KERNEL(FusedExpertMoe))
.SetInferShapeFn(PD_INFER_SHAPE(FusedExpertMoeInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(FusedExpertMoeInferDtype));
53 changes: 53 additions & 0 deletions custom_ops/metax_ops/fused_moe_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
#include "fused_moe_op.h"

using namespace phi;

template <typename T, int VecSize>
__global__ void moe_token_type_ids_kernel(T *gating_output,
const int *moe_token_type_ids_out,
const int num_rows,
const int num_experts,
const int k) {
const int moe_token_index = blockIdx.x * blockDim.x + threadIdx.x;

if (moe_token_index >= num_rows) {
return;
}

gating_output[moe_token_index * 2] =
gating_output[moe_token_index * 2] +
(moe_token_type_ids_out[moe_token_index]) * -1e10;
gating_output[moe_token_index * 2 + 1] =
gating_output[moe_token_index * 2 + 1] +
(1 - moe_token_type_ids_out[moe_token_index]) * -1e10;
}

template <typename T>
void moe_token_type_ids_kernelLauncher(T *gating_output,
const int *moe_token_type_ids_out,
const int num_rows,
const int num_experts,
const int k,
cudaStream_t stream) {
const int blocks = num_rows * k / 512 + 1;
const int threads = 512;
moe_token_type_ids_kernel<T, 1><<<blocks, 512, 0, stream>>>(
gating_output, moe_token_type_ids_out, num_rows, num_experts, k);
}
123 changes: 123 additions & 0 deletions custom_ops/metax_ops/fused_moe_imp_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <string>
#include <sstream>
#include "cub/cub.cuh"

static const float HALF_FLT_MAX = 65504.F;
static const float HALF_FLT_MIN = -65504.F;
static inline size_t AlignTo16(const size_t& input) {
static constexpr int ALIGNMENT = 16;
return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT);
}

class CubKeyValueSorter {
public:
CubKeyValueSorter() : num_experts_(0), num_bits_(sizeof(int) * 8) {}

explicit CubKeyValueSorter(const int num_experts)
: num_experts_(num_experts),
num_bits_(static_cast<int>(log2(num_experts)) + 1) {}

void update_num_experts(const int num_experts) {
num_experts_ = num_experts;
num_bits_ = static_cast<int>(log2(num_experts)) + 1;
}

size_t getWorkspaceSize(const size_t num_key_value_pairs,
bool descending = false) {
num_key_value_pairs_ = num_key_value_pairs;
size_t required_storage = 0;
int* null_int = nullptr;
if (descending) {
cub::DeviceRadixSort::SortPairsDescending(NULL,
required_storage,
null_int,
null_int,
null_int,
null_int,
num_key_value_pairs,
0,
32);
} else {
cub::DeviceRadixSort::SortPairs(NULL,
required_storage,
null_int,
null_int,
null_int,
null_int,
num_key_value_pairs,
0,
num_bits_);
}
return required_storage;
}

template <typename KeyT>
void run(void* workspace,
const size_t workspace_size,
const KeyT* keys_in,
KeyT* keys_out,
const int* values_in,
int* values_out,
const size_t num_key_value_pairs,
bool descending,
cudaStream_t stream) {
size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs);
size_t actual_ws_size = workspace_size;

if (expected_ws_size > workspace_size) {
std::stringstream err_ss;
err_ss << "[Error][CubKeyValueSorter::run]\n";
err_ss << "Error. The allocated workspace is too small to run this "
"problem.\n";
err_ss << "Expected workspace size of at least " << expected_ws_size
<< " but got problem size " << workspace_size << "\n";
throw std::runtime_error(err_ss.str());
}
if (descending) {
cub::DeviceRadixSort::SortPairsDescending(workspace,
actual_ws_size,
keys_in,
keys_out,
values_in,
values_out,
num_key_value_pairs,
0,
32,
stream);
} else {
cub::DeviceRadixSort::SortPairs(workspace,
actual_ws_size,
keys_in,
keys_out,
values_in,
values_out,
num_key_value_pairs,
0,
num_bits_,
stream);
}
}

private:
size_t num_key_value_pairs_;
int num_experts_;
int num_bits_;
};
Loading
Loading