diff --git a/src/infiniop/ops/topkrouter/cuda/kernel.cuh b/src/infiniop/ops/topkrouter/cuda/kernel.cuh index 0832c5b93..0e1578b50 100644 --- a/src/infiniop/ops/topkrouter/cuda/kernel.cuh +++ b/src/infiniop/ops/topkrouter/cuda/kernel.cuh @@ -6,16 +6,16 @@ #include #include #include -#include -#include -#include +// #include +// #include +// #include template inline __device__ float exp_func(T x) { float data; if constexpr (std::is_same_v) { data = x; - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { data = __bfloat162float(x); } else if constexpr (std::is_same_v) { data = __half2float(x); diff --git a/src/infiniop/ops/topkrouter/metax/topkrouter_metax.h b/src/infiniop/ops/topkrouter/metax/topkrouter_metax.h new file mode 100644 index 000000000..62f17dc6c --- /dev/null +++ b/src/infiniop/ops/topkrouter/metax/topkrouter_metax.h @@ -0,0 +1,8 @@ +#ifndef __TOPKROUTER_METAX_H__ +#define __TOPKROUTER_METAX_H__ + +#include "../topkrouter.h" + +DESCRIPTOR(metax) + +#endif diff --git a/src/infiniop/ops/topkrouter/metax/topkrouter_metax.maca b/src/infiniop/ops/topkrouter/metax/topkrouter_metax.maca new file mode 100644 index 000000000..71c2d37d6 --- /dev/null +++ b/src/infiniop/ops/topkrouter/metax/topkrouter_metax.maca @@ -0,0 +1,93 @@ +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_kernel_common.h" +#include "../cuda/kernel.cuh" +#include "topkrouter_metax.h" +#include + +namespace op::topkrouter::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t correction_bias_desc) { + auto result = TopkrouterInfo::create(x_desc); + CHECK_RESULT(result); + auto info = result.take(); + + if (info.x_strides[1] != 1) { + return INFINI_STATUS_BAD_TENSOR_STRIDES; + } + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + std::move(info), + 0, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +namespace { + +template +infiniStatus_t launch_topkrouter(float *d_values_out, int *d_indices_out, const void *d_input, const float *d_correction_bias, + const float routed_scaling_factor, const size_t N, const size_t width, const size_t topk, infiniDtype_t xtype, + hcStream_t stream) { + const int block_threads = BLOCK_SIZE; + dim3 blocks(N); + dim3 threads(block_threads); + + if (xtype == INFINI_DTYPE_F32) { + topkrouter_kernel<<>>(d_values_out, d_indices_out, (float *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk); + } else if (xtype == INFINI_DTYPE_F16) { + topkrouter_kernel<<>>(d_values_out, d_indices_out, (half *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk); + } else if (xtype == INFINI_DTYPE_BF16) { + topkrouter_kernel<<>>(d_values_out, d_indices_out, (cuda_bfloat16 *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +}; // namespace + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + float *values, + int *indices, + const void *x, + const float *correction_bias, + const float routed_scaling_factor, + const size_t topk, + void *stream) const { + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + size_t N = _info.N; + size_t width = _info.width; // 256 + + // size_t n_routed_experts = 256; + // size_t n_group = 8; + // size_t topk_group = 4; + auto cuda_stream = reinterpret_cast(stream); + + if (256 == width) { + launch_topkrouter<256>(values, indices, x, correction_bias, routed_scaling_factor, N, width, topk, _info.xtype, cuda_stream); + } else { + return INFINI_STATUS_BAD_PARAM; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::topkrouter::metax diff --git a/src/infiniop/ops/topkrouter/operator.cc b/src/infiniop/ops/topkrouter/operator.cc index 89555e9f9..73b6e9bcf 100644 --- a/src/infiniop/ops/topkrouter/operator.cc +++ b/src/infiniop/ops/topkrouter/operator.cc @@ -8,6 +8,9 @@ #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) #include "nvidia/topkrouter_nvidia.cuh" #endif +#ifdef ENABLE_METAX_API +#include "metax/topkrouter_metax.h" +#endif #ifdef ENABLE_KUNLUN_API #include "kunlun/topkrouter_kunlun.h" #endif @@ -30,6 +33,9 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i #ifdef ENABLE_QY_API CREATE(INFINI_DEVICE_QY, nvidia); #endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_KUNLUN_API CREATE(INFINI_DEVICE_KUNLUN, kunlun); #endif @@ -56,6 +62,9 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript #ifdef ENABLE_QY_API GET(INFINI_DEVICE_QY, nvidia); #endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_KUNLUN_API GET(INFINI_DEVICE_KUNLUN, kunlun); #endif @@ -85,6 +94,9 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void #ifdef ENABLE_QY_API CALCULATE(INFINI_DEVICE_QY, nvidia); #endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_KUNLUN_API CALCULATE(INFINI_DEVICE_KUNLUN, kunlun); #endif @@ -111,6 +123,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip #ifdef ENABLE_QY_API DESTROY(INFINI_DEVICE_QY, nvidia); #endif +#ifdef ENABLE_METAX_API + DESTROY(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_KUNLUN_API DESTROY(INFINI_DEVICE_KUNLUN, kunlun); #endif