Decouple the computational batch size and minibatch size by accumulating gradients #1977

Merged
merged 8 commits into from May 30, 2015

Conversation

Projects
None yet
7 participants
Owner

shelhamer commented Feb 26, 2015

Accumulate gradients across batches through the iter_size solver field. With this setting batch_size: 16 with iter_size: 1 and batch_size: 4 with iter_size: 4 are equivalent.

master edition of #1663.

  • deduplicate solver logic: done by #2518
  • adjust local_rate and local_decay according to iter_size normalize gradients by iter_size
  • test equality of updates for batch size equivalents

Historical context:
From @longjon

This PRs the gradient accumulation branch living at https://github.com/shelhamer/caffe/tree/accum-grad. I took a lighter approach here than the one there: parameter gradients are always accumulated, there is no other option. The gradient checker is made correct by zero-initing parameter diffs.

Issues:

  • This changes the behavior of Backward. External code that used Backward is likely to break, if there is any.
  • I think this breaks solvers other than SGDSolver, but haven't thought carefully about that yet.

From @jeffdonahue

Have we thought about how to handle the case when we're sharing parameters but using different learning rates? I would be okay with simply disallowing that case since it would probably be a pretty weird thing to do. Otherwise the only other way I can think to handle it is pretty messy -- we could have a a special case where, e.g. if blobs_lr is 2 in one layer but 1 in all others, the Net could prescale (by a factor of 2) the top_diff for the layer with blobs_lr 2 by 2... Actually, even that wouldn't work if the layer has other shared param blobs that don't also have the same relative LR...

From @shelhamer

Always accumulating is simple and good, but let's review the weight sharing and solvers issues before merging.

@shelhamer shelhamer added a commit to shelhamer/caffe that referenced this pull request Feb 26, 2015

@shelhamer shelhamer Merge pull request #1977 from shelhamer/accum-grad
Decouple the computational batch size and minibatch size by accumulating gradients

* shelhamer/accum-grad:
  accumulate gradients in cudnn conv layer
  accumulate gradients in (de)conv layers
  accumulate gradients in inner product layer
  zero-init param diffs in gradient checker
  zero-init param diffs and accumulate gradients
4f56f71

@shelhamer shelhamer added a commit to shelhamer/caffe that referenced this pull request Feb 26, 2015

@shelhamer shelhamer Merge pull request #1977 from shelhamer/accum-grad
Decouple the computational batch size and minibatch size by accumulating gradients

* shelhamer/accum-grad:
  accumulate gradients in cudnn conv layer
  accumulate gradients in (de)conv layers
  accumulate gradients in inner product layer
  zero-init param diffs in gradient checker
  zero-init param diffs and accumulate gradients
dc3479e

@tnarihi tnarihi and 2 others commented on an outdated diff Feb 26, 2015

src/caffe/solver.cpp
@@ -477,7 +502,8 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
case Caffe::CPU:
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
// Compute the value to history, and then copy them to the blob's diff.
- Dtype local_rate = rate * net_params_lr[param_id];
+ Dtype local_rate = rate * net_params_lr[param_id]
+ / this->param_.iter_size();
@tnarihi

tnarihi Feb 26, 2015

Contributor

I think this does not work correctly. Diving by iter_size should be applied before accumulating parameter decays.

@tnarihi

tnarihi Feb 26, 2015

Contributor

Multiplying local_decay by iter_size should be okay?

Dtype local_decay = weight_decay * net_params_weight_decay[param_id] * this->param_.iter_size();
@longjon

longjon Feb 26, 2015

Contributor

Ah... good point.

@shelhamer

shelhamer May 14, 2015

Owner

To clarify: the local_decay needs to be multiplied by the iter_size because the update will include the product of local_rate and local_decay. That is, the update by https://github.com/BVLC/caffe/blob/master/src/caffe/solver.cpp#L497-L499 is computed after weight decay is included on https://github.com/BVLC/caffe/blob/master/src/caffe/solver.cpp#L479-L483. As is, weight decay is defined per iteration so should not be scaled by the effective batch size of batch_size * iter_size.

@tnarihi tnarihi commented on an outdated diff Feb 26, 2015

src/caffe/solver.cpp
@@ -513,7 +539,8 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
#ifndef CPU_ONLY
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
// Compute the value to history, and then copy them to the blob's diff.
- Dtype local_rate = rate * net_params_lr[param_id];
+ Dtype local_rate = rate * net_params_lr[param_id]
+ / this->param_.iter_size();
@tnarihi

tnarihi Feb 26, 2015

Contributor

The same as above.

Contributor

tnarihi commented Feb 26, 2015

Commented on the diff.

By the way, I am not understanding very well about what @jeffdonahue mentions. Is there any relations between this PR and weight sharing. Gradient accumulations among shared parameters are computed independently (since diffs are not shared), and applying learning rates are also independent (then, accumulated to owners). If my understanding is correct, accumulating gradient does not matter for weight sharing. I may be missing something..
Actually, for my purpose, very recently I have implemented accumulating gradient using a simple way but it needs additional memory (It doesn't change the behavior of Backward). tnarihi/caffe@d016dbd
If my understanding is incorrect, my branch also doesn't work in the special case Jeff mentions.

Anyway, the idea of always accumulating is very good (less memory) if the issues are solved. Both issues do not matter to me since I usually use SGDSolver and I will notice the behavior change of Backward, but it might matter to others :P I will move to use this when this PR is done.

Contributor

jeffdonahue commented Feb 26, 2015

Oops, I haven't actually stepped through this myself, but I think you're totally right @tnarihi -- there shouldn't be an issue with weight sharing in this implementation. I was confusing it with my version -- I had rebased my RNN PR (#1873) on this, and then just threw the additional changes to net.cpp I had made in my recurrent branch into a single non-descript commit (jeffdonahue/caffe@8983e6f). That commit should really probably have been 3 separate commits with good clarifying messages, as they were pretty substantial now that I look back on it... The first is that I indeed added ShareDiff to the parameter sharing behavior, for the exact reason you mention -- it's essential for memory efficiency in the RNN/LSTMLayers added there (which are implemented as unrolled networks). The other is that I removed the Net's shared parameter gradient accumulation since they're sharing diffs, which gives a bit of a speedup but probably doesn't really matter much because it's just an axpy (but O(T) of them for unrolled recurrent nets...). Those changes do assume lr_mult (the new blobs_lr) is the same for all shared parameters (I should probably have added that that PR is not currently mergable...)

Besides the other issues Takuya mentioned, I now think this is strictly good (i.e. it doesn't break anything that works now) and should be merged. Maybe I'll write a new PR based on this, or a commit to append to this one, that does ShareDiff (and skips the Net's accumulation) for any shared parameters where all lr_mults match. Not the prettiest solution, but I think the speed/memory savings is worth the additional complexity in this case.

Contributor

tnarihi commented Feb 26, 2015

I see. Thanks Jeff! Sharing diff for weight sharing is nice for memory consumption. I think to restrict all lr_mults match should be okay.

The other thing is, I think, to notifying developers (especially for developers working on PR regarding layers that has parameter updates) that Backward behavior change is necessary. To add a test checking if all layers work in gradient accumulation is better, but it seems difficult...

Contributor

jeffdonahue commented Feb 26, 2015

The other thing is, I think, to notifying developers (especially for developers working on PR regarding layers that has parameter updates) that Backward behavior change is necessary. To add a test checking if all layers work in gradient accumulation is better, but it seems difficult...

Another good point. At some point I had modified the gradient checker to check accumulation (by adding some random noise to the param blob diffs, calling Backward, then subtracting the noise before checking the result) -- I can try to dig that up to add to this.

Contributor

tnarihi commented Feb 26, 2015

That sounds nice idea. Abstracting Backward testing could solve this issue. Currently gradient checker seems to play the role.

Contributor

longjon commented Feb 26, 2015

Actually I have thought of a simpler way to implement this that is independent of gradient accumulation. Maybe it is too tricky, maybe not. Will update.

(I am still mildly in favor of always accumulating gradients, disallowing different lr_mults, and simplifying weight sharing. Ideally one implements lr_mult by having a backward net that is slightly different from one's forward net.)

This was referenced Mar 4, 2015

@longjon longjon added a commit to longjon/caffe that referenced this pull request Mar 10, 2015

@longjon longjon Merge pull request #1977 from shelhamer/accum-grad
Decouple the computational batch size and minibatch size by accumulating gradients
be026fc

@longjon longjon added a commit to longjon/caffe that referenced this pull request Mar 10, 2015

@longjon longjon Merge pull request #1977 from shelhamer/accum-grad
Decouple the computational batch size and minibatch size by accumulating gradients
ae12045

@longjon longjon added a commit to longjon/caffe that referenced this pull request Mar 10, 2015

@longjon longjon Merge pull request #1977 from shelhamer/accum-grad
Decouple the computational batch size and minibatch size by accumulating gradients
10c133a

@weiliu89 weiliu89 added a commit to weiliu89/caffe that referenced this pull request Apr 1, 2015

@weiliu89 weiliu89 Merge pull request #1977 from shelhamer/accum-grad
Decouple the computational batch size and minibatch size by accumulating gradients
ad6fede

@weiliu89 weiliu89 added a commit to weiliu89/caffe that referenced this pull request Apr 14, 2015

@weiliu89 weiliu89 fix a bug in #1977 (accum-grad) suggested by narihi 0710648
Contributor

sguada commented Apr 19, 2015

@shelhamer @jeffdonahue @longjon what is happening with this PR, I think we need to find a solution and merge it as soon as possible. Actually I thought it was already merged since the solution has been around for a while.

shelhamer referenced this pull request Apr 27, 2015

Closed

dynamic input #2355

@elleryrussell elleryrussell added a commit to elleryrussell/caffe that referenced this pull request May 1, 2015

@elleryrussell elleryrussell Merge pull request #1977 from shelhamer/accum-grad
Decouple the computational batch size and minibatch size by accumulating gradients
d4ad090
Owner

shelhamer commented May 14, 2015

Accumulating gradients includes subtleties with regards to scaling gradients and hyperparameters w.r.t. to the effective batch size vs. the computational batch size batch_size and iteration size iter_size.

For merge, this needs a test that compares the updates computed by batch_size: 16 and iter_size: 1 with batch_size: 4 and iter_size: 4 for instance. We have merely reasoned our way to correctness.

Contributor

longjon commented May 15, 2015

Right, this should be fine after @shelhamer's list. My idea for a simpler implementation did not pan out; it would only have worked for SGD with momentum.

longjon and others added some commits Aug 12, 2014

@longjon @shelhamer longjon zero-init param diffs and accumulate gradients
(With layers whose backward accumulates gradients), this effectively
decouples the computational batch from the SGD minibatch. Each
iteration accumulates gradients over iter_size batches, then parameters
are updated.
41cf06c
@sguada @shelhamer sguada accumulate gradients in inner product layer 3262e46
@longjon @shelhamer longjon accumulate gradients in (de)conv layers 8cc9af0
@longjon @shelhamer longjon zero-init param diffs in gradient checker 539f879
@longjon @shelhamer longjon accumulate gradients in cudnn conv layer 67b1ff3
@shelhamer shelhamer adjust local learning rate and decay according to gradient accumulation
Divide local rate by `iter_size` to normalize the gradient according to
the full minibatch size and not only the computational batch size.

Multiply the local decay by `iter_size` to counter the division of the
local learning rate since the decay is multiplied by the rate in the
update equation.
55585f5
Owner

shelhamer commented May 28, 2015

@jeffdonahue @longjon this should finally be ready. I had to set the gradient based solver test net to constant data to avoid a tricky issue number draws but I think this is fine -- the regression targets are still random so this does give a sequence of gradients to check.

@shelhamer shelhamer test equivalence of solving with accumulating gradients
Compare the parameters after solving with a given batch size and the
halved batch size + two iter accumulation of gradients equivalent.

Note: the test net dummy data layer now makes constant data and random
gaussian targets. This assures the standard and gradient accumulation
cases check the same data. Otherwise the difference in batch sizes
causes different orders of random number draws.
92ab737
Owner

shelhamer commented May 28, 2015

Accumulation is now checked for SGDSolver as well as NesterovSolver and AdaGradSolver but AdaGrad fails at 1-e2 precision. although I can't find a mistake. I could reduce the precision for that test but that might be irresponsible...

I suspect the issue is due to how the history is recorded. The gradient might not be normalized by iter_size when its accumulated into the history.

@shelhamer shelhamer directly normalize accumulated gradients
`SGDSolver::Normalize()` normalizes accumulated gradients by scaling
inversely to the accumulation as `1 / iter_size`.

This fixes accumulation for AdaGrad and is more obvious than fooling
with rates and decays in 55585f5.
0e7a078
Owner

shelhamer commented May 28, 2015

0e7a078 makes the normalization for accumulation more obvious and fixes the issue with AdaGrad by normalizing the gradient before the update and history are computed.

However, when gradients are accumulated there's overhead for this separate scaling step. The time to update CaffeNet parameters for batch_size: 128 and iter_size: 2 rises 1.3x to 3.1 ms from 2.3 ms. If we truly care about the 1 ms / iter AdaGradSolver::ComputeUpdateValue() could be hacked instead.

Owner

shelhamer commented May 30, 2015

Merging for control of memory usage now that this is simple and tested.

@sguada sorry for the wait!

@shelhamer shelhamer added a commit that referenced this pull request May 30, 2015

@shelhamer shelhamer Merge pull request #1977 from shelhamer/accum-grad
Decouple the computational batch size and minibatch size by accumulating gradients
aeef453

@shelhamer shelhamer merged commit aeef453 into BVLC:master May 30, 2015

1 check passed

continuous-integration/travis-ci/pr The Travis CI build passed
Details

shelhamer deleted the shelhamer:accum-grad branch May 30, 2015

hli2020 commented Jun 1, 2015

I think the PReLU layer also needs to accumulate the gradients. @shelhamer

Contributor

tnarihi commented Jun 1, 2015

Here is my implementation of PReLU gradient accumulation: tnarihi/caffe@4d3fbd5
@shelhamer You can cherry-pick it (code may not be clean).

Owner

shelhamer commented Jun 1, 2015

@tnarihi sorry, I missed the PReLU accumulation when merging accumulating gradients. Could you send a PR for that particular patch? Thanks.

Thanks for the comment @hli2020

Owner

shelhamer commented Jun 1, 2015

Oh sorry, I missed the cherry pick + commit ID. That'll work fine.

@gcr gcr commented on the diff Jan 12, 2016

src/caffe/solver.cpp
@@ -469,6 +495,32 @@ void SGDSolver<Dtype>::ApplyUpdate() {
}
template <typename Dtype>
+void SGDSolver<Dtype>::Normalize(int param_id) {
+ if (this->param_.iter_size() == 1) { return; }
+ // Scale gradient to counterbalance accumulation.
+ const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
+ const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size();
@gcr

gcr Jan 12, 2016

Is this normalization correct?

Doing this will reduce the gradient by a factor of iter_size compared to computing the gradient over an entire batch. If I'm interpreting this correctly, learning rates should be multiplied by iter_size to overcome this existing code.

Or: Is learning rate automatically scaled by the batch size elsewhere, and this code is necessary to account for the effective increase in the batch size?

@shelhamer

shelhamer Jan 12, 2016

Owner

It is done this way due to the separation of Net and Solver but it is correct. Net normalizes by the (computation) batch size but only Solver knows about iter_size so it does the portion of the normalization needed to handle accumulation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment