Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct chainrules for abs2, abs, conj and angle #196

Merged
merged 34 commits into from Jun 28, 2020

Conversation

MasonProtter
Copy link
Contributor

@MasonProtter MasonProtter commented May 20, 2020

Closes #195.

Pending some more thoughts in https://discourse.julialang.org/t/taking-complex-autodiff-seriously-in-chainrules/39317/49 and / or an issue here, we should consider adding something along the lines of

@scalar_rule abs2(z::Complex) (z', z)

Current state of the PR is described here: #196 (comment)

@MasonProtter
Copy link
Contributor Author

MasonProtter commented May 20, 2020

Seth Axen pointed out that the Zygote chainrules PR does this: https://github.com/FluxML/Zygote.jl/blob/bf913a2a8ed616242e2f5378fbe598b289dd550a/src/lib/number.jl#L26-L30 to get correct answers. I think this is a reasonable way to go about it rather than using Wirtinger definitions.

Copy link
Member

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

I think these functions are a case where for Complex we need to define frule and rrule explicitly instead of using @scalar_rule. The adjoint rules @oxinabox defined in that PR look right to me:

# we intentionally define these here rather than falling back on ChainRules.jl
# because ChainRules doesn't really handle nonanalytic complex functions
@adjoint abs(x::Real) = abs(x), Δ -> (real(Δ)*sign(x),)
@adjoint abs(x::Complex) = abs(x), Δ -> (real(Δ)*x/abs(x),)
@adjoint abs2(x::Number) = abs2(x), Δ -> (real(Δ)*(x + x),)

However, I'm pretty sure the frules for abs(x::Complex) and abs2(x::Complex) are

function frule((Δx,), abs2, x::ComplexF64)
    return abs2(x), 2 * (real(x) * real(Δx) + imag(x) * imag(Δx))
end

function frule((Δx,), abs, x::ComplexF64)
    Ω = abs(x)
    return Ω, (real(x) * real(Δx) + imag(x) * imag(Δx)) / Ω
end

(confirmed by FD), and I don't see a good way to generate both these frules and the rrules from a single scalar rule.

src/rulesets/Base/fastmath_able.jl Outdated Show resolved Hide resolved
src/rulesets/Base/fastmath_able.jl Outdated Show resolved Hide resolved
MasonProtter and others added 2 commits May 22, 2020 21:04
Co-authored-by: Seth Axen <seth.axen@gmail.com>
Co-authored-by: Seth Axen <seth.axen@gmail.com>
Copy link
Member

@YingboMa YingboMa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment on what it returns for the complex case?

@YingboMa
Copy link
Member

Given ChainRules.jl does not support complex rules. Maybe it is better to remove or comment out complex rules.

@willtebbutt
Copy link
Member

Given ChainRules.jl does not support complex rules. Maybe it is better to remove or comment out complex rules.

This is incorrect. We definitely support complex rules.

@YingboMa
Copy link
Member

This is incorrect. We definitely support complex rules.

I don't see Wirtinger definition at all.

@willtebbutt
Copy link
Member

We just don't use Wirtingers. There's nothing stopping your from writing rules that work with complex numbers though.

@YingboMa
Copy link
Member

But are abs and abs2 complex analytic?

@willtebbutt
Copy link
Member

But are abs and abs2 complex analytic?

I don't believe so, but that's not the point. I'm not arguing that our @scalar_rule won't do the correct thing in the non-analytic case. It was a more general comment, not meant to be contraversial.

src/rulesets/Base/fastmath_able.jl Outdated Show resolved Hide resolved
src/rulesets/Base/fastmath_able.jl Outdated Show resolved Hide resolved
@MasonProtter
Copy link
Contributor Author

Does anyone know what the corresponding rrules should be?

@sethaxen
Copy link
Member

Does anyone know what the corresponding rrules should be?

They should be the ones you linked to above in Zygote. Assuming Zygote's conjugation conventions of course.

@YingboMa
Copy link
Member

I don't believe so, but that's not the point. I'm not arguing that our @scalar_rule won't do the correct thing in the non-analytic case. It was a more general comment, not meant to be contraversial.

I don't understand how can you define complex rules without using 2x2 matrix or Wirtinger, if the function is not complex analytic.

@sethaxen
Copy link
Member

I don't believe so, but that's not the point. I'm not arguing that our @scalar_rule won't do the correct thing in the non-analytic case. It was a more general comment, not meant to be contraversial.

I don't understand how can you define complex rules without using 2x2 matrix or Wirtinger, if the function is not complex analytic.

I believe all of these rules in ChainRules assume that the tangents and cotangents are derivatives of the primal with respect to a real scalar or a real scalar with respect to a primal, respectively. Or equivalently, which is why no Wirtinger is needed. If complex differentiation is what is needed, you just call the pushforward/pullback twice to fill the Jacobian. This is basically what Zygote does: https://fluxml.ai/Zygote.jl/latest/complex/

@YingboMa
Copy link
Member

I see. That makes sense. But then how can we warn the user that the function is not analytic? I don't think silently giving the wrong answer is a good idea.

@MasonProtter
Copy link
Contributor Author

MasonProtter commented May 29, 2020

But then how can we warn the user that the function is not analytic? I don't think silently giving the wrong answer is a good idea.

It's not the wrong answer. It's just that naively asking for pullback(1) doesn't allow you to derive the whole Jacobian, unlike in the holomorphic case.

It's like how in ForwardDiff2 if you had f(::Vector)::Vector and did

D(f)(v) * [1, 0]

you wouldn't know the full Jacobian. You need D(f)(v) * [0, 1] as well to be able to derive it, unless f had a special structure.

In this case, the whole Jacobian can be obtained from pullback(1), pullback(im).

It would however be a good idea to make this thing more clear in the docs somehow.

@YingboMa
Copy link
Member

YingboMa commented May 29, 2020

It's not a documentation issue. The information that the function is not holomorphic is never forwarded. So an AD system doesn't know it needs special handling for the non-holomorphic case.

Though, I definitely like this approach of handling complex AD.

@sethaxen
Copy link
Member

It's not a documentation issue. The information that the function is not holomorphic is never forwarded. So an AD system doesn't know it needs special handling for the non-holomorphic case.

That's fair. In this approach, you need to check if the function is holomorphic by checking the Cauchy-Riemann equations. And it'll be a bit wasteful (although you can reuse the pullback when not mutating). But on the upside, the rules are simpler, and I don't think there's anything preventing future implementation of Wirtinger derivatives or something equivalent.

In case you didn't see it, this discussion is relevant: https://discourse.julialang.org/t/taking-complex-autodiff-seriously-in-chainrules

@MasonProtter MasonProtter changed the title restrict abs2 to ::Real Correct chainrules for complex abs2 and abs Jun 1, 2020
@MasonProtter MasonProtter changed the title Correct chainrules for complex abs2 and abs Correct chainrules for ::Complex abs2 and abs Jun 1, 2020
Copy link
Contributor

@nickrobinson251 nickrobinson251 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

src/rulesets/Base/fastmath_able.jl Outdated Show resolved Hide resolved
src/rulesets/Base/fastmath_able.jl Outdated Show resolved Hide resolved
test/rulesets/Base/fastmath_able.jl Outdated Show resolved Hide resolved
MasonProtter and others added 3 commits June 1, 2020 14:10
Co-authored-by: Nick Robinson <npr251@gmail.com>
Co-authored-by: Nick Robinson <npr251@gmail.com>
MasonProtter and others added 2 commits June 24, 2020 18:39
Co-authored-by: Seth Axen <seth.axen@gmail.com>
Co-authored-by: Seth Axen <seth.axen@gmail.com>
@MasonProtter
Copy link
Contributor Author

Okay, so this PR now adopts the subgradient convention where in the situations that might cause functions like the gradient of angle abs to give NaN for non-nan inputs (i.e. z = 0), we instead return an appropriate zero.

It also has changed to the point of view for angle that it does not treat the reals as embedded in the complex plane, so frule and rrule of angle now give Zero() when x::Real, Δx::Real, ΔΩ::Real.

MasonProtter and others added 7 commits June 26, 2020 13:11
Co-authored-by: Simon Etter <ettersi@users.noreply.github.com>
Co-authored-by: Simon Etter <ettersi@users.noreply.github.com>
Co-authored-by: Simon Etter <ettersi@users.noreply.github.com>
Co-authored-by: Simon Etter <ettersi@users.noreply.github.com>
Co-authored-by: Simon Etter <ettersi@users.noreply.github.com>
Co-authored-by: Simon Etter <ettersi@users.noreply.github.com>
Copy link
Member

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Can you increment the version number to v0.7.0-DEV? I think we can merge this soon but hold on a release until the coming PR with compatibility for ChainRulesCore v0.9 is merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Incorrect chain rule for abs2
7 participants