Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parameter Sharing breaks destructure #1767

Closed
avik-pal opened this issue Nov 17, 2021 · 4 comments · Fixed by #1901
Closed

Parameter Sharing breaks destructure #1767

avik-pal opened this issue Nov 17, 2021 · 4 comments · Fixed by #1901

Comments

@avik-pal
Copy link
Member

MWE:

using Flux

struct Model{A}
    a::A
    b::A
end

Flux.@functor Model

(m::Model)(x) = m.a(x) .+ m.b(x)

d = Dense(1, 1)
x = rand(Float32, 1, 1)

# Sharing the parameters
model = Model(d, d)

# Works
Flux.gradient(() -> sum(model(x)), Flux.params(model)).grads

p, re = Flux.destructure(model)

# Fails
Flux.gradient(p -> sum(re(p)(x)), p).grads

Stacktrace:

┌ Warning: Expected 2 params, got 3
└ @ Flux ~/.julia/packages/Flux/BPPNj/src/utils.jl:647
ERROR: DimensionMismatch("variable with size(x) == (2,) cannot have a gradient with size(dx) == (3,)")
Stacktrace:
 [1] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(dx::Vector{Float32})
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/7ZiwT/src/projection.jl:226
 [2] _project
   @ ~/.julia/packages/Zygote/AlLTp/src/compiler/chainrules.jl:182 [inlined]
 [3] map(f::typeof(Zygote._project), t::Tuple{Vector{Float32}}, s::Tuple{Vector{Float32}})
   @ Base ./tuple.jl:246
 [4] gradient(f::Function, args::Vector{Float32})
   @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface.jl:77
 [5] top-level scope
   @ REPL[10]:2
 [6] top-level scope
   @ ~/.julia/packages/CUDA/2C5YQ/src/initialization.jl:52
@avik-pal avik-pal changed the title Parameter Sharing breaks with destructure Parameter Sharing breaks destructure Nov 17, 2021
@ToucheSir
Copy link
Member

ToucheSir commented Nov 18, 2021

The key line is https://github.com/FluxML/Flux.jl/blob/master/src/utils.jl#L649. Because Zygote is blissfully unaware of the tying, it will return a separate gradient for each layer. However, since the gradient for the biases are Fills and FillArrays are bits types, they hash by content rather than address in an IdDict and thus only 3 params (model.a.weight, model.a.bias and model.b.weight) are retained by fmap.

Fixing this would require a few things. First, passing some additional additional metadata (e.g. offsets of each param) to _restructure for aliasing tracking. Second, excluding types which calculate objectid based on value (i.e. non-mutable types) from caching in fmap1. And third, accumulate gradients for tied parameters in _restructure. This would ideally be handled on the AD, but because it has a custom @adjoint we've effectively opted out of that assistance.

Footnotes

  1. I've experimented with this as part of (the very experimental) https://github.com/FluxML/Functors.jl/pull/27, but it shouldn't be hard to add to Functors proper.

@mcabbott
Copy link
Member

are Fills and FillArrays are bits types,

Note that the problem is worse than this. Even with dense arrays, it's easy for two parameters to get the same gradient, e.g. if they enter as f(x + y), then the same array will be used for both. So you can gradients being === even if all parameters are distinct.

Conversely, with shared parameters in the model, the present structure of fmap means that it never visits the later ones, which is wrong for gradients. The gradients from different occurrences of x in the loss need to be added.

@ToucheSir
Copy link
Member

Exactly. There's no getting around either closing over the original structure or creating a new auxiliary one for use in co-iterating over the gradients and determining what goes where.

@mcabbott
Copy link
Member

Maybe also worth noting that the notion of sharing in Functors is a===b, where these are leaflike AbstractArray{<:Number}. So it will not notice that W and W' are the same data.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants