-
Notifications
You must be signed in to change notification settings - Fork 439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Huber loss #1444
Implement Huber loss #1444
Conversation
Instead of using a sign or abs function, uses clamping to compute it outside the bounds. This is better for the autodiff backend.
Submitted a PR for |
Note: I think the method of |
CI failed due to:
Rerunning to see if it's a fluke. Tagging @nathanielsimard and @louisfd since they're currently working in this area. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1444 +/- ##
==========================================
+ Coverage 85.81% 85.97% +0.15%
==========================================
Files 610 646 +36
Lines 70417 71847 +1430
==========================================
+ Hits 60428 61769 +1341
- Misses 9989 10078 +89 ☔ View full report in Codecov by Sentry. |
Sign tensor op PR is merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM @louisfd for further review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but can you clarify the math in comments? I think it works out but it's a bit confusing, in particular I think r
, err
and res
are three names for the same thing?
Ah yes, I initially wanted to use "error" for the difference between targets and predictions, then remembered the better term residuals and I guess didn't catch a few things when renaming. Will fix. |
Thanks, we can merge once CI passes |
Checklist
run-checks all
script has been executed.Related Issues/PRs
Closes #1441 as I think the remaining feature request for a sign function is already tracked in #522.
Changes
Implements the Huber Loss function.
Instead of strictly following the definition of using a sign or abs function, the implementation uses clamping, which computes the same value outside the
delta
bounds but is better behaved on the autodiff backend and does not need any extra primitive ops. See also #1441 for my first attempt of implementing this.Testing
Test data should cover all relevant branches of the operation, and critical points on the autodiff backend, i.e. zero residuals and the point where the loss switches between the branches. Test assertions have been generated from executing the equivalent in scipy.
Note: the
test_downsample_interpolation
test innearest_interpolate.rs
is failing locally for me. Not caused by the patch, I've ignored it when running run-checks.