From b8609df3449bfab3224e91a3aa12c2b453e10dcf Mon Sep 17 00:00:00 2001 From: Ceng <441651826@qq.com> Date: Tue, 9 Sep 2025 14:35:22 +0800 Subject: [PATCH 1/2] issue/434 hccl support bf16 Signed-off-by: Ceng <441651826@qq.com> --- src/infiniccl/metax/infiniccl_metax.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/infiniccl/metax/infiniccl_metax.cc b/src/infiniccl/metax/infiniccl_metax.cc index 04b91dea9..373bc36ba 100644 --- a/src/infiniccl/metax/infiniccl_metax.cc +++ b/src/infiniccl/metax/infiniccl_metax.cc @@ -23,6 +23,8 @@ inline hcclDataType_t getHcclDtype(infiniDtype_t datatype) { return hcclFloat; case INFINI_DTYPE_F16: return hcclHalf; + case INFINI_DTYPE_BF16: + return hcclBfloat16; default: std::abort(); return hcclHalf; @@ -83,9 +85,7 @@ infiniStatus_t allReduce( infinicclComm_t comm, infinirtStream_t stream) { - if (datatype != INFINI_DTYPE_F32 && datatype != INFINI_DTYPE_F16) { - return INFINI_STATUS_BAD_PARAM; - } + CHECK_DTYPE(datatype, INFINI_DTYPE_F32, INFINI_DTYPE_F16, INFINI_DTYPE_BF16); CHECK_HCCL(hcclAllReduce(sendbuf, recvbuf, count, getHcclDtype(datatype), getHcclRedOp(op), getHcclComm(comm), getMacaStream(stream))); From 3bb0c93099d7d9d6ffb70962425c59cc7a69e883 Mon Sep 17 00:00:00 2001 From: Ceng2333 <441651826@qq.com> Date: Wed, 10 Sep 2025 16:42:41 +0800 Subject: [PATCH 2/2] fix rope_v2 compiling && update infiniccl_test Signed-off-by: Ceng <441651826@qq.com> --- src/infiniccl-test/infiniccl_test.cpp | 16 ++++- .../ops/softplus/metax/softplus_metax.h | 8 +++ .../ops/softplus/metax/softplus_metax.maca | 60 +++++++++++++++++++ 3 files changed, 83 insertions(+), 1 deletion(-) create mode 100644 src/infiniop/ops/softplus/metax/softplus_metax.h create mode 100644 src/infiniop/ops/softplus/metax/softplus_metax.maca diff --git a/src/infiniccl-test/infiniccl_test.cpp b/src/infiniccl-test/infiniccl_test.cpp index 892465a39..f8566cc17 100644 --- a/src/infiniccl-test/infiniccl_test.cpp +++ b/src/infiniccl-test/infiniccl_test.cpp @@ -11,6 +11,7 @@ #define TEST_INFINI_THREAD(API__) CHECK_API_OR(API__, INFINI_STATUS_SUCCESS, return nullptr) const size_t MAX_COUNT = 8ULL * 1024 * 1024; +// const size_t MAX_COUNT = 512 * 1024; // for metax const size_t TEST_COUNTS[] = { 128, @@ -19,7 +20,7 @@ const size_t TEST_COUNTS[] = { MAX_COUNT, }; -const infiniDtype_t TEST_DTYPES[] = {INFINI_DTYPE_F32, INFINI_DTYPE_F16}; +const infiniDtype_t TEST_DTYPES[] = {INFINI_DTYPE_F32, INFINI_DTYPE_F16, INFINI_DTYPE_BF16}; const size_t WARM_UPS = 10; @@ -51,6 +52,11 @@ void setData(infiniDtype_t dtype, void *data, size_t count, float val) { ((fp16_t *)data)[i] = utils::cast(val); } break; + case INFINI_DTYPE_BF16: + for (size_t i = 0; i < count; i++) { + ((bf16_t *)data)[i] = utils::cast(val); + } + break; default: std::abort(); break; @@ -67,6 +73,12 @@ int checkData(const T *actual_, const T *expected_, size_t count) { if (std::abs(actual - expected) > 1e-4) { failed += 1; } + } else if constexpr (std::is_same::value) { + float actual = utils::cast(actual_[i]); + float expected = utils::cast(expected_[i]); + if (std::abs(actual - expected) > 1e-4) { + failed += 1; + } } else { if (std::abs(actual_[i] - expected_[i]) > 1e-4) { failed += 1; @@ -82,6 +94,8 @@ int checkData(const void *actual, const void *expected, infiniDtype_t dtype, siz return checkData((const float *)actual, (const float *)expected, count); case INFINI_DTYPE_F16: return checkData((const fp16_t *)actual, (const fp16_t *)expected, count); + case INFINI_DTYPE_BF16: + return checkData((const bf16_t *)actual, (const bf16_t *)expected, count); default: std::abort(); return 1; diff --git a/src/infiniop/ops/softplus/metax/softplus_metax.h b/src/infiniop/ops/softplus/metax/softplus_metax.h new file mode 100644 index 000000000..8da2b4d76 --- /dev/null +++ b/src/infiniop/ops/softplus/metax/softplus_metax.h @@ -0,0 +1,8 @@ +#ifndef __SOFTPLUS_METAX_API_H__ +#define __SOFTPLUS_METAX_API_H__ + +#include "../../../elementwise/metax/elementwise_metax_api.h" + +ELEMENTWISE_DESCRIPTOR(softplus, metax) + +#endif // __SOFTPLUS_METAX_API_H__ diff --git a/src/infiniop/ops/softplus/metax/softplus_metax.maca b/src/infiniop/ops/softplus/metax/softplus_metax.maca new file mode 100644 index 000000000..5744f8c04 --- /dev/null +++ b/src/infiniop/ops/softplus/metax/softplus_metax.maca @@ -0,0 +1,60 @@ +#include "softplus_metax.h" + +#include "../../../elementwise/metax/elementwise_metax.h" + +#include "../cuda/kernel.cuh" + +namespace op::softplus::metax { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &x_desc = input_desc_vec.at(0); + const auto &y_shape = out_desc->shape(); + const auto &x_shape = x_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(y_shape, x_shape); + + // create METAX elementwise descriptor + CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::SoftplusOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::SoftplusOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::SoftplusOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::SoftplusOp, double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::softplus::metax