Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/infiniccl-test/infiniccl_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#define TEST_INFINI_THREAD(API__) CHECK_API_OR(API__, INFINI_STATUS_SUCCESS, return nullptr)

const size_t MAX_COUNT = 8ULL * 1024 * 1024;
// const size_t MAX_COUNT = 512 * 1024; // for metax

const size_t TEST_COUNTS[] = {
128,
Expand All @@ -19,7 +20,7 @@ const size_t TEST_COUNTS[] = {
MAX_COUNT,
};

const infiniDtype_t TEST_DTYPES[] = {INFINI_DTYPE_F32, INFINI_DTYPE_F16};
const infiniDtype_t TEST_DTYPES[] = {INFINI_DTYPE_F32, INFINI_DTYPE_F16, INFINI_DTYPE_BF16};

const size_t WARM_UPS = 10;

Expand Down Expand Up @@ -51,6 +52,11 @@ void setData(infiniDtype_t dtype, void *data, size_t count, float val) {
((fp16_t *)data)[i] = utils::cast<fp16_t>(val);
}
break;
case INFINI_DTYPE_BF16:
for (size_t i = 0; i < count; i++) {
((bf16_t *)data)[i] = utils::cast<bf16_t>(val);
}
break;
default:
std::abort();
break;
Expand All @@ -67,6 +73,12 @@ int checkData(const T *actual_, const T *expected_, size_t count) {
if (std::abs(actual - expected) > 1e-4) {
failed += 1;
}
} else if constexpr (std::is_same<T, bf16_t>::value) {
float actual = utils::cast<float>(actual_[i]);
float expected = utils::cast<float>(expected_[i]);
if (std::abs(actual - expected) > 1e-4) {
failed += 1;
}
} else {
if (std::abs(actual_[i] - expected_[i]) > 1e-4) {
failed += 1;
Expand All @@ -82,6 +94,8 @@ int checkData(const void *actual, const void *expected, infiniDtype_t dtype, siz
return checkData((const float *)actual, (const float *)expected, count);
case INFINI_DTYPE_F16:
return checkData((const fp16_t *)actual, (const fp16_t *)expected, count);
case INFINI_DTYPE_BF16:
return checkData((const bf16_t *)actual, (const bf16_t *)expected, count);
default:
std::abort();
return 1;
Expand Down
6 changes: 3 additions & 3 deletions src/infiniccl/metax/infiniccl_metax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ inline hcclDataType_t getHcclDtype(infiniDtype_t datatype) {
return hcclFloat;
case INFINI_DTYPE_F16:
return hcclHalf;
case INFINI_DTYPE_BF16:
return hcclBfloat16;
default:
std::abort();
return hcclHalf;
Expand Down Expand Up @@ -83,9 +85,7 @@ infiniStatus_t allReduce(
infinicclComm_t comm,
infinirtStream_t stream) {

if (datatype != INFINI_DTYPE_F32 && datatype != INFINI_DTYPE_F16) {
return INFINI_STATUS_BAD_PARAM;
}
CHECK_DTYPE(datatype, INFINI_DTYPE_F32, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);

CHECK_HCCL(hcclAllReduce(sendbuf, recvbuf, count, getHcclDtype(datatype),
getHcclRedOp(op), getHcclComm(comm), getMacaStream(stream)));
Expand Down
8 changes: 8 additions & 0 deletions src/infiniop/ops/softplus/metax/softplus_metax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __SOFTPLUS_METAX_API_H__
#define __SOFTPLUS_METAX_API_H__

#include "../../../elementwise/metax/elementwise_metax_api.h"

ELEMENTWISE_DESCRIPTOR(softplus, metax)

#endif // __SOFTPLUS_METAX_API_H__
60 changes: 60 additions & 0 deletions src/infiniop/ops/softplus/metax/softplus_metax.maca
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include "softplus_metax.h"

#include "../../../elementwise/metax/elementwise_metax.h"

#include "../cuda/kernel.cuh"

namespace op::softplus::metax {

Descriptor::~Descriptor() = default;

infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {

auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
auto dtype = out_desc->dtype();

const auto &x_desc = input_desc_vec.at(0);
const auto &y_shape = out_desc->shape();
const auto &x_shape = x_desc->shape();

CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16);

CHECK_SAME_SHAPE(y_shape, x_shape);

// create METAX elementwise descriptor
CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)

return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {

if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}

switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<256, cuda::SoftplusOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<256, cuda::SoftplusOp, cuda_bfloat16>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, cuda::SoftplusOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, cuda::SoftplusOp, double>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

return INFINI_STATUS_SUCCESS;
}
} // namespace op::softplus::metax
Loading