From 5ffd00ea94d1499e55282f02ed3689102758f865 Mon Sep 17 00:00:00 2001 From: zhushuang <974198603@qq.com> Date: Wed, 3 Sep 2025 19:48:19 +0800 Subject: [PATCH] feat: implement GEMM with MUBLAS and MUDNN backends in moore gpu --- src/infiniop/ops/gemm/moore/gemm_moore.h | 100 ++++++++- .../ops/gemm/moore/mublas/gemm_mublas.h | 8 + .../{gemm_moore.mu => mublas/gemm_mublas.mu} | 10 +- .../ops/gemm/moore/mudnn/gemm_mudnn.h | 8 + .../ops/gemm/moore/mudnn/gemm_mudnn.mu | 198 ++++++++++++++++++ xmake/moore.lua | 3 + 6 files changed, 320 insertions(+), 7 deletions(-) create mode 100644 src/infiniop/ops/gemm/moore/mublas/gemm_mublas.h rename src/infiniop/ops/gemm/moore/{gemm_moore.mu => mublas/gemm_mublas.mu} (95%) create mode 100644 src/infiniop/ops/gemm/moore/mudnn/gemm_mudnn.h create mode 100644 src/infiniop/ops/gemm/moore/mudnn/gemm_mudnn.mu diff --git a/src/infiniop/ops/gemm/moore/gemm_moore.h b/src/infiniop/ops/gemm/moore/gemm_moore.h index 1fe0e8171..03f399fe6 100644 --- a/src/infiniop/ops/gemm/moore/gemm_moore.h +++ b/src/infiniop/ops/gemm/moore/gemm_moore.h @@ -1,8 +1,104 @@ #ifndef __GEMM_MOORE_H__ #define __GEMM_MOORE_H__ -#include "../gemm.h" +#include "mublas/gemm_mublas.h" +#include "mudnn/gemm_mudnn.h" -DESCRIPTOR(moore) +namespace op::gemm::moore { + +// Descriptor class for GEMM operations on Moore devices. +// This class acts as a wrapper to select either mublas or mudnn backend. +// It encapsulates the backend-specific Descriptor implementation and provides +// a unified interface for workspace query and GEMM calculation. +class Descriptor final : public InfiniopDescriptor { +public: + // Destructor: deletes the backend-specific descriptor. + ~Descriptor() { + if (_backend == Backend::MUBLAS) { + delete reinterpret_cast(_impl); + } else { + delete reinterpret_cast(_impl); + } + } + + // Returns the required workspace size for the GEMM operation. + size_t workspaceSize() const { + if (_backend == Backend::MUBLAS) { + return reinterpret_cast(_impl)->workspaceSize(); + } else { + return reinterpret_cast(_impl)->workspaceSize(); + } + } + + // Static factory method to create a Descriptor instance. + // This method chooses the backend (mublas or mudnn) and constructs + // the corresponding implementation internally. + static infiniStatus_t create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc) { + auto desc = new Descriptor(handle->device, handle->device_id); + + // Backend selection strategy: + // Currently defaulting to MUDNN. + // Can be modified to choose based on environment variables or runtime parameters. + desc->_backend = Backend::MUDNN; + + if (desc->_backend == Backend::MUBLAS) { + mublas::Descriptor *impl; + auto status = mublas::Descriptor::create(handle, &impl, c_desc, a_desc, b_desc); + if (status != INFINI_STATUS_SUCCESS) { + delete desc; + return status; + } + desc->_impl = impl; + } else { + mudnn::Descriptor *impl; + auto status = mudnn::Descriptor::create(handle, &impl, c_desc, a_desc, b_desc); + if (status != INFINI_STATUS_SUCCESS) { + delete desc; + return status; + } + desc->_impl = impl; + } + + *desc_ptr = desc; + return INFINI_STATUS_SUCCESS; + } + + // Unified GEMM calculation interface. + // Calls the corresponding backend's calculate function internally. + infiniStatus_t calculate( + void *workspace, size_t workspace_size, + void *c, float beta, + const void *a, const void *b, + float alpha, + void *stream) const { + if (_backend == Backend::MUBLAS) { + return reinterpret_cast(_impl) + ->calculate(workspace, workspace_size, c, beta, a, b, alpha, stream); + } else { + return reinterpret_cast(_impl) + ->calculate(workspace, workspace_size, c, beta, a, b, alpha, stream); + } + } + +private: + // Private constructor: ensures users cannot directly instantiate Descriptor. + // Instances must be created via the static create() factory method. + Descriptor(infiniDevice_t device_type, int device_id) + : InfiniopDescriptor{device_type, device_id}, _impl(nullptr) {} + + // Enum to indicate which backend is being used internally. + enum class Backend { MUBLAS, + MUDNN }; + + Backend _backend; // Currently selected MUBLAS/MUDNN backend + void *_impl; // Pointer to backend-specific descriptor (mublas::Descriptor* or mudnn::Descriptor*) +}; + +} // namespace op::gemm::moore #endif // __GEMM_MOORE_H__ diff --git a/src/infiniop/ops/gemm/moore/mublas/gemm_mublas.h b/src/infiniop/ops/gemm/moore/mublas/gemm_mublas.h new file mode 100644 index 000000000..f09c1f909 --- /dev/null +++ b/src/infiniop/ops/gemm/moore/mublas/gemm_mublas.h @@ -0,0 +1,8 @@ +#ifndef __GEMM_MUBLAS_H__ +#define __GEMM_MUBLAS_H__ + +#include "../../gemm.h" + +DESCRIPTOR(mublas) + +#endif // __GEMM_MUBLAS_H__ diff --git a/src/infiniop/ops/gemm/moore/gemm_moore.mu b/src/infiniop/ops/gemm/moore/mublas/gemm_mublas.mu similarity index 95% rename from src/infiniop/ops/gemm/moore/gemm_moore.mu rename to src/infiniop/ops/gemm/moore/mublas/gemm_mublas.mu index 2d22f3720..e52b72acb 100644 --- a/src/infiniop/ops/gemm/moore/gemm_moore.mu +++ b/src/infiniop/ops/gemm/moore/mublas/gemm_mublas.mu @@ -1,8 +1,8 @@ -#include "../../../devices/moore/moore_common.h" -#include "../../../devices/moore/moore_handle.h" -#include "gemm_moore.h" +#include "../../../../devices/moore/moore_common.h" +#include "../../../../devices/moore/moore_handle.h" +#include "gemm_mublas.h" -namespace op::gemm::moore { +namespace op::gemm::mublas { struct Descriptor::Opaque { std::shared_ptr internal; @@ -122,4 +122,4 @@ infiniStatus_t Descriptor::calculate( return INFINI_STATUS_SUCCESS; } -} // namespace op::gemm::moore +} // namespace op::gemm::mublas diff --git a/src/infiniop/ops/gemm/moore/mudnn/gemm_mudnn.h b/src/infiniop/ops/gemm/moore/mudnn/gemm_mudnn.h new file mode 100644 index 000000000..30a2c707d --- /dev/null +++ b/src/infiniop/ops/gemm/moore/mudnn/gemm_mudnn.h @@ -0,0 +1,8 @@ +#ifndef __GEMM_MUDNN_H__ +#define __GEMM_MUDNN_H__ + +#include "../../gemm.h" + +DESCRIPTOR(mudnn) + +#endif // __GEMM_MUDNN_H__ diff --git a/src/infiniop/ops/gemm/moore/mudnn/gemm_mudnn.mu b/src/infiniop/ops/gemm/moore/mudnn/gemm_mudnn.mu new file mode 100644 index 000000000..12acd7fc5 --- /dev/null +++ b/src/infiniop/ops/gemm/moore/mudnn/gemm_mudnn.mu @@ -0,0 +1,198 @@ +#include "../../../../devices/moore/moore_common.h" +#include "../../../../devices/moore/moore_handle.h" +#include "gemm_mudnn.h" + +#include + +namespace op::gemm::mudnn { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc) { + auto handle = reinterpret_cast(handle_); + auto dtype = c_desc->dtype(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16); + + auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::ROW_MAJOR); + CHECK_RESULT(result); + + *desc_ptr = new Descriptor( + dtype, result.take(), 0, + new Opaque{handle->internal()}, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculate( + const MatmulInfo &info, + std::shared_ptr &_internal, + void *c, + float beta, + const void *a, + const void *b, + float alpha, + void *stream) +{ + // 0. For muDNN development, refer to the official documentation and the following headers: + // - /usr/local/musa/include/mudnn_base.h + // - /usr/local/musa/include/mudnn_math.h + // - /usr/local/musa/include/mudnn.h + + // 1. Create BatchMatMul operator + auto matmul_operator = std::make_unique<::musa::dnn::BatchMatMul>(); + matmul_operator->SetComputeMode(::musa::dnn::BatchMatMul::ComputeMode::TENSOR); + + // 2. Use _internal->useMudnn to manage muDNN handle + return _internal->useMudnn((musaStream_t)stream, [&](::musa::dnn::Handle &mudnn_handle) -> infiniStatus_t { + + // 3. Create BatchMatMul Tensor + ::musa::dnn::Tensor out, left, right; + + if constexpr (std::is_same::value) { + out.SetType(::musa::dnn::Tensor::Type::HALF); + left.SetType(::musa::dnn::Tensor::Type::HALF); + right.SetType(::musa::dnn::Tensor::Type::HALF); + } + else if constexpr (std::is_same::value){ + out.SetType(::musa::dnn::Tensor::Type::BFLOAT16); + left.SetType(::musa::dnn::Tensor::Type::BFLOAT16); + right.SetType(::musa::dnn::Tensor::Type::BFLOAT16); + } + else{ + out.SetType(::musa::dnn::Tensor::Type::FLOAT); + left.SetType(::musa::dnn::Tensor::Type::FLOAT); + right.SetType(::musa::dnn::Tensor::Type::FLOAT); + } + + // 4. Bind BatchMatMul Tensor addr + out.SetAddr(c); + left.SetAddr(a); + right.SetAddr(b); + + // 5. Config Tensor left + std::array a_dims_array; + std::array a_stride_array; + if (info.a_matrix.col_stride != 1) { + a_dims_array = { static_cast(info.batch), + static_cast(info.k), + static_cast(info.m) }; + } else { + a_dims_array = { static_cast(info.batch), + static_cast(info.m), + static_cast(info.k) }; + } + a_stride_array = { static_cast(info.a_matrix.stride), + static_cast(info.a_matrix.ld()), + 1 }; + left.SetNdInfo(static_cast(a_dims_array.size()), a_dims_array.data(), a_stride_array.data()); + + // 6. Config Tensor right + std::array b_dims_array; + std::array b_stride_array; + if (info.b_matrix.col_stride != 1) { + b_dims_array = { static_cast(info.batch), + static_cast(info.n), + static_cast(info.k) }; + } else { + b_dims_array = { static_cast(info.batch), + static_cast(info.k), + static_cast(info.n) }; + } + b_stride_array = { static_cast(info.b_matrix.stride), + static_cast(info.b_matrix.ld()), + 1 }; + right.SetNdInfo(static_cast(b_dims_array.size()), b_dims_array.data(), b_stride_array.data()); + + // 7. Confit Tensor out, muDNN BatchMatMul output only support row-major tensor + std::array c_dims_array = { static_cast(info.batch), + static_cast(info.m), + static_cast(info.n) }; + std::array c_stride_array = { static_cast(info.c_matrix.stride), + static_cast(info.c_matrix.ld()), + 1 }; + out.SetNdInfo(static_cast(c_dims_array.size()), c_dims_array.data(), c_stride_array.data()); + + // 8. Workspace Memory Handler + ::musa::dnn::MemoryMaintainer maintainer = [](size_t size) -> ::musa::dnn::MemoryHandler { + void* ptr = nullptr; + musaMalloc(&ptr, size); + return ::musa::dnn::MemoryHandler(ptr, [](void* p) { if(p) musaFree(p); }); + }; + + // 9. Tensor left and Tensor right transpose config + if (info.a_matrix.col_stride == 1 && info.b_matrix.col_stride != 1) + matmul_operator->SetTranspose(false, true); + else if (info.a_matrix.col_stride != 1 && info.b_matrix.col_stride == 1) + matmul_operator->SetTranspose(true, false); + else if (info.a_matrix.col_stride != 1 && info.b_matrix.col_stride != 1) + matmul_operator->SetTranspose(true, true); + else + matmul_operator->SetTranspose(false, false); + + // 10. BatchMatMul workspace config + size_t workspace_size_in_bytes = 0; + matmul_operator->GetWorkspaceSize(mudnn_handle, workspace_size_in_bytes, out, left, right); + + // 11. Alpha Beta Gamma + matmul_operator->SetAlpha(static_cast(alpha)); + matmul_operator->SetBeta(static_cast(beta)); + matmul_operator->SetGamma(0.0); + + // 12. Run + matmul_operator->Run( + mudnn_handle, + out, + left, + right, + static_cast(info.batch), + static_cast(info.m), + static_cast(info.n), + static_cast(info.k), + static_cast(info.a_matrix.ld()), + static_cast(info.b_matrix.ld()), + static_cast(info.c_matrix.ld()), + static_cast(info.a_matrix.stride), + static_cast(info.b_matrix.stride), + static_cast(info.c_matrix.stride), + maintainer + ); + + return INFINI_STATUS_SUCCESS; + }); +} + + +infiniStatus_t Descriptor::calculate(void *workspace, + size_t workspace_size, + void *c, + float beta, + const void *a, + const void *b, + float alpha, + void *stream) const { + switch (_dtype) { + case INFINI_DTYPE_F16: + return mudnn::calculate(_info, _opaque->internal, c, beta, a, b, alpha, stream); + case INFINI_DTYPE_F32: + return mudnn::calculate(_info,_opaque->internal, c, beta, a, b, alpha, stream); + case INFINI_DTYPE_BF16: + return mudnn::calculate<__mt_bfloat16>(_info,_opaque->internal, c, beta, a, b, alpha, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +} // namespace op::gemm::mudnn diff --git a/xmake/moore.lua b/xmake/moore.lua index 64cd13ddd..25eddf522 100644 --- a/xmake/moore.lua +++ b/xmake/moore.lua @@ -44,6 +44,9 @@ target("infiniop-moore") add_cxflags("-lstdc++", "-fPIC", "-Wno-comment") add_files("../src/infiniop/devices/moore/*.cc") add_files("../src/infiniop/ops/*/moore/*.mu", {rule = "mu"}) + + -- Add source files for Moore muBLAS/muDNN GEMM backends. + add_files("../src/infiniop/ops/gemm/moore/*/*.mu", {rule = "mu"}) target_end() target("infinirt-moore")