Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Power Series adjoints are subtly wrong in a way gradcheck can't catch #608

Open
oxinabox opened this issue Apr 22, 2020 · 5 comments
Open

Comments

@oxinabox
Copy link
Member

oxinabox commented Apr 22, 2020

This came up in #366.
And it has proved tricky to debug.
AFAICT there were two issues, one from ChainRules directly (I forgot to put the adjoint for abs(::Complex) back in, fixed now).
But secondly, what I think is an error in the actual definition of the power series adjoints.
Which is what this issue is about

It is a subtle thing, that gradcheck does not catch
because gradcheck sums things, which results into the errors canceling, because they are almost symmetric (I think they are symmetric up to floating point errors).
But something in the chainrules PR breaks the symmetry just enough to register.

Further testing by using FiniteDifferences to look at the whole gradient of A^p as in https://gist.github.com/oxinabox/fad7bae6dc11a7f0de31eff8666656cd
Shows Zygote#master giving identical outputs to the ChainRules branch.
(Which means the symmetry break is not in the actual output we are testing but somewhere else)

(However, not 100% trusting of FiniteDifferences with complex numbers right now. In particular as of today if using the JuliaDiff/FiniteDifferences.jl#76 branch it gets the answer wrong in the same way shown in JuliaDiff/FiniteDifferences.jl#76 (comment))

More broadly, this is a great example of things that can go wrong with Zygotes current method of gradient checks.
Moving forward, ChainRules has ChainRulesTestUtils,
which uses more complete finite difference based tests.
And adjoints implemented based on ChainRulesCore can use that.
It is probably still worth fixing Zygotes grad-tests


Quoting relevant bits from #366

@oxinabox

current issue: A^p for A being a symmetric matrix and p=-0.5
see this reproducer
https://gist.github.com/oxinabox/fad7bae6dc11a7f0de31eff8666656cd

@sethaxen

I'm not certain what's going on with the matrix. If there's a bug, it'll probably be here:

Zygote.jl/src/lib/array.jl

Lines 619 to 627 in 713715f

function _pullback_series_func_scalar(f::typeof(^), λ, p)
compλ = _process_series_eigvals(f, λ)
r, powλ = isinteger(p) ? (Integer(p), λ) : (p, compλ)
= powλ .^ r
return (fλ,
()->conj.(r .* powλ .^ (r - 1)),
()->conj.((r * (r - 1)) .* powλ .^ (r - 2)),
f̄λ -> (dot(fλ .* log.(compλ), f̄λ),))
end

The adjoint for the power is handled by

ārgs = hasargs ? argsback(diag(f̄Λ)) : ()

and may be off by a conj.

I probably won't have a chance to check the math until the end of the week.

By comparison with the second adjoint in

@adjoint Base.:^(x::Number, p::Number) = x^p,
Δ ->* conj(p * x^(p-1)), Δ * conj(x^p * log(complex(x))))

This line does look to be missing a conj:

f̄λ -> (dot(fλ .* log.(compλ), f̄λ),))

If this can't wait until the end of the week, I dropped my notes on the derivation here in case someone else wants to check the math or that the code is compatible with it: https://gist.github.com/sethaxen/000d164e515014fdda70601be1ecfb56.

Disclaimer: I don't know if this is the final version; it's just what I found on my machine.

@antoine-levitt

Sorry, can't resist the nitpicking since I've run into this exact problem recently : using a Taylor series to compute accurately terms like (f(a) - f(b)) / (a-b) doesn't get you to O(eps) accuracy, it gets you sqrt(eps) when using a first order expansion. You can increase the order to n but then you get something like eps^(1-1/n) in the worst case. I don't know of any method that gives full accuracy here for a general function f.

I don't have anything better: I ended up with https://github.com/JuliaMolSim/DFTK.jl/blob/master/src/Smearing.jl#L44, with O(sqrt(eps)) accuracy using the first derivative
(sorry for derailing this topic)

@sethaxen
Copy link
Contributor

I agree that the ChainRulesTestUtils approach is better than gradcheck (I have adapted the rrules_test for Zygote for a package before). To aid in diagnosing, what ended up being the fix in #366 that got the power series tests to pass again?

@oxinabox
Copy link
Member Author

Fixed it so inputs to ChainRules were conjugated, which correct results for sin(::Complex)

@sethaxen
Copy link
Contributor

I'm not certain if this is the cause, but it's a little tricky to test these power series functions using FiniteDifferences. Here's a simple example:

julia> using Random, Zygote, FiniteDifferences

julia> Random.seed!(42);

julia> _fdm = central_fdm(5, 1; adapt=5);

julia> seed = randn(3, 3)
3×3 Array{Float64,2}:
 -0.556027   -0.299484  -0.468606
 -0.444383    1.77786    0.156143
  0.0271553  -1.1449    -2.64199

julia> A = Symmetric(randn(3, 3))
3×3 Symmetric{Float64,Array{Float64,2}}:
  1.00331   0.518149  -0.886205
  0.518149  1.49138    0.684565
 -0.886205  0.684565  -1.59058

julia> _, zpb = Zygote.pullback(exp, A);

julia> zg = zpb(seed)[1]
3×3 Array{Float64,2}:
 -2.11602    0.648903   0.634618
 -0.680315   7.49855    0.822502
  0.54197   -1.24868   -1.01659

julia> ng = conj.(FiniteDifferences.j′vp(_fdm, exp, seed, A))[1]
3×3 Symmetric{Float64,Array{Float64,2}}:
 -2.11602    -0.0314117   1.17659
 -0.0314117   7.49855    -0.426182
  1.17659    -0.426182   -1.01659

The problem here is that FiniteDifferences constraints the j′vp to have the same type as the input, hence it makes it symmetric (see JuliaDiff/FiniteDifferences.jl#76 (comment)). Zygote has no such constraint (I think the FD approach essentially gives us Zygote's intended adjoint followed by a projection to the tangent space to the manifold defined by the constraint, though it's not generally the case that elements of the tangent space can be represented as points on the manifold). If you continue pulling back the adjoint through a Symmetric call, then the two agree.

julia> _, zpb = Zygote.pullback(exp  Symmetric, collect(A));

julia> zg = zpb(seed)[1]
3×3 Array{Float64,2}:
 -2.11602  -0.0314117   1.17659
  0.0       7.49855    -0.426182
  0.0       0.0        -1.01659

julia> ng = conj.(FiniteDifferences.j′vp(_fdm, exp  Symmetric, seed, collect(A)))[1]
3×3 Array{Float64,2}:
 -2.11602      -0.0314117     1.17659
 -1.02408e-15   7.49855      -0.426182
 -1.02408e-15  -1.02408e-15  -1.01659

@oxinabox
Copy link
Member Author

right, so do you think that means all is well?

@sethaxen
Copy link
Contributor

right, so do you think that means all is well?

So far so good: https://gist.github.com/sethaxen/fa67e541c4a2a5e773b475349ed87fb9.

Still a couple of edge cases to check.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants