In [1]:
import cupy as cp
from cupy.cuda import nccl

WORLD = 1
DEVS = list(range(WORLD))

unique_id = nccl.get_unique_id()
comms = []
streams = []
for r in DEVS:
    cp.cuda.Device(r).use()
    stream = cp.cuda.Stream()
    comm = nccl.NcclCommunicator(WORLD, unique_id, r)
    comms.append(comm)
    streams.append(stream)

In [2]:
def all_to_all():
    elems_total = 1 << 16
    elems_chunk = elems_total // WORLD

    sendbuf, recvbuf = [], []
    for r in DEVS:
        cp.cuda.Device(r).use()
        send = cp.arange(elems_total, dtype=cp.float32) + r * 1000
        recv = cp.zeros_like(send)
        sendbuf.append(send)
        recvbuf.append(recv)

    for r in DEVS:
        cp.cuda.Device(r).use()
        comm = comms[r]
        stream = streams[r]
        nccl.groupStart()
        for peer in DEVS:
            s_off = peer * elems_chunk
            r_off = peer * elems_chunk
            s_view = sendbuf[r][s_off : s_off + elems_chunk]
            r_view = recvbuf[r][r_off : r_off + elems_chunk]
            # send/recv(ptr, count, dtype, peer, stream)
            comm.send(s_view.data.ptr, elems_chunk, nccl.NCCL_FLOAT32, peer, stream.ptr)
            comm.recv(r_view.data.ptr, elems_chunk, nccl.NCCL_FLOAT32, peer, stream.ptr)
        nccl.groupEnd()

    for r in DEVS:
        cp.cuda.Device(r).use()
        streams[r].synchronize()
    print("All-to-all OK")

In [3]:
def gemm_reducescatter():
    M, K, N = 1024, 512, 1024
    rows_per_rank = M // WORLD

    A, B, Cpartial, Cstrip = [], [], [], []
    for r in DEVS:
        cp.cuda.Device(r).use()
        A_r = cp.random.rand(M, K, dtype=cp.float32)
        B_r = cp.random.rand(K, N, dtype=cp.float32)
        C_r = A_r @ B_r                                   # (M, N)
        C_out = cp.zeros((rows_per_rank, N), dtype=cp.float32)
        A.append(A_r); B.append(B_r)
        Cpartial.append(C_r); Cstrip.append(C_out)

    for r in DEVS:
        cp.cuda.Device(r).use()
        comm = comms[r]
        stream = streams[r]
        comm.reduceScatter(
            Cpartial[r].ravel().data.ptr,                 # send buffer (size = recvcount*WORLD)
            Cstrip[r].ravel().data.ptr,                   # recv buffer
            rows_per_rank * N,                            # recvcount (elements)
            nccl.NCCL_FLOAT32,
            nccl.NCCL_SUM,
            stream.ptr
        )

    for r in DEVS:
        cp.cuda.Device(r).use()
        streams[r].synchronize()
    print("GEMM + ReduceScatter OK")

In [5]:
def allgather_gemm():
    M, K, N = 1024, 512, 1024
    n_per_rank = N // WORLD

    Arows, Bshard, Ball, C = [], [], [], []
    for r in DEVS:
        cp.cuda.Device(r).use()
        A_r = cp.random.rand(M, K, dtype=cp.float32)
        B_shard = cp.random.rand(K, n_per_rank, dtype=cp.float32)
        B_all = cp.zeros((K, N), dtype=cp.float32)
        C_r = cp.empty((M, N), dtype=cp.float32)
        Arows.append(A_r); Bshard.append(B_shard)
        Ball.append(B_all); C.append(C_r)

    for r in DEVS:
        cp.cuda.Device(r).use()
        comm = comms[r]
        stream = streams[r]
        comm.allGather(
            Bshard[r].ravel().data.ptr,                   # send (K*n_per_rank)
            Ball[r].ravel().data.ptr,                     # recv (K*N)
            K * n_per_rank,                               # sendcount (elements)
            nccl.NCCL_FLOAT32,
            stream.ptr
        )

    for r in DEVS:
        cp.cuda.Device(r).use()
        C[r][:] = Arows[r] @ Ball[r]

    for r in DEVS:
        cp.cuda.Device(r).use()
        streams[r].synchronize()
    print("AllGather + GEMM OK")

In [6]:
if __name__ == "__main__":
    all_to_all()
    gemm_reducescatter()
    allgather_gemm()

All-to-all OK
GEMM + ReduceScatter OK
AllGather + GEMM OK
