-
Notifications
You must be signed in to change notification settings - Fork 760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
does NCCL all reduce on two streams block each other? #217
Comments
Without the per-thread default stream option, the default stream is a special stream which implicitly synchronizes with all other streams on the device. Please refer to "The Default Stream" section here:
|
Also, are you sure all the ncclAllReduce calls -- when those on the default stream and those on the non-blocking stream are viewed as a whole -- are in the same order across all the processes? You only mentioned that "the reduce order on each stream is fixed." The inter-stream order is also important here. |
Also please refer to my comment here : #195 (comment). There is no guarantee that two NCCL operations will be able to make progress concurrently. So you need to make sure things will still work even if one blocks the other. |
Thanks for the reply. @kwen2501 I thought the non-blocking stream will be an exception here, right? the More Tips section in your link:
And the inter-stream order is not the same, since the two steams not blocking each other, I thought they would make progress independently. @sjeaugey that helps, though still confused why does it not work. Could you elaborate more about what the no guarantee means?
|
You are right. I missed that the other stream is created with the In that case (the two streams are non-blocking to each other), CUDA makes no guarantee about the order of execution of operations issued to those independent streams. Then, even if all the NCCL calls (including both streams) are coded in the same order on all the processes, there still cannot be guarantee that there is no hang. For example, GPU 0 can first launches AllReduce_a on stream_a, and then finds that there is no more free compute resource on the device to launch AllReduce_b on stream_b; whereas, GPU 1 somehow launches AllReduce_b first, and finds no free resource to launch AllReduce_a. Then there may be a deadlock situation where the two GPUs are waiting for each other to launch the operation they first launch. |
Well, I understand the case you given in the example, but that depends on the AllReduce would occupy the whole GPU resources, which is quite odd to me, how many resources do NCCL kernel need? BTW, what do you mean by saying the following in #208 (comment)
|
The number of blocks NCCL launches depends on the platform and the bandwidth we're trying to achieve. Currently, it is 1-16 blocks of 64-256 threads. I agree that if there are enough resources on the GPU to have two NCCL operations execute concurrently, one would expect things to work, but the CUDA programming model does not guarantee that it will work, which means it is not supported, might stop working at any moment, or could work on only some GPU types. |
You can see how many rings NCCL creates by setting
I was just giving an extreme example that it may not take many other types of kernels to occupy a GPU. One can also launch many NCCL operations (on concurrent streams) to achieve that effect. |
Thank you for the detailed explanation. |
My setting is that I use 16 GPUs on 2 nodes each with 8 GPUs, 16 processes with 1 GPU and two streams each (one default stream, and one non-blocking stream), I do all reduce on the two streams concurrently, the reduce order on each stream is fixed among all the processes, but may overlap between the two streams. I expected that NCCL kernels won't take much GPU resources and can work concurrently, however I observed hangs on this setting and ncclAllReduce would block. Is that normal?
The text was updated successfully, but these errors were encountered: