Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/infiniccl/infiniccl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#include "./ascend/infiniccl_ascend.h"
#include "./cuda/infiniccl_cuda.h"
#include "./maca/infiniccl_maca.h"
#include "./metax/infiniccl_metax.h"

__C infiniStatus_t infinicclCommInitAll(
infiniDevice_t device_type,
Expand All @@ -17,7 +17,7 @@ __C infiniStatus_t infinicclCommInitAll(
switch (device_type) {
COMM_INIT_ALL(INFINI_DEVICE_NVIDIA, cuda)
COMM_INIT_ALL(INFINI_DEVICE_ASCEND, ascend)
COMM_INIT_ALL(INFINI_DEVICE_METAX, maca)
COMM_INIT_ALL(INFINI_DEVICE_METAX, metax)
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
Expand All @@ -37,7 +37,7 @@ __C infiniStatus_t infinicclCommDestroy(infinicclComm_t comm) {
switch (comm->device_type) {
COMM_DESTROY(INFINI_DEVICE_NVIDIA, cuda)
COMM_DESTROY(INFINI_DEVICE_ASCEND, ascend)
COMM_DESTROY(INFINI_DEVICE_METAX, maca)
COMM_DESTROY(INFINI_DEVICE_METAX, metax)

default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down Expand Up @@ -65,7 +65,7 @@ __C infiniStatus_t infinicclAllReduce(
switch (comm->device_type) {
ALL_REDUCE(INFINI_DEVICE_NVIDIA, cuda)
ALL_REDUCE(INFINI_DEVICE_ASCEND, ascend)
ALL_REDUCE(INFINI_DEVICE_METAX, maca)
ALL_REDUCE(INFINI_DEVICE_METAX, metax)

default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down
12 changes: 0 additions & 12 deletions src/infiniccl/maca/infiniccl_maca.h

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "infiniccl_maca.h"
#include "infiniccl_metax.h"

#include "../../utils.h"

Expand Down Expand Up @@ -51,7 +51,7 @@ inline hcclComm_t getHcclComm(infinicclComm_t comm) {
return static_cast<hcclComm_t>(comm->comm);
}

namespace infiniccl::maca {
namespace infiniccl::metax {

infiniStatus_t commInitAll(
infinicclComm_t *comms,
Expand Down Expand Up @@ -92,4 +92,4 @@ infiniStatus_t allReduce(

return INFINI_STATUS_SUCCESS;
}
} // namespace infiniccl::maca
} // namespace infiniccl::metax
12 changes: 12 additions & 0 deletions src/infiniccl/metax/infiniccl_metax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#ifndef INFINICCL_METAX_H_
#define INFINICCL_METAX_H_

#include "../infiniccl_impl.h"

#if defined(ENABLE_METAX_API) && defined(ENABLE_CCL)
INFINICCL_DEVICE_API_IMPL(metax)
#else
INFINICCL_DEVICE_API_NOOP(metax)
#endif

#endif /* INFINICCL_METAX_H_ */
6 changes: 3 additions & 3 deletions src/infiniop/devices/handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include "kunlun/kunlun_handle.h"
#endif
#ifdef ENABLE_METAX_API
#include "maca/maca_handle.h"
#include "metax/metax_handle.h"
#endif

__C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
Expand Down Expand Up @@ -57,7 +57,7 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, maca);
CREATE(INFINI_DEVICE_METAX, metax);
#endif

default:
Expand Down Expand Up @@ -94,7 +94,7 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, maca);
DELETE(INFINI_DEVICE_METAX, metax);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#include "../../../utils.h"
#include "../pool.h"
#include "maca_handle.h"
#include "metax_handle.h"
#include <hcblas/hcblas.h>
#include <hcdnn/hcdnn.h>
#include <memory>

#define CHECK_MCBLAS(API) CHECK_INTERNAL(API, HCBLAS_STATUS_SUCCESS)
#define CHECK_MCDNN(API) CHECK_INTERNAL(API, HCDNN_STATUS_SUCCESS)

namespace device::maca {
namespace device::metax {

class Handle::Internal {
Pool<hcblasHandle_t> mcblas_handles;
Expand Down Expand Up @@ -39,4 +39,4 @@ class Handle::Internal {

hcdnnDataType_t getHcdnnDtype(infiniDtype_t dt);

} // namespace device::maca
} // namespace device::metax
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "common_maca.h"
#include "metax_common.h"

namespace device::maca {
namespace device::metax {
Handle::Handle(infiniDevice_t device, int device_id)
: InfiniopHandle{device, device_id},
_internal(std::make_shared<Handle::Internal>(device_id)) {}
Expand Down Expand Up @@ -83,4 +83,4 @@ infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
return INFINI_STATUS_SUCCESS;
}

} // namespace device::maca
} // namespace device::metax
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#ifndef __INFINIOP_MACA_HANDLE_H__
#define __INFINIOP_MACA_HANDLE_H__
#ifndef __INFINIOP_METAX_HANDLE_H__
#define __INFINIOP_METAX_HANDLE_H__

#include "../../handle.h"
#include <memory>

namespace device::maca {
namespace device::metax {
struct Handle : public InfiniopHandle {
Handle(int device_id);
class Internal;
Expand All @@ -20,6 +20,6 @@ struct Handle : public InfiniopHandle {
std::shared_ptr<Internal> _internal;
};

} // namespace device::maca
} // namespace device::metax

#endif // __INFINIOP_MACA_HANDLE_H__
#endif // __INFINIOP_METAX_HANDLE_H__
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
#define INFINIOP_MACA_KERNEL __global__ void
#define INFINIOP_METAX_KERNEL __global__ void

// Posible maximum number of threads per block for MACA architectures
// Posible maximum number of threads per block for METAX architectures
// Used for picking correct kernel launch configuration
#define MACA_BLOCK_SIZE_1024 1024
#define MACA_BLOCK_SIZE_512 512
#define METAX_BLOCK_SIZE_1024 1024
#define METAX_BLOCK_SIZE_512 512

#define CHECK_MACA(API) CHECK_INTERNAL(API, hcSuccess)
#define CHECK_METAX(API) CHECK_INTERNAL(API, hcSuccess)

using cuda_bfloat16 = hpcc_bfloat16;
using cuda_bfloat162 = hpcc_bfloat162;

namespace device::maca {
namespace device::metax {

// return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__ __device__ __host__ size_t
Expand Down Expand Up @@ -41,7 +41,7 @@ indexToOffset(
}
return res;
}
} // namespace device::maca
} // namespace device::metax

__forceinline__ __device__ float
exp_(const float val) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
#ifndef __INFINIOP_ELEMENTWISE_MACA_H__
#define __INFINIOP_ELEMENTWISE_MACA_H__
#ifndef __INFINIOP_ELEMENTWISE_METAX_H__
#define __INFINIOP_ELEMENTWISE_METAX_H__

#include "../../../utils.h"
#include "../../devices/maca/common_maca.h"
#include "../../devices/maca/maca_kernel_common.h"
#include "elementwise_maca_api.h"
#include "../../devices/metax/metax_common.h"
#include "../../devices/metax/metax_kernel_common.h"
#include "elementwise_metax_api.h"

namespace op::elementwise::maca {
namespace op::elementwise::metax {
template <typename T>
__device__ __forceinline__ const T *typedInputPtr(const void *ptr) {
return reinterpret_cast<const T *>(ptr);
}

__device__ __forceinline__ size_t getOutputIndex(size_t idx, bool is_contiguous, size_t ndim,
const size_t *shape, const ptrdiff_t *strides) {
return is_contiguous ? idx : device::maca::indexToOffset(idx, ndim, shape, strides);
return is_contiguous ? idx : device::metax::indexToOffset(idx, ndim, shape, strides);
}

struct InputIndexer {
Expand All @@ -30,8 +30,8 @@ struct InputIndexer {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? device::maca::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::maca::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
? device::metax::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::metax::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
}
};

Expand All @@ -41,7 +41,7 @@ __device__ __forceinline__ void unpackInputsAndApply(F &&f, std::index_sequence<
}

template <size_t N, typename Op, typename Tdata, typename... Args>
INFINIOP_MACA_KERNEL elementwiseKernel(
INFINIOP_METAX_KERNEL elementwiseKernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
Expand Down Expand Up @@ -72,7 +72,7 @@ INFINIOP_MACA_KERNEL elementwiseKernel(
}

template <typename Op, typename Tout, typename... Tin>
INFINIOP_MACA_KERNEL elementwiseKernel(
INFINIOP_METAX_KERNEL elementwiseKernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
Expand Down Expand Up @@ -102,9 +102,9 @@ INFINIOP_MACA_KERNEL elementwiseKernel(
}

struct DeviceImpl::Opaque {
std::shared_ptr<device::maca::Handle::Internal> internal;
std::shared_ptr<device::metax::Handle::Internal> internal;

Opaque(const std::shared_ptr<device::maca::Handle::Internal> &internal)
Opaque(const std::shared_ptr<device::metax::Handle::Internal> &internal)
: internal(internal) {}

template <uint32_t BLOCK_SIZE, size_t N, typename Op, typename Tdata, typename... Args>
Expand Down Expand Up @@ -159,8 +159,8 @@ struct DeviceImpl::Opaque {
const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size;

// copy the input pointer array and meta to device
CHECK_MACA(hcMemcpyAsync(workspace, h_inputs_arr, input_arr_size, hcMemcpyHostToDevice, stream));
CHECK_MACA(hcMemcpyAsync((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), hcMemcpyHostToDevice, stream));
CHECK_METAX(hcMemcpyAsync(workspace, h_inputs_arr, input_arr_size, hcMemcpyHostToDevice, stream));
CHECK_METAX(hcMemcpyAsync((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), hcMemcpyHostToDevice, stream));

// offset/assign the pointers
d_inputs_arr = reinterpret_cast<const void **>(workspace);
Expand Down Expand Up @@ -259,6 +259,6 @@ infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &inf
std::forward<Args>(args)...);
}

} // namespace op::elementwise::maca
} // namespace op::elementwise::metax

#endif
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#ifndef __INFINIOP_ELEMENTWISE_MACA_API_H__
#define __INFINIOP_ELEMENTWISE_MACA_API_H__
#ifndef __INFINIOP_ELEMENTWISE_METAX_API_H__
#define __INFINIOP_ELEMENTWISE_METAX_API_H__

#include "../elementwise.h"

namespace op::elementwise::maca {
namespace op::elementwise::metax {

class DeviceImpl final {
struct Opaque;
Expand Down Expand Up @@ -37,23 +37,23 @@ class DeviceImpl final {
void *stream,
Args &&...args);
};
} // namespace op::elementwise::maca
#define CREATE_ELEMENTWISE_MACA_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \
auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\
auto device_impl_result = op::elementwise::maca::DeviceImpl::create(HANDLE->internal()); \
CHECK_RESULT(device_impl_result); \
\
*desc_ptr = new Descriptor( \
DTYPE, \
std::move(info), \
std::move(device_impl_result.take()), \
workspace_size, \
HANDLE->device, \
} // namespace op::elementwise::metax
#define CREATE_ELEMENTWISE_METAX_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \
auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\
auto device_impl_result = op::elementwise::metax::DeviceImpl::create(HANDLE->internal()); \
CHECK_RESULT(device_impl_result); \
\
*desc_ptr = new Descriptor( \
DTYPE, \
std::move(info), \
std::move(device_impl_result.take()), \
workspace_size, \
HANDLE->device, \
HANDLE->device_id);

#endif // __INFINIOP_ELEMENTWISE_MACA_API_H__
#endif // __INFINIOP_ELEMENTWISE_METAX_API_H__
18 changes: 9 additions & 9 deletions src/infiniop/ops/causal_softmax/metax/causal_softmax_metax.maca
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
#include "../../../devices/maca/common_maca.h"
#include "../../../devices/metax/metax_common.h"
#include "causal_softmax_metax.h"

#include <hccub/block/block_reduce.cuh>
#include "../../../devices/maca/maca_kernel_common.h"
#include "../../../devices/metax/metax_kernel_common.h"

#include "../../../reduce/cuda/reduce.cuh"

#include "../cuda/kernel.cuh"

template <unsigned int BLOCK_SIZE, typename Tdata, typename Tcompute>
INFINIOP_MACA_KERNEL causalSoftmax(
INFINIOP_METAX_KERNEL causalSoftmax(
Tdata *y, const Tdata *x,
size_t batch, size_t height, size_t width,
ptrdiff_t y_stride_b, ptrdiff_t y_stride_h,
Expand All @@ -20,7 +20,7 @@ INFINIOP_MACA_KERNEL causalSoftmax(
namespace op::causal_softmax::metax {

struct Descriptor::Opaque {
std::shared_ptr<device::maca::Handle::Internal> internal;
std::shared_ptr<device::metax::Handle::Internal> internal;
};

Descriptor::~Descriptor() {
Expand All @@ -35,7 +35,7 @@ infiniStatus_t Descriptor::create(
auto info = CausalSoftmaxInfo::create(y_desc, x_desc);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::maca::Handle *>(handle)->internal()},
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
Expand Down Expand Up @@ -76,12 +76,12 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
const void *x,
void *stream_) const {
hcStream_t stream = (hcStream_t)stream_;
if (_opaque->internal->maxThreadsPerBlock() == MACA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<MACA_BLOCK_SIZE_1024>(
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>(
y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len,
_info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MACA_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<MACA_BLOCK_SIZE_512>(
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_512>(
y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len,
_info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream));
} else {
Expand Down
Loading