From 81093e0b2fd9ab6172d0a131f391f4e75831c9b9 Mon Sep 17 00:00:00 2001 From: PanZezhong1725 Date: Tue, 9 Sep 2025 11:32:20 +0800 Subject: [PATCH] issue/434 nccl support bf16 --- src/infiniccl/cuda/infiniccl_cuda.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/infiniccl/cuda/infiniccl_cuda.cu b/src/infiniccl/cuda/infiniccl_cuda.cu index df699ba45..2f8869c18 100644 --- a/src/infiniccl/cuda/infiniccl_cuda.cu +++ b/src/infiniccl/cuda/infiniccl_cuda.cu @@ -22,6 +22,8 @@ inline ncclDataType_t getNcclDtype(infiniDtype_t datatype) { return ncclFloat; case INFINI_DTYPE_F16: return ncclHalf; + case INFINI_DTYPE_BF16: + return ncclBfloat16; default: std::abort(); return ncclHalf; @@ -82,9 +84,7 @@ infiniStatus_t allReduce( infinicclComm_t comm, infinirtStream_t stream) { - if (datatype != INFINI_DTYPE_F32 && datatype != INFINI_DTYPE_F16) { - return INFINI_STATUS_BAD_PARAM; - } + CHECK_DTYPE(datatype, INFINI_DTYPE_F32, INFINI_DTYPE_F16, INFINI_DTYPE_BF16); CHECK_NCCL(ncclAllReduce(sendbuf, recvbuf, count, getNcclDtype(datatype), getNcclRedOp(op), getNcclComm(comm), getCudaStream(stream)));