Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Leak in FIFO queue #1251

Closed
samsamoa opened this issue Apr 13, 2024 · 7 comments · Fixed by huggingface/text-generation-inference#2099
Closed

Leak in FIFO queue #1251

samsamoa opened this issue Apr 13, 2024 · 7 comments · Fixed by huggingface/text-generation-inference#2099

Comments

@samsamoa
Copy link

samsamoa commented Apr 13, 2024

We are experiencing an issue where 8 processes, each controlling one GPU on a node, all lock up at the same time.  It seems to be deterministic, though we don't know exactly the operation that is causing trouble.  But something like "after N graph executions, all 8 processes stall at the same time".

We're relatively confident that this is a leak, because increasing NCCL_WORK_FIFO_DEPTH seems to increase the number of graphs that can be executed prior to stalling.  And decreasing it causes the stall to happen sooner.

Here is the stack trace for where we get stuck:

sched_yield (libc.so.6)
waitWorkFifoAvailable (enqueue.cc:745)
uploadWork (enqueue.cc:777)
ncclLaunchKernelBefore_NoUncapturedCuda (enqueue.cc:1008)
doLaunches (group.cc:172)
groupLaunch (group.cc:342)
ncclGroupEndInternal (group.cc:423)
ncclGroupEndInternal (group.cc:378)
ncclGroupEnd (group.cc:106)

This is on an older NCCL patch, a8511ca, but we have verified that the same issue is present on master as of yesterday. The issue affects both H100 and A100.

We'll update the issue as we gather more info on the exact operation that is causing trouble.

@samsamoa
Copy link
Author

We now suspect https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-graph-mixing-support is at issue here. We had turned it off to get a significant speedup, but we may be misusing that feature.

@samsamoa
Copy link
Author

samsamoa commented Apr 13, 2024

We actually were still able to reproduce with graph mixing support turned on. Adding a synchronize between usages somehow also doesn't help. We're working on a more minimal reproducer but it will take some time.

@WhiteFangBuck
Copy link

@sjeaugey @KaimingOuyang

@jbachan
Copy link
Collaborator

jbachan commented Apr 13, 2024 via email

@samsamoa
Copy link
Author

samsamoa commented Apr 14, 2024

Reproducer (on 2 H100s):

import torch


def _test(rank):
    torch.cuda.set_device(rank)
    torch.distributed.init_process_group(
        backend="nccl", rank=rank, world_size=2, init_method="tcp://localhost:2379"
    )

    size = 100_000
    t = torch.zeros(size, dtype=torch.bfloat16, device="cuda")
    torch.distributed.all_reduce(t)
    torch.distributed.all_reduce(t)
    with torch.cuda.graphs.graph(torch.cuda.graphs.CUDAGraph()):
        torch.distributed.all_reduce(t)

    # Uncommenting this will fix the hang
    # torch.distributed.all_reduce(t)

    random.seed(0)
    for i in range(100_000):
        if i % 100 == 0 and rank == 0:
            print(i)
        size = 49_000
        t = torch.zeros(size, dtype=torch.bfloat16, device="cuda")
        torch.distributed.all_reduce(t)
        torch.cuda.synchronize()


if __name__ == "__main__":
    torch.multiprocessing.start_processes(fn=_test, nprocs=2)

@samsamoa
Copy link
Author

samsamoa commented Apr 15, 2024

Here's a C++ version (thanks claude)

#include <iostream>
#include <nccl.h>
#include <mpi.h>

void test(int rank) {
//    setenv("NCCL_WORK_FIFO_DEPTH", "128", 1);
//    if (rank == 0) {
//        setenv("NCCL_DEBUG", "TRACE", 1);
//        setenv("NCCL_DEBUG_SUBSYS", "ALL", 1);
//    }

    cudaSetDevice(rank);

    ncclComm_t comm;
    ncclUniqueId id;
    if (rank == 0) {
        ncclGetUniqueId(&id);
    }
    MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
    ncclCommInitRank(&comm, 2, id, rank);

    int size = 100000;
    ncclDataType_t dataType = ncclBfloat16;
    size_t elemSize = sizeof(uint16_t);

    uint16_t* d_data;
    cudaMalloc(&d_data, size * elemSize);
    cudaMemset(d_data, 0, size * elemSize);

    ncclAllReduce(d_data, d_data, size, dataType, ncclSum, comm, cudaStreamDefault);
    ncclAllReduce(d_data, d_data, size, dataType, ncclSum, comm, cudaStreamDefault);


    cudaStream_t stream;
    cudaStreamCreate(&stream);

    // Create CUDA graph
    cudaGraph_t graph;
    cudaGraphCreate(&graph, 0);

    cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
    ncclAllReduce(d_data, d_data, size, dataType, ncclSum, comm, stream);
    cudaStreamEndCapture(stream, &graph);

    cudaGraphExec_t graphExec;
    cudaGraphInstantiate(&graphExec, graph, NULL, NULL, 0);


    for (int i = 0; i < 10000; ++i) {
        if (i % 100 == 0 && rank == 0) {
            std::cout << i << std::endl;
        }
        size = 49000;
        cudaMemset(d_data, 0, size * elemSize);
        ncclAllReduce(d_data, d_data, size, dataType, ncclSum, comm, cudaStreamDefault);
        cudaStreamSynchronize(cudaStreamDefault);
    }

    cudaGraphExecDestroy(graphExec);
    cudaGraphDestroy(graph);
    cudaFree(d_data);
    ncclCommDestroy(comm);
}

int main(int argc, char* argv[]) {
    int rank, world_size;
    MPI_Init(&argc, &argv);
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    MPI_Comm_size(MPI_COMM_WORLD, &world_size);

    test(rank);

    MPI_Finalize();
    return 0;
}

nvcc -o nccl_test repro.cc -lnccl -lmpi
mpirun -np 2 ./nccl_test

@samsamoa
Copy link
Author

resolved by this commit (i assume will be added to master soon) ee3d92b

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants