diff --git a/src/infiniccl/cuda/infiniccl_cuda.cu b/src/infiniccl/cuda/infiniccl_cuda.cu index df699ba45..2f8869c18 100644 --- a/src/infiniccl/cuda/infiniccl_cuda.cu +++ b/src/infiniccl/cuda/infiniccl_cuda.cu @@ -22,6 +22,8 @@ inline ncclDataType_t getNcclDtype(infiniDtype_t datatype) { return ncclFloat; case INFINI_DTYPE_F16: return ncclHalf; + case INFINI_DTYPE_BF16: + return ncclBfloat16; default: std::abort(); return ncclHalf; @@ -82,9 +84,7 @@ infiniStatus_t allReduce( infinicclComm_t comm, infinirtStream_t stream) { - if (datatype != INFINI_DTYPE_F32 && datatype != INFINI_DTYPE_F16) { - return INFINI_STATUS_BAD_PARAM; - } + CHECK_DTYPE(datatype, INFINI_DTYPE_F32, INFINI_DTYPE_F16, INFINI_DTYPE_BF16); CHECK_NCCL(ncclAllReduce(sendbuf, recvbuf, count, getNcclDtype(datatype), getNcclRedOp(op), getNcclComm(comm), getCudaStream(stream)));