New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support batched training #18

Closed
wants to merge 0 commits into
base: master
from

Conversation

Projects
None yet
2 participants
@ylxdzsw
Contributor

ylxdzsw commented Apr 9, 2017

This PR add batched training by:

  • adding a Batched type, which wraps an Iterator and output in batch
  • add keyword argument all for update! of MXNet backend, which copy the final params to all execs
  • make mse! supports batched arguments
  • alter the training process of TensorFlow backend
  • fix a little bug of CatMat

@ylxdzsw ylxdzsw force-pushed the ylxdzsw:master branch 2 times, most recently from fae946d to 5d02d9f Apr 13, 2017

@MikeInnes

This looks great, thanks 👍. The batch iterator approach is really nice. I've left a couple of comments on small things but I think this will be easy to polish up.

@@ -127,7 +127,17 @@ function Flux.back!(m::Model, Δ, xs...)
end
end
Flux.update!(m::Model, η) = (update!(m.last, η); m)
Flux.update!(m::Model, η; all=false) = begin

This comment has been minimized.

@MikeInnes

MikeInnes Apr 17, 2017

Member

We can't modify the signature of update! like this, because it exposes an implementation detail of MXNet models; which means training processes etc can't be written generically any more.

I think it'd be sufficient to keep the old update!, but be careful to loadparams! whenever .last changes.

This comment has been minimized.

@ylxdzsw

ylxdzsw Apr 18, 2017

Contributor

Yes, I saw your note in the file. I will try to implement that way. Just another idea: is it possible to use exactly the same NDArray among Execs? I think move the params local varaiable in mxparams function to the Model or Graph object and do not reinitialize it every time should do the trick. This should completely get rid of this problem and in addition, save precious GPU memory.

This comment has been minimized.

@MikeInnes

MikeInnes Apr 18, 2017

Member

I don't know that it won't work, so I guess it's worth a shot :)

Y => batchone(convertel(Float32, y))))
if i % 5000 == 0
Dict(m.inputs[1] => rebatch(convertel(Float32, rawbatch(x))),
Y => rebatch(convertel(Float32, rawbatch(y)))))

This comment has been minimized.

@MikeInnes

MikeInnes Apr 17, 2017

Member

convertel should already work on batches without needing to convert to an array and back

This comment has been minimized.

@ylxdzsw

ylxdzsw Apr 18, 2017

Contributor

Yes, it does. I didn't notice that.

src/cost.jl Outdated
@@ -1,7 +1,7 @@
export mse, mse!
function mse!(Δ, pred, target)
map!(-, Δ, pred, target)
map!(-, Δ, rawbatch(pred), rawbatch(target))

This comment has been minimized.

@MikeInnes

MikeInnes Apr 17, 2017

Member

It seems wrong that mse! needs to know about batching; in general, we should be able to write functions over generic arrays and have batching work for free. In this case, perhaps we can implement broadcast! for Batch so that mse! works as is? If it's complicated to do this it doesn't have to block the PR.

This comment has been minimized.

@ylxdzsw

ylxdzsw Apr 18, 2017

Contributor

I'm afraid that overloading broadcast! is not enough, as this is the map! in julia 0.6:

function map!{F}(f::F, dest::AbstractArray, A::AbstractArray, B::AbstractArray)
    for (i, j, k) in zip(eachindex(dest), eachindex(A), eachindex(B))
        dest[i] = f(A[j], B[k])
    end
    return dest
end

We treat Batch as an Vector of Arrays, so the dimensions of those eachindexs are not match. What about just specialize mse! for Batchs?

This comment has been minimized.

@MikeInnes

MikeInnes Apr 18, 2017

Member

We can just use broadcast! instead of map! here.

end
function Batched(iter::T, batch::Integer) where T
batch >= 1 || throw(ArgumentError("batch size must >= 1"))

This comment has been minimized.

@MikeInnes

MikeInnes Apr 17, 2017

Member

Is there a particular reason to disable batch size 1?

This comment has been minimized.

@ylxdzsw

ylxdzsw Apr 18, 2017

Contributor

No, it doesn't. batch == 1 is allowed and I tested.

This comment has been minimized.

@MikeInnes

MikeInnes Apr 18, 2017

Member

Obviously – reviewing PRs too late :)

for ibatch in 1:x.batch
if done(x.iter, x.i)
warn("cannot perfectly divide data by batch size, remainder will be discarded")

This comment has been minimized.

@MikeInnes

MikeInnes Apr 17, 2017

Member

In principle, we could just return a smaller final batch. I'm not sure if we should or not, but it might be a nicer default.

This comment has been minimized.

@ylxdzsw

ylxdzsw Apr 18, 2017

Contributor

AFAIK most frameworks simply discard the remainder, but yes, it worth being implemented as at least an option for user to choose.

This comment has been minimized.

@MikeInnes

MikeInnes Apr 18, 2017

Member

Ok, I'm happy to follow precedent here.

for _ in 1:epoch
@progress for (x, y) in train
@progress for (x, y) in Batched(train, batch)

This comment has been minimized.

@MikeInnes

MikeInnes Apr 17, 2017

Member

I think perhaps train! should stay generic and continue to accept an iterator of either raw arrays or batches, rather than trying to handle batching entirely itself. That gives you the control to batch things ahead of time, for example.

This comment has been minimized.

@ylxdzsw

ylxdzsw Apr 18, 2017

Contributor

How do we know the elements of an iterator is a Batch or raw Array? For example, if the iterator yields Arrays of size (5, 28, 28, 3), is it a Batch of (28, 28, 3) Arrays or just an element of 4-dimension array? Or should we export rebatch function and tell users to pack Batches in the first case?

This comment has been minimized.

@ylxdzsw

ylxdzsw Apr 18, 2017

Contributor

Here's the Keras way: Keras has two methods, fit and fit_generator, the fit accept a big array (can be treat as an entire batch of all data) and split it by batch_size; while fit_generator accepts iterator that generates batchs and do not have batch_size argument. I think we can do the same thing: train! still accept batch and pack the elements of iterator accroding to it, and add a new train_batched! allow user to yield the batch and we just map rebatch on the iterator.

This comment has been minimized.

@MikeInnes

MikeInnes Apr 18, 2017

Member

Yes, Batch is supposed to be user facing. The idea is that we then don't ever have to worry about the question of "is this a batch?", making things less error prone. We should also be able to write functions that work in a generic way for batches and samples, avoiding duplication.

The model interface will accept either a sample or a batch, so in theory train! need not know about batching at all. We can then call either train!(m, xs) or train!(m, batched(xs, 10)). (This doesn't prevent train! from having a batch_size argument for convenience / familiarity of course.)

This comment has been minimized.

@MikeInnes

MikeInnes Apr 19, 2017

Member

#21 will invalidate some of what I said above. So let's just get this working in a simple way, as you've done it, and worry about making things more general later.

i % 1000 == 0 && @show accuracy(m, test)
if i % 100 == 0
update!(m, η, all=true)
@show accuracy(m, test)

This comment has been minimized.

@MikeInnes

MikeInnes Apr 17, 2017

Member

When I last tried this I had problems with accuracy not taking batches. Are the mse! changes enough to fix that?

This comment has been minimized.

@ylxdzsw

ylxdzsw Apr 18, 2017

Contributor

umm, In fact I didn't touch it and still evaluate the model one sample at a time. It should not be hard to support batch though.

This comment has been minimized.

@MikeInnes

MikeInnes Apr 18, 2017

Member

Whoops, my bad, I was thinking of when I tried train! with an iterator of batches; which almost worked aside from this line. Of course, if you expect test to be samples then this will work. We can discuss that assumption elsewhere.

@ylxdzsw ylxdzsw force-pushed the ylxdzsw:master branch from 5d02d9f to 2575f4a Apr 20, 2017

@ylxdzsw

This comment has been minimized.

Contributor

ylxdzsw commented Apr 20, 2017

Rebased, plus all execs in the same mx model now share the same set of args and grads - we don't need to worry about last anymore.

@ylxdzsw ylxdzsw force-pushed the ylxdzsw:master branch 2 times, most recently from fae946d to 94d9937 May 3, 2017

@ylxdzsw

This comment has been minimized.

Contributor

ylxdzsw commented May 3, 2017

Rebased with latest master again. Will this get merged in the near future? Keeping syncing with master is somewhat error prone.

@ylxdzsw ylxdzsw force-pushed the ylxdzsw:master branch from 94d9937 to 2343b7d May 3, 2017

@MikeInnes MikeInnes closed this May 11, 2017

@MikeInnes MikeInnes force-pushed the ylxdzsw:master branch from 2343b7d to 019e341 May 11, 2017

@MikeInnes

This comment has been minimized.

Member

MikeInnes commented May 11, 2017

Err, whoops. I tried to push a change to your branch but it ended up deleting the commits.

Rest assured I have these locally and I'm rebasing/merging them right now.

This was referenced May 11, 2017

@ylxdzsw

This comment has been minimized.

Contributor

ylxdzsw commented May 12, 2017

Yeah I have a local backup too. I'm a little surprised that GitHub allows pushing to forked repos directly.

@MikeInnes

This comment has been minimized.

Member

MikeInnes commented May 12, 2017

Maintainers can push to branches used for PRs, which is pretty handy for patching things up before a merge without going back and forth. Of course, as soon as the PR closed itself that no longer applied, which meant I couldn't fix my mistake :P

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment