Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for AD backends and explicit optimizers #2083

Closed

Conversation

darsnack
Copy link
Member

@darsnack darsnack commented Oct 14, 2022

This is another approach in a similar vein to #2029 and #2082. The primary goal of this PR is to focus on the use of alternate AD backends, since this is necessary for explicit mode training. Note that while this PR does add explicit support to train!, it does not tackle transitioning the optimizers to using only Optimisers.jl. As such, this PR could be merged and #2082 put on top of it (so it is not a complete replacement for the other PRs).

The changes allow Flux to be used with explicit mode by passing Flux.Optimise.ZygoteExplicitBackend() to train!:

model = Chain(Dense(2 => 1, relu), Dense(1 => 1))
opt = Optimisers.Momentum(1e0)
optstate = Optimisers.setup(opt, model)
ad = Flux.Optimise.ZygoteExplicitBackend()

for epoch in 1:10
    @info "Epoch $epoch"
    optstate, model = Flux.train!(ad, model, data, optstate) do model, x, y
        Flux.Losses.logitbinarycrossentropy(model(x), y)
    end
end

From the user's perspective, the approach taken here is to explicitly (pun not intended) require the correct things to be passed to train (the AD backend and optimizer state tree). If there's a mismatch, then errors are thrown. The default backend is the implicit mode for Zygote. Since AbstractDifferentiation already has backends for other ADs, a user can load the corresponding AD and run the code they want:

model = Chain(Dense(2 => 1, relu), Dense(1 => 1))
opt = Optimisers.Momentum(1e0)
optstate = Optimisers.setup(opt, model)
ad = AD.TrackerBackend()

for epoch in 1:10
    @info "Epoch $epoch"
    optstate, model = Flux.train!(ad, model, data, optstate) do model, x, y
        Flux.Losses.logitbinarycrossentropy(model(x), y)
    end
end

The main change here is to stop using Zygote.gradient and use AD.gradient (where AD === AbstractDifferentiation) instead. Even though AbstractDifferentiation.jl supports Zygote, it really only supports the explicit mode as a special case of ChainRules.jl compatible reverse mode ADs. This PR wraps this and does instead:

  1. Define ZygoteImplicitBackend and ZygoteExplicitBackend as AbstractDifferentiation backends with the appropriate primitives defined. Both wrap AD.ZygoteBackend but a clear type for each allows Flux to specialize dispatch.
  2. Define AD.gradient for both of the above to get around AD failure where Zygote succeeds  JuliaDiff/AbstractDifferentiation.jl#63 (comment).
  3. Add Optimisers.update and Optimisers.update! for Flux.AbstractOptimiser (removing the old Flux.update!)
  4. Minor changes to train! to utilize everything above correctly. Namely, constructing the loss correctly for implicit vs. explicit gradients.
  5. The following needs to be defined for Tracker.jl support:
    # this is to work around AD.TrackerBackend only supporting vectors of params
    AD.gradient(::AD.TrackerBackend, f, xs...) = Tracker.withgradient(f, xs...).grad
    This can be submitted as a PR to AbstractDifferentiation.jl (or Tracker.jl can officially adopt AbstractDifferentiation.jl). I tried adding it directly to Flux, but seemed to run into issues with Tracker being @require-ed in AbstractDifferentiation.jl.

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

@darsnack darsnack marked this pull request as draft October 14, 2022 18:13
@darsnack
Copy link
Member Author

My main motivation here was to make the AD-agnostic piece of the puzzle simpler. One concern here is AbstractDifferentiation.jl needs more time in the oven. I think since we provide our own backends for Zygote implicit/explicit, and since we don't actually rely on any code in AbstractDifferentiation.jl to take gradients (see Step 2 above), the main feature that it provides here is a smooth transition for when it is ready.


# this is a hack to get around
# https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues/63#issuecomment-1225959150
AD.gradient(::ZygoteImplicitBackend, f, x::Zygote.Params) = Zygote.gradient(f, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be value_and_gradient to support changes like #2070?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite, because it runs into the issue you mentioned in the link above the code. I could define both gradient and value_and_gradient to essentially block out AbstractDifferentiation until they sort out the primitives issues.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, it might make sense to have Flux.gradient and Flux.withgradient that defaults to AD.gradient and AD.value_and_gradient. Right now, Flux.gradient(f, xs...) wouldn't default to ZygoteImplicitBackend. Defining our own method would allow us to do this.

src/optimise/train.jl Outdated Show resolved Hide resolved
src/optimise/train.jl Outdated Show resolved Hide resolved
@lorenzoh lorenzoh marked this pull request as ready for review October 16, 2022 05:54
@darsnack darsnack force-pushed the explitcit-abstractdifferentiation branch from f261979 to 37c9759 Compare November 1, 2022 18:50
@mcabbott
Copy link
Member

mcabbott commented Nov 1, 2022

I wrote some comments which I should post:

Complicated Flux models currently use gradient and update!, and I think we agree this will still be the case in future. The role of train! is presently fairly simple uses where you don't want to customise much.

Without train!, at present:

  • To select a different AD, you can use a different package's gradient. You have to load (say) Diffractor, and surely you need to know that Diffractor.gradient or Yota.grad exist -- a module and a function.

  • To select a different mode of Zygote, you call a different method: gradient(fn, args...) vs. gradient(f0, ::Params). (The corresponding methods update!(opt, m, g) vs update!(opt, ::Params, ::Grads) are not yet the same function, but could be.)

If I understand right, this PR wants to introduce a different mechanism for changing what AD is used, and what mode in the case of Zygote. You would do both by passing a token to train!. My objections are:

  • 4 positional arguments is already a lot, easy to get wrong. 5 in a particular order is... one more.

  • And one more object you need to import from somewhere, give a name to, and keep track of. Now three: a module, a function, and a special token.

  • It means train! has a different mechanism for choosing the AD (and style) to gradient (etc). You could fix this by also extending the same mechanism to a new Flux.gradient, or hope everyone starts using Abst.Diff.'s function. But it's unlikely you that won't have to know about Yota.grad, so it's still an additional thing to know, another layer of API.

  • With Zygote alone, this mechanism lets you use the same signature train!(ad, loss3, model, data, state) for both implicit and explicit mode. That might be a less weird way to use implicit mode than the present train!(loss0, ::Params, data, opt), and would enable us to remove the old ::Params signature from the docs. But why use it at all? If you are changing the syntax, why not move fully to explicit mode in one go? Or is there a limitation to Zygote's explicit mode that makes it desirable to keep using Params, internally?

@darsnack
Copy link
Member Author

darsnack commented Nov 1, 2022

We want three things (in order): (1) to use rules from Optimisers.jl, (2) to use explicit gradients from Zygote, and (3) to swap AD backends. If train! never existed, all three objectives would be "just" documentation. Now, train! does exist, so we could have the following solution. Swap to explicit fully which gets us (1) and (2), then the answer for (3) is to write a custom loop calling the X.gradient of your choice. Possible issues with this are:

  • We don't really have (3), since swapping ADs becomes an involved task for the end user, and packages like FluxTraining.jl or FastAI.jl need to invent their own logic for swapping ADs.
  • We don't do enough performance benchmarking and testing the gamut of use-cases to know whether Zygote's implicit mode needs to stick around.

What this PR aims to do is (a) make implicit vs. explicit Zygote equivalent to using two different AD backends, and (b) "pick" a mechanism for (3) that can be ecosystem-wide (in this case using AbstractDifferentiation.jl).

AD has always been the thorn in our side, so a robust solution to (3) is in my mind something we want now. Instead of inventing our own, this PR wants to use and improve AbstractDifferentiation.jl. At the same time, doing (a) means that we don't have to also juggle around Zygote's mode as a separate axis. The fact that #2082 has explicit_withgradient to me says that we need something like this.

@darsnack
Copy link
Member Author

darsnack commented Nov 1, 2022

Answering the more specific comments separately:

4 positional arguments is already a lot, easy to get wrong. 5 in a particular order is... one more.

Alternatively, train! could not accept the backend as an argument, and simply construct ZygoteImplicitBackend() internally and be the same. Or don't make the AD backend a positional argument. I only wrote it this way for consistency with AD.gradient(ad, ...).

And one more object you need to import from somewhere, give a name to, and keep track of. Now three: a module, a function, and a special token.

Well if you don't care to switch ADs, then there is nothing to know or remember. I think this version is easier to learn and use than knowing to overload explicit_withgradient correctly.

It means train! has a different mechanism for choosing the AD (and style) to gradient (etc). You could fix this by also extending the same mechanism to a new Flux.gradient, or hope everyone starts using Abst.Diff.'s function. But it's unlikely you that won't have to know about Yota.grad, so it's still an additional thing to know, another layer of API.

The point is to coalesce around AD.gradient which is really the only thing that will work consistently across packages and users. FluxML owns Zygote.jl (implicit and explicit) and Tracker.jl. Just having those two buy into AbstractDifferentiation.jl would cover most ADs that people use in ML. Any ChainRules-compatible reverse mode AD gets supported by default, so Yota and Diffractor should work too. Explicit buy-in from those two would be nice though.

With Zygote alone, this mechanism lets you use the same signature train!(ad, loss3, model, data, state) for both implicit and explicit mode. That might be a less weird way to use implicit mode than the present train!(loss0, ::Params, data, opt), and would enable us to remove the old ::Params signature from the docs. But why use it at all? If you are changing the syntax, why not move fully to explicit mode in one go? Or is there a limitation to Zygote's explicit mode that makes it desirable to keep using Params, internally?

Unless we plan on killing implicit mode within Zygote itself, I don't see a reason to forbid it. There doesn't need to be train! API for it, but the beauty of this PR is that train! can support it without us ever mentioning or promoting it.

@ToucheSir
Copy link
Member

@darsnack and I had a productive but unfortunately too short discussion on this last ML call, so putting some follow-up thoughts here.

Both #2082 and #2083 seek to have training with implicit and explicit params use roughly the same interface. Same function arity, same number of return values, etc. Given we're planning on removing train!(loss, ps::Params, data, opt::AbstractOptimiser) though, is there any reason to do this? It seems some of the contention around e.g. new vs old Optimizer compatibility would be resolved by adding a deprecation warning in the implicit params train! like so:

train!(loss, ps::Params, data, opt::AbstractOptimiser) is deprecated and will be removed in Flux v0.14.

Going forward, please use the new train! interface:

using Flux.Optimisers: Optimisers

optstate = Optimisers.setup(Optimisers.Descent(), model) 
optstate, model = Flux.train!((model, x, y) -> loss_fn(model(x), y), model, data, optstate)

This means no back and forth conversions. The train! methods with old and new optimizers (and implicit vs explicit params) are now mutually exclusive. We could even consider adding dep warnings to the old optimizer constructors themselves pushing users towards using new-style train with Optimisers.jl rules. As long as new style train! can handle any workflow that old train! could, we could change it to have whatever interface we want.

@mcabbott
Copy link
Member

mcabbott commented Nov 21, 2022

The train! methods with old and new optimizers (and implicit vs explicit params) are now mutually exclusive. We could even consider adding dep warnings

Yes, this is now the present state after #2082. There are doc notes but no dep warnings. They could be added in last version of 0.13? Should be in train! & update! I think.

Unless we plan on killing implicit mode within Zygote itself, I don't see a reason to forbid it

The reason to drop it entirely from 0.14 is that this lets us delete Flux.Optimise, and all of its duplicate code for how Adam works, etc.

Keeping a path for implicit parameters while having only one Adam definition (in Optimisers.jl) means writing some new code to make corresponding IdDicts. Nobody much liked this in #2029. Maybe it could be done in other ways, using the same state tree from setup.

Other uses of Zygote can of course use it as they wish. If ripping implicit parameters out of Zygote completely led to some substantial improvement, then that could be considered as a major change. But realistically nobody is going to get around to trying.

this version is easier to learn and use than knowing to overload explicit_withgradient correctly.

Note that such overloading was never a proposed API. The initial proposal #2082 (now removed) was that there be a macro @train_ad Tracker which globally changes what Flux uses, by some internal mechanism. This is similar to what Turing does. Global switches have their downsides, but they do avoid this one-more-named-thing problem with passing special tokens around.

packages like FluxTraining.jl or FastAI.jl need to invent their own logic for swapping ADs

Yes. If they do settle on some nice high-level scheme, then one future for train! could be that it should be a junior version of those, which follows the same scheme.

The other possible future is that train! is a legacy function, only for very simple training loops & to make upgrading old code easier. The question of whether it should take no/the same/more callbacks is related.

@darsnack
Copy link
Member Author

Yeah I think we are all in complete agreement, so I will close this one now. Most of what is here is already done or belongs in the AD packages whenever the high level interface comes together.

@darsnack darsnack closed this Nov 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants