Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues when using multiple communicators in one job distributedly. #195

Open
lowintelligence opened this issue Mar 15, 2019 · 8 comments
Open

Comments

@lowintelligence
Copy link
Contributor

lowintelligence commented Mar 15, 2019

Hi everyone, I'm an engineer in Alibaba PAI team. In recent days I'm trying to enabling distributed NCCL support in our internal TF based training framework.

Although I noticed TF community had committed a work with collective strategy and NCCL manager based implementation, I attempted another way, which simply building different rings with different ncclUniqueId (which means using more ports). That means, in a single process, it might emit multiple InitCommRank and AllReduce operation on same device by different rings concurrently. Since they are using different ports, CPU threads and GPU streams, I don't need to manage the order of the operations like TF collective or Horovod implementations.

It maybe looks unusual, I've almost finished my work however. During the whole process, I've meet some issues and list them here. For some of them I would give out my solution by propose PRs, but for others I could only give my workarounds.

1. Hanging between CommA's CommInitRank and CommB's Allreduce.
It should be caused by the implicit Stream/Device sync, which blocks the Allreduce kernel issue for launch. I found two scenarios in my cases.
image

  1. cudaMemcpy blocking in CommInitRankSync phases, i.e. devCommSetup().
    Here cudaMemcpy uses default stream and will wait a implicit sync of other threads. At this time a certain Allreduce kernel is probably working. But in remote side, the corresponding remote Allreduce kernel might be blocked by other cudaMemcpys. Then the program hangs.
    The impact of this sync could be dismissed by using Non-blocking cuda stream for Allreduce kernels, but in TF stream create function, the cudaStream create parameter is fixed set as 0. Thus in my work, I just use '--default-stream per-thread' to make a workaround.
  2. cudaFree blocking in ncclIbGdrSupport().
    This function is used to detect GDR support during Init phase, while 'cudaFree' seems has a very strong implicit sync manner even if we use non-blocking stream or 'per-thread' compilation. Since it only use 4-bytes memory, I tried to give a solution that only keep the space in global scope and re-use the memory by using mutex locks during a single job run. I sent out a PR for this change, here.

2. crash during shared memory creation (like #185)
Although several days ago there is a related commit, it only fix the case with single communicator ring. When controlling multiple communicator in one process, the shared memory key of CommA and CommB would collide each other because they have same 'pidHash', 'hostHash' and ranks. To solve this, I tried to add a commHash during the Init phase. And the PR is here.

3. Hang in GROUP launch mode
My use cases would be recognized as GROUP launch mode when running multiple communicators concurrently, though CPU threads and cudaStreams of each Comm and Rank are all different with each other. In this mode, the kernel launch would be enqueued as a GROUP launching manner. But it won't work, and job hanged there. I didn't make further investigation and simply set 'NCCL_LAUNCH_MODE=PARALLEL' as a workaround.

4. Hang in NET communication between 2 non-NVLink GPUs within one nodes
When building a ring across 2 same node GPUs w/o NVLink and disable shared memory using, NCCL would try to use NET interface to connect them, but actually it would hang in both Socket or IB transportation. No debug information and stack tracing could be observed. I think something is wrong in the implementation of this case. However, since I have fixed the shared memory collision and make SHM usable, this issue won't bother me at this time, but it is still there.
image

@sjeaugey
Copy link
Member

sjeaugey commented Mar 15, 2019

Thanks for this detailed explanation of your experiments !

« Since they are using different ports, CPU threads and GPU streams, I don't need to manage the order of the operations like TF collective or Horovod implementations. »
Unfortunately, that is not true. There is no safe way to use NCCL on two communicators without a global lock and careful synchronization. Even if you manage to create the two communicators and even if you use two asynchronous streams, there is no guarantee that the two NCCL operations will progress concurrently and not cause a hang. That's why most frameworks and horovod pipeline all operations to a side module which takes care of ordering the operations and ensure there is no deadlock.

Right now the only safe way of using two NCCL communicators is to not use them simultaneously, i.e. make sure they are use in different epochs of your application. For example, use communicator A during the forward/backward phase of your DL training, then use communicator B during the gradient update, and make sure all operations on one communicator are complete before launching an operation on the other one.

More on the points above :

  1. Eliminating cudaFree is not enough. We should eliminate all CUDA calls since all calls might block in some situations. Obviously we do not want to do that, we need at least some CUDA calls.
  2. Makes sense, we need to have different shared mem segments for different communicators.
  3. The GROUP mode is actually the only way to launch a kernel on multiple GPUs with no chance of a cudaFree or other CUDA call causing a sync in the middle, hence a hang. Disabling it, you might experience other types of hangs.
  4. The message WARN Could not create rings is a sign that something is not working properly indeed, but it might indeed be because you disabled SHM.

@lowintelligence
Copy link
Contributor Author

lowintelligence commented Mar 18, 2019

Thank you for the quick response, Sylvain.

In the beginning, I also noticed some discussion with the same viewpoint you mentioned. But After I did some experiments on my platform, I found that multiple communicators looked working well even they were running together. That's why I did my work with NCCL in this way. At lease during passed couple of days, it didn't get hanging during the gradients all reduce phase in all tests. :(

However, after getting your reply, I repeated my experiments with an extremely high pressure, which repeatedly doing all reduce on 16 cards with 128 communicator/uniqueIds, each communicator uses a single thread/stream, a thread sync is called after every all reduce call. It finally got hanged after several rounds and proved your words. 👍

So here may I ask for more detail about the reason of the reduce kernel get hanging in this all reduce only case? Is it means some kernels will be stuck because of too many reduce kernels scheduling or queueing? Or is there any limits in hardware level make this concurrent kernel running case not reliable? (at this time I merely focus on 8xV100 w/NVLink multi-node cluster) Hope to get more deeper knowledges here. Thx. :)

Additionally, since V100 has 80 SMs, if I set the limit of channels as 1 (thus only 1 grid in the kernel), and use less than 80 communicators, sync after each round communication, should all AllReduce kernels from different Comms be scheduled and launched?

@pritamdamania87
Copy link

Unfortunately, that is not true. There is no safe way to use NCCL on two communicators without a global lock and careful synchronization. Even if you manage to create the two communicators and even if you use two asynchronous streams, there is no guarantee that the two NCCL operations will progress concurrently and not cause a hang. That's why most frameworks and horovod pipeline all operations to a side module which takes care of ordering the operations and ensure there is no deadlock.

Right now the only safe way of using two NCCL communicators is to not use them simultaneously, i.e. make sure they are use in different epochs of your application. For example, use communicator A during the forward/backward phase of your DL training, then use communicator B during the gradient update, and make sure all operations on one communicator are complete before launching an operation on the other one.

@sjeaugey Thanks for your detailed explanation above. I had a few clarifying questions to make sure I completely understood what is safe and what is not. The use case I have is that we have 4 GPUs with a process running on each GPU and there is one communicator which does an allreduce across all 4 GPUs and then there are two other communicators, one does allreduce between GPU0 and GPU2 and the other between GPU1 and GPU3. Based on this setup, I had the following questions:

  1. Is it safe to run these 3 allreduce operations concurrently as long as the launch order of the allreduce ops are same across all processes?
  2. Is it safe to launch concurrent allreduce on communicators with different GPUs? For example, lets say we launch the allreduce on all 4 GPUs and wait for it to complete. Then we launch the allreduce between GPU0 and GPU2 and GPU1 and GPU3 concurrently. Would this be safe since the GPUs used are distinct?
  3. Or is the only option we have is that we need to wait for each allreduce to finish and only then launch the subsequent allreduce?

@sjeaugey
Copy link
Member

  1. No, unfortunately it is not. To be safe you need to have a stream dependency so that the order they execute is the order they were submitted in.
  2. Yes, that is safe.
  3. You can have the 4 GPU allreduce, then 2x2 in one dimension, then 2x2 in the other dimension, for a total of 3 steps.

@pritamdamania87
Copy link

pritamdamania87 commented Sep 24, 2021

You can have the 4 GPU allreduce, then 2x2 in one dimension, then 2x2 in the other dimension, for a total of 3 steps.

Thanks for the clarification @sjeaugey! I didn't quite understand this fully, what does "2x2 in one dimension" refer to? Do you mean do the 4 GPU allreduce, then do the GPU0 <-> GPU2 allreduce, then do the GPU1 <-> GPU3 allreduce for a total of 3 steps where we wait for each allreduce to complete before proceeding? If so, isn't option 2 in my comment above more efficient since it is only 2 steps where GPU0 <-> GPU2 allreduce and GPU1 <-> GPU3 allreduce can proceed in parallel?

@sjeaugey
Copy link
Member

Sorry, I misunderstood. Yes you can do the 4 GPU allreduce, then the 0<->2 and 1<->3 concurrently as a second step -- just two steps.

I misread and thought you had a third 0<->1 and 2<->3 step in the orthogonal dimension. Sorry about that.

@manauref
Copy link

Dear @sjeaugey
regarding your answer 1. above:
"No, unfortunately it is not. To be safe you need to have a stream dependency so that the order they execute is the order they were submitted in."
in the context of the example posed by @pritamdamania87, there are 3 communicators

  • comm0 for allreduce GPU0, GPU1, GPU2, GPU3
  • comm1 for allreduce GPU0, GPU2
  • comm2 for allreduce GPU1, GPU3

You say that launching the comm0 allreduce, waiting for it to complete, and then launching the comm1 and comm2 allreduce's would be safe. I have 2 questions:

  1. Do comm1 and comm2 allreduce's have to be launched on distinct streams?
  2. In your answer you said that to be safe you have to have stream dependency. Does this mean that it'd be safe to launch all three allreduce using distinct comms but a single unique stream? If not: do you mean launching all three comms to separate streams but indicating dependency between them? is that done through graphs or is there a more basic way?

@sjeaugey
Copy link
Member

sjeaugey commented Oct 14, 2022

To avoid deadlocks, you need to make sure operations start executing on the GPU in the same order across ranks.
To guarantee that, currently the only way is to make sure the stream dependencies enforce the serialization of operations in the same order across all ranks.
The easiest is to post all 3 operations on the same stream. That way they will naturally be ordered. If you want to use 3 different streams (because that's how your code works and you need ordering of kernels producing data and consuming data) then you need to use cuda events to enforce the ordering between streams.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants