Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Added in Large-Batch SGD with a warmup, and a LARS startegy. Also add… #8918

Merged
merged 5 commits into from
Jan 29, 2018

Conversation

ashokei
Copy link
Contributor

@ashokei ashokei commented Dec 2, 2017

Large-Batch SGD with a warmup, and a LARS strategy.

Description

Added in Large-Batch SGD with a warmup, and a LARS strategy. Also added in a Polynomial Decay learning rate scheduler. Modified the example image fit code to allow these options to be selectable.

Checklist

Essentials

  • [x ] Passed code style checking (make lint)
  • [x ] Changes are complete (i.e. I finished coding on this PR)
  • [x ] All changes have test coverage
  • For user-facing API changes, API doc string has been updated. For new C++ functions in header files, their functionalities and arguments are well-documented.
  • [x ] To my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Added in Large-Batch SGD with a warmup, and a LARS strategy. Also added in a Polynomial Decay learning rate scheduler. Modified the example image fit code to allow these options to be selectable.

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here

@piiswrong
Copy link
Contributor

@zhreshold

@piiswrong piiswrong assigned zhreshold and unassigned zhreshold Dec 12, 2017
@eric-haibin-lin
Copy link
Member

@zhreshold could you help review?

elif (strategy == 'power2'):
mult = 1.0 + (maxmult - 1) * (nup * nup) / (nwup * nwup)
elif (strategy == 'power3'):
mult = 1.0 + (maxmult - 1) * (nup * nup) / (nwup * nwup)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Power3 is wrong

@zhreshold
Copy link
Member

@ashokei See comments. Could you rebase and make pylint as well?

@ashokei
Copy link
Contributor Author

ashokei commented Jan 12, 2018

@zhreshold thank you, will make requested changes.

@ashokei
Copy link
Contributor Author

ashokei commented Jan 18, 2018

@zhreshold i made the requested changes, please review. thank you,

Copy link
Member

@zhreshold zhreshold left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see comments. The rest part LGTM and good to be merged once this issue resolved.

self.max_update = max_update
self.power = pwr
self.count = num_update
if num_update <= max_update:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This this duplicate with line 173. I understand it is for resume training, but that should be handled in __call__, see the example in MultiFactorScheduler. Therefore, num_update is not necessary in __init__

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhreshold i removed the num_update duplicate line, can you please check.thanks.

@ashokei ashokei requested a review from szha as a code owner January 27, 2018 02:04

"""

def __init__(self, num_update, max_update, base_lr=0.01, pwr=2):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_update useless here

self.base_lr_orig = self.base_lr
self.max_update = max_update
self.power = pwr
self.count = num_update
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for self.count

if num_update <= self.max_update:
self.base_lr = self.base_lr_orig * pow(1.0 - float(num_update) / float(self.max_update),
self.power)
self.count += 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here

Copy link
Contributor Author

@ashokei ashokei Jan 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhreshold thanks.. i see "count" is not being used, i removed it. Though, MultiFactorScheduler seems to be tracking count.

@zhreshold
Copy link
Member

@szha test_io.test_LibSVMIter fails accasionally, we should host the data (http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/news20.t.bz2
) to avoid that. Add it to #9412 ??

@zhreshold
Copy link
Member

@ashokei Try rebase and trigger the CI once more. We can merge once passed.

chaseadams509 and others added 5 commits January 28, 2018 18:57
…ed in a Polynomial Decay learning rate scheduler. Modified the example image fit code to allow these options to be selectable.
@ashokei
Copy link
Contributor Author

ashokei commented Jan 29, 2018

@zhreshold all done, thanks!

@zhreshold zhreshold merged commit 785690c into apache:master Jan 29, 2018
state = momentum * state + lr * rescale_grad * clip(grad, clip_gradient) + wd * weight
weight = weight - state

For details of the update algorithm see :class:`~mxnet.ndarray.lbsgd_update` and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ashokei @zhreshold
Where is lbsgd_update defined? I don't see it.

Please add proper reference to relevant papar

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is the update method in LBSGD class. we can fix that to resolve to right method, the '_' is misleading.
the paper is here:
https://arxiv.org/pdf/1708.03888.pdf

self.adaptive = False
self.admult = 1 # adaptation constant

def create_state(self, index, weight):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this copied from SGD?
Why not inherit SGD instead?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ashokei As suggested, could you change to inherit SGD and override create_state_multi_precision, create_state, update, update_multi_precision only if necessary. Seems like you are mixing multi_precision part into the normal one.

warmup_strategy: string ('linear', 'power2', 'sqrt'. , 'lars' default : 'linear')
warmup_epochs: unsigned, default: 5
batch_scale: unsigned, default: 1 (same as batch size*numworkers)
updates_per_epoch: updates_per_epoch (default: 32, Default might not reflect true number batches per epoch. Used for warmup.)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use warmup_epochs and updates_per_epoch? Why not just warmup_updates?
Why should it have a default value?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it requires the epoch number to stop warming up, which does not depend on the number of updates.

warmup_epochs: unsigned, default: 5
batch_scale: unsigned, default: 1 (same as batch size*numworkers)
updates_per_epoch: updates_per_epoch (default: 32, Default might not reflect true number batches per epoch. Used for warmup.)
begin_epoch: unsigned, default 0, starting epoch.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's starting epoch? What would it do before start epoch?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ashokei please add more details describing the strategy.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piiswrong, The begin_epoch flag is because our data scientist saw that it was possible to stop a training partway, save it, then resume it again. We wanted to be able to have that option, which required passing in what the starting epoch is to do the learning rate decay calculations correctly.

larroy pushed a commit to larroy/mxnet that referenced this pull request Jan 31, 2018
apache#8918)

* Added in Large-Batch SGD with a warmup, and a LARS startegy. Also added in a Polynomial Decay learning rate scheduler. Modified the example image fit code to allow these options to be selectable.

* Fix pylint issues

* pylint fixes

* remove duplicate num_update

* remove unused count
@chenchao50
Copy link

Have you tested the accuracy of Resnet50(or AlexnetBN) in ImageNet using 'Lars' method ?

@rahul003
Copy link
Member

rahul003 commented May 10, 2018

@ashokei Similar question as above. How do I use this? I'm unable to train resnet50 with lbsgd. What configuration works? I have an effective batch size of about 20k across 20 worker nodes

initializer = mx.init.Xavier(
rnd_type='gaussian', factor_type="in", magnitude=2)
# A limited number of optimizers have a warmup period
has_warmup = {'lbsgd', 'lbnag'}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like LBNAG doesn't exist?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chaseadams509 can you please address above comments. thanks.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, lbnag was an optimizer we were experimenting with with incorporating the large-batch algorithm with Nesterov accelerated gradient. LBNAG was not showing the desired improvements, so we hadn't pushed it yet.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chaseadams509 How large batch sizes did you experiment with?

@ashokei ashokei deleted the lbsgd-optimizer branch May 10, 2018 23:12
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
apache#8918)

* Added in Large-Batch SGD with a warmup, and a LARS startegy. Also added in a Polynomial Decay learning rate scheduler. Modified the example image fit code to allow these options to be selectable.

* Fix pylint issues

* pylint fixes

* remove duplicate num_update

* remove unused count
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
apache#8918)

* Added in Large-Batch SGD with a warmup, and a LARS startegy. Also added in a Polynomial Decay learning rate scheduler. Modified the example image fit code to allow these options to be selectable.

* Fix pylint issues

* pylint fixes

* remove duplicate num_update

* remove unused count
@ThomasDelteil
Copy link
Contributor

ThomasDelteil commented Jul 5, 2018

@ashokei were the comments above ever addressed in a follow-up pull request?
What is this comment about: https://github.com/apache/incubator-mxnet/blame/master/python/mxnet/optimizer.py#L734 ?

Has multiply been defined anywhere: https://github.com/apache/incubator-mxnet/blame/master/python/mxnet/optimizer.py#L769 ?

See:
#11278

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants