-
-
Couldn't load subscription status.
- Fork 100
Skip gradient updates when gradients contain NaN or Inf #1081
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Skip gradient updates when gradients contain NaN or Inf #1081
Conversation
- Add check before Optimisers.update to detect NaN/Inf in gradients - Skip update but still increment iteration counter when detected - Add warning message (maxlog=10) to inform users - Fixes issue where NaN gradients corrupt all subsequent updates Addresses: https://discourse.julialang.org/t/how-to-ignore-minibatches-with-nan-gradients-optimizing-a-hybrid-lux-model-using-optimization-jl/132615
- Test with custom gradient function that injects NaN periodically - Test with custom gradient function that injects Inf periodically - Verify iterations complete and parameters remain finite - Verify optimizer doesn't crash when encountering bad gradients
Use any(isnan, G) || any(isinf, G) instead of lambda function to correctly handle array elements and hierarchical structures
Use any(.!(isfinite.(G))) to properly handle arrays with broadcasting
- Add has_nan_or_inf() helper function - Uses Functors.fmap to recursively check all elements - Handles arbitrary nested structures (arrays, ComponentArrays, etc.) - Checks if any element is not finite (catches both NaN and Inf)
- Use Zygote for gradient computation - Inject NaN/Inf via callback that modifies state.grad - This better simulates real-world scenarios where autodiff produces NaN - Avoids issues with custom gradient function signatures
- Apply SciMLStyle formatting to OptimizationOptimisers.jl - Remove JuliaFormatter from runtime dependencies
- Remove Functors dependency and use simple all(isfinite, G) check - Make warning conditional on cache.progress flag - Rewrite tests to use functions that return NaN/Inf in certain regions instead of callback-based approach
- Use sqrt and max to produce NaN when x goes negative - Use 1/x pattern to produce Inf gradients - Functions naturally produce problematic gradients during optimization
| SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" | ||
| Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" | ||
| Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" | ||
| Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not be here 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ugh
yeah
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#1083 will fix it
Summary
This PR addresses the issue reported in the Julia Discourse thread about handling NaN gradients during optimization with minibatches. When training neural networks with minibatches, some batches may produce NaN or Inf gradients due to numerical instability. Previously, applying these invalid gradients would corrupt all subsequent parameter updates, preventing the optimizer from converging.
Changes
Implementation
has_nan_or_inf()helper function that usesFunctors.fmapto recursively check all gradient elementsOptimizationOptimisers.jlto skip parameter updates when NaN/Inf is detectedDependencies
Functorsas a dependency for robust NaN/Inf checking across arbitrary nested structuresTests
Fixes
Addresses: https://discourse.julialang.org/t/how-to-ignore-minibatches-with-nan-gradients-optimizing-a-hybrid-lux-model-using-optimization-jl/132615
Testing
All existing tests pass. New tests verify:
🤖 Generated with Claude Code
Co-Authored-By: Claude noreply@anthropic.com