Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support batched training #18

Closed
wants to merge 0 commits into from
Closed

Support batched training #18

wants to merge 0 commits into from

Conversation

@ylxdzsw
Copy link
Contributor

@ylxdzsw ylxdzsw commented Apr 9, 2017

This PR add batched training by:

  • adding a Batched type, which wraps an Iterator and output in batch
  • add keyword argument all for update! of MXNet backend, which copy the final params to all execs
  • make mse! supports batched arguments
  • alter the training process of TensorFlow backend
  • fix a little bug of CatMat
@ylxdzsw ylxdzsw force-pushed the master branch 2 times, most recently from fae946d to 5d02d9f Apr 13, 2017
Copy link
Member

@MikeInnes MikeInnes left a comment

This looks great, thanks 👍. The batch iterator approach is really nice. I've left a couple of comments on small things but I think this will be easy to polish up.

@@ -127,7 +127,17 @@ function Flux.back!(m::Model, Δ, xs...)
end
end

Flux.update!(m::Model, η) = (update!(m.last, η); m)
Flux.update!(m::Model, η; all=false) = begin
Copy link
Member

@MikeInnes MikeInnes Apr 17, 2017

We can't modify the signature of update! like this, because it exposes an implementation detail of MXNet models; which means training processes etc can't be written generically any more.

I think it'd be sufficient to keep the old update!, but be careful to loadparams! whenever .last changes.

Copy link
Contributor Author

@ylxdzsw ylxdzsw Apr 18, 2017

Yes, I saw your note in the file. I will try to implement that way. Just another idea: is it possible to use exactly the same NDArray among Execs? I think move the params local varaiable in mxparams function to the Model or Graph object and do not reinitialize it every time should do the trick. This should completely get rid of this problem and in addition, save precious GPU memory.

Copy link
Member

@MikeInnes MikeInnes Apr 18, 2017

I don't know that it won't work, so I guess it's worth a shot :)

Y => batchone(convertel(Float32, y))))
if i % 5000 == 0
Dict(m.inputs[1] => rebatch(convertel(Float32, rawbatch(x))),
Y => rebatch(convertel(Float32, rawbatch(y)))))
Copy link
Member

@MikeInnes MikeInnes Apr 17, 2017

convertel should already work on batches without needing to convert to an array and back

Copy link
Contributor Author

@ylxdzsw ylxdzsw Apr 18, 2017

Yes, it does. I didn't notice that.

src/cost.jl Outdated
@@ -1,7 +1,7 @@
export mse, mse!

function mse!(Δ, pred, target)
map!(-, Δ, pred, target)
map!(-, Δ, rawbatch(pred), rawbatch(target))
Copy link
Member

@MikeInnes MikeInnes Apr 17, 2017

It seems wrong that mse! needs to know about batching; in general, we should be able to write functions over generic arrays and have batching work for free. In this case, perhaps we can implement broadcast! for Batch so that mse! works as is? If it's complicated to do this it doesn't have to block the PR.

Copy link
Contributor Author

@ylxdzsw ylxdzsw Apr 18, 2017

I'm afraid that overloading broadcast! is not enough, as this is the map! in julia 0.6:

function map!{F}(f::F, dest::AbstractArray, A::AbstractArray, B::AbstractArray)
    for (i, j, k) in zip(eachindex(dest), eachindex(A), eachindex(B))
        dest[i] = f(A[j], B[k])
    end
    return dest
end

We treat Batch as an Vector of Arrays, so the dimensions of those eachindexs are not match. What about just specialize mse! for Batchs?

Copy link
Member

@MikeInnes MikeInnes Apr 18, 2017

We can just use broadcast! instead of map! here.

end

function Batched(iter::T, batch::Integer) where T
batch >= 1 || throw(ArgumentError("batch size must >= 1"))
Copy link
Member

@MikeInnes MikeInnes Apr 17, 2017

Is there a particular reason to disable batch size 1?

Copy link
Contributor Author

@ylxdzsw ylxdzsw Apr 18, 2017

No, it doesn't. batch == 1 is allowed and I tested.

Copy link
Member

@MikeInnes MikeInnes Apr 18, 2017

Obviously – reviewing PRs too late :)


for ibatch in 1:x.batch
if done(x.iter, x.i)
warn("cannot perfectly divide data by batch size, remainder will be discarded")
Copy link
Member

@MikeInnes MikeInnes Apr 17, 2017

In principle, we could just return a smaller final batch. I'm not sure if we should or not, but it might be a nicer default.

Copy link
Contributor Author

@ylxdzsw ylxdzsw Apr 18, 2017

AFAIK most frameworks simply discard the remainder, but yes, it worth being implemented as at least an option for user to choose.

Copy link
Member

@MikeInnes MikeInnes Apr 18, 2017

Ok, I'm happy to follow precedent here.

src/utils.jl Outdated
for _ in 1:epoch
@progress for (x, y) in train
@progress for (x, y) in Batched(train, batch)
Copy link
Member

@MikeInnes MikeInnes Apr 17, 2017

I think perhaps train! should stay generic and continue to accept an iterator of either raw arrays or batches, rather than trying to handle batching entirely itself. That gives you the control to batch things ahead of time, for example.

Copy link
Contributor Author

@ylxdzsw ylxdzsw Apr 18, 2017

How do we know the elements of an iterator is a Batch or raw Array? For example, if the iterator yields Arrays of size (5, 28, 28, 3), is it a Batch of (28, 28, 3) Arrays or just an element of 4-dimension array? Or should we export rebatch function and tell users to pack Batches in the first case?

Copy link
Contributor Author

@ylxdzsw ylxdzsw Apr 18, 2017

Here's the Keras way: Keras has two methods, fit and fit_generator, the fit accept a big array (can be treat as an entire batch of all data) and split it by batch_size; while fit_generator accepts iterator that generates batchs and do not have batch_size argument. I think we can do the same thing: train! still accept batch and pack the elements of iterator accroding to it, and add a new train_batched! allow user to yield the batch and we just map rebatch on the iterator.

Copy link
Member

@MikeInnes MikeInnes Apr 18, 2017

Yes, Batch is supposed to be user facing. The idea is that we then don't ever have to worry about the question of "is this a batch?", making things less error prone. We should also be able to write functions that work in a generic way for batches and samples, avoiding duplication.

The model interface will accept either a sample or a batch, so in theory train! need not know about batching at all. We can then call either train!(m, xs) or train!(m, batched(xs, 10)). (This doesn't prevent train! from having a batch_size argument for convenience / familiarity of course.)

Copy link
Member

@MikeInnes MikeInnes Apr 19, 2017

#21 will invalidate some of what I said above. So let's just get this working in a simple way, as you've done it, and worry about making things more general later.

src/utils.jl Outdated
i % 1000 == 0 && @show accuracy(m, test)
if i % 100 == 0
update!(m, η, all=true)
@show accuracy(m, test)
Copy link
Member

@MikeInnes MikeInnes Apr 17, 2017

When I last tried this I had problems with accuracy not taking batches. Are the mse! changes enough to fix that?

Copy link
Contributor Author

@ylxdzsw ylxdzsw Apr 18, 2017

umm, In fact I didn't touch it and still evaluate the model one sample at a time. It should not be hard to support batch though.

Copy link
Member

@MikeInnes MikeInnes Apr 18, 2017

Whoops, my bad, I was thinking of when I tried train! with an iterator of batches; which almost worked aside from this line. Of course, if you expect test to be samples then this will work. We can discuss that assumption elsewhere.

@ylxdzsw
Copy link
Contributor Author

@ylxdzsw ylxdzsw commented Apr 20, 2017

Rebased, plus all execs in the same mx model now share the same set of args and grads - we don't need to worry about last anymore.

@ylxdzsw ylxdzsw force-pushed the master branch 2 times, most recently from fae946d to 94d9937 May 3, 2017
@ylxdzsw
Copy link
Contributor Author

@ylxdzsw ylxdzsw commented May 3, 2017

Rebased with latest master again. Will this get merged in the near future? Keeping syncing with master is somewhat error prone.

@MikeInnes
Copy link
Member

@MikeInnes MikeInnes commented May 11, 2017

Err, whoops. I tried to push a change to your branch but it ended up deleting the commits.

Rest assured I have these locally and I'm rebasing/merging them right now.

@MikeInnes MikeInnes mentioned this pull request May 11, 2017
@ylxdzsw
Copy link
Contributor Author

@ylxdzsw ylxdzsw commented May 12, 2017

Yeah I have a local backup too. I'm a little surprised that GitHub allows pushing to forked repos directly.

@MikeInnes
Copy link
Member

@MikeInnes MikeInnes commented May 12, 2017

Maintainers can push to branches used for PRs, which is pretty handy for patching things up before a merge without going back and forth. Of course, as soon as the PR closed itself that no longer applied, which meant I couldn't fix my mistake :P

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Linked issues

Successfully merging this pull request may close these issues.

None yet

2 participants