Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple the computational batch size and minibatch size by accumulating gradients #1663

Closed
wants to merge 5 commits into from

Conversation

longjon
Copy link
Contributor

@longjon longjon commented Dec 31, 2014

After #1615, so that this code already supports deconv layer. (The actual diff is just +37/-40 lines.)

This PRs the gradient accumulation branch living at https://github.com/shelhamer/caffe/tree/accum-grad. I took a lighter approach here than the one there: parameter gradients are always accumulated, there is no other option. The gradient checker is made correct by zero-initing parameter diffs.

Issues:

  • This changes the behavior of Backward. External code that used Backward is likely to break, if there is any.
  • I think this breaks solvers other than SGDSolver, but haven't thought carefully about that yet.

longjon added a commit to longjon/caffe that referenced this pull request Dec 31, 2014
Decouple the computational batch size and minibatch size by accumulating gradients
@longjon longjon force-pushed the accum-grad branch 4 times, most recently from a4d2e6d to d76653a Compare December 31, 2014 22:53
longjon added a commit to longjon/caffe that referenced this pull request Dec 31, 2014
Decouple the computational batch size and minibatch size by accumulating gradients
longjon added a commit to longjon/caffe that referenced this pull request Jan 1, 2015
Decouple the computational batch size and minibatch size by accumulating gradients
longjon added a commit to longjon/caffe that referenced this pull request Jan 2, 2015
Decouple the computational batch size and minibatch size by accumulating gradients
longjon added a commit to longjon/caffe that referenced this pull request Jan 2, 2015
Decouple the computational batch size and minibatch size by accumulating gradients
longjon added a commit to longjon/caffe that referenced this pull request Jan 3, 2015
Decouple the computational batch size and minibatch size by accumulating gradients
longjon added a commit to longjon/caffe that referenced this pull request Jan 3, 2015
Decouple the computational batch size and minibatch size by accumulating gradients
longjon and others added 5 commits January 11, 2015 00:31
(With layers whose backwards accumlate gradients), this effectively
decouples the computational batch from the SGD minibatch. Each
iteration accumulates gradients over iter_size batches, then parameters
are updated.
@jeffdonahue
Copy link
Contributor

Have we thought about how to handle the case when we're sharing parameters but using different learning rates? I would be okay with simply disallowing that case since it would probably be a pretty weird thing to do. Otherwise the only other way I can think to handle it is pretty messy -- we could have a a special case where, e.g. if blobs_lr is 2 in one layer but 1 in all others, the Net could prescale (by a factor of 2) the top_diff for the layer with blobs_lr 2 by 2... Actually, even that wouldn't work if the layer has other shared param blobs that don't also have the same relative LR...

philkr added a commit to philkr/caffe that referenced this pull request Jan 25, 2015
Decouple the computational batch size and minibatch size by accumulating gradients
@shelhamer
Copy link
Member

Always accumulating is simple and good, but let's review the weight sharing and solvers issues before merging.

@shelhamer
Copy link
Member

Replaced by #1977.

@shelhamer shelhamer closed this Feb 26, 2015
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants