Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can two AllReduce operations run concurrently using different communicators and different streams? #315

Closed
zpcalan opened this issue Mar 30, 2020 · 10 comments

Comments

@zpcalan
Copy link

zpcalan commented Mar 30, 2020

Hi developers of NCCL, lately I've been working with ncclAllReduce operation.
I found out two AllReduce operations can't run concurrently even though I pass different communicators and different streams.

I've read the similar issues like #217 #195 . But none of them is exactly the same as my senario.
So here's my code(I did a little modfication of your example on your website):

MPI_Init(&argc, &argv);
MPI_Comm_rank(MPI_COMM_WORLD, &myRank);
MPI_Comm_size(MPI_COMM_WORLD, &nRanks);

cudaSetDevice(myRank);
cudaStream_t stream_main = nullptr;
cudaStreamCreateWithFlags(&stream_main, cudaStreamNonBlocking);
cudaStream_t stream_ar1 = nullptr;
cudaStreamCreateWithFlags(&stream_ar1, cudaStreamNonBlocking);
cudaStream_t stream_ar2 = nullptr;
cudaStreamCreateWithFlags(&stream_ar2, cudaStreamNonBlocking);

ncclUniqueId id1;
ncclComm_t comm1;
if (myRank == 0) ncclGetUniqueId(&id1);
MPI_Bcast((void *)&id1, sizeof(id1), MPI_BYTE, 0, MPI_COMM_WORLD);
ncclCommInitRank(&comm1, nRanks, id1, myRank);
ncclUniqueId id2;
ncclComm_t comm2;
if (myRank == 0) ncclGetUniqueId(&id2);
MPI_Bcast((void *)&id2, sizeof(id2), MPI_BYTE, 0, MPI_COMM_WORLD);
ncclCommInitRank(&comm2, nRanks, id2, myRank);

 size_t size = 128*1024*1024;
cudaMalloc(&buffer1, size*sizeof(size_t));
cudaMalloc(&buffer2, size*sizeof(size_t));

//Some of my cuda kernels using stream_main.
cudaEvent_t event1 = nullptr;
cudaEventCreate(&event1);
cudaEventRecord(event1, stream_main);
cudaStreamWaitEvent(stream_ar1, event1, 0);
//The first AllReduce. 
ncclAllReduce((const void*)buffer1, (void*)buffer1, size, ncclUint64, ncclSum, comm1, stream_ar1);
//Some of my cuda kernels using stream_main.
cudaEvent_t event2 = nullptr;
cudaEventCreate(&event2);
cudaEventRecord(event2, stream_main);
cudaStreamWaitEvent(stream_ar2, event2, 0);
//The second AllReduce. I thought this could excute concurrently with the first AllReduce.
ncclAllReduce((const void*)buffer2, (void*)buffer2, size, ncclUint64, ncclSum, comm1, stream_ar2);

//To ensure the execution order is what I want, I wait until two AllReduce are both finished by cuda event.
cudaEvent_t event3 = NULL;
cudaEventCreate(&event3);
cudaEventRecord(event3, stream_ar1);
cudaEvent_t event4 = NULL;
cudaEventCreate(&event4);
cudaEventRecord(event4, stream_ar2);

cudaStreamWaitEvent(stream_main, event3, 0);
//Some of my cuda kernels using stream_main.
cudaStreamWaitEvent(stream_main, event4, 0);
//Some of my cuda kernels using stream_main.

//Get start_time
cudaStreamSynchronize(stream_main)
//Get end_time
//print(end_time - start_time);

This is how I run the executable file: mpirun -n 8 ./file. I run 8 same processes which communicate with each other.

Each AllReduce operation processes 128MB buffer. I assume each AllReduce will cost T milliseconds. I ensure the execution order of my cuda kernels and NCCL operations in each process is identical by cudaEventCreate and cudaStreamWaitEvent.
So I think the output of print(end_time - start_time); should be quite close to T ms since my cuda kernels costs a little time and two AllReduce is running concurrently. But the result turned out to be 2*T ms.

So here's my question: Can two AllReduce operations run concurrently using different communicators and different streams in the same process? I notice lots of comments of @sjeaugey very helpful especially this one #239 (comment). But I still can't figure out why these two AllReduce operations can't run concurrently while the excution order is fixed.
Please correct me if my code has mistakes or my thinking is wrong. Thank you very much.

My NCCL+cuda version is NCCL2.4.8+cuda_10.1
operating system: ubuntu-16.04
GPU: Tesla V100-SXM2
OpenMPI version: 3.1.5

Best regard,
zpcalan

@zpcalan
Copy link
Author

zpcalan commented Mar 30, 2020

Do cudaEventRecord or cudaStreamWaitEvent have something to do with this behavior?Not running concurrently?

@kwen2501
Copy link
Contributor

Hi, you are getting 2*T likely because the two AllReduces are sharing (competing for) the same bandwidth. 128 MB is quite large a message size that each AllReduce could have consumed the full bandwidth if they run alone.

@zpcalan
Copy link
Author

zpcalan commented Mar 30, 2020

Thanks for your quick reply! @kwen2501

Well I just pick this number 128MB for testing because the size of gradients of some simple networks like ResNet50 is nearly 100MB if i'm not mistaken.

  • Do you mean slicing the whole 128MB buffer into multiple small buffers and multiple AllReduce operations would help? But could this result in the extra cost because of calling ncclAllReduce multiple times instead of twice?

And here's my GPU topology:

        GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    mlx5_0  mlx5_1  CPU Affinity
GPU0     X      NV2     NV2     NV1     NV1     SYS     SYS     SYS     NODE    NODE    0-23,48-71
GPU1    NV2      X      NV1     NV2     SYS     NV1     SYS     SYS     NODE    NODE    0-23,48-71
GPU2    NV2     NV1      X      NV1     SYS     SYS     NV2     SYS     PIX     PIX     0-23,48-71
GPU3    NV1     NV2     NV1      X      SYS     SYS     SYS     NV2     PIX     PIX     0-23,48-71
GPU4    NV1     SYS     SYS     SYS      X      NV2     NV2     NV1     SYS     SYS     24-47,72-95
GPU5    SYS     NV1     SYS     SYS     NV2      X      NV1     NV2     SYS     SYS     24-47,72-95
GPU6    SYS     SYS     NV2     SYS     NV2     NV1      X      NV1     SYS     SYS     24-47,72-95
GPU7    SYS     SYS     SYS     NV2     NV1     NV2     NV1      X      SYS     SYS     24-47,72-95
mlx5_0  NODE    NODE    PIX     PIX     SYS     SYS     SYS     SYS      X      PIX
mlx5_1  NODE    NODE    PIX     PIX     SYS     SYS     SYS     SYS     PIX      X
  • Isn't NVLink is fast enough for such big buffer communication? Or is the bottleneck communication through SYS?

@zpcalan
Copy link
Author

zpcalan commented Mar 30, 2020

Actually when I use two AllReduce the time cost is down below:
Process5 total cost time 4558us
Process7 total cost time 6762us
Process4 total cost time 3081us
Process6 total cost time 4330us
Process3 total cost time 4836us
Process2 total cost time 3450us
Process0 total cost time 2839us
Process1 total cost time 3588us

After I delete one AllReduce, the time cost is:
Process5 total cost time 2510us
Process4 total cost time 2100us
Process6 total cost time 2992us
Process2 total cost time 1096us
Process7 total cost time 3277us
Process0 total cost time 847us
Process3 total cost time 1229us
Process1 total cost time 5872us

@kwen2501
Copy link
Contributor

You have a DGX1-like system so all the communication intra-node would be via NVLinks. However, even with the large bandwidth of NVLink, 128MB is still big enough to achieve almost the peak.

I don't see how you record the start time in the code. Do you have a barrier before it (to sync all the processes)?

@zpcalan
Copy link
Author

zpcalan commented Mar 30, 2020

//Get start_time
cudaStreamSynchronize(stream_main)
//Get end_time
//print(end_time - start_time);

I record start and end time right before and after cudaStreamSynchronize.

@kwen2501
Copy link
Contributor

You need to record the start time before calling ncclAllReduce. Like this:

MPI_Barrier(); 
Record start time; 
ncclAllReduce(...);

@zpcalan
Copy link
Author

zpcalan commented Mar 30, 2020

Thanks for the tip! Now I move the code of recording time before the first AllReduce and add MPI_Barrier() before it. The time costs do look more accurate.

I think about your explanation about the badwidth so I reduce the buffer size to 40MB, but the time is still near 2T ms.
Then I reduce the buffer to 4MB and it seems that the two AllReduce is running concurrently because the time is close to T ms. I will test some other buffer size to see what will happen later. Thank you very much!

Before that I notice you mentioned that all 8 GPUs comminicate via NVLinks, which can reach up to 62GB/s according to this benchmark. But I find out the speed is not that fast according to my test. Does this mean some GPUs are not using NVLink to communicate? Like GPU0 and GPU7 via SYS as shown in the topology? I'm new to the hardware so I'm confused. It would help a lot if you explain.(^__^)

@kwen2501
Copy link
Contributor

The bandwidth reported on that web page is what we call Bus Bandwidth. You can find a difference between the definitations of Bus Bandwidth and Algorithm Bandwidth here. Simply put, when you are doing an AllReduce with a large enough number of ranks, BusBw is about 2x AlgoBw.

It looks like you are using 8x V100 GPUs, in that case, you should see a peak BusBw of ~130 GB/s. Of course, that's achieved when the message size is large enough (e.g. >= 128 MB).

If you suspect your system has a performance issue, I would recommend running the NCCL perf tests. You may also turn on NCCL_DEBUG=INFO to see what channels NCCL uses on your system.

@zpcalan
Copy link
Author

zpcalan commented Apr 10, 2020

Issue closed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants