-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix hypot
, return Wirtinger
appropriately
#55
Conversation
Bump. Can we get this merged? |
src/rules/base.jl
Outdated
|
||
function frule(::typeof(hypot), x::Real...) | ||
Ω = hypot(x...) | ||
return Ω, Rule((Δ...) -> sum(Δ .* x) * inv(Ω)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return Ω, Rule((Δ...) -> sum(Δ .* x) * inv(Ω)) | |
return Ω, Rule((Δ...) -> sum(Δ .* x) / Ω) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is that Zero() / x
and One() / x
doesn't get overdubbed, so this throws errors if all Δ
s are Zero()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that because only multiplications and additions are overdubbed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, exactly
src/rules/base.jl
Outdated
function frule(::typeof(hypot), x...) | ||
Ω = hypot(x...) | ||
return Ω, WirtingerRule( | ||
Rule((Δ...) -> sum(Δ .* conj.(x) * inv(2Ω))), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be moved flush with the opening parenthesis with the second line aligned below it.
return Ω, WirtingerRule(Rule((Δ...) -> sum(Δ .* conj.(x) * inv(2Ω))),
Rule((Δ...) -> sum(Δ .* x) * inv(2Ω)))
Again in this case, why multiply by inv
instead of just dividing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my comment above. It might make sense to overload /
for Zero()
and One()
though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think that would probably make sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be ok if I left it for now? I'm not too experienced with Cassette and there would be some edge cases to think about, so I think this deserves its own PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had a go at it. Could maybe be done more elegantly, but works for now.
@@ -1,10 +1,27 @@ | |||
_isapprox(x, y; kwargs...) = isapprox(x, y; kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comments as in the other PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll delete it once #64 is merged
Turns out the derivative for
hypot
was wrong, even in the real case. Now supports an arbitrary number of real or complex arguments. Also in the process discovered that therrule
for*
was wrong for complex arguments.Still needs tests.