Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code for reproducing cifar-10 examples in "Deep Residual Learni… #38

Merged
merged 8 commits into from Feb 1, 2016

Conversation

auduno
Copy link
Contributor

@auduno auduno commented Dec 26, 2015

Here's the code for reproducing the cifar-10 examples in "Deep Residual Learning for Image Recognition". The code is based on the MNIST example, feel free to reformat it. Note that it also depends on batch-normalization code in PR #467.

Training a 32-layer network (n=5) with learning parameters similar to the descriptions in the paper, I got validation error 6.88%, which actually is slightly better than the error in the paper. I wanted to try a 56-layer network as well, but currently this fails with error "RuntimeError: maximum recursion depth exceeded while calling a Python object".

There still seems to be some differences between this model and the one used in the paper, as my model seems to have slightly more parameters than what is described in the paper, e.g. the 1202-layer version has 19.6M parameters versus 19.4M in paper. This model also seems to learn a bit slower and more unstably than what I see in figure 6 in the paper, though the final accuracy is similar. It's actually not entirely clear to me what an iteration is in the paper, since they say "trained with a minibatch size of 128 on two GPUs", so I don't know if an iteration is equal to one or two minibatches of size 128. I've assumed an iteration is a single minibatch of 128, but two minibatches of 128 would make learning speed more similar.

Let me know if I should upload the weights for the trained 32-layer model as well, and of course if you discover any discrepancies between this model and what is described in the paper.

@ebenolson
Copy link
Member

Thanks! Look forward to trying it out, I've been meaning to read that paper.

About how long does it take to train?

When you say it learns unstably do you mean the loss is noisy, or that on some runs it diverges/fails to learn? If the second, might be good to seed the RNG with a known good value.

You might be able to fix the RuntimeError with something like

import sys
sys.setrecursionlimit(10000)

@bobchennan
Copy link

I'm running it now. Let me try it.

@auduno
Copy link
Contributor Author

auduno commented Dec 26, 2015

Training a 32-layer network takes 8-9 hours on an EC2 g2.2 instance with CuDNN3.

By unstable, I meant that the loss is more noisy than what it looks like in figure 6 in the paper. The training always seems to converge nicely, though I've only done a handful of runs so far.

I've started training a 56-layer network by adjusting the recursion limit as you mention, so I'm going to run it over night and will report the result.

@ebenolson
Copy link
Member

Great, I'll wait to merge this then in case you want to add that?

It would be nice to have pretrained weights available. If the file is small (<10M), you can add it directly to the repo. Otherwise I can send you access credentials for the Recipes S3 bucket.

@auduno
Copy link
Contributor Author

auduno commented Dec 26, 2015

Yes, merging can wait. Also don't know whether this merge should wait until PR #467 is merged?

The pretrained weights for the 32-layer network is only 1.9MB, so I can add it directly in repo. Where should I put it? Any specific format you want it in?

@bobchennan
Copy link

Running result on our computer:

Using gpu device 0: GeForce GTX 970 (CNMeM is disabled)
Loading data...
Building model and compiling functions...
number of parameters in model: 468666
Starting training...
......
Epoch 80 of 82 took 255.581s
  training loss:        0.183317
  validation loss:      0.334439
  validation accuracy:      92.22 %
Epoch 81 of 82 took 255.608s
  training loss:        0.182729
  validation loss:      0.335890
  validation accuracy:      92.18 %
Epoch 82 of 82 took 255.650s
  training loss:        0.181868
  validation loss:      0.334563
  validation accuracy:      92.11 %
Final results:
  test loss:            0.334563
  test accuracy:        92.11 %

real    351m31.140s
user    260m59.073s
sys 90m53.047s

@auduno
Copy link
Contributor Author

auduno commented Dec 27, 2015

@bobchennan was this results for a 56-layer network? I seemed to get similar results, with 92.17% test accuracy, i.e. a bit poorer than the paper.

@bobchennan
Copy link

No, it is the result given by 32 layers(without any parameters specified,
so n=5):

92.34% on validation set
92.11% on test set

I slightly change the structure and get better result:

92.64% on validation set
92.38% on test set

I think you'd better specified the random seed. It helps to get stable
result.

2015-12-27 17:38 GMT+08:00 Audun Mathias Øygard notifications@github.com:

@bobchennan https://github.com/bobchennan was this results for a
56-layer network? I seemed to get similar results, with 92.17% test
accuracy, i.e. a bit poorer than the paper.


Reply to this email directly or view it on GitHub
#38 (comment).

@auduno
Copy link
Contributor Author

auduno commented Dec 27, 2015

Yes, the final accuracy seem to vary a bit, but I'm not sure if it's within what is expected, in the paper they report a standard error of 0.16 on the 110-layer model. It's probably a good idea to specify the random seed, so I'll look into that. I've also noticed I have a batch-normalization layer both before and after summing shortcuts, which is probably unnecessary, so I'll remove it.

How did you change the structure to get better results?

@bobchennan
Copy link

Used 10 filter maps before the GlobalPool(Mentioned in the NIN Paper):

# average pooling
l = GlobalPoolLayer(ConvLayer(l,num_filters=10,filter_size=1,nonlinearity=rectify,W=lasagne.init.HeNormal(gain='relu')))
#cnx

Removed the shortcuts when the dimension increased:

if increase_dim:
    return NonlinearityLayer(stack_2, nonlinearity=rectify) #cnx

# identity shortcut, as option A in paper
# we use a pooling layer to get identity with strides, since identity layers with stride don't exist in Lasagne
identity = Pool2DLayer(l, pool_size=1, stride=(2,2), mode='average_exc_pad')
padding = PadLayer(identity, [out_num_filters/4,0,0], batch_ndim=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably better be out_num_filters // 4 so the result is always int.

@auduno
Copy link
Contributor Author

auduno commented Jan 1, 2016

Ok, I think this is ready for merge if it looks good.

After removing superfluous batch-normalizations, the network seems to learn just as fast and stable as in the paper, and with similar final accuracy. For the 56-layer network I didn't manage to reach error of 6.97, only 7.23, but there is some variance in the final accuracy (as reported in the paper for 110-layer network) so I might reach it if I run it more times. The model still seems to have slightly more parameters than reported in the paper (19.5M for the 1202-layer network, versus 19.4M reported in the paper), but I can't quite figure out where the difference is.

I can upload the trained 32-layer and 56-layer model, but not sure where I should put it? Also, I tried setting up seeding, but didn't manage to get consistent results even though I disabled cuDNN and set seed for cropping and shuffling and in lasagne, so not sure whether the issue is.

padding = PadLayer(identity, [out_num_filters//4,0,0], batch_ndim=1)
block = NonlinearityLayer(ElemwiseSumLayer([stack_2, padding]),nonlinearity=rectify)
else:
block = NonlinearityLayer(ElemwiseSumLayer([stack_2, l]),nonlinearity=rectify)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello! I would be extremely interested to see whether performance improves if you remove the ReLU layer just after each block.

In an experiment on my own Torch implementation of residual networks, removing the ReLU layer after the building block noticeably improves initial convergence. I think this is because the ReLU layer at the end of each block mutates the input, making identity connections no longer possible.

I'm curious to see whether this has an effect on this project as well. Or perhaps there's some other bug in my own implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I tried to delete the ReLU layer at the end of the block, and indeed it seems to converge faster in the beginning, though it settles into about the same convergence speed after a while. I didn't fully train the model though, so I don't know if the final accuracy is better or the same as the one described in the paper.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for rerunning your experiments! Good to know it might have some effect.

@ebenolson
Copy link
Member

I can upload the trained 32-layer and 56-layer model, but not sure where I should put it?

Can you make a new subdirectory under papers for this project, and move the files there?

@f0k
Copy link
Member

f0k commented Jan 3, 2016

The model still seems to have slightly more parameters than reported in the paper (19.5M for the 1202-layer network, versus 19.4M reported in the paper), but I can't quite figure out where the difference is.

Would be good to further scrutinize this, to ensure the model is identical to the one in the paper. I don't have the time to check the full paper right now, can you give us some pointers to which pages/sections the model you're trying to replicate is specified at, so we can compare to your code? Thanks!

@auduno
Copy link
Contributor Author

auduno commented Jan 3, 2016

The examples I'm trying to reproduce is from chapter 4.2 in the paper. However, I wonder if the culprit of the increased amount of parameters is the way the parameters is counted. The batch-normalization layers in lasagne count four parameters (mean,std,beta and gamma) for each feature, i.e. for the 1202-layer model (n=200) we introduce 89600 parameters ((16+32+64) * 4 * 200) per BN in the residual blocks. I was wondering though, if it's common to count the mean and std parameters of the BN as parameters of the model, as they're not "learned", but rather "given" by the dataset. If they did not count these in the paper, then the parameter count would be the same, since we lose 89600 parameters for the 1202-layer model.

@ebenolson : It turns out I'd copied the wrong model files from my EC2 instance, so I'll have to retrain the models again. I'll upload them once I've got them trained.

@f0k
Copy link
Member

f0k commented Jan 4, 2016

I was wondering though, if it's common to count the mean and std parameters of the BN as parameters of the model, as they're not "learned", but rather "given" by the dataset.

That's a good explanation for the difference. Note that you can do lasagne.layers.count_params(network, trainable=True) to only count trainable parameters. Nice you've been able to figure it out!

@erfannoury
Copy link

Hi,
I tried running this code (with Conv2DDNNLayer replaced by the vanilla Conv2DLayer), but unfortunately I get lots of complaints from Theano. Lots of errors like this:

ERROR (theano.gof.opt): Optimization failure due to: LocalOptGroup(local_abstractconv_cudnn,local_conv_dnn,local_abstractconv_gemm,local_abstractconv_gradinputs_gemm,local_abstractconv_gradweight_gemm,local_conv_gemm)
(...)
ValueError: shape must be given if subsample != (1, 1) or border_mode == "half"

NotImplementedError: AbstractConv2d theano optimization failed. Did you exclude both "conv_dnn" and "conv_gemm" from the optimizer?
Apply node that caused the error: AbstractConv2d{border_mode='half', subsample=(1, 1), filter_flip=True, imshp=(None, 128, 112, 112), kshp=(128, 128, 3, 3)}(GpuElemwise{Composite{((i0 * i1) + (i2 * Abs(i1)))}}[(0, 1)].0, W)

What is the problem? It's weird since I have been using the Conv2DLayer without any issues.

@benanne
Copy link
Member

benanne commented Jan 8, 2016

I believe this may be a Theano issue: Theano/Theano#3845

@erfannoury
Copy link

@benanne but I'm using Conv2DLayer instead of the Conv2DDNNLayer. But the issue you referenced is about the half mode in Conv2DDNNLayer.

@f0k
Copy link
Member

f0k commented Jan 9, 2016

But the issue you referenced is about the half mode in Conv2DDNNLayer.

Not exactly. The issue is about the "half" mode in Theano's cuDNN convolution wrapper. This is not only used when you explicitly use Lasagne's Conv2DDNNLayer, but also when you use the Conv2DLayer and cuDNN is available -- Theano will automatically use cuDNN then. You can circumvent the problem by passing THEANO_FLAGS=optimizer_excluding=dnn, or by using Conv2DMMLayer, or by merging the referenced PR into your Theano installation.

@erfannoury
Copy link

Oh I see. As you said, I first set the Theano flag to exclude dnn. However, it complained about the cuDNN again. Afterwards I manually replaced the Conv2DLayer with the Conv2DMMLayer and it finally compiled and worked.
However, I fear that this code, despite having less number of parameters (~30 Million vs. >130 Million for VGG 16), consumes much more memory. I can't run the code due to memory errors. How can I reduce the memory footprint of this network so that I can run the code? Does using Conv2DMMLayer instead of the Conv2DLayer make that much of a difference in terms of memory usage?

@erfannoury
Copy link

I now understand that the memory used isn't completely correlated with the number of parameters. After profiling the program I found out that the ElemwiseSumLayer portion of the network consume the most of the memory, since two (huge) matrices are added together. I think that I have no other solution but to use a larger GPU (my current GPU has 6GB of memory). However I'm interested in tips and tricks for decreasing the memory consumption of this particular network.

@benanne
Copy link
Member

benanne commented Jan 11, 2016

If nothing helps, you could split up the batches into two parts passed separately and add up the gradients...

Surely then it would be more sensible to just halve the batch size and do twice as many updates :)

@erfannoury
Copy link

Currently I can't use cuDNN, that's why I have to use Conv2DMMLayer.
This is my current Theano configuration:

[global]
device=gpu0
floatX=float32
optimizer_including=conv_gemm
optimizer_excluding=dnn
profile=True
optimizer=fast_run
profile_memory=True
exception_verbosity=high
allow_gc=1
[nvcc]
fastmath = True
[lib]
cnmem=0.985

I haven't even been able to run the code with minibatch size of 1! However, there might be errors in my code. Specially I intended to implement a model inspired by ResNet with residual blocks of 3 convolutional layers. I have recently found out about (a seemingly reference) implementation of the 152-layer ResNet. I will try to implement a Residual network based on this implementation until this weekend and I may be able to run it on a bigger GPU. And I will be using cuDNN as well. I have now managed to run a very small network similar to the current one with initial filter size of 16 and 4 blocks of increasing dimensionality (total of 9 conv layers) with very small number of parameters (~457k).

Another observation from the output of the profiling was that GlobalPoolingLayer takes almost all of the processing time, since a Python code for this layer is running.

Ops
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Op name>
  99.0%    99.0%     291.007s       1.21e+01s     Py      24        3   AveragePoolGrad{ds=(1, 1), ignore_border=True, st=(2, 2), padding=(0, 0), mode='average_exc_pad'}

There isn't a C implementation for this layer or it switches back to the Python implementation for some reason?

@f0k
Copy link
Member

f0k commented Jan 12, 2016

Another observation from the output of the profiling was that GlobalPoolingLayer takes almost all of the processing time, since a Python code for this layer is running.

Ops
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Op name>
  99.0%    99.0%     291.007s       1.21e+01s     Py      24        3   AveragePoolGrad{ds=(1, 1), ignore_border=True, st=(2, 2), padding=(0, 0), mode='average_exc_pad'}

This is not the GlobalPoolingLayer, but an average pooling layer of size (1, 1) and stride (2, 2). This is available in cuDNN, and it's possible that Theano doesn't have an alternative GPU implementation of it -- until recently, it only supported max-pooling, and average pooling was just added for compatibility to the corresponding cuDNN function. If you set the mode to 'max', it will be a lot faster. Using a pooling layer for this operation is overkill, though. The fastest option will probably be an ExpressionLayer(layer, lambda X: X[:, :, ::2, ::2], lambda s: (s[0], s[1], s[2]//2, s[3]//2)).

Surely then it would be more sensible to just halve the batch size and do twice as many updates :)

This can lead to worse results, though, and if the idea is to perfectly follow the paper, the batch size should probably stay the same. But yeah, it would definitely be easier to implement and could work as well.

I haven't even been able to run the code with minibatch size of 1!

Well, in that case there's not a lot you can do, I fear... I don't see anything wrong with your Theano profile either (except that I wouldn't leave profiling on by default).

@erfannoury
Copy link

This is not the GlobalPoolingLayer , but an average pooling layer of size (1, 1) and stride (2, 2) .

Yes you're right. It's the Pool2DLayer used as the identity layer here. I hadn't though this line through before you mentioned it, so now I wonder why was it used in the first place? @auduno Can you elaborate more on this? I think what @f0k suggested can be a better replacement for the current implementation.

Well, in that case there's not a lot you can do, I fear... I don't see anything wrong with your Theano profile either (except that I wouldn't leave profiling on by default).

Yeah, I think I shall move to a bigger GPU and get along with shallower networks. BTW, profiling isn't usually on. That was my "debugging" config file 😄.

Update: I replaced that line with your code, and it significantly reduced the minibatch time. Thank you @f0k!

@gcr
Copy link

gcr commented Jan 12, 2016

This can lead to worse results, though, and if the idea is to perfectly follow the paper, the batch size should probably stay the same. But yeah, it would definitely be easier to implement and could work as well.

To avoid this, you can just accumulate several gradients for several batches and then only update the weights when you're ready. This is similar to Caffe's iter_size parameter.

In theory, if you can fit a single image on your GPU, you'll be able to train by setting an appropriate iteration size.

Here's a code sketch:

    local loss_val = 0
    local N = 2 -- How many "sub-batches" to accumulate per batch
    local inputs, labels
    gradients:zero() -- Only done here, at the beginning
    for i=1,N do
        inputs, labels = dataTrain:getBatch()
        local y = model:forward(inputs)
        loss_val = loss_val + loss:forward(y, labels)
        local df_dw = loss:backward(y, labels) -- Accumulates gradients
        model:backward(inputs, df_dw)
    end
    loss_val = loss_val / N
    gradients:mul( 1.0 / N )
    optim.sgd(...) -- NOW update weights!

Important caveat: I'm not sure how this interacts with batch normalization layers. You may have to adjust your batch norm momentum to get equivalent results.

@auduno
Copy link
Contributor Author

auduno commented Jan 13, 2016

@erfannoury : I replaced the poolinglayer with the expressionlayer as suggested by f0k now. The poolinglayer was really just the first thing I could think of to implement an identity layer with strides, I wasn't aware of expressionlayer. :)

@f0k : I noticed that there's been some changes in the final implementation of batch normalization. I trained the models using this batch-normalization code, is there any easy way to convert the models to work with the new batch-normalization?

@f0k
Copy link
Member

f0k commented Jan 13, 2016

To avoid this, you can just accumulate several gradients for several batches and then only update the weights when you're ready.

Yep, that's what I suggested at first: #38 (comment). The implementation will look a bit different in Theano since it's non-imperative. If you're lucky, it's enough to just specify your loss as a sum of the losses of two half batches so that Theano propagates the two halves separately.

I replaced the poolinglayer with the expressionlayer as suggested by f0k now. The poolinglayer was really just the first thing I could think of to implement an identity layer with strides

Note that this is only relevant for what they term option (A), while they use option (B) for most experiments (p.6 right column and following). But if it's about reproducing all of the paper, it's good to have everything as efficient as possible!

is there any easy way to convert the models to work with the new batch-normalization?

Yes, I've sketched it here: Lasagne/Lasagne#467 (comment)

@erfannoury
Copy link

Can backward and forward be performed on different GPUs? I think memory requirement for forward pass is less than the memory requirement for backward pass.

@f0k
Copy link
Member

f0k commented Jan 13, 2016

Can backward and forward be performed on different GPUs?

The backward pass usually relies on information from the forward pass (e.g., computing the gradients wrt. the weights of a dense or convolutional layer requires its original inputs from the forward pass). You could transfer that information over, but then you won't have any memory benefits over computing everything on a single GPU. You could have one GPU access the memory of the other, but that will be slow. In either case, Theano currently doesn't support using multiple GPUs from the same process.

@okvol
Copy link

okvol commented Jan 20, 2016

@auduno
Could you post a plot of the training and validation errors over all epochs? Just for reference.
Thanks!

@auduno
Copy link
Contributor Author

auduno commented Jan 20, 2016

Hi @okvol, I don't have the training and validation errors per epoch anymore, but I can do a run overnight and give you a plot tomorrow.

@okvol
Copy link

okvol commented Jan 21, 2016

@auduno
Thanks a lot!

@auduno
Copy link
Contributor Author

auduno commented Jan 21, 2016

Here's the training and validation errors for a 56-layer network, training error is the dashed line and validation error is the solid line:

myplot

The validation error looks quite a lot noisier than what can be seen in the paper (Note that before commit 550c781 I incorrectly calculated validation error using the mini-batch mean and std in batch-normalization, so it looked much more similar to the paper). I'm curious if this means there is something wrong in my code, or if this is because the batch-normalization layers uses moving average mean/std instead of properly calculated mean/std over the entire dataset. I'd calculate the validation/training error using properly calculated mean/std for the BN as well, but I couldn't find any recipes for doing this quickly, and I don't have time to put together code for this myself right now. Does anyone know if this exists anywhere?

@auduno
Copy link
Contributor Author

auduno commented Jan 31, 2016

Is there anything I need to do to get this PR merged?

@benanne
Copy link
Member

benanne commented Jan 31, 2016

Good question! @ebenolson, any work left to do?

This is worth a read by the way, some interesting new results!
https://github.com/gcr/torch-residual-networks

@ebenolson
Copy link
Member

Looks good to me. I will merge this evening if there are no more comments.

@ebenolson
Copy link
Member

Merging, thanks again!

ebenolson added a commit that referenced this pull request Feb 1, 2016
Code for reproducing cifar-10 examples in "Deep Residual Learni…
@ebenolson ebenolson merged commit e4d9a4b into Lasagne:master Feb 1, 2016
@auduno
Copy link
Contributor Author

auduno commented Feb 1, 2016

Awesome!

Yes, I had a look at the results from @gcr, very interesting! I also noticed his test error curve looks pretty similar to what I'm seeing (i.e. a bit dissimilar from the figure in the paper), so at least there's nothing wrong with my code. I'm guessing there is some non-documented detail of how the test error is evaluated which might account for the difference. I'll update the code if anyone figures out the cause!

@gcr
Copy link

gcr commented Feb 1, 2016

If you're talking about instability in the test error (ie. if it looks too jittery compared to the version in the paper), I have some notes about that.

In my case, I was doing a few things wrong that cause artificial noise to seep into the testing error curve at first (the graphs in the current readme should have the following two problems corrected):

  • I accidentally sampled only the first 2k test images instead of 10k (typo)
  • I was sampling with replacement during evaluation, which also leads to instability

Kaiming He from MSRA sent me the following email regarding instability, which I'm pasting with permission (thanks, Kaiming!)

I noticed that you observed the results are unstable. This may be caused by your way of doing BN (I am sorry that I have not checked the code as I am not familiar with Torch). In particular, the way of doing BN at test-time seems important. In our implementation, the BN statistics (mean/var) used for testing are computed on (virtually) all training images, and in practice, a very large batch of training image will provide statistically the same BN statistics. If you compute BN mean/var for testing on a mini-batch that is too small, the results might be unstable. In my experiments, the variations of stopping at different final epochs of a single run is within 0.1-0.2%; but "more than half of a percent" sounds too much for me. Besides, the training-time implementation of BN also seems to matter.

Torch uses an exponentially reweighted batch normalization that uses a running average, so setting a lower momentum should have a similar effect to computing over a larger fraction of the training set. Kaiming has some comments about this (emphasis mine):

We did not use running average, and I do not recommend to do so. To obtain a reasonable test/val result, we always compute the BN stats (mean/var) on a very large training batch. Every point in the test/val curves in the paper is plotted in this way. I do not know how much the running average may impact; but if we use a very small batch for computing BN stats, the results are very unstable and are slightly worse. Nevertheless, because I did not use running average, so I am not 100% sure this is the reason.

I'm not sure how Lasagne does it, but it could be relevant if you want to get an exact reproduction (I did not bother to use this strategy)

@auduno
Copy link
Contributor Author

auduno commented Feb 1, 2016

That's very interesting @gcr, thanks for sharing! I have been suspecting that the way the BN statistics are calculated might have to do with the jittery test error. I also had the same issue with relatively large error variation between runs (around 0.5% as you report), the BN statistics might explain this as well. I'll try calculating the BN mean/var over the entire set and see if that results in more stable results!

@f0k
Copy link
Member

f0k commented Feb 2, 2016

I'm not sure how Lasagne does it

We're also doing an exponential moving average, using the same default momentum as Torch (and some other libraries).

I'll try calculating the BN mean/var over the entire set and see if that results in more stable results!

You can use the trick of setting the momentum term to 1/(1+n) with n being the number of batches seen, this gives you the mean over all batches (using Welford's algorithm). But note that you'll need to do this layer by layer, since updating the mean/var of a BN layer will affect what's propagated to higher layers (@gcr, that would be another variant to try!). lasagne.layers.get_all_layers(output_layer) gives you the layers in correct order for this, topographically sorted.
Using a single very large batch (with momentum 1) would avoid this route of doing multiple passes.

@gcr
Copy link

gcr commented Feb 2, 2016

aha! thanks for the hint, @f0k. i had tried setting BN momentum to 1/(1+n) but it really didn't work because I applied it to all layers at the same time, which seems incorrect now. (i also tried a variant where i reset n to 0 at the beginning of each epoch, which still didn't help)

if I understand correctly, if your model has k batch normalization layers, this will make training k times slower because you must send a large set of training examples through each BN layer before moving on to the next one?

@f0k
Copy link
Member

f0k commented Feb 2, 2016

if I understand correctly, if your model has k batch normalization layers, this will make training k times slower because you must send a large set of training examples through each BN layer before moving on to the next one?

The idea in the original BN paper was to do this for the final model only, so there wouldn't be any impact on training, just a one-time cost afterwards. The BN authors advocated computing the exponential moving average during training to have something for validation (so you can do early stopping somewhat reliably). Since we're now discussing making the validation error more robust, using a large training batch once will be a better idea than passing all the training data k times each time you want to compute a validation error. This could even make things faster, since you can disable updating the moving averages during training then.

@auduno
Copy link
Contributor Author

auduno commented Feb 2, 2016

@f0k : So if I understand you correctly, when evaluating the validation error every epoch, it is sufficient to do a training pass with large batch size, updating only the batch norm parameters (with momentum 1), then evaluating the validation error as usual? Or in our case something like this:

# set learning rate to 0 to not update parameters in training pass
old_lr = sh_lr.get_value()
sh_lr.set_value(lasagne.utils.floatX(0.))
# set momentum to 1 in all BN layers
for l in lasagne.layers.get_all_layers(network):
    if l.__class__.__name__ == "BatchNormLayer":
        l.alpha = 1.
# do training pass over a large batch of 5000 samples or so
indices = np.arange(100000)
np.random.shuffle(indices)
train_fn(X_train[indices[0:5000],:,:,:], Y_train[indices[0:5000]])
# revert learning rate and BN momentum
sh_lr.set_value(lasagne.utils.floatX(old_lr))
for l in lasagne.layers.get_all_layers(network):
    if l.__class__.__name__ == "BatchNormLayer":
        l.alpha = 1e-4

@f0k
Copy link
Member

f0k commented Feb 3, 2016

So if I understand you correctly, when evaluating the validation error every epoch, it is sufficient to do a training pass with large batch size, updating only the batch norm parameters (with momentum 1), then evaluating the validation error as usual?

That's how I interpret @gcr's quotation of Kaiming's email, yes, and it sounds plausible to me!

Or in our case something like this:

No, this won't work. train_fn was compiled before you changed the alpha values, so it won't be affected (the alpha's are not shared variables, but Python floats, so they have been compiled into the graph as constants). Instead of this trickery, you will need to compile two different functions:

  • train_fn() as usual, except that you pass batch_norm_update_averages=False in the get_output() call
  • update_bn_fn() which just does a forward pass with deterministic=True, batch_norm_use_averages=False, batch_norm_update_averages=True (i.e., no computation of the loss, no updates dictionary).
    You can set all the alpha=1 then, since they're not going to be used in train_fn() anyway.

@f0k
Copy link
Member

f0k commented Mar 17, 2016

There's a follow-up paper now: http://arxiv.org/abs/1603.05027
"In this paper, we analyze the propagation formulations behind the residual building blocks, which suggest that the forward and backward signals can be directly propagated from one block to any other block, when using identity mappings as the skip connections and after-addition activation."
They explore a lot of additional variants (including those tried by @gcr in https://github.com/gcr/torch-residual-networks#cifar-effect-of-model-architecture, see Fig. 4) and end up with a new proposal for the residual blocks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants