Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion example/image-classification/symbol_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def get_conv(
fix_gamma=False,
momentum=bn_momentum,
# Same with https://github.com/soumith/cudnn.torch/blob/master/BatchNormalization.lua
eps=1e-5
# cuDNN v5 don't allow a small eps of 1e-5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to 2e-5 instead of removing

eps=2e-5
)
return (
# It's better to remove ReLU here
Expand Down
31 changes: 31 additions & 0 deletions example/image-classification/train_cifar10_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,37 @@
Reproducing https://github.com/gcr/torch-residual-networks
For image size of 32x32

Test accuracy are 0.9309 and 0.9303 in this patch and Kaiming He's paper, respectively.
The accuracy is the best one of the last 3 epochs (0.930288, 0.930889 and 0.929587),
while the original paper select the best one in 5 runs.
The dockerfile and log are in: https://gist.github.com/Answeror/f9160145e1c64bb509f52c00014bdb77

The only difference between this patch and Facebook's implementation
(https://github.com/gcr/torch-residual-networks and https://github.com/facebook/fb.resnet.torch) are:

1. The kernel of shortcut with downsampling is 2x2 rather than 1x1.
I can't reproduce this accuracy with 1x1 kernel. Note the shortcut does not contain learnable parameters.
2. I use a BatchNorm after data layer to simulate z-score normalization.
Although subtract (127, 127, 127) and divide 60 works equally well.
3. An eps of 2e-5 is used in BatchNorm instead of 1e-5 because cuDNN v5 don't allow such small eps.

Some details affect the accuracy:

1. Z-score normalization of the input.
2. Weight decay of all parameters (weight, bias, gamma, beta). See comments in `train_cifar10_resnet.py `for details.
3. Nesterov momentum
4. `fix_gamma=False` in BatchNorm (gamma is necessary because of the weight decay of the conv weight)
5. Initialization
6. 4 pixel padding

And thanks #1230 (@freesouls) and #1041 (@shuokay) to provide preliminary implementations.

## update@2016-06-08

With #2366 and a batch size of 64, I got an accuracy of 0.939704 after 200 epochs on 2 GPUs.
Note, **the accuracy is strongly affected by the batch size**, the more GPU you use, the smaller batch size should be.
See https://gist.github.com/Answeror/f9160145e1c64bb509f52c00014bdb77#file-resnet-dual-gpu-log for the full log.

References:

Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Deep Residual Learning for Image Recognition"
Expand Down
14 changes: 9 additions & 5 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,11 +743,6 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc',

arg_names, param_names, aux_names = \
self._init_params(dict(data.provide_data+data.provide_label))
param_idx2name = {}
for i, n in enumerate(param_names):
for k in range(len(self.ctx)):
param_idx2name[i*len(self.ctx)+k] = n
self.kwargs["param_idx2name"] = param_idx2name

# setup metric
if not isinstance(eval_metric, metric.EvalMetric):
Expand All @@ -757,6 +752,15 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc',
(kvstore, update_on_kvstore) = _create_kvstore(
kvstore, len(self.ctx), self.arg_params)

param_idx2name = {}
if update_on_kvstore:
param_idx2name.update(enumerate(param_names))
else:
for i, n in enumerate(param_names):
for k in range(len(self.ctx)):
param_idx2name[i*len(self.ctx)+k] = n
self.kwargs["param_idx2name"] = param_idx2name

# init optmizer
if isinstance(self.optimizer, str):
batch_size = data.batch_size
Expand Down