Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate Flux.Optimisers and implicit parameters in favour of Optimisers.jl and explicit parameters #1986

Open
1 of 8 tasks
CarloLucibello opened this issue Jun 2, 2022 · 15 comments

Comments

@CarloLucibello
Copy link
Member

CarloLucibello commented Jun 2, 2022

So all the things are in place and we can get rid of the current pattern using implicit params:

using Flux
ps = Flux.params(model)
opt = Flux.Optimise.ADAM()
gs = gradient(() -> loss(model(x), y), ps)
Flux.Optimise.update!(opt, ps, grads)

to the one using explicit parameters and Optimisers.jl

using Flux, Optimisers
opt_state = Optimisers.setup(Optimisers.Adam(), model)
∇model = gradient(m -> loss(m(x), y), model)[1]
opt_state, model = Optimisers.update!(opt_state, model, ∇model)
## or the non-mutating
# state, model = Optimisers.update(opt_state, model, ∇model)

Code

julia> gradient(m -> (sum(norm, Flux.params(m))), (x=[1,2.0], y=[3.0]))
(nothing,)

Documentation

Examples

  • Port model zoo examples -- tag "update"
  • Help porting downstream libraries and check there are no surprises
    • GraphNeuralNetworks.jl

@mcabbott @ToucheSir @darsnack feel free to add to this

@ToucheSir
Copy link
Member

Thoughts on copying or moving parts of https://github.com/FluxML/Flux.jl/blob/v0.13.3/docs/src/training/optimisers.md#scheduling-optimisers over to the Optimisers.jl docs?

@darsnack
Copy link
Member

darsnack commented Jun 3, 2022

@DrChainsaw
Copy link
Contributor

I think that this issue might be a blocker for explicit parameters.

For example:

julia> gradient(bn -> sum(bn(randn(Float32, 3,1))), BatchNorm(3))
((λ = nothing, β = Fill(1.0f0, 3), γ = Float32[0.0, 0.0, 0.0], μ = nothing, σ² = nothing, ϵ = -0.0f0, momentum = nothing, affine = nothing, track_stats = nothing, active = nothing, chs = nothing),)

julia> gradient(bn -> sum(bn(randn(Float32, 3,1))), Chain(BatchNorm(3)))
(nothing,)

Not sure if related, but it seems that for some models the call hangs idefinitely (I killed the process after 8 hours) if gradients are missing from one of the layers due to the above. The only 100% reproducer I have for this so far is one generated by https://github.com/DrChainsaw/NaiveGAflux.jl for a particular random seed, but if I manage to narrow it down I'll post an issue. If anyone has any tips on could cause the stall it would be very helpful.

@ToucheSir
Copy link
Member

I wonder if this is encountering a similar issue as TuringLang/Turing.jl#1754. Have you tried -O1?

For troubleshooting, a smaller example that doesn't hang but still takes some time to compile (i.e. over a minute) would be good. Bonus if you can provide an inference/total time breakdown from SnoopCompile.@snoopi_deep.

@DrChainsaw
Copy link
Contributor

Thanks @ToucheSir. Sorry if this is derailing. I think the root issue with missing gradients is quite important regardless as it impacts vanilla Flux.

Putting the discussion on the hangs in details in an attempt to reduce clutter

From what I can tell there does not seem to be any poor scaling at work. When printing out the forwards and backwards pass for the call gradient(m -> sum(m(x)), model), it seems like the whole backwards pass through the model is completed before it hangs. Using -O1 does not seem to alleviate anything (although I haven't waited for 8 hours this time).

One hypothesis is that it is the evil loopy structure of the CompGraph where vertices has references to both inputs and outputs which somehow causes Zygote to get stuck in some infinite loop looking for the missing gradients. If I strip the graph of the outputs (they are not involved in the forward/backward pass) the gradient call terminates correctly after some 100 or so seconds (although vast majority of time is spent after the backwards pass through the model is completed). Maybe its just a coincidence from making the model a little bit simpler when stipping the outputs though.

@ToucheSir
Copy link
Member

Thanks. This looks like it deserves its own issue, so I've opened FluxML/Zygote.jl#1243.

@mcabbott
Copy link
Member

While this isn't a complete solution, one way around #1986 (comment) is to simply make the normalisation layers into immutable structs. The only mutation they need is a marker of whether they are active, and this can be a Ref(true) or something.

Is this worth doing?

@DhairyaLGandhi
Copy link
Member

Separating out the configuration as in #1509 is also a solution and gives type stability and inferrability as well

@DrChainsaw
Copy link
Contributor

Another possibility is to refactor into functions which don't use the mutable struct and then add trivial rrules, something like:

(BN::BatchNorm)(x) = batchnorm(BN..., x) # splatting mostly because I'm too lazy to type out all the members

function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, BN::BatchNorm, x)
    res, back = rrule_via_ad(config, batchnorm, BN..., x)
    function BatchNorm_back(Δ)
        δs = back(Δ)
        Tangent{BatchNorm}(δs[1:end-1]...), δs[end]
    end
    return res, BatchNorm_back
end

Not sure if this has any advantages over the above proposals, especially considering that the refactor looks like it could be a bit messy. I suppose one far fetched argument is that users might consider it a breaking change to make exported structs immutable.

@ToucheSir
Copy link
Member

ToucheSir commented Jul 30, 2022

The main sources of type stability on norm layer types are getproperty/setproperty! on mutable structs and control flow in the forward pass. I believe the former includes Refs, so just making the outer type immutable is probably insufficient. What would work is writing rrules that hide said field lookups. #2005 is more or less this for norm layers which don't track stats (note that kwarg default vaules can introduce control flow—Zygote is real finicky). https://github.com/FluxML/Flux.jl/blob/v0.13.4/src/layers/normalise.jl#L245-L255 is the big remaining culprit, but can be worked around in a similar way. https://github.com/FluxML/Flux.jl/blob/v0.13.4/src/layers/normalise.jl#L258 I suspect can be resolved without any rrules.

@ToucheSir
Copy link
Member

@DrChainsaw saw JuliaLang/julia#44185 today while catching up on JuliaCon talks. Are you willing to try it out on your side and seeing what the profiler spits out?

@DrChainsaw
Copy link
Contributor

@ToucheSir wow, that looks like a great QOL improvement in general. I only have access to a windows environment atm, but I will give it a shot when that changes. I'm a bit worried that it seems to require yielding. Ctrl-c does nothing so I fear that it is stuck in an infinite loop which does not yield anywhere.

@CarloLucibello
Copy link
Member Author

@mcabbott Can this be closed as done?

@mcabbott
Copy link
Member

I think the remaining issue is RNNs not working right with the new explicit style. Have not followed closely but e.g. #2258

@darsnack
Copy link
Member

And regularization e.g. FluxML/Optimisers.jl#57

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

No branches or pull requests

6 participants