Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concurrent initialization of communicators? #239

Closed
pietern opened this issue Jul 10, 2019 · 10 comments
Closed

Concurrent initialization of communicators? #239

pietern opened this issue Jul 10, 2019 · 10 comments

Comments

@pietern
Copy link

pietern commented Jul 10, 2019

Per pytorch/pytorch#18300 it looks like concurrent initialization of multiple NCCL communicators is not possible. The communicators are completely isolated from one another and will use independent values of ncclGetUniqueId. These values are generated on the same rank though. Is there anything in NCCL that prevents concurrent initialization?

Thanks!

@kwen2501
Copy link
Contributor

NCCL does not prevent concurrent initialization of communicators. The ncclUniqueId for those communicators must be different though. Generating those ncclUniqueId's on the same rank does not mean the ncclUniqueId's will be the same, so it should be okay.

Looking at the NCCL log in pytorch/pytorch#18300, it seems multiple threads are trying to claim they are rank 0. How many communicators does the user actually want to create? Indeed 7 as shown in the log?

v08:1707:1719 [0] NCCL INFO rank 0 nranks 2
v08:1707:1720 [0] NCCL INFO rank 0 nranks 2
v08:1707:1722 [0] NCCL INFO rank 0 nranks 2
v08:1707:1721 [0] NCCL INFO rank 0 nranks 2
v08:1707:1728 [0] NCCL INFO rank 0 nranks 2
v08:1707:1730 [0] NCCL INFO comm 0x7f4bec028280 rank 0 nranks 2
v08:1707:1735 [0] NCCL INFO comm 0x7f4bdc0014f0 rank 0 nranks 2
v08:1707:1732 [0] NCCL INFO comm 0x7f4be4000de0 rank 0 nranks 2
v08:1707:1733 [0] NCCL INFO comm 0x7f4be0000de0 rank 0 nranks 2
v08:1707:1734 [0] NCCL INFO comm 0x7f4be8000de0 rank 0 nranks 2
v08:1707:1731 [0] NCCL INFO comm 0x7f4bf0000de0 rank 0 nranks 2
v08:1707:1736 [0] NCCL INFO rank 0 nranks 2
v08:1707:1738 [0] NCCL INFO comm 0x7f4bc8000de0 rank 0 nranks 2

@pietern
Copy link
Author

pietern commented Jul 10, 2019

Thanks, Ken.

They want to create 7 communicators and the same process is rank 0 for all of them.

@kwen2501
Copy link
Contributor

Thanks @pietern for the confirmation.
I am not very sure the locking mechanism added in one of the replies in pytorch/pytorch#18300 is sufficient. What's needed is all ranks in the same communicator reading the same ncclUniqueIid.

@pietern
Copy link
Author

pietern commented Jul 10, 2019

That's taken care of by each communicator using a different prefix in the key/value store to share it with the other ranks. The locking mechanism in that reply doesn't fix the issue and I confirmed the issue exists without that patch applied.

@kwen2501
Copy link
Contributor

Hi Pieter, is there any ordering guarantee for the operations issued by those helper threads? To use multiple communicators, one needs to at least make sure collective calls are executed in the same order on different GPUs (process in this case). (That's why we don't recommend this practice in general.) NCCL comm init is a collective call too.

If PyTorch inits NCCL comms in a lazy manner, the init will be followed by a collective operation immediately. If multiple helper threads issue these two operations without ordering, it is possible that the operations mingle differently on the two GPUs (processes), thus causing a hang.

@pietern
Copy link
Author

pietern commented Jul 11, 2019

There is no ordering guarantee, only an isolation guarantee. The unique ID will always be generated once, and be propagated to the right counterparts on other processes. It's as if the communicators are created by different processes, on the same set of the GPUs.

Re: your comment on lazy initialization, even though the init is followed by a collective operation, it will be a collective against a particular communicator. We've established that this is fine, as long as they are different communicators, use different streams, etc. If the helper threads issue these operations out of order, shouldn't the fact that they use different NCCL unique IDs to initialize, ensure that they don't interfere?

@sjeaugey
Copy link
Member

« We've established that this is fine, as long as they are different communicators, use different streams, etc. »
I would disagree, as I already explained here : #195 (comment)
Any CUDA call (and NCCL functions all do CUDA calls) could block until all other NCCL kernels complete. Or said differently, there is no guarantee they will not block.

In practice, it seems the cudaEvent operations and the cudaLaunchKernel we do during ncclAllreduce do not block currently (it could change), so indeed launching multiple NCCL kernels in parallel seems to work provided you use multiple streams and it fits in the GPU. But that is not guaranteed to work.

However, the operations we do during ncclCommInit and in particular cudaMalloc effectively wait for all kernels to complete, so this is a guaranteed hang if different threads on different processes launch a ncclAllreduce while the others are still in the init phase.

@pietern
Copy link
Author

pietern commented Jul 12, 2019

I would disagree, as I already explained here : #195 (comment)
Any CUDA call (and NCCL functions all do CUDA calls) could block until all other NCCL kernels complete. Or said differently, there is no guarantee they will not block.

Thanks for the reminder @sjeaugey, I conveniently forgot about that... So if it works, good for you, but it may stop working in some future version, on different hardware (e.g. fewer SMs), etc.

However, the operations we do during ncclCommInit and in particular cudaMalloc effectively wait for all kernels to complete, so this is a guaranteed hang if different threads on different processes launch a ncclAllreduce while the others are still in the init phase.

This perfectly explains the hang described in the PyTorch issue. I'm familiar with the cudaMalloc behavior of effectively equating to device synchronization, and see how this prevents concurrent (and out of order) initialization of multiple communicators.

Thanks for walking through the underlying problem with me here.

@ckmufeng
Copy link

ckmufeng commented Feb 7, 2021

@sjeaugey @pietern
I'm still confused. Why hang happen when we use GROUP launch mode. I do some tests, cudaMalloc not wait kernels done for kernels launch by other process on the same device(on which we call cudaMalloc).
So would you explain the hang more clearly ?

@Forsworns
Copy link

So would you explain the hang more clearly ?

@ckmufeng I guess they are talking about the "implicit-synchronization"

I'm wondering if the cudaMallocAsync in Cuda 11.2 helps in this scenario. I mean, does "implicit-synchronization" still exists with cudaMallocAsync?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants