Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 98 additions & 2 deletions src/infiniop/ops/gemm/moore/gemm_moore.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,104 @@
#ifndef __GEMM_MOORE_H__
#define __GEMM_MOORE_H__

#include "../gemm.h"
#include "mublas/gemm_mublas.h"
#include "mudnn/gemm_mudnn.h"

DESCRIPTOR(moore)
namespace op::gemm::moore {

// Descriptor class for GEMM operations on Moore devices.
// This class acts as a wrapper to select either mublas or mudnn backend.
// It encapsulates the backend-specific Descriptor implementation and provides
// a unified interface for workspace query and GEMM calculation.
class Descriptor final : public InfiniopDescriptor {
public:
// Destructor: deletes the backend-specific descriptor.
~Descriptor() {
if (_backend == Backend::MUBLAS) {
delete reinterpret_cast<mublas::Descriptor *>(_impl);
} else {
delete reinterpret_cast<mudnn::Descriptor *>(_impl);
}
}

// Returns the required workspace size for the GEMM operation.
size_t workspaceSize() const {
if (_backend == Backend::MUBLAS) {
return reinterpret_cast<mublas::Descriptor *>(_impl)->workspaceSize();
} else {
return reinterpret_cast<mudnn::Descriptor *>(_impl)->workspaceSize();
}
}

// Static factory method to create a Descriptor instance.
// This method chooses the backend (mublas or mudnn) and constructs
// the corresponding implementation internally.
static infiniStatus_t create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
auto desc = new Descriptor(handle->device, handle->device_id);

// Backend selection strategy:
// Currently defaulting to MUDNN.
// Can be modified to choose based on environment variables or runtime parameters.
desc->_backend = Backend::MUDNN;

if (desc->_backend == Backend::MUBLAS) {
mublas::Descriptor *impl;
auto status = mublas::Descriptor::create(handle, &impl, c_desc, a_desc, b_desc);
if (status != INFINI_STATUS_SUCCESS) {
delete desc;
return status;
}
desc->_impl = impl;
} else {
mudnn::Descriptor *impl;
auto status = mudnn::Descriptor::create(handle, &impl, c_desc, a_desc, b_desc);
if (status != INFINI_STATUS_SUCCESS) {
delete desc;
return status;
}
desc->_impl = impl;
}

*desc_ptr = desc;
return INFINI_STATUS_SUCCESS;
}

// Unified GEMM calculation interface.
// Calls the corresponding backend's calculate function internally.
infiniStatus_t calculate(
void *workspace, size_t workspace_size,
void *c, float beta,
const void *a, const void *b,
float alpha,
void *stream) const {
if (_backend == Backend::MUBLAS) {
return reinterpret_cast<mublas::Descriptor *>(_impl)
->calculate(workspace, workspace_size, c, beta, a, b, alpha, stream);
} else {
return reinterpret_cast<mudnn::Descriptor *>(_impl)
->calculate(workspace, workspace_size, c, beta, a, b, alpha, stream);
}
}

private:
// Private constructor: ensures users cannot directly instantiate Descriptor.
// Instances must be created via the static create() factory method.
Descriptor(infiniDevice_t device_type, int device_id)
: InfiniopDescriptor{device_type, device_id}, _impl(nullptr) {}

// Enum to indicate which backend is being used internally.
enum class Backend { MUBLAS,
MUDNN };

Backend _backend; // Currently selected MUBLAS/MUDNN backend
void *_impl; // Pointer to backend-specific descriptor (mublas::Descriptor* or mudnn::Descriptor*)
};

} // namespace op::gemm::moore

#endif // __GEMM_MOORE_H__
8 changes: 8 additions & 0 deletions src/infiniop/ops/gemm/moore/mublas/gemm_mublas.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __GEMM_MUBLAS_H__
#define __GEMM_MUBLAS_H__

#include "../../gemm.h"

DESCRIPTOR(mublas)

#endif // __GEMM_MUBLAS_H__
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_handle.h"
#include "gemm_moore.h"
#include "../../../../devices/moore/moore_common.h"
#include "../../../../devices/moore/moore_handle.h"
#include "gemm_mublas.h"

namespace op::gemm::moore {
namespace op::gemm::mublas {

struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
Expand Down Expand Up @@ -122,4 +122,4 @@ infiniStatus_t Descriptor::calculate(
return INFINI_STATUS_SUCCESS;
}

} // namespace op::gemm::moore
} // namespace op::gemm::mublas
8 changes: 8 additions & 0 deletions src/infiniop/ops/gemm/moore/mudnn/gemm_mudnn.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __GEMM_MUDNN_H__
#define __GEMM_MUDNN_H__

#include "../../gemm.h"

DESCRIPTOR(mudnn)

#endif // __GEMM_MUDNN_H__
198 changes: 198 additions & 0 deletions src/infiniop/ops/gemm/moore/mudnn/gemm_mudnn.mu
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
#include "../../../../devices/moore/moore_common.h"
#include "../../../../devices/moore/moore_handle.h"
#include "gemm_mudnn.h"

#include <musa_bf16.h>

namespace op::gemm::mudnn {

struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};

Descriptor::~Descriptor() {
delete _opaque;
}

infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
auto dtype = c_desc->dtype();

CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);

auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::ROW_MAJOR);
CHECK_RESULT(result);

*desc_ptr = new Descriptor(
dtype, result.take(), 0,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}

template <typename Tdata>
infiniStatus_t calculate(
const MatmulInfo &info,
std::shared_ptr<device::moore::Handle::Internal> &_internal,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream)
{
// 0. For muDNN development, refer to the official documentation and the following headers:
// - /usr/local/musa/include/mudnn_base.h
// - /usr/local/musa/include/mudnn_math.h
// - /usr/local/musa/include/mudnn.h

// 1. Create BatchMatMul operator
auto matmul_operator = std::make_unique<::musa::dnn::BatchMatMul>();
matmul_operator->SetComputeMode(::musa::dnn::BatchMatMul::ComputeMode::TENSOR);

// 2. Use _internal->useMudnn to manage muDNN handle
return _internal->useMudnn((musaStream_t)stream, [&](::musa::dnn::Handle &mudnn_handle) -> infiniStatus_t {

// 3. Create BatchMatMul Tensor
::musa::dnn::Tensor out, left, right;

if constexpr (std::is_same<Tdata, half>::value) {
out.SetType(::musa::dnn::Tensor::Type::HALF);
left.SetType(::musa::dnn::Tensor::Type::HALF);
right.SetType(::musa::dnn::Tensor::Type::HALF);
}
else if constexpr (std::is_same<Tdata, __mt_bfloat16>::value){
out.SetType(::musa::dnn::Tensor::Type::BFLOAT16);
left.SetType(::musa::dnn::Tensor::Type::BFLOAT16);
right.SetType(::musa::dnn::Tensor::Type::BFLOAT16);
}
else{
out.SetType(::musa::dnn::Tensor::Type::FLOAT);
left.SetType(::musa::dnn::Tensor::Type::FLOAT);
right.SetType(::musa::dnn::Tensor::Type::FLOAT);
}

// 4. Bind BatchMatMul Tensor addr
out.SetAddr(c);
left.SetAddr(a);
right.SetAddr(b);

// 5. Config Tensor left
std::array<int64_t, 3> a_dims_array;
std::array<int64_t, 3> a_stride_array;
if (info.a_matrix.col_stride != 1) {
a_dims_array = { static_cast<int64_t>(info.batch),
static_cast<int64_t>(info.k),
static_cast<int64_t>(info.m) };
} else {
a_dims_array = { static_cast<int64_t>(info.batch),
static_cast<int64_t>(info.m),
static_cast<int64_t>(info.k) };
}
a_stride_array = { static_cast<int64_t>(info.a_matrix.stride),
static_cast<int64_t>(info.a_matrix.ld()),
1 };
left.SetNdInfo(static_cast<int>(a_dims_array.size()), a_dims_array.data(), a_stride_array.data());

// 6. Config Tensor right
std::array<int64_t, 3> b_dims_array;
std::array<int64_t, 3> b_stride_array;
if (info.b_matrix.col_stride != 1) {
b_dims_array = { static_cast<int64_t>(info.batch),
static_cast<int64_t>(info.n),
static_cast<int64_t>(info.k) };
} else {
b_dims_array = { static_cast<int64_t>(info.batch),
static_cast<int64_t>(info.k),
static_cast<int64_t>(info.n) };
}
b_stride_array = { static_cast<int64_t>(info.b_matrix.stride),
static_cast<int64_t>(info.b_matrix.ld()),
1 };
right.SetNdInfo(static_cast<int>(b_dims_array.size()), b_dims_array.data(), b_stride_array.data());

// 7. Confit Tensor out, muDNN BatchMatMul output only support row-major tensor
std::array<int64_t, 3> c_dims_array = { static_cast<int64_t>(info.batch),
static_cast<int64_t>(info.m),
static_cast<int64_t>(info.n) };
std::array<int64_t, 3> c_stride_array = { static_cast<int64_t>(info.c_matrix.stride),
static_cast<int64_t>(info.c_matrix.ld()),
1 };
out.SetNdInfo(static_cast<int>(c_dims_array.size()), c_dims_array.data(), c_stride_array.data());

// 8. Workspace Memory Handler
::musa::dnn::MemoryMaintainer maintainer = [](size_t size) -> ::musa::dnn::MemoryHandler {
void* ptr = nullptr;
musaMalloc(&ptr, size);
return ::musa::dnn::MemoryHandler(ptr, [](void* p) { if(p) musaFree(p); });
};

// 9. Tensor left and Tensor right transpose config
if (info.a_matrix.col_stride == 1 && info.b_matrix.col_stride != 1)
matmul_operator->SetTranspose(false, true);
else if (info.a_matrix.col_stride != 1 && info.b_matrix.col_stride == 1)
matmul_operator->SetTranspose(true, false);
else if (info.a_matrix.col_stride != 1 && info.b_matrix.col_stride != 1)
matmul_operator->SetTranspose(true, true);
else
matmul_operator->SetTranspose(false, false);

// 10. BatchMatMul workspace config
size_t workspace_size_in_bytes = 0;
matmul_operator->GetWorkspaceSize(mudnn_handle, workspace_size_in_bytes, out, left, right);

// 11. Alpha Beta Gamma
matmul_operator->SetAlpha(static_cast<double>(alpha));
matmul_operator->SetBeta(static_cast<double>(beta));
matmul_operator->SetGamma(0.0);

// 12. Run
matmul_operator->Run(
mudnn_handle,
out,
left,
right,
static_cast<int64_t>(info.batch),
static_cast<int64_t>(info.m),
static_cast<int64_t>(info.n),
static_cast<int64_t>(info.k),
static_cast<int64_t>(info.a_matrix.ld()),
static_cast<int64_t>(info.b_matrix.ld()),
static_cast<int64_t>(info.c_matrix.ld()),
static_cast<int64_t>(info.a_matrix.stride),
static_cast<int64_t>(info.b_matrix.stride),
static_cast<int64_t>(info.c_matrix.stride),
maintainer
);

return INFINI_STATUS_SUCCESS;
});
}


infiniStatus_t Descriptor::calculate(void *workspace,
size_t workspace_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
return mudnn::calculate<half>(_info, _opaque->internal, c, beta, a, b, alpha, stream);
case INFINI_DTYPE_F32:
return mudnn::calculate<float>(_info,_opaque->internal, c, beta, a, b, alpha, stream);
case INFINI_DTYPE_BF16:
return mudnn::calculate<__mt_bfloat16>(_info,_opaque->internal, c, beta, a, b, alpha, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}

} // namespace op::gemm::mudnn
3 changes: 3 additions & 0 deletions xmake/moore.lua
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ target("infiniop-moore")
add_cxflags("-lstdc++", "-fPIC", "-Wno-comment")
add_files("../src/infiniop/devices/moore/*.cc")
add_files("../src/infiniop/ops/*/moore/*.mu", {rule = "mu"})

-- Add source files for Moore muBLAS/muDNN GEMM backends.
add_files("../src/infiniop/ops/gemm/moore/*/*.mu", {rule = "mu"})
target_end()

target("infinirt-moore")
Expand Down
Loading