Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/infiniop/devices/handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "ascend/ascend_handle.h"
#endif
#ifdef ENABLE_MOORE_API
#include "musa/musa_handle.h"
#include "moore/moore_handle.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/kunlun_handle.h"
Expand Down Expand Up @@ -54,7 +54,7 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
CREATE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, musa);
CREATE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
Expand Down Expand Up @@ -94,7 +94,7 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
DELETE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_MOORE_API
DELETE(INFINI_DEVICE_MOORE, musa);
DELETE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "../../../utils.h"
#include "../pool.h"
#include "musa_handle.h"
#include "moore_handle.h"
#include <mublas.h>
#include <mudnn.h>
#include <musa.h>
Expand All @@ -10,7 +10,7 @@
#define CHECK_MUBLAS(API) CHECK_INTERNAL(API, MUBLAS_STATUS_SUCCESS)
#define CHECK_MUDNN(API) CHECK_INTERNAL((int)API, (int)::musa::dnn::Status::SUCCESS)

namespace device::musa {
namespace device::moore {

class Handle::Internal {
Pool<std::unique_ptr<mublasHandle_t>> mublas_handles;
Expand Down Expand Up @@ -39,4 +39,4 @@ class Handle::Internal {
int gridSizeZ() const;
};

} // namespace device::musa
} // namespace device::moore
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "common_musa.h"
#include "moore_common.h"

namespace device::musa {
namespace device::moore {
Handle::Handle(infiniDevice_t device, int device_id)
: InfiniopHandle{device, device_id},
_internal(std::make_shared<Handle::Internal>(device_id)) {}
Expand Down Expand Up @@ -67,4 +67,4 @@ infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
return INFINI_STATUS_SUCCESS;
}

} // namespace device::musa
} // namespace device::moore
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#ifndef __INFINIOP_MUSA_HANDLE_H__
#define __INFINIOP_MUSA_HANDLE_H__
#ifndef __INFINIOP_MOORE_HANDLE_H__
#define __INFINIOP_MOORE_HANDLE_H__

#include "../../handle.h"
#include <memory>

namespace device::musa {
namespace device::moore {
struct Handle : public InfiniopHandle {
Handle(int device_id);
class Internal;
Expand All @@ -20,6 +20,6 @@ struct Handle : public InfiniopHandle {
std::shared_ptr<Internal> _internal;
};

} // namespace device::musa
} // namespace device::moore

#endif // __INFINIOP_MUSA_HANDLE_H__
#endif // __INFINIOP_MOORE_HANDLE_H__
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
#define INFINIOP_MUSA_KERNEL __global__ void
#define INFINIOP_MOORE_KERNEL __global__ void

#include <musa_bf16.h>
#include <musa_fp16.h>

// Posible maximum number of threads per block for MUSA architectures
// Used for picking correct kernel launch configuration
#define MUSA_BLOCK_SIZE_2048 2048
#define MUSA_BLOCK_SIZE_1024 1024
#define MUSA_BLOCK_SIZE_512 512
#define MOORE_BLOCK_SIZE_2048 2048
#define MOORE_BLOCK_SIZE_1024 1024
#define MOORE_BLOCK_SIZE_512 512

#define CHECK_MUSA(API) CHECK_INTERNAL(API, musaSuccess)
#define CHECK_MOORE(API) CHECK_INTERNAL(API, musaSuccess)

using musa_bfloat16 = mt_bfloat16;
using musa_bfloat162 = mt_bfloat162;

namespace device::musa {
namespace device::moore {

// return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__ __device__ __host__ size_t
Expand Down Expand Up @@ -45,7 +45,7 @@ indexToOffset(
}
return res;
}
} // namespace device::musa
} // namespace device::moore

__forceinline__ __device__ float
exp_(const float val) {
Expand Down
8 changes: 8 additions & 0 deletions src/infiniop/ops/gemm/moore/gemm_moore.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __GEMM_MOORE_H__
#define __GEMM_MOORE_H__

#include "../gemm.h"

DESCRIPTOR(moore)

#endif // __GEMM_MOORE_H__
125 changes: 125 additions & 0 deletions src/infiniop/ops/gemm/moore/gemm_moore.mu
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_handle.h"
#include "gemm_moore.h"

namespace op::gemm::moore {

struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};

Descriptor::~Descriptor() {
delete _opaque;
}

infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
auto dtype = c_desc->dtype();

CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);

auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);

*desc_ptr = new Descriptor(
dtype, result.take(), 0,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) const {

musaDataType a_type, b_type, c_type;
mublasComputeType_t compute_type;

// MUSA's GEMM operations require that the scalar values alpha and beta have the same data type as the matrices.
// This ensures correct computation during the muBLAS GEMM operation.
// Declare half-precision variables to handle F16 types.
half alpha_h, beta_h;

// Initialize generic void pointers for alpha and beta.
// They point to the original float values
// It will be used directly when the GEMM operation is performed with F32 data.
const void *p_alpha = &alpha;
const void *p_beta = &beta;

switch (_dtype) {
case INFINI_DTYPE_F16:
a_type = b_type = c_type = MUSA_R_16F;
compute_type = MUBLAS_COMPUTE_16F;

// Convert alpha/beta to half-precision and update the pointers.
alpha_h = __float2half(alpha);
beta_h = __float2half(beta);
p_alpha = &alpha_h;
p_beta = &beta_h;

break;
case INFINI_DTYPE_BF16:
a_type = b_type = c_type = MUSA_R_16BF;
compute_type = MUBLAS_COMPUTE_32F;
break;
case INFINI_DTYPE_F32:
a_type = b_type = c_type = MUSA_R_32F;
compute_type = MUBLAS_COMPUTE_32F_FAST_TF32;
break;

default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

if (_info.is_transed) {
std::swap(a, b);
}

auto op_a = _info.a_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
auto op_b = _info.b_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;

CHECK_STATUS(_opaque->internal->useMublas(
(musaStream_t)stream,
[&](mublasHandle_t handle) {
CHECK_MUBLAS(
mublasGemmStridedBatchedEx(
handle,
op_a,
op_b,
static_cast<int>(_info.m),
static_cast<int>(_info.n),
static_cast<int>(_info.k),
p_alpha,
a,
a_type,
static_cast<int>(_info.a_matrix.ld()),
_info.a_matrix.stride,
b,
b_type,
static_cast<int>(_info.b_matrix.ld()),
_info.b_matrix.stride,
p_beta,
c,
c_type,
static_cast<int>(_info.c_matrix.ld()),
_info.c_matrix.stride,
static_cast<int>(_info.batch),
compute_type,
MUBLAS_GEMM_DEFAULT));
return INFINI_STATUS_SUCCESS;
}));
return INFINI_STATUS_SUCCESS;
}

} // namespace op::gemm::moore
8 changes: 0 additions & 8 deletions src/infiniop/ops/gemm/musa/gemm_musa.h

This file was deleted.

121 changes: 0 additions & 121 deletions src/infiniop/ops/gemm/musa/gemm_musa.mu

This file was deleted.

Loading
Loading