Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how to selectively take structural gradient #1042

Open
CarloLucibello opened this issue Jul 26, 2021 · 15 comments
Open

how to selectively take structural gradient #1042

CarloLucibello opened this issue Jul 26, 2021 · 15 comments

Comments

@CarloLucibello
Copy link
Member

CarloLucibello commented Jul 26, 2021

In Flux, we typically apply @functor to a type for 2 purposes:

  1. for recursively traversing structs and mapping leaves, as done by gpu
  2. collecting parameters in a Zygote.Params for gradient calculation (this is done by Flux.params(model)).
    When we what to distinguish the two behaviors, we use Flux.trainable for the parameters collection.
    This is an
using Flux, Zygote
using Flux: @functor

struct B
   b1::Array
   b2::Array
end
@functor B

struct A
   a1::Array
   eps::Number
   b::B
end
@functor A
Flux.trainable(a::A) = (a.a1,)

a = A(rand(3),0.1,B(rand(2), rand(2)))

Flux.params(a) 
#Params([[0.2755365528802143, 0.7419122552485184, 0.048976872406773175]])

loss(a) = a.eps + sum(a.a1) + sum(a.b.b1)

Now when ones computes the gradient in the implicit form, supposedly only the gradient with respect to
a.a1 should be computed. This appears to not be exactly currently true, every gradient seems to be computed, but at least only the one with respect to a.a1 is exposed

julia> g = gradient(() -> loss(a), Flux.params(a))
Grads(...)

julia> g[a.a1]
3-element Fill{Float64}: entries equal to 1.0

julia> g[a.b.b1]
ERROR: KeyError: key [0.7037661100448469, 0.34941543792301455] not found
Stacktrace:
 [1] getindex
   @ ./iddict.jl:93 [inlined]
 [2] getindex(gs::Zygote.Grads, x::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:279
 [3] top-level scope
   @ REPL[42]:1
 [4] top-level scope
   @ ~/.julia/packages/CUDA/lwSps/src/initialization.jl:52

julia> g = gradient(() -> loss(a), Flux.params(a)).grads
IdDict{Any, Any} with 2 entries:
  [0.275537, 0.741912, 0.0489769] => 3-element Fill{Float64}: entries equal to 1.0
  :(Main.a)                       => (a1 = nothing, eps = 1.0, b = (b1 = 2-element Fill{Float64}: entries equal to 1.0, b2 = nothing))

With explicit gradient instead, everything is computed and exposed

julia> gradient(a -> loss(a), a)
((a1 = 3-element Fill{Float64}: entries equal to 1.0, eps = 1.0, b = (b1 = 2-element Fill{Float64}: entries equal to 1.0, b2 = nothing)),)

This is bad since we would like to feed this to an update! function, and also inefficient. How do we tell Zygote to drop some model parts from the gradient computation? I would like the following

julia> gradient(a -> loss(a), a)
((a1 = 3-element Fill{Float64}: entries equal to 1.0),)

I see two possibilities:

  • we make gradient @functor/trainable aware
  • we pass to gradient a keyword argument for the gradient masking
@DhairyaLGandhi
Copy link
Member

What do you mean by gradient masking here?

Also this is needed since while the gradient for an argument may not be explicitly asked for, but might be required to compute the gradient of a different argument. Forcing that to nothing would still work with accum but give incorrect results.

@ToucheSir
Copy link
Member

Preventing updates during optimization could be accomplished with a helper like https://optax.readthedocs.io/en/latest/api.html?highlight=mask#optax.masked. That said, it doesn't account for the scenario where you want to save memory by not holding onto gradients for certain parameters that won't be updated.

@mcabbott
Copy link
Member

Functors.trainable and ChainRulesCore.ProjectTo have quite a bit in common, it's possible they should get to know each other better. I'm not precisely sure why this doesn't work today, but with #1044 it might go something like:

julia> import Zygote.ChainRulesCore: ProjectTo

julia> ProjectTo(a::A) = dA::NamedTuple -> ((; a1 = dA1.a1, eps=nothing, b=nothing),);

julia> gradient(loss, a)
((a1 = Fill(1.0, 3), eps = 1.0, b = (b1 = Fill(1.0, 2), b2 = nothing)),)
((a1 = Fill(1.0, 3), eps = nothing, b = nothing),)  # is what I hoped for

@ToucheSir
Copy link
Member

ToucheSir commented Jul 30, 2021

The wrench in the works is that Functors doesn't have a trainable method and isn't even involved when taking explicit gradients. Perhaps it could take a dep on ChainRulesCore?

@darsnack
Copy link
Member

I'm not precisely sure why this doesn't work today

Is it because _project is only defined for Numeric?

@mcabbott
Copy link
Member

Oh right, thanks both. I guess there are many details of this union I don't see yet. But Functors/Optimisers are interested in AD with nested structs, and there might be a nice ChainRules-level way to encode things.

@darsnack
Copy link
Member

darsnack commented Jul 30, 2021

There are two pieces here: (1) not updating non-trainable parameters, and (2) not computing gradients for non-trainable parameters.

For (1), Optimisers.jl uses Functors.jl to walk over the structure and the nested gradient tuple to apply updates. Thanks to FluxML/Functors.jl#14, we can know limit that walk to the parameters defined by trainable. I think that pretty much takes care of (1).

For (2), if f outputs a Foo and df operates on all the fields of dFoo, then I don't think you can selectively drop gradients for any fields. A more concrete example:

function make_model(some_hyperparams...)
    # do stuff with hyper-params to make W and b
    return Dense(W, b) # let's suppose only W is trainable
end

gradient(ps -> loss(make_model(ps...)(x)), ps)

Here it wouldn't make sense for the pullback of (::Dense)(x) to drop the gradients w.r.t. b. We'd basically need something like ProjectTo but dynamic to each gradient call.

Also, if this is somewhere in the middle of the computation, then I would hope the memory gets re-used once that unnecessary gradient is not in the following pullbacks. I think this is really a concern for only the inputs to the full computation.

@ToucheSir
Copy link
Member

This is the blessing and curse of Zygote supporting differentiation of arbitrary structs. AFAIK, there is no way to provide it additional information about what fields should be accumed into the final gradient tuple and which can be omitted (excepting intermediate calculations which require them). I'm not sure what a general solution for this would look like—could we make use of ChainRulesCore.Tangent somehow?

@darsnack
Copy link
Member

blessing and curse of Zygote supporting differentiation of arbitrary structs

Right, PyTorch autograd's requires_grad does (2), but it also prevents PyTorch's layers from being as flexible as ours. I feel like any general solution needs to be non-static. Meaning that the masking info is introduced on the gradient call.

@ToucheSir
Copy link
Member

ToucheSir commented Jul 30, 2021

Yup. Now if we had a function like trainable that returned a set of property/field names instead, I wonder if we could dynamically generate tangents with only those fields when in AD. ref. https://juliadiff.org/ChainRulesCore.jl/stable/converting_zygoterules.html

@ToucheSir
Copy link
Member

Another (possibly complementary) approach more in line with requires_grad would be some kind of wrapper type that instructs Zygote to always insert nothing when creating the gradient tuple. This of course has all the issues commonly associated with array wrappers.

@oschulz
Copy link

oschulz commented Aug 16, 2021

supposedly only the gradient with respect to
a.a1 should be computed. This appears to not be exactly currently true, every gradient seems to be computed

Maybe #966 could bring some improvements in that regard.

@ToucheSir
Copy link
Member

For my own edification, do thunks help with deeply nested struct or tangent fields? I can wrap my head around how an entire argument might be excluded from evaluation, but not a piece of one.

@oschulz
Copy link

oschulz commented Aug 16, 2021

For my own edification, do thunks help with deeply nested struct or tangent fields?

I have to admit I'm not entirely sure myself, resp. if it will be possible to make the pullback(s) for the struct creation smart enough.

Maybe we will need some kind of hinting procedure at some point, so the user can specify what quantities they want the gradient for, like Enzyme has.

@oxinabox
Copy link
Member

@willtebbutt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

No branches or pull requests

7 participants