-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
small fix in the backward rule of norm
#131
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #131 +/- ##
=======================================
Coverage 81.94% 81.94%
=======================================
Files 42 42
Lines 5666 5667 +1
=======================================
+ Hits 4643 4644 +1
Misses 1023 1023 ☔ View full report in Codecov by Sentry. |
ext/TensorKitChainRulesCoreExt.jl
Outdated
@@ -172,7 +172,9 @@ end | |||
function ChainRulesCore.rrule(::typeof(norm), a::AbstractTensorMap, p::Real=2) | |||
p == 2 || error("currently only implemented for p = 2") | |||
n = norm(a, p) | |||
norm_pullback(Δn) = NoTangent(), a * (Δn' + Δn) / (n * 2), NoTangent() | |||
function norm_pullback(Δn) | |||
return NoTangent(), a * (Δn' + Δn) / (n * 2 + eps(real(eltype(a)))), NoTangent() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you change this to a * (Δn' + Δn) / 2 / hypot(n, eps(one(n)))
? I think that is slightly nicer, in that, if n==1.
, then n+eps()
is no longer exactly one, but hypot(1.,eps())
is still exactly 1.
due to machine precision.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is indeed better. I have modified this line according to the suggestion.
Thanks; that's an important fix. I made one suggestion in the code. |
ext/TensorKitChainRulesCoreExt.jl
Outdated
@@ -172,7 +172,9 @@ end | |||
function ChainRulesCore.rrule(::typeof(norm), a::AbstractTensorMap, p::Real=2) | |||
p == 2 || error("currently only implemented for p = 2") | |||
n = norm(a, p) | |||
norm_pullback(Δn) = NoTangent(), a * (Δn' + Δn) / (n * 2), NoTangent() | |||
function norm_pullback(Δn) | |||
return NoTangent(), a * (Δn' + Δn) / (n * 2 + eps(real(eltype(a)))), NoTangent() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is indeed better. I have modified this line according to the suggestion.
Previously, the backward of
norm
will becomeNAN
if the norm is zero.