diff --git a/src/infiniccl/cambricon/infiniccl_cambricon.cc b/src/infiniccl/cambricon/infiniccl_cambricon.cc index cc5b677cf..f5ea5923f 100644 --- a/src/infiniccl/cambricon/infiniccl_cambricon.cc +++ b/src/infiniccl/cambricon/infiniccl_cambricon.cc @@ -25,6 +25,8 @@ inline cnclDataType_t getCnclDtype(infiniDtype_t datatype) { return cnclFloat32; case INFINI_DTYPE_F16: return cnclFloat16; + case INFINI_DTYPE_BF16: + return cnclBfloat16; default: std::cerr << "Unsupported data type: " << datatype << std::endl; std::abort(); @@ -89,9 +91,7 @@ infiniStatus_t allReduce( infinicclComm_t comm, infinirtStream_t stream) { - if (datatype != INFINI_DTYPE_F32 && datatype != INFINI_DTYPE_F16) { - return INFINI_STATUS_BAD_PARAM; - } + CHECK_DTYPE(datatype, INFINI_DTYPE_F32, INFINI_DTYPE_F16, INFINI_DTYPE_BF16); CHECK_CNCL(cnclAllReduce(sendbuf, recvbuf, count, getCnclDtype(datatype), getCnclRedOp(op), getCnclComm(comm), @@ -99,4 +99,5 @@ infiniStatus_t allReduce( return INFINI_STATUS_SUCCESS; } + } // namespace infiniccl::cambricon