Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NCCL Hang with CUDA_LAUNCH_BLOCKING=1 #750

Closed
xw285cornell opened this issue Nov 15, 2022 · 21 comments
Closed

NCCL Hang with CUDA_LAUNCH_BLOCKING=1 #750

xw285cornell opened this issue Nov 15, 2022 · 21 comments

Comments

@xw285cornell
Copy link

It seems to be a regression in v2.13 (2.10 should work) that when I run a simple MLP model (with all reduce) will hang when CUDA_LAUNCH_BLOCKING is present. The hang happens on all ranks with 100% GPU spinning on something. The backstack from gdb is this. It looks this ncclStrongStreamLaunchKernel is new in v2.13, so wondering if this is a new regression

#0 0x00007fec9b630c50 in ?? () from lib/libcuda.so.1
#1 0x00007fec9b49dcf7 in ?? () from lib/libcuda.so.1
#2 0x00007fec9b3f8b9f in ?? () from lib/libcuda.so.1
#3 0x00007fec9b3fa1b8 in ?? () from lib/libcuda.so.1
#4 0x00007fec9b3e301f in ?? () from /lib/libcuda.so.1
#5 0x00007fec9b636830 in ?? () from lib/libcuda.so.1
#6 0x00007fec9b370bd6 in ?? () from lib/libcuda.so.1
#7 0x00007fec9b372bc2 in ?? () from libcuda.so.1
#8 0x00007fec9b40d4d5 in ?? () from lib/libcuda.so.1
#9 0x00007fec9ae1405c in ?? () from /lib/libcudart.so.11.4.108
#10 0x00007fec9ae67f06 in cudaLaunchKernel () from /lib/libcudart.so.11.4.108
#11 0x00007fecdf7d6cf9 in ncclStrongStreamLaunchKernel(ncclCudaGraph, ncclStrongStream*, void*, dim3, dim3, void**, unsigned long) ()
#12 0x00007fecdf7a56c8 in ncclLaunchKernel(ncclComm*, ncclKernelPlan*) ()
#13 0x00007fecdf7c494d in ncclGroupEndInternal() ()
#14 0x00007fecdf7c5301 in ncclGroupEnd () from
#15 0x00007fecdf794f4c in torch::cuda::nccl::AutoNcclGroup::~AutoNcclGroup() ()
#16 0x00007fecdf725655 in c10d::ProcessGroupNCCL::allgather(std::vector<std::vector<at::Tensor, std::allocatorat::Tensor >, std::allocator<std::vector<at::Tensor, std::allocatorat::Tensor > > >&, std::vector<at::Tensor, std::allocatorat::Tensor >&, c10d::AllgatherOptions const&) ()
#17 0x00007feccf159640 in c10d::(anonymous namespace)::allgather_(std::vector<std::vector<at::Tensor, std::allocatorat::Tensor >, std::allocator<std::vector<at::Tensor, std::allocatorat::Tensor > > > const&, std::vector<at::Tensor, std::allocatorat::Tensor > const&, c10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_target_default_null_typec10d::ProcessGroup > const&, long) ()
#18 0x00007feccf163e8c in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<std::tuple<std::vector<std::vector<at::Tensor, std::allocatorat::Tensor >, std::allocator<std::vector<at::Tensor, std::allocatorat::Tensor > > >, c10::intrusive_ptr<c10d::Work, c10::detail::intrusive_target_default_null_typec10d::Work > > ()(std::vector<std::vector<at::Tensor, std::allocatorat::Tensor >, std::allocator<std::vector<at::Tensor, std::allocatorat::Tensor > > > const&, std::vector<at::Tensor, std::allocatorat::Tensor > const&, c10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_target_default_null_typec10d::ProcessGroup > const&, long), std::tuple<std::vector<std::vector<at::Tensor, std::allocatorat::Tensor >, std::allocator<std::vector<at::Tensor, std::allocatorat::Tensor > > >, c10::intrusive_ptr<c10d::Work, c10::detail::intrusive_target_default_null_typec10d::Work > >, c10::guts::typelist::typelist<std::vector<std::vector<at::Tensor, std::allocatorat::Tensor >, std::allocator<std::vector<at::Tensor, std::allocatorat::Tensor > > > const&, std::vector<at::Tensor, std::allocatorat::Tensor > const&, c10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_target_default_null_typec10d::ProcessGroup > const&, long> >, std::tuple<std::vector<std::vector<at::Tensor, std::allocatorat::Tensor >, std::allocator<std::vector<at::Tensor, std::allocatorat::Tensor > > >, c10::intrusive_ptr<c10d::Work, c10::detail::intrusive_target_default_null_typec10d::Work > > (std::vector<std::vector<at::Tensor, std::allocatorat::Tensor >, std::allocator<std::vector<at::Tensor, std::allocatorat::Tensor > > > const&, std::vector<at::Tensor, std::allocatorat::Tensor > const&, c10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_target_default_null_typec10d::ProcessGroup > const&, long)>::call(c10::OperatorKernel, c10::DispatchKeySet, std::vector<std::vector<at::Tensor, std::allocatorat::Tensor >, std::allocator<std::vector<at::Tensor, std::allocatorat::Tensor > > > const&, std::vector<at::Tensor, std::allocatorat::Tensor > const&, c10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_target_default_null_typec10d::ProcessGroup > const&, long) ()

@xw285cornell
Copy link
Author

cc. @kwen2501

@xw285cornell
Copy link
Author

Confirmed working with 2.10 + CUDA_LAUNCH_BLOCKING. ALso confirmed 2.13 + CUDA_LAUNCH_BLOCKING + NCCL_BLOCKING_WAIT doesnt' work

@kwen2501
Copy link
Contributor

kwen2501 commented Dec 1, 2022

I just notice this issue. Thanks @xw285cornell for reporting it.

@sjeaugey This may seem like a corner case but:

CUDA_LAUNCH_BLOCKING is often used for debugging other CUDA (compute) kernels. For example, if a CUDA error is spotted in the middle of a run, PyTorch would print a message when the CPU catches it, like below:

terminate called after throwing an instance of 'c10::Error'
what(): CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

If CUDA_LAUNCH_BLOCKING=1 does not work with NCCL, then the debug (for other kernels) could not proceed.

You can ignore NCCL_BLOCKING_WAIT though -- it is a PyTorch flag (with bad naming) rather than a CUDA flag. But if CUDA_LAUNCH_BLOCKING=1 works, NCCL_BLOCKING_WAIT should work too.

@sjeaugey
Copy link
Member

sjeaugey commented Dec 1, 2022

We're looking into the regression, but in general I'm not sure what we can guarantee with respect to CUDA_LAUNCH_BLOCKING.

In particular I don't see how it could work in the case of one process managing multiple GPUs and having to launch a kernel on each GPU: the first one would block and we'd never be able to launch the second one. Basically, given our kernels block waiting for other GPUs, as soon as some call related to another GPU blocks, we have a deadlock.

@kwen2501
Copy link
Contributor

kwen2501 commented Dec 1, 2022

We (PyTorch) are dropping one-thread multi-GPU support in c10d. So no need to support CUDA_LAUNCH_BLOCKING in that case indeed.

@xw285cornell
Copy link
Author

@sjeaugey yeah if you can guarantee CUDA_LAUNCH_BLOCKING for single GPU per process that'll provide a huge value to us already

@jbachan
Copy link
Collaborator

jbachan commented Dec 6, 2022

In 2.13 we moved away from cudaLaunchCooperativeKernelMultiDevice as it is being deprecated and have moved to a loop around cudaLaunchKernel. This could explain the hang if there are multiple devices per process. It isn't clear from the initial report, are we dealing with multi device per process?

@xw285cornell
Copy link
Author

xw285cornell commented Dec 6, 2022

@jbachan no, it's one-device (GPU) per process + multiple processes. The pretty standard pytorch DistributedDataParallel. as @kwen2501 mentioned we're moving away from multi-device per process in c10d. I wonder if you can try to reproduce on your side.

@jbachan
Copy link
Collaborator

jbachan commented Dec 6, 2022

Can you verify if the hang goes away on single-node but still multi-gpu and most importantly: no network. I believe the hang is due to our not pushing work to our proxy thread (manages nic traffic) until after cudaLaunchKernel returns which creates a circular dependency between the gpu and network progress.

@xw285cornell
Copy link
Author

@jbachan how do we disable network? NCCL_SOCKET_IFNAME=lo?

@AddyLaddy
Copy link
Collaborator

Just running a multi-process, multi-GPU job on a single node should not invoke the network proxies

@xw285cornell
Copy link
Author

Thanks @jbachan @AddyLaddy! We have done the experiments and confirmed what you said:

single host:

no env variables, passes
LAUNCH_BLOCKING one GPU, passes
LAUNCH_BLOCKING 8 GPU, passes

two hosts:
no envs, passes
LAUNCH_BLOCKING, one GPU per host: gets stuck as before

NCCL_P2P_DISABLE=1 + one host: stuck

@xw285cornell
Copy link
Author

@jbachan @AddyLaddy @sjeaugey any update for this?

@stas00
Copy link

stas00 commented Dec 27, 2022

Whoah! I was just directed here after I was trying to get help for a crashing setup that I was trying to debug with CUDA_LAUNCH_BLOCKING=1.

So how do we debug missing tracebacks now? That was the whole point of CUDA_LAUNCH_BLOCKING=1 - this is a huge huge huge problem for complex large code bases. I hope you now have an alternative solution to getting the traceback of the real problem. (1 gpu per process)

thank you!

@xw285cornell
Copy link
Author

The easiest thing to do for now might be to revert back to 2.10 or before (not sure about 2.11-2.12), until this issue is fixed

@jbachan
Copy link
Collaborator

jbachan commented Dec 28, 2022

Thank you @xw285cornell that clearly implicates the kernel launch before proxy thread launch as the culprit. I plan to make a fix available early January.

@jbachan
Copy link
Collaborator

jbachan commented Jan 20, 2023

@xw285cornell please try PR #774

@xw285cornell
Copy link
Author

Thanks @jbachan ! Let me try it out!

@spotluri
Copy link
Collaborator

@xw285cornell were you able to confirm that PR #744 helps? Thanks.

@xw285cornell
Copy link
Author

@xw285cornell were you able to confirm that PR #744 helps? Thanks.

@spotluri tried and confirmed it works!

@spotluri
Copy link
Collaborator

closing as fixed.
2.17.1 has fixes required to support "CUDA_LAUNCH_BLOCKING for single GPU per process"
confirmed in #750 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants