Yet another batch normalization PR #3229

Merged
merged 2 commits into from Oct 23, 2015

Conversation

Projects
None yet
Contributor

cdoersch commented Oct 21, 2015

This PR squashes together #1965 and #3161 to make sure that proper credit is given. The final functionality is much more like #3161: we ultimately decided that the scale/shift could be implemented as a separate layer (and should hence get its own PR) and the data shuffling, if it gets merged, should also be done as a separate PR (I have not reviewed that code closely enough to say whether it is mergeable). This version includes the global stats computations, and fixes the issue where #3161 was using the biased variance estimate (took a little while to convince myself that this is indeed the correct estimator to use).

It would be great if @ducha-aiki and @jeffdonahue could take a look at this.

@ducha-aiki ducha-aiki commented on the diff Oct 21, 2015

...s/cifar10/cifar10_full_sigmoid_train_test_bn.prototxt
+ }
+ }
+}
+layer {
+ name: "pool1"
+ type: "Pooling"
+ bottom: "conv1"
+ top: "pool1"
+ pooling_param {
+ pool: MAX
+ kernel_size: 3
+ stride: 2
+ }
+}
+
+layer {
@ducha-aiki

ducha-aiki Oct 21, 2015

Contributor

This prototxt does not work, because params taken from #1965.
Corrected one is here https://gist.github.com/ducha-aiki/6457bbd49fea8b7634a7

@cdoersch

cdoersch Oct 21, 2015

Contributor

Good catch; that totally slipped my mind.

@ducha-aiki ducha-aiki commented on the diff Oct 21, 2015

src/caffe/layers/batch_norm_layer.cpp
@@ -0,0 +1,230 @@
+#include <algorithm>
+#include <vector>
+
+#include "caffe/common_layers.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void BatchNormLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ BatchNormParameter param = this->layer_param_.batch_norm_param();
@ducha-aiki

ducha-aiki Oct 21, 2015

Contributor

I`d add
CHECK_NE(top[0], bottom[0]) << this->type() << " Layer does not "
"allow in-place computation.";

@cdoersch

cdoersch Oct 21, 2015

Contributor

I believe that this layer actually should work in-place; it turns out that if you cache the variance and the output, then you don't actually need to use the bottom data to compute the gradient. Note that bottom->data is never actually used in the gradient computation. Or am I missing something?

@ducha-aiki

ducha-aiki Oct 21, 2015

Contributor

I don`t know why it does not work, I just quickly tested it on the same cifar_bn_sigmoid - just not converges at all :)

@cdoersch

cdoersch Oct 21, 2015

Contributor

The tests pass, so it's unlikely to be mathematically wrong. Note that I stripped out the shift and scaling factors, which is probably the main difference. You can try adding those back in using DummyData/Conv/Eltwise operations. In fact, it may be a good to have an example of this.

@cdoersch

cdoersch Oct 21, 2015

Contributor

Hang on--did you mean to say that in-place computation doesn't work? Did you test with and without it and get different results?

@cdoersch

cdoersch Oct 21, 2015

Contributor

I'm confused...neither of those logs seem to be using in-place computation.

@ducha-aiki

ducha-aiki Oct 21, 2015

Contributor

conv2-bn
https://gist.github.com/ducha-aiki/297ea977a0b72ad8f8f6
lines 131-135 and 507-511
layer {
name: "bn2"
type: "BatchNorm"
bottom: "conv2"
top: "conv2"
...

I haven`t set computation in-place for pool1 for clean experiment, because before I have experience problems, when dropout was done in place in top of pooling.

@cdoersch

cdoersch Oct 21, 2015

Contributor

I see it now. Hmm--this is very strange. I don't really see how this is possible, since the batch norm layer doesn't access its bottom data during backward. Maybe it's triggering bugs in other layers? I'd like to at least try to debug this.

@cdoersch

cdoersch Oct 21, 2015

Contributor

Oh, I think I get it--obviously the bottom_diff and top_diff are shared too, and the backward doesn't take this into account. Will fix.

@cdoersch

cdoersch Oct 21, 2015

Contributor

Ok, I have pushed a fix. Can you re-try?

Contributor

ducha-aiki commented Oct 21, 2015

@cdoersch Looks good for me things I have commented.

@jeffdonahue jeffdonahue commented on an outdated diff Oct 21, 2015

src/caffe/layers/batch_norm_layer.cpp
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void BatchNormLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ BatchNormParameter param = this->layer_param_.batch_norm_param();
+ moving_average_fraction_ = param.moving_average_fraction();
+ use_global_stats_ = this->phase_ == TEST;
+ if (param.has_use_global_stats())
+ use_global_stats_ = param.use_global_stats();
+ if (bottom[0]->num_axes() == 1)
+ channels_ = 1;
+ else
+ channels_ = bottom[0]->channels();
@jeffdonahue

jeffdonahue Oct 21, 2015

Contributor

bottom[0]->shape(1) for ND compatibility. (Ideally there'd be an axis parameter with default 1 to specify which axis the stats are computed along, but that can always come later, and a Reshape will work, just less conveniently.)

@jeffdonahue jeffdonahue commented on an outdated diff Oct 21, 2015

src/caffe/layers/batch_norm_layer.cpp
+ if (num_by_chans_.num_axes() == 0 ||
+ num_by_chans_.shape(0) != numbychans) {
+ sz[0] = numbychans;
+ num_by_chans_.Reshape(sz);
+ caffe_set(batch_sum_multiplier_.count(), Dtype(1),
+ batch_sum_multiplier_.mutable_cpu_data());
+ }
+}
+
+template <typename Dtype>
+void BatchNormLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ Dtype* top_data = top[0]->mutable_cpu_data();
+ int num = bottom[0]->shape(0);
+ int spatial_dim = bottom[0]->height() * bottom[0]->width();
@jeffdonahue

jeffdonahue Oct 21, 2015

Contributor

int spatial_dim = bottom[0]->count(2);

@jeffdonahue jeffdonahue and 1 other commented on an outdated diff Oct 21, 2015

src/caffe/layers/batch_norm_layer.cpp
+ // TODO(cdoersch): The caching is only needed because later in-place layers
+ // might clobber the data. Can we skip this if they won't?
+ caffe_copy(x_norm_.count(), top_data,
+ x_norm_.mutable_cpu_data());
+}
+
+template <typename Dtype>
+void BatchNormLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down,
+ const vector<Blob<Dtype>*>& bottom) {
+ CHECK(!use_global_stats_);
+ const Dtype* top_diff = top[0]->cpu_diff();
+ const Dtype* top_data = x_norm_.cpu_data();
+ Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
+ int num = bottom[0]->shape()[0];
+ int spatial_dim = bottom[0]->height() * bottom[0]->width();
@jeffdonahue

jeffdonahue Oct 21, 2015

Contributor

int spatial_dim = bottom[0]->count(2); as above -- in general, any channels(), height(), width() calls should be removed.

@cdoersch

cdoersch Oct 21, 2015

Contributor

Ah yes, I tried, but apparently I fail grep :-)

Actually, I'd rather not assume that dimension 1 exists, in case people want to batch normalization on a single number per example for whatever reason. Hence I've written it a slightly different way that should be equivalent.

Unfortunately I can't guarantee that this works without any testing on my part, but I don't want to prevent it from working.

@jeffdonahue jeffdonahue commented on the diff Oct 21, 2015

src/caffe/layers/batch_norm_layer.cpp
+ caffe_cpu_gemv<Dtype>(CblasTrans, num, channels_, 1.,
+ num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0.,
+ mean_.mutable_cpu_data());
+ caffe_cpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim,
+ 1. / (num * spatial_dim), temp_.cpu_data(),
+ spatial_sum_multiplier_.cpu_data(), 0.,
+ num_by_chans_.mutable_cpu_data());
+ caffe_cpu_gemv<Dtype>(CblasTrans, num, channels_, 1.,
+ num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0.,
+ variance_.mutable_cpu_data());
+ this->blobs_[2]->mutable_cpu_data()[0] *= moving_average_fraction_;
+ this->blobs_[2]->mutable_cpu_data()[0] += 1;
+ caffe_cpu_axpby(mean_.count(), Dtype(1), mean_.cpu_data(),
+ moving_average_fraction_, this->blobs_[0]->mutable_cpu_data());
+ Dtype m = Dtype(bottom[0]->count()/channels_);
+ caffe_cpu_axpby(variance_.count(), m/(m-1), variance_.cpu_data(),
@jeffdonahue

jeffdonahue Oct 21, 2015

Contributor

I wonder if the m in the unbiased variance correction scalar m/(m-1) should be the "batch size" rather than the total number of samples, since the "pixels" within a given batch item are not going to be IID?

edit: Looking at the batch norm paper, this seems to be what they do -- in the algorithm box at the bottom of pg 4, step 10 has the unbiased variance computation and uses m/(m-1) where m is the batch size.

@cdoersch

cdoersch Oct 21, 2015

Contributor

So the paper is actually rather unclear about this. Right above the algorithm box, it says "We have m values of this activation in the mini-batch", which doesn't make it sound like it's num_width_height.

I also thought we had some evidence that, at least as far as second-order statistics go, the features within a channel are quite uncorrelated with each other, which would suggest num/(num-1) is too severe a correction. But I'm not sure. Is there a reference implementation anywhere?

@jeffdonahue

jeffdonahue Oct 21, 2015

Contributor

Hm, reading further I think what you have here is indeed what the paper does:

In Alg. 1, we let B be the set of
all values in a feature map across both the elements of a
mini-batch and spatial locations – so for a mini-batch of
size m and feature maps of size p × q, we use the effective
mini-batch of size m′ = |B| = m · p q.

Intuitively it seems strange to me -- I would think that for many types of filters, neighboring activations within a feature map would be highly correlated. I don't recall seeing evidence to the contrary (or supporting me, this is just my intuition) but would be interested if you or anyone has a reference.

Regardless, I'll retract my suggested change as it seems like this is in fact what the batch norm paper does. Perhaps it could later be made configurable if the alternative is ever found to be useful.

@ozabluda

ozabluda Oct 23, 2015

@jeffdonahue:

I would think that for many types of filters, neighboring activations within a feature map would be highly correlated. I don't recall seeing evidence to the contrary (or supporting me, this is just my intuition) but would be interested if you or anyone has a reference.

I was also looking for studies of how correlated (how dependent, actually) neighboring results of convolutions (and activations) are within a feature map. My intuition also tells me they should be highly dependent.

@cdoersch

cdoersch Oct 23, 2015

Contributor

I'll admit that I originally thought they would be highly correlated too. However, at one point I was trying to figure out whether I needed to do whitening on conv5 features extracted from patches before plugging them into some classifier, and so I actually measured it. I don't remember the actual numbers, but I remember being amazed at how small the correlations were. Even for neighboring units within the same channel, the correlations were something like R=.05 or .1. Somebody should repeat the experiment to make sure of this, though; this was a year and a half ago, so my memory is pretty faded.

@shelhamer

shelhamer Oct 23, 2015

Owner

When I looked through layers while learning and checking fully convolutional nets I remember seeing results like this too. At some point it'd be nice to go back and investigate the spatial correlations and effective receptive field sizes of different layers.

@bhokaal2k

bhokaal2k Oct 28, 2015

It is not surprising that the neighboring units have very little or no correlation. Stacked layers of CNNs actually work more alike very strong matched filters that have the property of strong spatially-localized responses. This strong spatial localization is the basis behind the ability to regress using Fully-connected layers for the object location using CNN feature maps as the input. Also, as we move up in the layer stack a pixel-difference starts corresponding to multiple pixels differences in the original image, therefore, small spatial is expected naturally.
Note - This comment is based on the observation of some feature maps, regression based object localization performance and some intuition developed along the way working with deep CNNs.

@ozabluda

ozabluda Oct 28, 2015

Intuition kinda says that in the first convolutional layer neighboring units are highly dependent for real (non-random) filters and real images (if for no other reason that neighboring image pixels are highly dependent, as are filter weights). As you move up the layer stack, both correlation and dependency should and do decrease (corresponding to "understanding is a compressed representation" paradigm), although they will still be dependent for homogeneous areas of the image corresponding to perceptive field.

This decrease in correlation/dependency may well be critical (who knows) for FC to work and, if measured precisely enough, may well tell us when the right time is to transition to an FC layer. Although the key for the transition seems to be sparsity (causing independence as a side-effect), so independence without sparsity probably has comparatively little value anyway.

Contributor

jeffdonahue commented Oct 21, 2015

Thanks for putting this PR together @cdoersch and for the early work @ducha-aiki! Besides above comment this looks good to me. I haven't tried the examples, but if they really don't converge I'm not sure yet how we should handle that...they probably shouldn't be merged as is; broken examples aren't great for most users... Perhaps they should temporarily use the 1x1 convolution hack until we have the dedicated layers merged as well? (ParameterLayer (#2079) and ScalarLayer (#3021) would handle the scaling parameter; a BiasLayer should be PRed at some point...)

Contributor

ducha-aiki commented Oct 21, 2015

@jeffdonahue they sucessfully converge and actually work (when use fixed example). What is not working - in-place computation, which I propose for now just to flag with
CHECK_NE(top[0], bottom[0]) << this->type() << " Layer does not allow in-place computation.";

Owner

shelhamer commented Oct 21, 2015

@cdoersch @ducha-aiki my experiments with in-place computation of this layer do not converge either although I have had convergence with this version of batch norm derived from #1965 https://github.com/HyeonwooNoh/caffe/blob/master/src/caffe/layers/bn_layer.cu. It could be worth a glance before merge.

Contributor

jeffdonahue commented Oct 21, 2015

Ah, I see, thanks for the clarification/reiteration @ducha-aiki. I agree with your suggestion in that case -- there should either be such a check (and ideally the then unnecessary variance caching would be removed), or we should try to fix in-place computation.

@jeffdonahue jeffdonahue and 1 other commented on an outdated diff Oct 21, 2015

...s/cifar10/cifar10_full_sigmoid_train_test_bn.prototxt
+ backend: LMDB
+ }
+}
+layer {
+ name: "conv1"
+ type: "Convolution"
+ bottom: "data"
+ top: "conv1"
+ param {
+ lr_mult: 1
+ }
+ param {
+ lr_mult: 2
+ }
+ convolution_param {
+ num_output: 32
@jeffdonahue

jeffdonahue Oct 21, 2015

Contributor

We should have bias_term: false in the Convolution layers before BatchNorms, since the effect is cancelled out by mean subtraction, no?

@cdoersch

cdoersch Oct 21, 2015

Contributor

good point--wasted computation.

@jeffdonahue

jeffdonahue Oct 21, 2015

Contributor

Thanks for updating -- the second param also needs to be removed though (there's a CHECK failure if not), and preferably the bias_filler should be removed too (though it has no effect).

@jeffdonahue jeffdonahue commented on an outdated diff Oct 21, 2015

...s/cifar10/cifar10_full_sigmoid_train_test_bn.prototxt
+ type: "Pooling"
+ bottom: "conv1"
+ top: "pool1"
+ pooling_param {
+ pool: MAX
+ kernel_size: 3
+ stride: 2
+ }
+}
+
+layer {
+ name: "bn1"
+ type: "BatchNorm"
+ bottom: "pool1"
+ top: "bn1"
+ batch_norm_param {
@jeffdonahue

jeffdonahue Oct 21, 2015

Contributor

The train/test versions could be collapsed into one with no use_global_stats setting, since you now set the use_global_stats setting this way by default, right? I'm okay with leaving in the two separate layer configs though since being explicit is generally good.

@jeffdonahue jeffdonahue commented on an outdated diff Oct 21, 2015

...s/cifar10/cifar10_full_sigmoid_train_test_bn.prototxt
+ top: "pool3"
+ pooling_param {
+ pool: AVE
+ kernel_size: 3
+ stride: 2
+ }
+}
+
+layer {
+ name: "ip1"
+ type: "InnerProduct"
+ bottom: "pool3"
+ top: "ip1"
+ param {
+ lr_mult: 1
+ decay_mult: 250
@jeffdonahue

jeffdonahue Oct 21, 2015

Contributor

This decay_mult setting and the bias lr_mult setting are pretty weird... @ducha-aiki was there a reason for these?

Contributor

ducha-aiki commented Oct 22, 2015

@cdoersch now it works :)
@jeffdonahue thanks for catch, it was artifact of search and replace. I have cleaned up (also set bias_lr to zero for batch-normalized net)
Corrected definitions and trains logs are at https://gist.github.com/ducha-aiki/c0d1325f0cebe0b05c36

Contributor

cdoersch commented Oct 22, 2015

@jeffdonahue @ducha-aiki I've fixed the lr's and decays. Can you confirm that I've made all the required changes?

Contributor

ducha-aiki commented Oct 22, 2015

@cdoersch LGTM 👍

Contributor

jeffdonahue commented Oct 23, 2015

Thanks again @cdoersch and @ducha-aiki! LGTM as well.

@jeffdonahue jeffdonahue added a commit that referenced this pull request Oct 23, 2015

@jeffdonahue jeffdonahue Merge pull request #3229 from cdoersch/batchnorm2
Yet another batch normalization PR
39f69fb

@jeffdonahue jeffdonahue merged commit 39f69fb into BVLC:master Oct 23, 2015

1 check passed

continuous-integration/travis-ci/pr The Travis CI build passed
Details
Member

ronghanghu commented Oct 23, 2015

Great work 👍

cysin commented Oct 23, 2015

More examples and tutorials to use it?

Contributor

ducha-aiki commented Oct 23, 2015

@jeffdonahue #2996 could be used for Scale + bias (in non-channel shared mode).

beniz referenced this pull request in beniz/deepdetect Oct 23, 2015

Open

Support for batch normalization #7

2 of 4 tasks complete

beniz commented Oct 25, 2015

@cdoersch, @ducha-aiki thanks for the work throughout this PR, I can confirm training phase is working great.

However, I'm struggling with the inference from a stored model: inference (caffe::TEST) appears to work fine during training, but not from a freshly loaded model.

From investigating the code, my understanding is that mean_ and variance_ are not stored as part of the model and thus cannot be loaded up. From the paper, my understanding is that at the moment this may prevent inference from a stored Caffe model trained with BN. Is this correct ?

Contributor

cdoersch commented Oct 25, 2015

@beniz During the testing phase (or whenever use_global_stats_ is true), the mean and variance information should be read from the BatchNormLayer's parameter blobs. Whether you're doing this from a freshly loaded model or a model that you've been training shouldn't matter. If the code is really behaving as you say, then there may be a bug.

Note that switching to global statistics is not guaranteed to work. If the mean and variance statistics are not stationary (and they generally won't be unless the network has converged), then the estimates may be inaccurate. You should add some debugging statements to make sure the code is indeed executing along the paths you think it is, because I think there's a chance that you're not using global stats for your training-time test phases.

beniz commented Oct 25, 2015

All points well understood @cdoersch and thanks for the quick answer. I've re-checked the execution path through global_stats and that mean and variance are in the saved blobs, and I believe there's no bug.

So this leaves the stationary requirement (though my nets do converge and test results during training are just fine, most of the time).

Since we're at it, would it be possible to get more details on this comment from common_layers.hpp:

* By default, during training time, the network is computing global mean/                                                                                     
 * variance statistics via a running average, which is then used at test                                                                                       
 * time to allow deterministic outputs for each input.  You can manually                                                                                       
 * toggle whether the network is accumulating or using the statistics via the                                                                                  
 * use_global_stats option.  IMPORTANT: for this feature to work, you MUST                                                                                     
 * set the learning rate to zero for all three parameter blobs, i.e.,                                                                                          
 * param {lr_mult: 0} three times in the layer definition.             

Thanks again.

beniz commented Oct 25, 2015

I've put the reworked protobuf files for GoogleNet with BN here (modified from https://github.com/lim0606/caffe-googlenet-bn):
https://github.com/beniz/deepdetect/tree/master/templates/caffe/googlenet_bn

Training is fine and fast. Note that the type of the test input layer needs to be modified in order to work with straight Caffe.

Contributor

cdoersch commented Oct 25, 2015

@beniz these statistics are collected through running averages; updating the associated parameters has nothing to do with optimizing the global objective function. Hence we don't want the solver trying to update these parameters, because it will only mess them up.

Contributor

erogol commented Oct 26, 2015

@beniz I guess BN wise GoogleNet does not have intermediate supervised layer. Even if it does, I guess loss weights are wrong in your configuration file since sum of them is larger than 1.

beniz commented Oct 26, 2015

Even if it does, I guess loss weights are wrong in your configuration file since sum of them is larger than 1.

My understanding is that it doesn't matter whether they sum to 1 or not. Also, they are the same as in https://github.com/BVLC/caffe/blob/master/models/bvlc_googlenet/train_val.prototxt

The original GGNet paper reports that a single intermediate supervised layer is enough, with a 0.6% positive effect in the end. I don't remember what the BN paper says about them.

Contributor

erogol commented Oct 26, 2015

@beniz in the end, tou are right, no practical effect.

cysin commented Oct 27, 2015

@beniz I tried your googlenet_bn but got following errors:

I1027 11:19:10.922319  5376 layer_factory.hpp:76] Creating layer inception_3c/output
I1027 11:19:10.922327  5376 net.cpp:115] Creating Layer inception_3c/output
I1027 11:19:10.922333  5376 net.cpp:463] inception_3c/output <- inception_3c/3x3/bn
I1027 11:19:10.922339  5376 net.cpp:463] inception_3c/output <- inception_3c/double3x3b/bn
I1027 11:19:10.922345  5376 net.cpp:463] inception_3c/output <- inception_3c/pool/3x3_s2
I1027 11:19:10.922353  5376 net.cpp:420] inception_3c/output -> inception_3c/output
F1027 11:19:10.922371  5376 concat_layer.cpp:42] Check failed: top_shape[j] == bottom[i]->shape(j) (10 vs. 9) All inputs must have the same shape, except at concat_axis.
*** Check failure stack trace: ***
    @     0x7f6e47a1ee6d  (unknown)
    @     0x7f6e47a20ced  (unknown)
    @     0x7f6e47a1ea5c  (unknown)
    @     0x7f6e47a2163e  (unknown)
    @     0x7f6e483d082d  caffe::ConcatLayer<>::Reshape()
    @     0x7f6e4848d3f6  caffe::Net<>::Init()
    @     0x7f6e4848e4a5  caffe::Net<>::Net()
    @     0x7f6e484e36fa  caffe::Solver<>::InitTrainNet()
    @     0x7f6e484e4cbc  caffe::Solver<>::Init()
    @     0x7f6e484e5009  caffe::Solver<>::Solver()
    @     0x7f6e484ad883  caffe::Creator_SGDSolver<>()
    @           0x411d36  caffe::SolverRegistry<>::CreateSolver()
    @           0x40a1bb  train()
    @           0x407ef1  main
    @     0x7f6e3d630af5  __libc_start_main
    @           0x40868d  (unknown)
Aborted (core dumped)

Am I missing something?

beniz commented Oct 27, 2015

@cysin weird, it's been running on several datasets including ilsvrc on my side without issue. Though note that the inception_3c is an area where the network differs from the bvlc_googlenet.

You can find a picture of the full BN net here: https://raw.githubusercontent.com/lim0606/caffe-googlenet-bn/master/inception_bn.png

Are you running this network with a crop size different than 224 maybe ?

@cysin
The pooling layer's output size is calculated differently with convolution's.
For convolution, it is

const int output_dim = (input_dim + 2 * pad_data[i] - kernel_shape_data[i])
        / stride_data[i] + 1;

Since all variables are positive integer, this is equivalent to floor((input_dim + 2 * pad_data[i] - kernel_shape_data[i]) / stride_data[i]) + 1.

While for pooling, it is somewhat weird

pooled_height_ = static_cast<int>(ceil(static_cast<float>(
      height_ + 2 * pad_h_ - kernel_h_) / stride_h_)) + 1;
  pooled_width_ = static_cast<int>(ceil(static_cast<float>(
      width_ + 2 * pad_w_ - kernel_w_) / stride_w_)) + 1;
  if (pad_h_ || pad_w_) {
    // If we have padding, ensure that the last pooling starts strictly
    // inside the image (instead of at the padding); otherwise clip the last.
    if ((pooled_height_ - 1) * stride_h_ >= height_ + pad_h_) {
      --pooled_height_;
    }
    if ((pooled_width_ - 1) * stride_w_ >= width_ + pad_w_) {
      --pooled_width_;
    }
    CHECK_LT((pooled_height_ - 1) * stride_h_, height_ + pad_h_);
    CHECK_LT((pooled_width_ - 1) * stride_w_, width_ + pad_w_);
  }

They do not always match with each other. So we must be careful with the pad parameter. The prototxt definition in https://github.com/lim0606/caffe-googlenet-bn is designed only for input size of [224, 224]. For other sizes, we should manually add or remove the pad: 1 paramter from the pooling_param of inception_3c/pool and inception_4e/pool.

cysin commented Oct 28, 2015

@beniz It turned out to be the size issue. I changed the image size to 224x224 and now it seems all fine.

lukeyeager referenced this pull request in NVIDIA/caffe Oct 28, 2015

Merged

Cherry-pick batch normalization PR #51

cysin commented Oct 30, 2015

Can the bn layer be removed at test phase? If not, can this feature be added?

beniz commented Oct 30, 2015

@cysin not sure who you are asking, but yes my understanding is that BN layer is required during test phase. From the code, the BN layer switches to global stat whenever the caffe::TEST phase is on, so you should not have to add anything to the .prototxt files regarding the BN layer, if this is what you're asking...

beniz commented Nov 9, 2015

This one PR seems to be relevant here: #3299

@Hrant-Khachatrian Hrant-Khachatrian commented on the diff Nov 19, 2015

...ples/cifar10/cifar10_full_sigmoid_train_test.prototxt
+ convolution_param {
+ num_output: 32
+ pad: 2
+ kernel_size: 5
+ stride: 1
+ weight_filler {
+ type: "gaussian"
+ std: 0.01
+ }
+ bias_filler {
+ type: "constant"
+ }
+ }
+}
+
+
@Hrant-Khachatrian

Hrant-Khachatrian Nov 19, 2015

Somewhat unrelated to batch normalization, but is it intentional to use conv -> pooling -> sigmoid in the first layer and conv -> sigmoid -> pooling in the second layer?

@ducha-aiki

ducha-aiki Nov 19, 2015

Contributor

Not. Intention was to reduce memory usage by conv -> pooling -> sigmoid, but missed it in 2nd layer.

Do I need to define a batch norm layer separately for train & test phases? (i.e. like below) (removed params to shroten) or if I define it just once with no phase mentioned, it would automatically switch to global stat during test phase?
layer {
name: "bn1"
type: "BatchNorm"
bottom: "conv1"
top: "bn1"
batch_norm_param {
use_global_stats: false
}
include {
phase: TRAIN
}
}
layer {
name: "bn1"
type: "BatchNorm"
bottom: "conv1"
top: "bn1"
batch_norm_param {
use_global_stats: true
}
include {
phase: TEST
}
}

@beniz @cysin : Would you guys be able to share your GN-BN caffe model?

beniz commented Nov 28, 2015

@siddharthm83 yes I believe it switches automatically. As for the model, I guess you are asking for a trained model.

While I do have one that I can share, I haven't done so yet because I cannot get it to work in pure prediction mode correctly outside the training/testing phase.

Culprit as far as my short investigation did go are the lines https://github.com/BVLC/caffe/blob/master/src/caffe/layers/batch_norm_layer.cpp#L143-L149

Somehow, when predicting over a single image for instance, these two operations turn final softmax into NaN. @cdoersch to be very honest I'm not sure what you mean by 'replicate variance to input size'.

Contributor

cdoersch commented Nov 28, 2015

@beniz the input is NxCxHxW, but the estimate of the variance (variance_) is 1xCx1x1. Hence, variance needs to be replicated (i.e. tiled) so that it's the same size as the input.

The final line is the one that actually divides the input by the variance (in retrospect, this line should probably have a separate comment). If you have a single input, and you're using fully-connected layers, you're going to be measuring the variance of a single number, which is not defined (I'm pretty sure it will be calculated as 0, which will produce NaNs when you divide).

If you're seeing the NaN's when use_global_stats is true, however, then it's probably a different issue. Note that the original version of the code for accumulating global statistics was incorrect; it was fixed recently, but the fix is likely to cause problems with global statistics that were computed using the old code.

@beniz , yes, I was referring to the GN-BN model trained on Imagenet-1000. In testing phase, shouldn't the global mean/variance (that is stored during training) be used? Mean/Variance shouldn't be calculated. I can check too on my side (in the middle of training a smaller network; will check when training is finished).

beniz commented Nov 29, 2015

If you're seeing the NaN's when use_global_stats is true, however, then it's probably a different issue. Note that the original version of the code for accumulating global statistics was incorrect; it was fixed recently, but the fix is likely to cause problems with global statistics that were computed using the old code.

@cdoersch thanks, this was very certainly the culprit. Tried again with intermediary snapshot from a new run and it's working fine now.

@siddharthm83 the use_global_stats is set to true in TEST (aka prediction) mode, and thus the mean/variance are reused.

Good to hear @beniz. If you do plan to upload it to the model zoo, let me know (saves some time in training one myself :).

Does anyone have a train_val file for the Inception-BN network? Thank you so much!

Cui

@cuihenggang: @beniz has an example here: https://github.com/beniz/deepdetect/tree/master/templates/caffe/googlenet_bn
I am yet to try it out but as per the above thread it should work. Its on my to do list.

@siddharthm83 Great! Thank you @beniz

@beniz ,
I noticed that there is no scale/shift layer in your googlenet_bn implementation. This is distinct from the original paper. Is this a new research advancement?

beniz commented Dec 4, 2015

@happynear whereat ? I may have missed something.

How can I set a constant learning rate for the batch norm parameters (gamma, beta). (i.e. for eg: i want the learning rate for the conv network/fully connected etc to start from 0.01 and cool down by a factor of 10 every n epochs but I want to keep my batch_norm parameters (gamma,beta as in the paper) constant at say 0.01 throughout learning. If this is possible, some info on how to change the prototxt will be much appreciated :) @cdoersch

Contributor

cdoersch commented Dec 5, 2015

@siddharthm83 I feel like I answered this somewhere before, but I can't find it now...

The BatchNorm layer doesn't implement the scale/shift, as it's pretty straightforward to do it with existing layers, and it's not always used. Use a DummyDataLayer followed by a conv layer to get a learned value that's distinct for each channel. Then you can use an eltwise product (for gamma) and an eltwise sum (for beta).

Thanks @cdoersch . I read the code as well and I understand that the batch norm layer gives only (x-mean)/(sqrt(var)).
@beniz , What @happynear is saying is correct in the sense that your prototxt does not completely implement batchnorm as per paper. See comment above. You would need to add additional layers. You will still get some performance without learning the scale/shift but performance will be better if you implement them.
From the paper (http://arxiv.org/abs/1502.03167):
Note that simply normalizing each input of a layer may change what the layer can represent. For instance, normalizing the inputs of a sigmoid would constrain them to the linear regime of the nonlinearity. To address this, we make sure that the transformation inserted in the network can represent the identity transform. To accomplish this, we introduce, for each activation x(k), a pair of parameters γ(k), β(k) , which scale and shift the normalized value

@cdoersch , if I understand correct, it is a bit tedious to implement scale/shift since one would need to manually set the size of Dummydatalayer to match the size of every batchnorm layer. I will give it a shot anyway. Perhaps, a PR like #2996 could be merged? @ducha-aiki .

@siddharthm83 ,
I think the EltwiseAffine layer should be merged and integrated into the BatchNorm layer, like Softmax and SoftmaxWithLoss layer. The current implementation would (maybe already) mislead lots of people who haven't read the original paper.

Contributor

cdoersch commented Dec 7, 2015

@siddharthm83 @happynear I totally agree, the situation right now is pretty awkward. My impression, though, is that there's some high-level philosophical issues here: we don't want the layer catalogue to grow too large, and we don't want to overfit caffe to the way we do neural networks right now. I know that for a long time, the caffe developers considered treating parameters like we treat data: they're just another input to the network, just one that gets updated by SGD rather than being updated through disk reads. If you had such a parameter layer, plus an Eltwise layer that supports expansion of singleton dimensions, then the EltwiseAffine layer would be totally redundant. We probably wouldn't want to support it going forward, and caffe hates deprecating things.

Of course, if the caffe devs have no plans to move forward on that front, then I'm all for merging the EltwiseAffine layer. It's definitely better than the situation we have now.

It's unlikely that the EltwiseAffine layer would be included by default with BatchNorm layers, just for backwards compatibility reasons.

Owner

shelhamer commented Dec 7, 2015

@cdoersch a ChannelScalarLayer for the scale and a BiasLayer for the shift make this less awkward -- especially when defining everything by net spec -- so merging these layers could help and furthermore have other purposes.

beniz commented Dec 7, 2015

@siddharthm83 @happynear thanks for catching this, I've re-read the original paper this week-end as well. The ReLU of GoogleNet may mitigate the effect for now, but I'll update the net once there's agreement on the best way to fix the BN. My favor goes to an integrated scale/shift of course.

Contributor

ducha-aiki commented Dec 7, 2015

@beniz @shelhamer @happynear @cdoersch Let me as author of EltwiseAffine tell something against it in special BatchNorm case :)
First I have implemented it to reproduce fully results of the original paper. But now:

  1. Don`t see why following convolution layer could not learn scale or bias integrated itself.
  2. More important dont see any difference with EltwiseAffine and without it. Now I am making quick expriment on BN-Caffenet-128 (instead of 227 image size for speed): original, this BN, this BN + EltwiseAffine, this BN, but after ReLU, not before. As soon as it will finish, Ill put results online, but so far variant "this BN, but after ReLU" is leading by great margin.

ChannelScalarLayer and BiasLayer makes sense for me :)

@cdoersch @shelhamer , thanks for the response. A ChannelScalarLayer and a BiasLayer would be great.
For some of the nets that I tested, I don't see any performance difference between adding BatchNormLayer as-is (compared to the same architecture without batchnorm). In addition, there is a lot of noise/variation at higher learning rates. I am not sure if others observed similar behavior.

Contributor

cdoersch commented Dec 7, 2015

@ducha-aiki that's a very interesting result! As a side note, my ICCV paper (http://arxiv.org/abs/1505.05192) doesn't actually use the scale/shift layer, because my goal was to prevent the network activations from collapsing to zero. I haven't tried the network with the scale/shift layer.

In my opinion, this is one of the main advantages of batch normalization: it forces the network to try to learn something even when the problem is extremely hard, rather than giving up and ignoring the input. A few others at CMU working on unsupervised learning have been using it this way as well. Do the scale/shift layers break this property? Not clear, though it seems like they might. If anyone wants to play around it, the source code for that paper is up now.

@ducha-aiki , do you have an example prototxt? When I use the EltwiseAffineLayer I don't see the net converging. This is how I am using it. Perhaps I am not using it correctly.
layer {
name: "ea1"
type: "EltwiseAffine"
bottom: "bn1"
top: "ea1"
eltwise_affine_param {
slope_filler{
type: "gaussian"
std: 0.01
}
bias_filler{
type: "gaussian"
std: 0.01
}
}
}

Contributor

ducha-aiki commented Dec 7, 2015

@siddharthm83
Currently I haven`t ready build to check, but when I used it, parameters were smth. like:

slope_filler{
type: "constant"
value: 1
}
bias_filler{
type: "constant"
value: 0.0001
}

Because you don`t want to decrease activations 100x times after batch normalization, which your code do.
And set learning rate and weight decay same as for other layers

@ducha-aiki , That makes sense.
So I changed it as per above and I find that theEltwiseAffineLayer only works for me in CPU mode (loss is converging). In GPU mode, the loss is constant at 87.3365 and does not decrease. In fact, this is the same number I got with 2 different net architecture (one was caffenet with batchnorm and eltwiseaffine). I added a batchnorm layer and eltwiseaffine layer after all relu's. Interestingly enough if I have only 1 batchnorm followed by eltwiseaffine, it works. So perhaps, the other layers are compensating for whatever wrong that's happening.

I will send some sample snippets and prototxt file in your PR page.

@ducha-aiki

Adding BN without scale/shift before ReLU really hurts the complexity of the model, since the output of BN is expected to have zero mean, making the bias term in convolution layer meaningless.

However, when the BN layer is added after ReLU, there is no need to append a scale/shift layer after it because scale/shift are also linear transformations. In the Batch Normalization paper, they did not do any experiments to analyse where to put the BN layer is the best. They claimed that BN + ReLU can produce a stable distribution. I can't understand this. I am looking forward for your results.

Contributor

ducha-aiki commented Dec 9, 2015

@happynear @siddharthm83 @cdoersch @beniz @shelhamer
Here results of tests.
https://github.com/ducha-aiki/batchnorm-benchmark
Not yet with EltwiseAffine, because each training takes 48 hours.

Hi, I assume the BatchNormalization layer is pretty much done (is it?). I'm wonder has anyone tried training the Ilsvrc12 task using the Inception-BN network? What validation accuracies have we got here?

Contributor

cdoersch commented Jan 12, 2016

Ok, I chatted with some other BVLC folks and it sounds like we're going to go ahead and merge some kind of per-channel scale/shift layers. What PRs do we currently have that do this?

I'm currently aware of #2996, as well as #2079/#3021. I also vaguely remember someone referencing a PR with separate scaling and shifting layers, but I can't find it now, so maybe I'm imagining things.

Let's first try to come to some consensus about which one should be merged, and then I'll review it. Right now I think #2996 looks pretty straightforward.

Contributor

ducha-aiki commented Jan 13, 2016

@cdoersch I am for separating #2996 into bias and scale, but also OK with it as now (but rebased to current). So should I start doing this?
@shelhamer @ronghanghu

Contributor

cdoersch commented Jan 13, 2016

I'm not sure I see the point in separating it into two layers. I think it would be better if there's just options in the prototxt to turn off the bias or scale and save the computation if desired. I think the general use-case for this will involve both of them, and combining them should save memory.

I might pick a different name though--EltwiseAffine makes it sound like something different is happening for each element, when it's really a per-channel operation. Maybe PerChannelAffine?

Contributor

ducha-aiki commented Jan 13, 2016

@cdoersch agree, flags for turn off/on are better, than separation.
ChannelWiseAffine?

Contributor

cdoersch commented Jan 13, 2016

@ducha-aiki I guess to follow the 'Eltwise' pattern, it should be ChanwiseAffine. Probably ChannelwiseAffine would more clear though.

Contributor

ducha-aiki commented Jan 13, 2016

@cdoersch OK, then I will clean it up, add flags and rebase.

Contributor

jeffdonahue commented Jan 13, 2016

@cdoersch ScalarLayer (#3021) and BiasLayer (in the newly-created #3550) are what I've been using to learn the BN params. I'd be interested to know how the performance and memory use compares with the combined approach in #2996 from @ducha-aiki.

Contributor

ducha-aiki commented Jan 13, 2016

@jeffdonahue I can test both variants.

P.S. caffenet128 training with BN-EA layer is almost come to finish and it looks like EA helps at least with BN before non-linearity setup. Will see if it helps for BN-after-ReLU, which performs much better.

Contributor

jeffdonahue commented Jan 13, 2016

Thanks @ducha-aiki, that would be great. For reference, this Python function (using NetSpec) should do in-place batch norm with ScalarLayer + BiasLayer:

def batch_norm(x, with_params=True, lr_mult=1, in_place=True):
    param = [dict(lr_mult=0)] * 3
    out = L.BatchNorm(x, param=param, in_place=in_place)
    if with_params:
        param = [dict(lr_mult=lr_mult, decay_mult=0)]
        kwargs = dict(axis=1, num_axes=1, param=param, in_place=in_place)
        out = L.Scalar(out, **kwargs)
        out = L.Bias(out, **kwargs)
    return out
Contributor

ducha-aiki commented Jan 14, 2016

@jeffdonahue
The memory consumption is identical up to 1 Mb for BN-caffenet. I just have nvidia-smi when training networks.

Speed:
CA:
I0114 18:22:07.118582 11780 caffe.cpp:358] ChannelwiseAffine1 forward: 0.3755 ms.
I0114 18:22:07.118588 11780 caffe.cpp:361] ChannelwiseAffine1 backward: 1.17972 ms.

S+B
I0114 18:23:03.240345 11875 caffe.cpp:358] Scalar1 forward: 0.352228 ms.
I0114 18:23:03.240351 11875 caffe.cpp:361] Scalar1 backward: 0.521535 ms.
I0114 18:23:03.240358 11875 caffe.cpp:358] Bias1 forward: 0.176595 ms.
I0114 18:23:03.240365 11875 caffe.cpp:361] Bias1 backward: 0.72729 ms.
Sum:
forward: 0.528823 ms
backward: 1.248825 ms

So my implementation is faster a bit.

Contributor

ducha-aiki commented Jan 16, 2016

BN+EA

Hi everyone, I am not clear what the param "moving_average_fraction" means and how to determine its value, could anyone give me any hint?

jeffdonahue referenced this pull request Jan 23, 2016

Merged

Scale and Bias Layers #3591

classner commented Feb 3, 2016

@nian-liu It is the weight with which the'old' moving average parts are down-weighed every iteration, i.e., (1-moving_average_fraction) gives a measure for the speed of decay of the mean (the higher, the higher the decay).

I just observed that this layer is not using the running mean and variance with exponential decay during the training. This means, that it becomes 'risky to use' (to say the least) with very small batch sizes, in the most extreme case with batch size one (this has the nice advantage for semantic segmentation that the network input can be dynamically resized). It can especially lead to large discrepancies between training and testing performance, when the per batch statistics do not approximate the global statistics well.

What was the reasoning behind the decision to do this? It requires the assumption that the mean and variance over a batch are a reasonable estimate for the mean and variance of the entire dataset to hold. This may or may not be the case, and increasingly not ;) for small batch sizes.

So do we have a new train_val protobuf for ImageNet that includes scale/shift operations after batch normalization?

Actually while looking at the BatchNormLayer implementation, I find some issues.

I find there is a default option of "use_global_stats_" that let the BatchNormLayer store and use the moving average of the mean and variance values. I think in the original paper, they normalize only inside the mini-batch, without considering the previous samples, which makes their normalized mini-batch white (mean=0 and variance=1). I think with the use of moving average of mean and variance, the output of our BatchNormLayer won't be white, because they are not the real mean and variance of this mini-batch. Will this cause any problems?

Actually, by default during training it does not use the moving average, during testing it does. That may lead to exactly the problem I described above...

ducha-aiki referenced this pull request in smichalowski/google_inception_v3_for_caffe Mar 16, 2016

Closed

Where is ScaleBias layers? #2

Actually I have a question regarding to the BatchNormLayer design. Are there any specific reasons why we choose to implement scale and bias in a separate ScaleLayer, rather than implementing it inside the BatchNormLayer? Aren't we consuming more memory from adding an extra ScaleLayer after each BatchNormLayer?

d4nst commented Jun 15, 2016

I am currently testing some nets with and without batch normalization and I see that the memory consumption for nets with batch normalization is twice as much. For example, using this resnet I can train with a batch size of 24 images in my GPU. However, if I remove all the BatchNorm layers I can use up to 60 images or so. The problem seems to be in the BatchNorm backward pass.

What is the reason for this? It seems like a very high memory compsumption.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment