Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient over implicit parameters returns nothing #692

Closed
cossio opened this issue Jun 17, 2020 · 13 comments
Closed

Gradient over implicit parameters returns nothing #692

cossio opened this issue Jun 17, 2020 · 13 comments

Comments

@cossio
Copy link
Contributor

cossio commented Jun 17, 2020

I have encountered this issue several times. This is the smallest example I was able to find to reproduce it.

using Flux, Zygote
using Zygote: @adjoint
struct S
    W::Array{Float64}
end
Flux.@functor S
s = S(randn(4,4))
ps = params(s)
fun(s::S) = sum(s.W)
@adjoint function fun(s::S)
    fun(s), Δ -> ((; W = similar(s.W) .= Δ),)
end
gs = gradient(ps) do
    fun(s)
end
gs[s.W] # nothing
gradient(w -> fun(S(w)), randn(2,2)) # correct gradient

I noticed that gs is storing the correct gradients in W in another key, which equals Main.s, but I'm not even sure what that is and I cannot access it.

julia> gs.grads
IdDict{Any,Any} with 2 entries:
  [-0.453576 0.131353 0.0619522 0.126699; -0.172607 0.306845 0.522566 1.4498; 0.82781 -0.222564 -0.104318 0.0206807; -0.47… => nothing
  :(Main.s)  => (W = [1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0],)

But the correct key s.W is populated with nothing, which is wrong.

What is going on here?

@cossio cossio changed the title gradients are nothing Grads stores gradients in wrong key Jun 17, 2020
@cossio cossio changed the title Grads stores gradients in wrong key Grads stores gradients in wrong key, and correct key is populated with nothing Jun 17, 2020
@cossio
Copy link
Contributor Author

cossio commented Jun 18, 2020

If I remove the custom adjoint, it works fine. So it has to be something related to the interaction between params and a custom @adjoint

struct S
    W::Array{Float64}
end
Flux.@functor S
s = S(randn(4,4))
ps = params(s)
foo(s::S) = sum(s.W)
gs = gradient(ps) do
    foo(s)
end
gs[s.W] # correct gradient

@cossio cossio changed the title Grads stores gradients in wrong key, and correct key is populated with nothing Gradient over implicit parameters returns nothing Jun 18, 2020
@cossio
Copy link
Contributor Author

cossio commented Jun 18, 2020

A workaround is to write an intermediary function that takes only array inputs:

struct S
    W::Array{Float64}
end
Flux.@functor S
s = S(randn(4,4))
ps = params(s)
fff(s::S) = _fff(s.W)
_fff(w) = sum(sin.(w))
@adjoint function _fff(w)
    _fff(w), Δ -> (similar(w) .= Δ .* cos.(w),)
end
gs = gradient(ps) do
    fff(s)
end
gs[s.W] # correct gradients

@cossio
Copy link
Contributor Author

cossio commented Jun 19, 2020

@MikeInnes Any idea what is happening here? This issue is producing wrong gradients silently, and it took me a while just to figure out the bug originated in Zygote.

@ToucheSir
Copy link
Member

ToucheSir commented Nov 4, 2021

I don't believe Zygote can track implicit params usage in adjoint functions (by design I assume, otherwise there'd be no way to avoid AD in custom adjoints). So if s.W doesn't show up in a place that the AD has visibility over, it won't have a gradient. Is there any reason you can't work with a structural gradient for s?

@darsnack
Copy link
Member

darsnack commented Nov 5, 2021

To clarify a bit more, your custom adjoint means that the computation graph that the AD system works with looks like:

s -> f(s) -> output

In other words, it never "sees" the array s.W. The intermediate function avoids this by:

s -> getproperty(s, :W) -> _fft(w) -> output

Here, Zygote returns your custom adjoint for the gradient w.r.t. w for _fft(w), then on the call to getproperty(s, :W) is where the AD "sees" s.W as an implicit array and accumulates your custom adjoint output into it.

Similarly, without the custom adjoint, the AD has

s -> getproperty(s, :W) -> sum(x) -> output

Here, sum(x) plays a similar role to _fft(w). Basically the AD actually "seeing" the array that gets returned by getproperty is key here.

@darsnack darsnack closed this as completed Nov 6, 2021
@cossio
Copy link
Contributor Author

cossio commented Nov 6, 2021

@darsnack You consider this resolved?

@darsnack
Copy link
Member

darsnack commented Nov 6, 2021

Resolved == can't fix? (anyone feel free to reopen in case I'm wrong)

Yes, my understanding is that this by design for adjoints. Writing a custom rule forces the AD to look away, and I don't think we would merge a change that breaks that fundamental assumption.

The only alternative fix I see on Zygote's end would be to do post-pullback accumulating into implicit params since Zygote does the structural gradient anyways. This would require recursively traversing all the values. Maybe @mcabbott can comment on the correctness/feasibility of this.

The recommended fixes here are to:

  • Write the code with indirection like _fft so that the struct serves as syntactic sugar in your program.
  • Use structural gradients. Meaning take the gradient w.r.t. s directly. This will give you the Main.s result as the gradient.

@darsnack
Copy link
Member

darsnack commented Nov 6, 2021

Structural gradient for reference:

gradient(fun, S(rand(2, 2)))

@cossio
Copy link
Contributor Author

cossio commented Nov 6, 2021

Ok, thanks! I will also put here this example from @ToucheSir (posted on the Slack) for future reference.

struct S
    W::Array{Float64}
end
s = S(randn(4,4))
fun(s::S) = sum(s.W)
@adjoint function fun(s::S)
    fun(s), Δ -> ((; W = similar(s.W) .= Δ),)
end
julia> gs = gradient(s) do s
           fun(s)
       end
((W = [1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0],),)

@cossio
Copy link
Contributor Author

cossio commented Nov 7, 2021

@darsnack One more question. Is there an alternative way I could have written the above explicit adjoint so that this example works fine?

@ChrisRackauckas
Copy link
Member

I don't believe Zygote can track implicit params usage in adjoint functions (by design I assume, otherwise there'd be no way to avoid AD in custom adjoints). So if s.W doesn't show up in a place that the AD has visibility over, it won't have a gradient. Is there any reason you can't work with a structural gradient for s?

I would go even further and say Zygote should remove the implicit parameter system, or at least Flux should. It seemed like a good idea but 3 years later I think we've all learned it only causes pain. The main gain was syntactic sugar, but the system underlying it never really was that solid. This is just one of many unsolvable issues that arise from it, others being performance or compile time related, along with other weird correctness edge cases. Instead, we should all explore different ways to make explicit parameters have similarly nice syntax, and that would be the best of all worlds.

@DhairyaLGandhi
Copy link
Member

We have the explicit form, and that would be good to use. Elsewhere, implicit gradients are tracked over the same rules and explicit ones. There's no design constraint over why one should work and another not.

@ToucheSir
Copy link
Member

@ChrisRackauckas I would love for nothing more, but there unfortunately hasn't been a big push behind figuring out how to bring explicit params to parity, let alone a migration plan. This includes non-syntactic issues such as how to do tied weights and how to exclude certain params from optimization. I myself have at least a couple pages of design notes on various aspects/challenges, and looking at what others are doing it's clear this is not a trivial task!

Anyhow, I just created a tracking project at https://github.com/orgs/FluxML/projects/2. Please add new issues/tasks as you encounter them—it would be great to record all this disparate discussion about implicit vs explicit params in one place.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

No branches or pull requests

5 participants