-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Power Series adjoints are subtly wrong in a way gradcheck can't catch #608
Comments
I agree that the ChainRulesTestUtils approach is better than |
Fixed it so inputs to ChainRules were conjugated, which correct results for |
I'm not certain if this is the cause, but it's a little tricky to test these power series functions using FiniteDifferences. Here's a simple example: julia> using Random, Zygote, FiniteDifferences
julia> Random.seed!(42);
julia> _fdm = central_fdm(5, 1; adapt=5);
julia> seed = randn(3, 3)
3×3 Array{Float64,2}:
-0.556027 -0.299484 -0.468606
-0.444383 1.77786 0.156143
0.0271553 -1.1449 -2.64199
julia> A = Symmetric(randn(3, 3))
3×3 Symmetric{Float64,Array{Float64,2}}:
1.00331 0.518149 -0.886205
0.518149 1.49138 0.684565
-0.886205 0.684565 -1.59058
julia> _, zpb = Zygote.pullback(exp, A);
julia> zg = zpb(seed)[1]
3×3 Array{Float64,2}:
-2.11602 0.648903 0.634618
-0.680315 7.49855 0.822502
0.54197 -1.24868 -1.01659
julia> ng = conj.(FiniteDifferences.j′vp(_fdm, exp, seed, A))[1]
3×3 Symmetric{Float64,Array{Float64,2}}:
-2.11602 -0.0314117 1.17659
-0.0314117 7.49855 -0.426182
1.17659 -0.426182 -1.01659 The problem here is that FiniteDifferences constraints the j′vp to have the same type as the input, hence it makes it symmetric (see JuliaDiff/FiniteDifferences.jl#76 (comment)). Zygote has no such constraint (I think the FD approach essentially gives us Zygote's intended adjoint followed by a projection to the tangent space to the manifold defined by the constraint, though it's not generally the case that elements of the tangent space can be represented as points on the manifold). If you continue pulling back the adjoint through a julia> _, zpb = Zygote.pullback(exp ∘ Symmetric, collect(A));
julia> zg = zpb(seed)[1]
3×3 Array{Float64,2}:
-2.11602 -0.0314117 1.17659
0.0 7.49855 -0.426182
0.0 0.0 -1.01659
julia> ng = conj.(FiniteDifferences.j′vp(_fdm, exp ∘ Symmetric, seed, collect(A)))[1]
3×3 Array{Float64,2}:
-2.11602 -0.0314117 1.17659
-1.02408e-15 7.49855 -0.426182
-1.02408e-15 -1.02408e-15 -1.01659 |
right, so do you think that means all is well? |
So far so good: https://gist.github.com/sethaxen/fa67e541c4a2a5e773b475349ed87fb9. Still a couple of edge cases to check. |
This came up in #366.
And it has proved tricky to debug.
AFAICT there were two issues, one from ChainRules directly (I forgot to put the adjoint for
abs(::Complex)
back in, fixed now).But secondly, what I think is an error in the actual definition of the power series adjoints.
Which is what this issue is about
It is a subtle thing, that
gradcheck
does not catchbecause
gradcheck
sums things, which results into the errors canceling, because they are almost symmetric (I think they are symmetric up to floating point errors).But something in the chainrules PR breaks the symmetry just enough to register.
Further testing by using FiniteDifferences to look at the whole gradient of
A^p
as in https://gist.github.com/oxinabox/fad7bae6dc11a7f0de31eff8666656cdShows Zygote#master giving identical outputs to the ChainRules branch.
(Which means the symmetry break is not in the actual output we are testing but somewhere else)
(However, not 100% trusting of FiniteDifferences with complex numbers right now. In particular as of today if using the JuliaDiff/FiniteDifferences.jl#76 branch it gets the answer wrong in the same way shown in JuliaDiff/FiniteDifferences.jl#76 (comment))
More broadly, this is a great example of things that can go wrong with Zygotes current method of gradient checks.
Moving forward, ChainRules has ChainRulesTestUtils,
which uses more complete finite difference based tests.
And adjoints implemented based on ChainRulesCore can use that.
It is probably still worth fixing Zygotes grad-tests
Quoting relevant bits from #366
@oxinabox
@sethaxen
@antoine-levitt
The text was updated successfully, but these errors were encountered: