Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NCCL][CUDA] Optionally avoid rethrowing CUDA Errors in NCCL Watchdog #126587

Closed
wants to merge 3 commits into from

Conversation

eqy
Copy link
Collaborator

@eqy eqy commented May 17, 2024

Doesn't affect current behavior by default, for #126544
I'm not sure what the exact mechanism is here but CUDA errors appear to already be thrown in the main process, meaning that the watchdog is separately throwing CUDA errors again. However this rethrown error causes the process to be terminated as it cannot be handled from user code (which doesn't have visibility of the watchdog thread).

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

@eqy eqy requested review from wconstab and kwen2501 May 17, 2024 23:30
Copy link

pytorch-bot bot commented May 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126587

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 6a56cf2 with merge base 796dff7 (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels May 17, 2024
@drisspg drisspg added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 20, 2024
@kwen2501
Copy link
Contributor

kwen2501 commented May 21, 2024

Doesn't affect current behavior by default, for #126544

Do you mean that this PR is a fix for #126544?
If so, a question I have is how it can avoid the process termination mentioned in #126544.
Would appreciate your comment.

@wconstab
Copy link
Contributor

can you explain the mechanism for throwing the cuda errors in the main thread?

is it because any current cuda error on any stream/kernel will cause any future cpu synchronization call to report the error in the current cpu thread? If so, then we could argue that the watchdog does not need to rethrow cuda errors because users will discover them unless the usercode has stopped issuing new cuda work. (that's still technically a gap, but probably not an important one?)

watchDogException_ =
std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exitMsg));
std::rethrow_exception(watchDogException_);
if (C10_LIKELY(rethrowCUDAErrors_) ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of defining a new ENV to control the error handling to make things more complicated, could we re-use existing ones, e.g, calling SHOULD_TEAR_DOWN(asyncErrorHandling_) here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be expected that SHOULD_TEAR_DOWN(asyncErrorHandling_) would not rethrow CUDA errors? If so we could consider repurposing that for this case as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think asyncErrorHandling_ is supposed to handle any errors, including CUDA errors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to check is whether there are nccl errors that we need to handle still, bc they are not raised in the main thread. Or if all nccl errors would be raised the same way as cuda errors, it's simpler logic.

@eqy
Copy link
Collaborator Author

eqy commented May 21, 2024

@kwen2501 Yes it does fix the repro in #126544
It avoids process termination by caching the CUDA error (which the repro tries to do but is unable to if the watchdog rethrows the exception).

@wconstab That matches with the observed behavior (CUDA error is still visible even if the watchdog does not rethrow it) but I'm not sure if that's the exact mechanism here. Will check if other PyTorch-NV folks have an explanation for this.

Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Giving an approval per my reasoning in the original issue: #126544 (comment)

It seems to me that this is the solution.

@kwen2501
Copy link
Contributor

FYI @wconstab @eqy @shuqiangzhang
What happened has nothing to do with handleException(), see #126544 (comment).
What killed the process is this line:

std::rethrow_exception(watchDogException_);

Thus, I think we'd not need to spend time discussing TORCH_NCCL_ASYNC_ERROR_HANDLING here.

@kwen2501
Copy link
Contributor

Thus, I think we'd not need to spend time discussing TORCH_NCCL_ASYNC_ERROR_HANDLING here.

Edit: unless, we decide to pull TORCH_NCCL_ASYNC_ERROR_HANDLING out of handleException() and put it at a higher level, like under ProcessGroupNCCL::ncclCommWatchdog(), and funnel all exceptions there.

@eqy
Copy link
Collaborator Author

eqy commented May 28, 2024

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 28, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Aidyn-A pushed a commit to tinglvv/pytorch that referenced this pull request May 30, 2024
…pytorch#126587)

Doesn't affect current behavior by default, for pytorch#126544
I'm not sure what the exact mechanism is here but CUDA errors appear to already be thrown in the main process, meaning that the watchdog is separately throwing CUDA errors again. However this rethrown error causes the process to be terminated as it cannot be handled from user code (which doesn't have visibility of the watchdog thread).

Pull Request resolved: pytorch#126587
Approved by: https://github.com/kwen2501
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (c10d) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants