-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Checkify exception only reports error from one device (not all) #21246
Comments
An additional clarification. The code looks essentially like the following:
|
Thanks for the report! This code is a few years old now and the author is no longer working on the JAX project. I took a look and I found that where the extra errors are removed is in the Lines 445 to 457 in a820387
If I turn this function into an identity: def _reduce_any_error(error: Error):
return error then I'm not sure of the other implications of that, but if someone wants to explore this further, then understanding the intent of that helper function is probably where to start. |
Thanks, Jake! @sharadmv and I worked closely with Lena on checkify, so I think we can debug this. I think this was the intended behavior at one point; indeed that's what @billmark for prioritization purposes: is this blocking your work in some way? Or just a preference you wanted to surface? |
It is not blocking my work any more, but I think it is critical to address this issue for other users, either by a code change or a documentation change. In the meantime, to help others... The idiom implicitly recommended by the jax documentation is the following:
That idiom doesn't work properly -- it ignores errors from all but one device. Instead, I use a variant of the following:
|
For what it's worth, I think the current behavior is defensible: e.g. if you have 64 shards that all error, it's not terrible to only see one copy of the error in the traceback. |
I respectfully disagree. The current behavior is terrible, and just caused me to waste an enormous amount of time. When trying to track down a NaN with checkify's "float" check, one device typically has the "original" error, but all devices have a NaN error, since the NaN's later propagate via collective operations. I was seeing the error from the collective operation without realizing that there was a "hidden" error from another device that was the original cause of the NaN. At a bare minimum, the exception needs to state that it is reporting e.g. only one out of 64 errors and that the other errors may be different. The checkify documentation here would also need to discuss this case. I'll call this solution A. The best solution ("solution B") would be to de-dup the errors, so that the message would say something like: 64 errors on 64 devices, of two different types. Error #1 (devices 0, 2, 3): XXX. Error #2 (device 1): YYY. Solution "C" is to just dump all 64 of the errors. Solution "D", currently implemented, is to arbitrarily choose one of the errors to display, without any indication that others are being suppressed. This solution is terrible, particularly since it is not documented. I would rank the solutions as follows: |
To clarify my example above: All devices had a NaN error, but there were two different source-code locations for the error. This is critical information! |
Description
When running with four hosts and four devices on each host, I see an "errs" returned by pmap of checkify that looks like the folllowing:
However, an
errs.throw()
(as recommended in JAX docs) only shows one of these four errors:I consider this behavior to be a bug. No reasonable person would expect the exception string to omit the errors from three our of four devices on that host. The exception string should contain all four errors.
System info (python version, jaxlib version, accelerator, etc.)
HEAD at google as of May 15, 2024. Running on TPU. (Four hosts, four devices per host).
The text was updated successfully, but these errors were encountered: