diff --git a/CMakeLists.txt b/CMakeLists.txt index 87433aa6b..daae12628 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -129,6 +129,7 @@ else() set(CU_SOURCES src/collectives/device/all_reduce.cu src/collectives/device/all_gather.cu + src/collectives/device/alltoall_pivot.cu src/collectives/device/reduce.cu src/collectives/device/broadcast.cu src/collectives/device/reduce_scatter.cu diff --git a/src/collectives/all_to_all_api.cc b/src/collectives/all_to_all_api.cc index 562f451a3..d5f83e189 100644 --- a/src/collectives/all_to_all_api.cc +++ b/src/collectives/all_to_all_api.cc @@ -7,20 +7,30 @@ #include "enqueue.h" #include "collectives.h" +#include "graph/topo.h" NCCL_API(ncclResult_t, ncclAllToAll, const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclComm_t comm, hipStream_t stream); ncclResult_t ncclAllToAll(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclComm_t comm, hipStream_t stream) { - int nRanks; - NCCLCHECK(ncclCommCount(comm, &nRanks)); - size_t rankOffset = count * ncclTypeSize(datatype); - if (count == 0) return ncclSuccess; - NCCLCHECK(ncclGroupStart()); - for (int r=0; rtopo->pivotA2AEnabled = comm->topo->pivotA2AEnabled && comm->nChannels >= comm->topo->pivotA2ANumBiRings * 2; + if (comm->topo->pivotA2AEnabled) { + struct ncclInfo info = { ncclFuncAllToAllPivot, "AllToAllPivot", + sendbuff, recvbuff, count, datatype, ncclSum, 0, comm, stream, /* Args */ + ALLTOALL_PIVOT_CHUNKSTEPS, ALLTOALL_PIVOT_SLICESTEPS }; + return ncclEnqueueCheck(&info); + } else { + int nRanks; + NCCLCHECK(ncclCommCount(comm, &nRanks)); + size_t rankOffset = count * ncclTypeSize(datatype); + if (count == 0) return ncclSuccess; + NCCLCHECK(ncclGroupStart()); + for (int r=0; r + __device__ __attribute__((noinline)) void runRing(ncclWorkElem *args) { + const int tid = threadIdx.x; + const int nthreads = args->nThreads; + const int bid = args->coll.bid; + const int nranks = ncclShmem->comm.nRanks; + const ncclRing *ring = &ncclShmem->channel.ring; + const int num_bi_rings = args->coll.pivotA2ANumBiRings; + const int num_uni_rings = num_bi_rings * 2; + const int num_chunks = args->coll.nChannels / 2; + const int chunk_id = (bid % num_bi_rings) + (bid / num_uni_rings * num_bi_rings); + const int elem_size = args->coll.count % 256 ? 1 : 256; + const ssize_t num_elems = args->coll.count / elem_size; + const int num_padding_chunks = num_elems % num_chunks; + const ssize_t chunk_offset = elem_size * (num_elems / num_chunks * chunk_id + (chunk_id < num_padding_chunks ? chunk_id : num_padding_chunks)); + const ssize_t chunk_size = elem_size * (num_elems / num_chunks + (chunk_id < num_padding_chunks ? 1 : 0)); + const int pivot_direction = (bid % num_uni_rings) / num_bi_rings; + const ssize_t prims_size = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? ALLTOALL_PIVOT_CHUNKSTEPS : 1)); + + Primitives, 0, Proto> prims + (tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, /*redOpArg(ignored)=*/0, args->coll.connIndex << 16); + + for (int num_hops = 0; num_hops <= nranks / 2; num_hops++) { + const int src_rank = ring->devUserRanks[(nranks - num_hops) % nranks]; + const int dst_rank = ring->devUserRanks[num_hops]; + const ssize_t send_offset = + dst_rank * num_elems * elem_size + chunk_offset + + (src_rank == dst_rank ? pivot_direction * chunk_size / 2 : 0); + const ssize_t recv_offset = + src_rank * num_elems * elem_size + chunk_offset + + (src_rank == dst_rank ? pivot_direction * chunk_size / 2 : 0); + const ssize_t send_recv_size = + src_rank == dst_rank ? + (pivot_direction == 0 ? chunk_size / 2 : chunk_size - chunk_size / 2) : chunk_size; + + if (num_hops == 0 && args->sendbuff != args->recvbuff) { + const T* sendbuff = (const T*)args->sendbuff + send_offset; + T* recvbuff = (T *)args->recvbuff + recv_offset; + ReduceOrCopyMulti( + tid, nthreads, nullptr, false, 1, &sendbuff, 1, &recvbuff, send_recv_size); + } else { + for (ssize_t prims_offset = 0; prims_offset < send_recv_size; prims_offset += prims_size) { + const int prims_nelem = min(prims_size, send_recv_size - prims_offset); + + // step 0: send + prims.send(send_offset + prims_offset, prims_nelem); + + // num_hops - 1 steps: recv and copy to next gpu + for (int i = 0; i < num_hops - 1; i++) { + prims.recvSend(prims_nelem); + } + + // final step: recv + prims.directRecv(recv_offset + prims_offset, prims_nelem); + } + } + } + } +} + +template +struct RunWorkElement { + __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { + using Proto = ProtoSimple; + runRing(args); + } +}; diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h index 543a3ca03..b484e0e3c 100644 --- a/src/collectives/device/common.h +++ b/src/collectives/device/common.h @@ -128,6 +128,7 @@ static const __device__ constexpr ncclKernelFunc_t ncclFuncs[]{ NCCL_ONERANK_REDUCE_NAME(PreMulSum, rccl_bfloat16), #endif NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t), + NCCL_FUNC_NAME(AllToAllPivot, RING, SIMPLE, Sum, int8_t), #endif #endif }; @@ -150,6 +151,7 @@ struct Caller{ }; static_assert(FUNC_INDEX_P2P == 2710, "Wrong P2P function index"); +static_assert(FUNC_INDEX_ALLTOALL_PIVOT == 2711, "Wrong AllToAllPivot function index"); inline __device__ @@ -231,6 +233,8 @@ void NCCL_CALL_FUNCTIONS(struct ncclWorkElem* const c) noexcept { case 10: ncclFunction_SendRecv_RING_SIMPLE_Sum_int8_t(c); break; + case 11: + ncclFunction_AllToAllPivot_RING_SIMPLE_Sum_int8_t(c); default: break; } @@ -618,4 +622,8 @@ __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, dev IMPL_COLL_FUNC(func, RING, SIMPLE, Sum, int8_t); \ IMPL_COLL_KERN(func, RING, SIMPLE, Sum, int8_t, FUNC_INDEX_P2P); +// AllToAll Pivot primitive only has one function. +#define IMPL_COLL_ALLTOALL_PIVOT(func) \ + IMPL_COLL_FUNC(func, RING, SIMPLE, Sum, int8_t); + #endif diff --git a/src/collectives/device/functions.cu b/src/collectives/device/functions.cu index 7acb80be6..72f5509b4 100644 --- a/src/collectives/device/functions.cu +++ b/src/collectives/device/functions.cu @@ -89,12 +89,13 @@ __device__ struct ncclShmemData* ncclShmem; NCCL_FUNCS3B(func, Sum) // Must be consistent with the ncclFuncSet enum -__device__ ncclKern_t ncclFuncs[1+ncclNumTypes+NCCL_NUM_FUNCTIONS*ncclNumDevRedOps*ncclNumTypes*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS] = { +__device__ ncclKern_t ncclFuncs[2+ncclNumTypes+NCCL_NUM_FUNCTIONS*ncclNumDevRedOps*ncclNumTypes*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS] = { // Don't try to initialize the host shadow copy of this device-side global // variable. There is no host pointer to a device-side function, which // confuses clang. This will be fixed in the next clang release. #if __CUDA_ARCH__ NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t), + NCCL_FUNC_NAME(AllToAllPivot, RING, SIMPLE, Sum, int8_t), NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t), NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t), NCCL_ONERANK_REDUCE_NAME(PreMulSum, int32_t), diff --git a/src/collectives/device/prims_ll.h b/src/collectives/device/prims_ll.h index 8b3f84ac6..d6e76d698 100644 --- a/src/collectives/device/prims_ll.h +++ b/src/collectives/device/prims_ll.h @@ -454,4 +454,7 @@ class Primitives: __device__ void recvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { return LLGenericOp<1, 1, Input, Output>(inpIx, outIx, eltN, postOp); } + __device__ void recvSend(int eltN) { + return LLGenericOp<1, 1, -1, -1>(-1, -1, eltN, false); + } }; diff --git a/src/collectives/device/prims_ll128.h b/src/collectives/device/prims_ll128.h index e1247022f..4dc0c2175 100644 --- a/src/collectives/device/prims_ll128.h +++ b/src/collectives/device/prims_ll128.h @@ -424,4 +424,7 @@ class Primitives: __device__ void recvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { return GenericOp<1, 1, Input, Output>(inpIx, outIx, eltN, postOp); } + __device__ void recvSend(int eltN) { + return GenericOp<1, 1, -1, -1>(-1, -1, eltN, false); + } }; diff --git a/src/collectives/device/prims_simple.h b/src/collectives/device/prims_simple.h index 4b4fd227d..7160cc44f 100644 --- a/src/collectives/device/prims_simple.h +++ b/src/collectives/device/prims_simple.h @@ -635,4 +635,8 @@ class Primitives< directGather(intptr_t outIx, int totalElem, int peerElem, int skip, int shift) { ScatterGatherOp<1, 0, 1, 0>(-1, outIx, totalElem, peerElem, skip, shift, /*postOp=*/false); } + + __device__ __forceinline__ void recvSend(int eltN) { + genericOp<0, 0, 1, 1, -1, -1>(-1, -1, -1, eltN, /*postOp=*/false); + } }; diff --git a/src/enqueue.cc b/src/enqueue.cc index 4223c0541..a1fca5edf 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -361,7 +361,7 @@ static inline ncclResult_t getCollNetSupport(struct ncclInfo* info, int* collNet static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetTypeSupport, int numPipeOps) { struct ncclComm* comm = info->comm; - if (comm->nRanks == 1) { + if (comm->nRanks == 1 || info->coll == ncclFuncAllToAllPivot) { info->algorithm = NCCL_ALGO_RING; info->protocol = NCCL_PROTO_SIMPLE; } @@ -423,7 +423,12 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetTypeSupport, i if (info->algorithm == NCCL_ALGO_COLLNET) nt += 3*WARP_SIZE; } #endif - info->nChannels = nc; + if (info->coll == ncclFuncAllToAllPivot) { + int pivotA2ANumUniRings = comm->topo->pivotA2ANumBiRings * 2; + info->nChannels = comm->nChannels / pivotA2ANumUniRings * pivotA2ANumUniRings; + } else { + info->nChannels = nc; + } info->nThreads = nt; return ncclSuccess; } @@ -436,6 +441,7 @@ static ncclResult_t getPatternInfo(struct ncclInfo* info) { info->pattern = info->algorithm == NCCL_ALGO_TREE ? ncclPatternTreeUp : ncclPatternPipelineTo; break; case ncclFuncReduceScatter: case ncclFuncAllGather: + case ncclFuncAllToAllPivot: info->pattern = ncclPatternRing; break; case ncclFuncAllReduce: info->pattern = info->algorithm == NCCL_ALGO_COLLNET ? ncclPatternCollTreeUpDown : info->algorithm == NCCL_ALGO_TREE ? ncclPatternTreeUpDown : ncclPatternRingTwice; break; @@ -497,10 +503,12 @@ static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclWo // one-rank reduce index work->funcIndex = FUNC_INDEX_P2P - ncclNumTypes + int(info->datatype); return ncclSuccess; + } else if (info->coll == ncclFuncAllToAllPivot) { + work->funcIndex = FUNC_INDEX_ALLTOALL_PIVOT; + } else { + work->funcIndex = FUNC_INDEX(info->coll, info->opFull.op, info->datatype, info->algorithm, info->protocol); } - work->funcIndex = FUNC_INDEX(info->coll, info->opFull.op, info->datatype, info->algorithm, info->protocol); - work->coll.connIndex = 0; proxyArgs->connIndex = 0; if (info->protocol == NCCL_PROTO_SIMPLE && info->algorithm == NCCL_ALGO_RING) { @@ -599,6 +607,12 @@ static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclWo TRACE(NCCL_COLL,"opCount %lx slicesteps %d spl %d cpl %d nbytes %zi -> protocol %d nchannels %d nthreads %d, nloops %d nsteps %d chunksize %d comm %p", proxyArgs->opCount, sliceSteps, info->nstepsPerLoop, info->nchunksPerLoop, info->nBytes, info->protocol, info->nChannels, info->nThreads, nLoops, proxyArgs->subs[0].nsteps, chunkSize, info->comm); + + // For Pivot A2A, lastChunkSize is not needed, set pivotA2ANumBiRings instead + if (info->coll == ncclFuncAllToAllPivot) { + work->coll.pivotA2ANumBiRings = info->comm->topo->pivotA2ANumBiRings; + } + return ncclSuccess; } @@ -760,7 +774,12 @@ ncclResult_t ncclSetupAsyncKernels(ncclComm_t comm) { int allCollNetSupport = comm->collNetSupport; for (int c = 0; c < comm->asyncOpCount; c++) { struct ncclInfo* info = comm->asyncOps+c; - info->nChannels = std::min(std::max(1, (int)DIVUP(info->nBytes, channelSize)), comm->nChannels); // assign number of channels + if (info->coll == ncclFuncAllToAllPivot) { + int pivotA2ANumUniRings = comm->topo->pivotA2ANumBiRings * 2; + info->nChannels = comm->nChannels / pivotA2ANumUniRings * pivotA2ANumUniRings; + } else { + info->nChannels = std::min(std::max(1, (int)DIVUP(info->nBytes, channelSize)), comm->nChannels); // assign number of channels + } channelUsed += info->nChannels; // We can use fast path if all collectives are the same homogeneous &= info->coll == comm->asyncOps[0].coll && diff --git a/src/graph/rome_models.cc b/src/graph/rome_models.cc index 2bbe7bb94..a76b25e35 100644 --- a/src/graph/rome_models.cc +++ b/src/graph/rome_models.cc @@ -367,7 +367,7 @@ static struct rcclRomeModel rome_model_56 = { .gdrLevel = { }, .pattern = "40404040", .ringBase = "0 1 3 2 6 7 15 14 10 11 9 8 12 13 5 4|0 1 2 3 7 6 13 12 8 9 10 11 15 14 5 4|0 2 3 7 6 14 15 11 10 8 9 13 12 4 5 1|4 5 13 12 8 9 11 10 14 15 7 6 2 3 1 0|4 5 14 15 11 10 9 8 12 13 6 7 3 2 1 0|1 5 4 12 13 9 8 10 11 15 14 6 7 3 2 0", - .options = "", + .options = "pivotA2AEnabled=1,pivotA2ANumBiRings=3", }; static struct rcclRomeModel rome_model_58 = { @@ -629,7 +629,7 @@ ncclResult_t parseGraph(const char* str, struct ncclTopoSystem* system, struct n static void parseOptions(struct ncclTopoSystem* system, const char *options) { if (strcmp(options, "")) { - char *str_temp = (char *)malloc(sizeof(options)); + char *str_temp = (char *)malloc(strlen(options) + 1); strcpy(str_temp, options); char* tokens[MAX_OPT_TOKENS]; int numTokens = 0; @@ -640,6 +640,10 @@ static void parseOptions(struct ncclTopoSystem* system, const char *options) { for (int i = 0; i < numTokens/2; i++) { if (strcmp(tokens[i*2], "netGdrLevel") == 0) { system->netGdrLevel = atol(tokens[i*2+1]); + } else if (strcmp(tokens[i*2], "pivotA2AEnabled") == 0) { + system->pivotA2AEnabled = (bool)atol(tokens[i*2+1]); + } else if (strcmp(tokens[i*2], "pivotA2ANumBiRings") == 0) { + system->pivotA2ANumBiRings = atol(tokens[i*2+1]); } } free(str_temp); diff --git a/src/graph/topo.h b/src/graph/topo.h index ba1ed11f7..fbd12a80d 100644 --- a/src/graph/topo.h +++ b/src/graph/topo.h @@ -137,6 +137,9 @@ struct ncclTopoSystem { int type; int nRanks; int netGdrLevel; + + bool pivotA2AEnabled; + int pivotA2ANumBiRings; }; ncclResult_t ncclTopoGetNode(struct ncclTopoSystem* system, struct ncclTopoNode** node, int type, uint64_t id); diff --git a/src/include/collectives.h b/src/include/collectives.h index 10be4501a..d52d59706 100644 --- a/src/include/collectives.h +++ b/src/include/collectives.h @@ -20,6 +20,7 @@ struct ncclDevRedOpFull { }; #define FUNC_INDEX_P2P (ncclNumTypes+NCCL_NUM_FUNCTIONS*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS*ncclNumTypes*ncclNumDevRedOps) +#define FUNC_INDEX_ALLTOALL_PIVOT (FUNC_INDEX_P2P+1) #define FUNC_INDEX(func, devredop, ncclType, al, pr) ((((((func)*ncclNumDevRedOps + (devredop))*ncclNumTypes) + (ncclType))*NCCL_NUM_ALGORITHMS+(al))*NCCL_NUM_PROTOCOLS+(pr)) #define NCCL_FUNC_NAME(func, algo, proto, devredop, type) \ @@ -93,6 +94,7 @@ DECL2(AllGather, Sum, /*undefForFloat=*/0) DECL(ReduceScatter) DECL(AllReduce) DECL5(SendRecv, RING, SIMPLE, Sum, int8_t) +DECL5(AllToAllPivot, RING, SIMPLE, Sum, int8_t) extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t)(struct ncclWorkElem* args); extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t)(struct ncclWorkElem* args); @@ -126,5 +128,7 @@ extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, double)(struct ncclWo #define REDUCE_CHUNKSTEPS 2 #define SENDRECV_SLICEFACTOR 1 #define NCCL_MAX_SLICE_PER_CHUNK 2 // max value for CHUNKSTEPS/SLICESTEPS, must accord with above +#define ALLTOALL_PIVOT_SLICESTEPS 2 +#define ALLTOALL_PIVOT_CHUNKSTEPS 4 #endif diff --git a/src/include/devcomm.h b/src/include/devcomm.h index e69a71c33..26936f530 100644 --- a/src/include/devcomm.h +++ b/src/include/devcomm.h @@ -26,9 +26,9 @@ #endif -#define NCCL_NUM_FUNCTIONS 5 // SendRecv not included for now -typedef enum { ncclFuncBroadcast, ncclFuncReduce, ncclFuncAllGather, ncclFuncReduceScatter, ncclFuncAllReduce, ncclFuncSendRecv, ncclNumFuncs} ncclFunc_t; -extern const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+1]; +#define NCCL_NUM_FUNCTIONS 5 // SendRecv and AllToAllPivot not included for now +typedef enum { ncclFuncBroadcast, ncclFuncReduce, ncclFuncAllGather, ncclFuncReduceScatter, ncclFuncAllReduce, ncclFuncSendRecv, ncclFuncAllToAllPivot, ncclNumFuncs} ncclFunc_t; +extern const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+2]; #define NCCL_NUM_ALGORITHMS 3 // Tree/Ring/CollNet #define NCCL_ALGO_TREE 0 @@ -202,7 +202,12 @@ struct ncclWorkElem { union { struct { size_t count; - size_t lastChunkSize; + union { + size_t lastChunkSize; + // Pivot A2A kernel computes chunk size itself. + // Instead, it needs the number of bidirectional rings. + size_t pivotA2ANumBiRings; + }; uint64_t redOpArg; uint16_t root; uint8_t bid; diff --git a/src/init.cc b/src/init.cc index 23cf33035..53ac54a1f 100644 --- a/src/init.cc +++ b/src/init.cc @@ -47,7 +47,7 @@ std::chrono::high_resolution_clock::time_point ncclEpoch; #define NCCL_GROUP_CUDA_STREAM 1 // CGMD: CUDA 9.0,9.1 Need to use an internal CUDA stream #endif -const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+1] = { "Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv" }; +const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+2] = { "Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv", "AllToAllPivot" }; const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS] = { "Tree", "Ring", "CollNet" }; const char* ncclProtoStr[NCCL_NUM_PROTOCOLS] = { "LL", "LL128", "Simple" }; const char* ncclDevRedOpStr[ncclNumDevRedOps] = { "Sum", "Prod", "Max", "Min", "PreMulSum", "SumPostDiv" }; @@ -870,6 +870,9 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm comm->topo->nRanks = comm->nRanks; // init netGdrLevel comm->topo->netGdrLevel = -2; + // init Pivot A2A related fields + comm->topo->pivotA2AEnabled = false; + comm->topo->pivotA2ANumBiRings = 0; // Compute paths between GPUs and NICs NCCLCHECK(ncclTopoComputePaths(comm->topo, comm->peerInfo)); // Remove inaccessible GPUs and unused NICs @@ -996,6 +999,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm struct ncclGraphInfo ring; struct ncclGraphInfo collNet; struct ncclTopoRanks topoRanks; + bool pivotA2AEnabled; } *allGather3Data; NCCLCHECK(ncclCalloc(&allGather3Data, nranks)); @@ -1036,6 +1040,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm allGather3Data[rank].collNet.typeIntra = collNetGraph.typeIntra; allGather3Data[rank].collNet.typeInter = collNetGraph.typeInter; allGather3Data[rank].collNetSupport = comm->collNetSupport; + allGather3Data[rank].pivotA2AEnabled = comm->topo->pivotA2AEnabled; comm->nChannels = (comm->topo->nodes[GPU].count != comm->topo->nRanks && comm->topo->nodes[NET].count) ? std::min(treeGraph.nChannels, ringGraph.nChannels) : ringGraph.nChannels; @@ -1089,6 +1094,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm collNetGraph.typeIntra = std::min(allGather3Data[i].collNet.typeIntra, collNetGraph.typeIntra); collNetGraph.typeInter = std::min(allGather3Data[i].collNet.typeInter, collNetGraph.typeInter); comm->collNetSupport = std::min(allGather3Data[i].collNetSupport, comm->collNetSupport); + comm->topo->pivotA2AEnabled = comm->topo->pivotA2AEnabled && allGather3Data[i].pivotA2AEnabled; } comm->nChannels = treeGraph.nChannels = ringGraph.nChannels = diff --git a/src/misc/argcheck.cc b/src/misc/argcheck.cc index d1aabec47..c39b06de6 100644 --- a/src/misc/argcheck.cc +++ b/src/misc/argcheck.cc @@ -44,9 +44,9 @@ ncclResult_t ArgsCheck(struct ncclInfo* info) { WARN("%s : invalid type %d", info->opName, info->datatype); return ncclInvalidArgument; } - // Type is OK, compute nbytes. Convert Allgather/Broadcast/P2P calls to chars. + // Type is OK, compute nbytes. Convert Allgather/Broadcast/P2P/AllToAllPivot calls to chars. info->nBytes = info->count * ncclTypeSize(info->datatype); - if (info->coll == ncclFuncAllGather || info->coll == ncclFuncBroadcast) { + if (info->coll == ncclFuncAllGather || info->coll == ncclFuncBroadcast || info->coll == ncclFuncAllToAllPivot) { info->count = info->nBytes; info->datatype = ncclInt8; } diff --git a/tools/topo_expl/utils.cpp b/tools/topo_expl/utils.cpp index 938a3f935..2d6e3550a 100644 --- a/tools/topo_expl/utils.cpp +++ b/tools/topo_expl/utils.cpp @@ -31,7 +31,7 @@ #include "utils.h" #include "rocm_smi/rocm_smi.h" -const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+1] = { "Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv" }; +const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+2] = { "Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv", "AllToAllPivot" }; const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS] = { "Tree", "Ring", "CollNet" }; const char* ncclProtoStr[NCCL_NUM_PROTOCOLS] = { "LL", "LL128", "Simple" };