Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

introduce gradient update handler to the base estimator #16900

Merged
merged 10 commits into from Dec 9, 2019

Conversation

liuzh47
Copy link
Contributor

@liuzh47 liuzh47 commented Nov 25, 2019

Description

This change add default gradient update handler to the base estimator class. The purpose of this change is to decouple the forward/backward computation with the gradient apply process. In some use cases such as gradient accumulation, gradient updates are not executed during each training batch. See issue description #16869.

@ptrendx
Copy link
Member

ptrendx commented Nov 25, 2019

@leezu What is the rationale behind marking this PR as 1.6? This looks like adding a new functionality.

Copy link
Member

@roywei roywei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the improvement! 2 concerns.
Also could you point an example which require custom gradient handler? (gradient clipping or aggregation)

python/mxnet/gluon/contrib/estimator/estimator.py Outdated Show resolved Hide resolved
tests/python/unittest/test_gluon_estimator.py Outdated Show resolved Hide resolved
@leezu
Copy link
Contributor

leezu commented Nov 26, 2019

@ptrendx the estimator API is currently experimental and is not yet up to more complex real-world use-cases. @liuzh91 is currently converting GluonNLP training scripts to make use of the Estimator API and ran into several remaining shortcomings of the API.
We shouldn't release 1.6 with a known insufficient version of Estimator if fixes are available. Thus if this PR is approved and merged to master in time for the 1.6 release, it may be good to backport. Given the scope of the estimator API and the current shortcomings we can consider these PRs as bugfixes. If we don't fix the API with the 1.6 release, we may be bound to the current API due to backwards compatibility commitments.

Do you think backporting the fix would be reasonable or do you have concerns? When do you plan to tag a release candidate?

These extends the fixes commited started November 2019 https://github.com/apache/incubator-mxnet/commits/master/python/mxnet/gluon/contrib/estimator.

@liuzh47
Copy link
Contributor Author

liuzh47 commented Nov 26, 2019

Thank you for the improvement! 2 concerns.
Also could you point an example which require custom gradient handler? (gradient clipping or aggregation)

Thank u for the review.

For the gradient update example, one use case of using gradient accumulation appears when training a transformer. (https://github.com/dmlc/gluon-nlp/blob/master/scripts/machine_translation/train_transformer.py#L320) Because the size of parameters in the transformer network is too large, we can compute gradient for a small batch of data examples during each iteration. In this case, the gradient is updated periodically on the weight parameters.

@codecov-io
Copy link

codecov-io commented Nov 26, 2019

Codecov Report

Merging #16900 into master will decrease coverage by 0.11%.
The diff coverage is 80.95%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master   #16900      +/-   ##
==========================================
- Coverage   67.06%   66.95%   -0.12%     
==========================================
  Files         271      264       -7     
  Lines       30173    29804     -369     
  Branches     4477     4400      -77     
==========================================
- Hits        20237    19954     -283     
+ Misses       8648     8563      -85     
+ Partials     1288     1287       -1
Impacted Files Coverage Δ
python/mxnet/gluon/contrib/estimator/estimator.py 92.46% <80%> (-0.33%) ⬇️
...hon/mxnet/gluon/contrib/estimator/event_handler.py 71.96% <81.25%> (+0.1%) ⬆️
python/mxnet/rtc.py 29.03% <0%> (-58.07%) ⬇️
python/mxnet/ndarray/_internal.py 52.17% <0%> (-39.14%) ⬇️
python/mxnet/symbol/_internal.py 56% <0%> (-36%) ⬇️
python/mxnet/gluon/rnn/rnn_layer.py 72.78% <0%> (-10.66%) ⬇️
python/mxnet/context.py 85.71% <0%> (-7.15%) ⬇️
python/mxnet/module/bucketing_module.py 73.66% <0%> (-6.7%) ⬇️
python/mxnet/util.py 68.18% <0%> (-6.07%) ⬇️
python/mxnet/tvmop.py 84.21% <0%> (-5.27%) ⬇️
... and 21 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 9b49cfe...d325e87. Read the comment docs.

@leezu leezu requested a review from roywei November 27, 2019 06:31
Copy link
Member

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Does this require the user to specify grad_req = 'add' for gradient accumulation? When is the gradient reset to zeros?
  2. Now GradientUpdateHandler is added to the list of default handlers. If users explicitly create estimator with event_handlers=[metric, logging], will GradientUpdateHandler still be included in the estimator and performing updates correctly?

@liuzh47
Copy link
Contributor Author

liuzh47 commented Dec 2, 2019

1. Does this require the user to specify grad_req = 'add' for gradient accumulation? When is the gradient reset to zeros?

2. Now GradientUpdateHandler is added to the list of default handlers. If users explicitly create estimator with `event_handlers=[metric, logging]`, will `GradientUpdateHandler` still be included in the estimator and performing updates correctly?
  1. The gradient update handler is for general gradient update. When we implement gradient accumulation handler in the future, we will inherit this class.

  2. Yes, it is added as a default handler. Please refer to (https://github.com/apache/incubator-mxnet/blob/5878451c4b476d8771d93000ff5d51b81071997d/tests/python/unittest/test_gluon_estimator.py#L369).

Copy link
Member

@roywei roywei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, re-triggered CI

@ptrendx ptrendx merged commit 986a902 into apache:master Dec 9, 2019
ptrendx pushed a commit to ptrendx/mxnet that referenced this pull request Dec 9, 2019
* introduce  gradient update handler to the  base estimator

* Modify the gradient update handler to include the batch size

* Remove unrelated gradient update handler.

* Modify gradient update handler to take the current batch size.

* Remove white space to avoid the sanity check failure

* add small tweak to the handler code

* Modify the documentation of priority parameter of relevant handlers.

* small modification on the documentation.

* Add small modification on the documentation.

* Remove unnecessary list check
ptrendx added a commit that referenced this pull request Dec 10, 2019
* Fix ndarray indexing bug (#16895)

* Fix indexing bug

* More test cases

* Add test from 16647

* [Gluon] Update contrib.Estimator LoggingHandler to support logging per batch interval (#16922)

* Update LoggingHandler to support logging per interval

* Fix the constant variable issue in the logging handler

* Remove the constant variable hack in the logging handler.

* 1) replace LOG_PER_BATCH with LOG_PER_INTERVAL 2) add test case

* Improve the test script for LoggingHandler

* small fix on the test script

* logging handler test case bug fix

* remove parameter verbose from LoggingHandler

* move log_interval to the first argument

* resolve unittest mistakes

* Add micro averaging strategy to pearsonr metric (#16878)

        Strategy to be used for aggregating across mini-batches.
            "macro": average the pearsonr scores for each batch.
            "micro": compute a single pearsonr score across all batches.

* [Bugfix] [Numpy] Add `kAddTo` and kNullOp to Transpose (#16979)

* update

Check for repeated axes

enable addto to transpose

fix

fix

fix

fix

remove unused ndim

Update pseudo2DTranspose_op-inl.cuh

Update pseudo2DTranspose_op-inl.cuh

Update pseudo2DTranspose_op-inl.cuh

fix

Update pseudo2DTranspose_op-inl.cuh

try to fix

Update pseudo2DTranspose_op-inl.cuh

Update pseudo2DTranspose_op-inl.cuh

Update pseudo2DTranspose_op-inl.cuh

fix

Update np_matrix_op.cc

Update test_numpy_op.py

update test case

fix implementation

fix bug

update

fix bug

Update pseudo2DTranspose_op-inl.cuh

fix

fix

Update test_numpy_op.py

* Fix bug

* fix docstring

* try to address comment

* no need to change this line

* Fix bug

* address comments

* address comment

* introduce  gradient update handler to the  base estimator (#16900)

* introduce  gradient update handler to the  base estimator

* Modify the gradient update handler to include the batch size

* Remove unrelated gradient update handler.

* Modify gradient update handler to take the current batch size.

* Remove white space to avoid the sanity check failure

* add small tweak to the handler code

* Modify the documentation of priority parameter of relevant handlers.

* small modification on the documentation.

* Add small modification on the documentation.

* Remove unnecessary list check
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants