Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizers #26

Closed
wants to merge 6 commits into from
Closed

Optimizers #26

wants to merge 6 commits into from

Conversation

ylxdzsw
Copy link
Contributor

@ylxdzsw ylxdzsw commented May 4, 2017

I tried to add some optimizers. As for the definition of model in src/model.jl line 8:

A "model" is a function with state.

Optimizers are essentially models, which have the signature of (param, grad) -> delta with their internal states. So we can implement them just as models, using @net. This PR did exactly this. The only problem is that we cannot use the same back! and update! to update optimizers. Currently I introduced two APIs called axpy! and zero! and found them enough to update almost all common optimizers. Maybe there is better way to design this.

a definition of optimizer looks like this:

@net type _NesterovMomentum # <: Optimizer
  μ
  α
  momentum
  (param, grad) -> begin
    v = grad .+ μ * momentum
    μ * v .+ α * grad
  end
end

NesterovMomentum(μ=.9, α=1) = param -> _NesterovMomentum(μ, zero!(similar(param)))

update!(o::_NesterovMomentum, param, grad) = begin
  axpy!(o.μ, o.momentum, o.momentum)
  axpy!(1, grad, o.momentum)
  o
end

and to train a model using optimizers (SGD with momentum and decay) looks like this:

train!(model, data, η = .01, optimizers=[SGD(), Momentum(.9), Decay(1e-4)])

Currently them works on MXNet backend. What's your opinions of this approch?

@MikeInnes
Copy link
Member

My first thought is, this is putting a fairly large burden on the backend to implement much of the optimisation process; I was expecting that a backend wouldn't have to do much more than forward update!(optimiser, model) to update!(optimiser, ::Param), and the optimiser can hold the necessary state for all params. This will be important if we want the optimisers to work with Knet or other frameworks.

Being able to implement optimisers with @net is a nice feature, especially if we can eventually compile optimisers to run on the backend. However, we shouldn't base the design on what @net can currently do; instead it's better to come up with a good design and then throw @net in.

I think it would be a good idea to implement the most straightforward version of this that works in pure Julia mode; e.g. when calling update!(optimiser, Affine(...)) without any backend. Then we can figure out how to make that work with MXNet.

Hopefully you should find that straightforward to do, but let me know if it's not clear.

@ylxdzsw
Copy link
Contributor Author

ylxdzsw commented May 12, 2017

I'm not very clear about the semantic of Param{T} and the "conversion of model". As seem from code, I guess T should always be arrays on the host, right? When we do m2 = mxnet(m1), m2 should still operate on the exact Params on m1, which basically means each operation on m2 involves transferring the content of all Params between GPU and host. This way, if all update!(optimiser, model) ends up to update!(optimiser, ::Param), does it mean we should transfer all weight and grad back to host every batch? It should be a big overhead.

To address this, we need to clearly define what m2 = mxnet(m1) means.

  1. m2 is a wrapper of m1, which means training m2 will update m1 too.
  2. m2 is a deep copy of m1, which means m1 and m2 are totally independent.
  3. this transfer the ownership like rust, which means m1 is no longer valid anymore.

If we go the first way, there should be a way to "turn off" the synchronization of params in the training process, otherwise transfering weights and grads every batch is inevitable. This will make things very complex.

If we choose 2 or 3, that will make things much more easier since we can convert Param{Array} to Param{MXArray} when mxnet(m) is called. And the training should run just fine and fast. However, we may loss some features, for example, we cannot simply save the models as before. To re-enable this, we
can:

  1. add a new method to convert a backended model back to a plain model, so the user can save it.
  2. if we choose the second semantic (aka deep copy), we can add a sync! which operates on two models having exactly the same (==) graph. Since we can keep the origin graph object on backedned models, we can ensure that models converted from the same origin are synchronizable.

As for where to put the state, yes I can have the optimizer to keep a WeakKeyDict{Param, State}, just need to ensure that Params always use their objectID to compare and hash, which is already the case if we don't make Param an AbstractArray.

And for the @net approach, yes, they already compile and run on GPU in this commit. Currently I compile them at the first call on update!(m::Model, optimizers, η). I can move it to update!(optimizer, ::Params{MXArray}) if we use Param{MXArray} instead of the args and grads field in the model.

@MikeInnes
Copy link
Member

Great questions. I'll try to explain my current thinking on this as much as I can.

In general, functions of models have "wrapper semantics"; think Chain or TLP. My thinking so far has been that treating mxnet or tf as another kind of wrapper is the most consistent approach (as well as enabling some nice usage patterns). This might not be realistic though, in which case some functions can have copy semantics and explicit syncing. Another option is to make this opt-in, e.g. detach!(mxnet(m)).

I don't think any of those options should make things significantly more or less complex then; it's just a question of when load/storeparams! is called, and by whom.

Params don't necessarily need to be on the host. You can imagine writing cuda(Affine(10, 20)) to convert the weights to CUDAArrays and then training it. In general we should treat the pure-Julia mode as a first-class citizen and use it as a starting point for designs, even if it's not fully functional yet.

Keep the thoughts coming, especially if that's not clear.

Copy link
Member

@MikeInnes MikeInnes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, this looks much improved.

Although I started this out with a recursive update!, I'm wondering if it would just be cleaner to grab all params up front – like params(Affine(10,5)) == [Param(10,5), Param(1,5)], opt = SGD(params(model)). Then you could call update!(opt) to carry out the update. What do you think?

That would also make it easier to get rid of the nested closures and just have an SGD object with the appropriate state, and an update! method.

I really like the way you can compose optimisers together, but it would be nice if that was built on top of the basic framework, rather than special cased at the bottom. For example, with the tweaks above:

struct Multi
  fs
end

update!(m::Multi) = foreach(update!, m.fs)

That would also avoid the need to repeat the decay stuff many times. If the user wants a decay they can easily just compose the basic optimiser with a decay themselves.

decay::Real=0,
nesterov::Bool=false)

@restrict_range lr "[0, ∞)"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just @assert 0 < lr < ∞?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MikeInnes you mean 0 <= lr || throw(ArgumentError("lr must be > 0")).

The square bracket indicates inclusive in this notation.
And ArgumentError is semantically different from @assert.
In particular, oneday I hope to see the ability to disable asserts in optimised mode.
And it is code using @assert to check user logic that is stopping that happen.
(I'm sure you've seen the issues/PRs in the julialang repo)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@assert is not for this purpose as @oxinabox indicated. However, https://github.com/jw3126/ArgCheck.jl could be a good option. I just feel it's still not clean enough when many checks happen together:

@argcheck lr >= 0
@argcheck 0 <= beta1 <= 1
@argcheck 0 <= beta2 <= 1
@argcheck epsilon > 0
@argcheck decay >= 0

compared with this, which I think is more aligned and mathematical:

@restrict_range lr      "[0, ∞)"
@restrict_range beta1   "[0, 1]"
@restrict_range beta2   "[0, 1]"
@restrict_range epsilon "(0, ∞)"
@restrict_range decay   "[0, ∞)"

@MikeInnes
Copy link
Member

I merged this in 97ecb26 (though it's not active yet, given that I need to make some tweaks for the big refactor).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants