Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pivot AllToAll algorithm for Rome model #503

Merged
merged 8 commits into from
Feb 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ else()
set(CU_SOURCES
src/collectives/device/all_reduce.cu
src/collectives/device/all_gather.cu
src/collectives/device/alltoall_pivot.cu
src/collectives/device/reduce.cu
src/collectives/device/broadcast.cu
src/collectives/device/reduce_scatter.cu
Expand Down
30 changes: 20 additions & 10 deletions src/collectives/all_to_all_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,30 @@

#include "enqueue.h"
#include "collectives.h"
#include "graph/topo.h"

NCCL_API(ncclResult_t, ncclAllToAll, const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
ncclComm_t comm, hipStream_t stream);
ncclResult_t ncclAllToAll(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
ncclComm_t comm, hipStream_t stream) {
int nRanks;
NCCLCHECK(ncclCommCount(comm, &nRanks));
size_t rankOffset = count * ncclTypeSize(datatype);
if (count == 0) return ncclSuccess;
NCCLCHECK(ncclGroupStart());
for (int r=0; r<nRanks; r++) {
NCCLCHECK(ncclSend(((char*)sendbuff)+r*rankOffset, count, datatype, r, comm, stream));
NCCLCHECK(ncclRecv(((char*)recvbuff)+r*rankOffset, count, datatype, r, comm, stream));
// Determine Pivot A2A support now that we know number of channels
comm->topo->pivotA2AEnabled = comm->topo->pivotA2AEnabled && comm->nChannels >= comm->topo->pivotA2ANumBiRings * 2;
if (comm->topo->pivotA2AEnabled) {
wenkaidu marked this conversation as resolved.
Show resolved Hide resolved
struct ncclInfo info = { ncclFuncAllToAllPivot, "AllToAllPivot",
sendbuff, recvbuff, count, datatype, ncclSum, 0, comm, stream, /* Args */
ALLTOALL_PIVOT_CHUNKSTEPS, ALLTOALL_PIVOT_SLICESTEPS };
return ncclEnqueueCheck(&info);
} else {
int nRanks;
NCCLCHECK(ncclCommCount(comm, &nRanks));
size_t rankOffset = count * ncclTypeSize(datatype);
if (count == 0) return ncclSuccess;
NCCLCHECK(ncclGroupStart());
for (int r=0; r<nRanks; r++) {
NCCLCHECK(ncclSend(((char*)sendbuff)+r*rankOffset, count, datatype, r, comm, stream));
NCCLCHECK(ncclRecv(((char*)recvbuff)+r*rankOffset, count, datatype, r, comm, stream));
}
NCCLCHECK(ncclGroupEnd());
return ncclSuccess;
}
NCCLCHECK(ncclGroupEnd());
return ncclSuccess;
}
2 changes: 1 addition & 1 deletion src/collectives/device/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ include ../../../makefiles/version.mk
BUILDDIR ?= $(abspath ../../../build)
OBJDIR := $(BUILDDIR)/obj/collectives/device

LIBSRCFILES := all_reduce.cu broadcast.cu reduce.cu all_gather.cu reduce_scatter.cu sendrecv.cu onerank_reduce.cu
LIBSRCFILES := all_reduce.cu broadcast.cu reduce.cu all_gather.cu reduce_scatter.cu sendrecv.cu onerank_reduce.cu alltoall_pivot.cu

LIBSRCFILES += functions.cu

Expand Down
11 changes: 11 additions & 0 deletions src/collectives/device/alltoall_pivot.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
/*************************************************************************
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/

#include "alltoall_pivot.h"
#include "common.h"
#include "collectives.h"

IMPL_COLL_ALLTOALL_PIVOT(AllToAllPivot);
78 changes: 78 additions & 0 deletions src/collectives/device/alltoall_pivot.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*************************************************************************
* Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/

#include "devcomm.h"
#include "collectives.h"
#include "primitives.h"

namespace {
template<typename T, typename RedOp, typename Proto>
__device__ __attribute__((noinline)) void runRing(ncclWorkElem *args) {
const int tid = threadIdx.x;
const int nthreads = args->nThreads;
const int bid = args->coll.bid;
const int nranks = ncclShmem->comm.nRanks;
const ncclRing *ring = &ncclShmem->channel.ring;
const int num_bi_rings = args->coll.pivotA2ANumBiRings;
const int num_uni_rings = num_bi_rings * 2;
const int num_chunks = args->coll.nChannels / 2;
const int chunk_id = (bid % num_bi_rings) + (bid / num_uni_rings * num_bi_rings);
const int elem_size = args->coll.count % 256 ? 1 : 256;
const ssize_t num_elems = args->coll.count / elem_size;
const int num_padding_chunks = num_elems % num_chunks;
const ssize_t chunk_offset = elem_size * (num_elems / num_chunks * chunk_id + (chunk_id < num_padding_chunks ? chunk_id : num_padding_chunks));
const ssize_t chunk_size = elem_size * (num_elems / num_chunks + (chunk_id < num_padding_chunks ? 1 : 0));
const int pivot_direction = (bid % num_uni_rings) / num_bi_rings;
const ssize_t prims_size = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? ALLTOALL_PIVOT_CHUNKSTEPS : 1));

Primitives<T, RedOp, FanSymmetric<1>, 0, Proto> prims
(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, /*redOpArg(ignored)=*/0, args->coll.connIndex << 16);

for (int num_hops = 0; num_hops <= nranks / 2; num_hops++) {
const int src_rank = ring->devUserRanks[(nranks - num_hops) % nranks];
const int dst_rank = ring->devUserRanks[num_hops];
const ssize_t send_offset =
dst_rank * num_elems * elem_size + chunk_offset +
(src_rank == dst_rank ? pivot_direction * chunk_size / 2 : 0);
const ssize_t recv_offset =
src_rank * num_elems * elem_size + chunk_offset +
(src_rank == dst_rank ? pivot_direction * chunk_size / 2 : 0);
const ssize_t send_recv_size =
src_rank == dst_rank ?
(pivot_direction == 0 ? chunk_size / 2 : chunk_size - chunk_size / 2) : chunk_size;

if (num_hops == 0 && args->sendbuff != args->recvbuff) {
const T* sendbuff = (const T*)args->sendbuff + send_offset;
T* recvbuff = (T *)args->recvbuff + recv_offset;
ReduceOrCopyMulti<COLL_UNROLL, RedOp, T, 1, 1, 1, 1, 0>(
tid, nthreads, nullptr, false, 1, &sendbuff, 1, &recvbuff, send_recv_size);
} else {
for (ssize_t prims_offset = 0; prims_offset < send_recv_size; prims_offset += prims_size) {
const int prims_nelem = min(prims_size, send_recv_size - prims_offset);

// step 0: send
prims.send(send_offset + prims_offset, prims_nelem);

// num_hops - 1 steps: recv and copy to next gpu
for (int i = 0; i < num_hops - 1; i++) {
prims.recvSend(prims_nelem);
}

// final step: recv
prims.directRecv(recv_offset + prims_offset, prims_nelem);
}
}
}
}
}

template<typename T, typename RedOp>
struct RunWorkElement<ncclFuncAllToAllPivot, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
using Proto = ProtoSimple<ALLTOALL_PIVOT_CHUNKSTEPS/ALLTOALL_PIVOT_SLICESTEPS, ALLTOALL_PIVOT_SLICESTEPS>;
runRing<T, RedOp, Proto>(args);
}
};
8 changes: 8 additions & 0 deletions src/collectives/device/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ static const __device__ constexpr ncclKernelFunc_t ncclFuncs[]{
NCCL_ONERANK_REDUCE_NAME(PreMulSum, rccl_bfloat16),
#endif
NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t),
NCCL_FUNC_NAME(AllToAllPivot, RING, SIMPLE, Sum, int8_t),
#endif
#endif
};
Expand All @@ -150,6 +151,7 @@ struct Caller<f, f + 1>{
};

static_assert(FUNC_INDEX_P2P == 2710, "Wrong P2P function index");
static_assert(FUNC_INDEX_ALLTOALL_PIVOT == 2711, "Wrong AllToAllPivot function index");

inline
__device__
Expand Down Expand Up @@ -231,6 +233,8 @@ void NCCL_CALL_FUNCTIONS(struct ncclWorkElem* const c) noexcept {
case 10:
ncclFunction_SendRecv_RING_SIMPLE_Sum_int8_t(c);
break;
case 11:
ncclFunction_AllToAllPivot_RING_SIMPLE_Sum_int8_t(c);
default:
break;
}
Expand Down Expand Up @@ -618,4 +622,8 @@ __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, dev
IMPL_COLL_FUNC(func, RING, SIMPLE, Sum, int8_t); \
IMPL_COLL_KERN(func, RING, SIMPLE, Sum, int8_t, FUNC_INDEX_P2P);

// AllToAll Pivot primitive only has one function.
#define IMPL_COLL_ALLTOALL_PIVOT(func) \
IMPL_COLL_FUNC(func, RING, SIMPLE, Sum, int8_t);

#endif
3 changes: 2 additions & 1 deletion src/collectives/device/functions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,13 @@ __device__ struct ncclShmemData* ncclShmem;
NCCL_FUNCS3B(func, Sum)

// Must be consistent with the ncclFuncSet enum
__device__ ncclKern_t ncclFuncs[1+ncclNumTypes+NCCL_NUM_FUNCTIONS*ncclNumDevRedOps*ncclNumTypes*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS] = {
__device__ ncclKern_t ncclFuncs[2+ncclNumTypes+NCCL_NUM_FUNCTIONS*ncclNumDevRedOps*ncclNumTypes*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS] = {
// Don't try to initialize the host shadow copy of this device-side global
// variable. There is no host pointer to a device-side function, which
// confuses clang. This will be fixed in the next clang release.
#if __CUDA_ARCH__
NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t),
NCCL_FUNC_NAME(AllToAllPivot, RING, SIMPLE, Sum, int8_t),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int32_t),
Expand Down
3 changes: 3 additions & 0 deletions src/collectives/device/prims_ll.h
Original file line number Diff line number Diff line change
Expand Up @@ -454,4 +454,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL>:
__device__ void recvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
return LLGenericOp<1, 1, Input, Output>(inpIx, outIx, eltN, postOp);
}
__device__ void recvSend(int eltN) {
return LLGenericOp<1, 1, -1, -1>(-1, -1, eltN, false);
}
};
3 changes: 3 additions & 0 deletions src/collectives/device/prims_ll128.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,4 +424,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128>:
__device__ void recvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
return GenericOp<1, 1, Input, Output>(inpIx, outIx, eltN, postOp);
}
__device__ void recvSend(int eltN) {
return GenericOp<1, 1, -1, -1>(-1, -1, eltN, false);
}
};
4 changes: 4 additions & 0 deletions src/collectives/device/prims_simple.h
Original file line number Diff line number Diff line change
Expand Up @@ -635,4 +635,8 @@ class Primitives<
directGather(intptr_t outIx, int totalElem, int peerElem, int skip, int shift) {
ScatterGatherOp<1, 0, 1, 0>(-1, outIx, totalElem, peerElem, skip, shift, /*postOp=*/false);
}

__device__ __forceinline__ void recvSend(int eltN) {
genericOp<0, 0, 1, 1, -1, -1>(-1, -1, -1, eltN, /*postOp=*/false);
}
};
29 changes: 24 additions & 5 deletions src/enqueue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ static inline ncclResult_t getCollNetSupport(struct ncclInfo* info, int* collNet

static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetTypeSupport, int numPipeOps) {
struct ncclComm* comm = info->comm;
if (comm->nRanks == 1) {
if (comm->nRanks == 1 || info->coll == ncclFuncAllToAllPivot) {
info->algorithm = NCCL_ALGO_RING;
info->protocol = NCCL_PROTO_SIMPLE;
}
Expand Down Expand Up @@ -423,7 +423,12 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetTypeSupport, i
if (info->algorithm == NCCL_ALGO_COLLNET) nt += 3*WARP_SIZE;
}
#endif
info->nChannels = nc;
if (info->coll == ncclFuncAllToAllPivot) {
int pivotA2ANumUniRings = comm->topo->pivotA2ANumBiRings * 2;
info->nChannels = comm->nChannels / pivotA2ANumUniRings * pivotA2ANumUniRings;
} else {
info->nChannels = nc;
}
info->nThreads = nt;
return ncclSuccess;
}
Expand All @@ -436,6 +441,7 @@ static ncclResult_t getPatternInfo(struct ncclInfo* info) {
info->pattern = info->algorithm == NCCL_ALGO_TREE ? ncclPatternTreeUp : ncclPatternPipelineTo; break;
case ncclFuncReduceScatter:
case ncclFuncAllGather:
case ncclFuncAllToAllPivot:
info->pattern = ncclPatternRing; break;
case ncclFuncAllReduce:
info->pattern = info->algorithm == NCCL_ALGO_COLLNET ? ncclPatternCollTreeUpDown : info->algorithm == NCCL_ALGO_TREE ? ncclPatternTreeUpDown : ncclPatternRingTwice; break;
Expand Down Expand Up @@ -497,10 +503,12 @@ static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclWo
// one-rank reduce index
work->funcIndex = FUNC_INDEX_P2P - ncclNumTypes + int(info->datatype);
return ncclSuccess;
} else if (info->coll == ncclFuncAllToAllPivot) {
work->funcIndex = FUNC_INDEX_ALLTOALL_PIVOT;
} else {
work->funcIndex = FUNC_INDEX(info->coll, info->opFull.op, info->datatype, info->algorithm, info->protocol);
}

work->funcIndex = FUNC_INDEX(info->coll, info->opFull.op, info->datatype, info->algorithm, info->protocol);

work->coll.connIndex = 0;
proxyArgs->connIndex = 0;
if (info->protocol == NCCL_PROTO_SIMPLE && info->algorithm == NCCL_ALGO_RING) {
Expand Down Expand Up @@ -599,6 +607,12 @@ static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclWo
TRACE(NCCL_COLL,"opCount %lx slicesteps %d spl %d cpl %d nbytes %zi -> protocol %d nchannels %d nthreads %d, nloops %d nsteps %d chunksize %d comm %p",
proxyArgs->opCount, sliceSteps, info->nstepsPerLoop, info->nchunksPerLoop, info->nBytes, info->protocol, info->nChannels, info->nThreads,
nLoops, proxyArgs->subs[0].nsteps, chunkSize, info->comm);

// For Pivot A2A, lastChunkSize is not needed, set pivotA2ANumBiRings instead
if (info->coll == ncclFuncAllToAllPivot) {
work->coll.pivotA2ANumBiRings = info->comm->topo->pivotA2ANumBiRings;
}

return ncclSuccess;
}

Expand Down Expand Up @@ -760,7 +774,12 @@ ncclResult_t ncclSetupAsyncKernels(ncclComm_t comm) {
int allCollNetSupport = comm->collNetSupport;
for (int c = 0; c < comm->asyncOpCount; c++) {
struct ncclInfo* info = comm->asyncOps+c;
info->nChannels = std::min(std::max(1, (int)DIVUP(info->nBytes, channelSize)), comm->nChannels); // assign number of channels
if (info->coll == ncclFuncAllToAllPivot) {
int pivotA2ANumUniRings = comm->topo->pivotA2ANumBiRings * 2;
info->nChannels = comm->nChannels / pivotA2ANumUniRings * pivotA2ANumUniRings;
} else {
info->nChannels = std::min(std::max(1, (int)DIVUP(info->nBytes, channelSize)), comm->nChannels); // assign number of channels
}
channelUsed += info->nChannels;
// We can use fast path if all collectives are the same
homogeneous &= info->coll == comm->asyncOps[0].coll &&
Expand Down
8 changes: 6 additions & 2 deletions src/graph/rome_models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ static struct rcclRomeModel rome_model_56 = {
.gdrLevel = { },
.pattern = "40404040",
.ringBase = "0 1 3 2 6 7 15 14 10 11 9 8 12 13 5 4|0 1 2 3 7 6 13 12 8 9 10 11 15 14 5 4|0 2 3 7 6 14 15 11 10 8 9 13 12 4 5 1|4 5 13 12 8 9 11 10 14 15 7 6 2 3 1 0|4 5 14 15 11 10 9 8 12 13 6 7 3 2 1 0|1 5 4 12 13 9 8 10 11 15 14 6 7 3 2 0",
.options = "",
.options = "pivotA2AEnabled=1,pivotA2ANumBiRings=3",
};

static struct rcclRomeModel rome_model_58 = {
Expand Down Expand Up @@ -629,7 +629,7 @@ ncclResult_t parseGraph(const char* str, struct ncclTopoSystem* system, struct n

static void parseOptions(struct ncclTopoSystem* system, const char *options) {
if (strcmp(options, "")) {
char *str_temp = (char *)malloc(sizeof(options));
char *str_temp = (char *)malloc(strlen(options) + 1);
strcpy(str_temp, options);
char* tokens[MAX_OPT_TOKENS];
int numTokens = 0;
Expand All @@ -640,6 +640,10 @@ static void parseOptions(struct ncclTopoSystem* system, const char *options) {
for (int i = 0; i < numTokens/2; i++) {
if (strcmp(tokens[i*2], "netGdrLevel") == 0) {
system->netGdrLevel = atol(tokens[i*2+1]);
} else if (strcmp(tokens[i*2], "pivotA2AEnabled") == 0) {
system->pivotA2AEnabled = (bool)atol(tokens[i*2+1]);
} else if (strcmp(tokens[i*2], "pivotA2ANumBiRings") == 0) {
system->pivotA2ANumBiRings = atol(tokens[i*2+1]);
}
}
free(str_temp);
Expand Down
3 changes: 3 additions & 0 deletions src/graph/topo.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ struct ncclTopoSystem {
int type;
int nRanks;
int netGdrLevel;

bool pivotA2AEnabled;
int pivotA2ANumBiRings;
};

ncclResult_t ncclTopoGetNode(struct ncclTopoSystem* system, struct ncclTopoNode** node, int type, uint64_t id);
Expand Down
4 changes: 4 additions & 0 deletions src/include/collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct ncclDevRedOpFull {
};

#define FUNC_INDEX_P2P (ncclNumTypes+NCCL_NUM_FUNCTIONS*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS*ncclNumTypes*ncclNumDevRedOps)
#define FUNC_INDEX_ALLTOALL_PIVOT (FUNC_INDEX_P2P+1)
#define FUNC_INDEX(func, devredop, ncclType, al, pr) ((((((func)*ncclNumDevRedOps + (devredop))*ncclNumTypes) + (ncclType))*NCCL_NUM_ALGORITHMS+(al))*NCCL_NUM_PROTOCOLS+(pr))

#define NCCL_FUNC_NAME(func, algo, proto, devredop, type) \
Expand Down Expand Up @@ -93,6 +94,7 @@ DECL2(AllGather, Sum, /*undefForFloat=*/0)
DECL(ReduceScatter)
DECL(AllReduce)
DECL5(SendRecv, RING, SIMPLE, Sum, int8_t)
DECL5(AllToAllPivot, RING, SIMPLE, Sum, int8_t)

extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t)(struct ncclWorkElem* args);
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t)(struct ncclWorkElem* args);
Expand Down Expand Up @@ -126,5 +128,7 @@ extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, double)(struct ncclWo
#define REDUCE_CHUNKSTEPS 2
#define SENDRECV_SLICEFACTOR 1
#define NCCL_MAX_SLICE_PER_CHUNK 2 // max value for CHUNKSTEPS/SLICESTEPS, must accord with above
#define ALLTOALL_PIVOT_SLICESTEPS 2
#define ALLTOALL_PIVOT_CHUNKSTEPS 4

#endif
13 changes: 9 additions & 4 deletions src/include/devcomm.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
#endif


#define NCCL_NUM_FUNCTIONS 5 // SendRecv not included for now
typedef enum { ncclFuncBroadcast, ncclFuncReduce, ncclFuncAllGather, ncclFuncReduceScatter, ncclFuncAllReduce, ncclFuncSendRecv, ncclNumFuncs} ncclFunc_t;
extern const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+1];
#define NCCL_NUM_FUNCTIONS 5 // SendRecv and AllToAllPivot not included for now
typedef enum { ncclFuncBroadcast, ncclFuncReduce, ncclFuncAllGather, ncclFuncReduceScatter, ncclFuncAllReduce, ncclFuncSendRecv, ncclFuncAllToAllPivot, ncclNumFuncs} ncclFunc_t;
extern const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+2];

#define NCCL_NUM_ALGORITHMS 3 // Tree/Ring/CollNet
#define NCCL_ALGO_TREE 0
Expand Down Expand Up @@ -202,7 +202,12 @@ struct ncclWorkElem {
union {
struct {
size_t count;
size_t lastChunkSize;
union {
size_t lastChunkSize;
// Pivot A2A kernel computes chunk size itself.
// Instead, it needs the number of bidirectional rings.
size_t pivotA2ANumBiRings;
};
uint64_t redOpArg;
uint16_t root;
uint8_t bid;
Expand Down
Loading