Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Lamb optimizer update #16715

Merged
merged 3 commits into from Nov 24, 2019
Merged

Lamb optimizer update #16715

merged 3 commits into from Nov 24, 2019

Conversation

access2rohit
Copy link
Contributor

@access2rohit access2rohit commented Nov 4, 2019

Description

adding to new operators:

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Code is well-documented:
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • lamb_update, tests, (and when applicable, API doc)

Testing

[DEBUG] 1000 of 1000: Setting test np/mx/python random seeds, use MXNET_TEST_SEED=2052238159 to reproduce.
ok

----------------------------------------------------------------------
Ran 1 test in 5085.147s

OK

@access2rohit access2rohit changed the title Lamb optimizer update [WIP]Lamb optimizer update Nov 4, 2019
@access2rohit access2rohit changed the title [WIP]Lamb optimizer update Lamb optimizer update Nov 13, 2019
@access2rohit
Copy link
Contributor Author

access2rohit commented Nov 13, 2019

@mxnet-label-bot add [pr-awaiting-review]

@lanking520 lanking520 added the pr-awaiting-review PR is waiting for code review label Nov 13, 2019

@register
class LAMB(Optimizer):
"""LAMB Optimizer.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls add doc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

working on it now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name is clashing with the GLuon one, can we give it a different name?

python/mxnet/optimizer/optimizer.py Outdated Show resolved Hide resolved
src/operator/optimizer_op.cc Show resolved Hide resolved
src/operator/optimizer_op-inl.h Outdated Show resolved Hide resolved
tests/python/unittest/test_optimizer.py Show resolved Hide resolved
tests/python/unittest/test_optimizer.py Show resolved Hide resolved
src/operator/optimizer_op.cc Outdated Show resolved Hide resolved
tests/python/unittest/test_optimizer.py Outdated Show resolved Hide resolved
tests/python/unittest/test_optimizer.py Show resolved Hide resolved
src/operator/optimizer_op-inl.h Show resolved Hide resolved
@larroy
Copy link
Contributor

larroy commented Nov 14, 2019

Please add description and reference to paper in the PR.

@larroy
Copy link
Contributor

larroy commented Nov 14, 2019

I see a crash in the way gluon trainer is calling the optimizer...

+ python3 run_pretraining.py '--data=/home/piotr/mxnet-data/bert-pretraining/datasets/book-corpus/book-corpus-large-split/*.train,/home/piotr/mxnet-data/bert-pretraining/datasets/enwiki/enwiki-feb-doc-split/*.train' '--data_eval=/home/piotr/mxnet-data/bert-pretraining/datasets/book-corpus/book-corpus-large-split/*.test,/home/piotr/mxnet-data/bert-pretraining/datasets/enwiki/enwiki-feb-doc-split/*.test' --optimizer lamb3 --warmup_ratio 0.2 --num_steps 200 --ckpt_interval 300000000 --dtype float16 --ckpt_dir ./test-ckpt --lr 0.0001 --total_batch_size 32 --total_batch_size_eval 32 --accumulate 1 --model bert_24_1024_16 --max_seq_length 128 --max_predictions_per_seq 20 --num_data_workers 1 --eval_interval 100000000 --verbose --no_compute_acc --raw --comm_backend horovod --log_interval 10 --verbose --synthetic_data --raw --eval_use_npz
[22:42:38] ../src/storage/storage.cc:110: Using GPUPooledRoundedStorageManager.
Traceback (most recent call last):
  File "run_pretraining.py", line 574, in <module>
    train(data_train, data_eval, model)
  File "run_pretraining.py", line 457, in train
    num_ctxs=len(ctxs) * num_workers)
  File "/home/piotr/gluon-nlp/scripts/bert/fp16_utils.py", line 433, in step
    self.fp32_trainer.update(step_size)
  File "/home/piotr/mxnet_lamb/python/mxnet/gluon/trainer.py", line 397, in update
    self._update(ignore_stale_grad)
  File "/home/piotr/mxnet_lamb/python/mxnet/gluon/trainer.py", line 434, in _update
    updater(i, w, g)
  File "/home/piotr/mxnet_lamb/python/mxnet/optimizer/optimizer.py", line 1777, in __call__
    self.optimizer.update_multi_precision(i, w, g, self.states[i])
  File "/home/piotr/mxnet_lamb/python/mxnet/optimizer/optimizer.py", line 291, in update_multi_precision
    self.update(index, weight_master_copy, grad32, original_state)
  File "/home/piotr/mxnet_lamb/python/mxnet/optimizer/optimizer.py", line 1012, in update
    g = lamb_update(weight, grad, mean, var, wd=wd, **kwargs)
  File "<string>", line 88, in lamb_update
  File "/home/piotr/mxnet_lamb/python/mxnet/_ctypes/ndarray.py", line 107, in _imperative_invoke
    ctypes.byref(out_stypes)))
  File "/home/piotr/mxnet_lamb/python/mxnet/base.py", line 254, in check_call
    raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: Some trailing characters could not be parsed: '[0.0015625]
<NDArray 1 @gpu(0)>', in operator lamb_update(name="", rescale_grad="
[0.0015625]
<NDArray 1 @gpu(0)>", wd="0.01", bias_correction="True", t="1", epsilon="1e-06", beta2="0.999", beta1="0.9")

@access2rohit
Copy link
Contributor Author

Please add description and reference to paper in the PR.

Done

@larroy
Copy link
Contributor

larroy commented Nov 20, 2019

@access2rohit could you answer Sam's comments?

@access2rohit
Copy link
Contributor Author

@access2rohit could you answer Sam's comments?

Done

@eric-haibin-lin eric-haibin-lin merged commit 85d3ef3 into apache:master Nov 24, 2019
float beta1;
float beta2;
float epsilon;
float t;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eric-haibin-lin @access2rohit I find this issue when reading the code. Here, the t should be the number of updates and should not be stored as float, which will lose the precision. I think we need to store it as index_t.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are using float here for integer data type. @sxjscience can you explain how we will loses precision for the operation beta^t ?


if (bias_correction) {
DType mean_hat = mean_data[i] / (1. - power::Map(beta1, t));
DType var_hat = var_data[i] / (1 - power::Map(beta2, t));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, in apex, it uses a float32 to calculate the power and then switch to float16:
https://github.com/NVIDIA/apex/blob/325f5a0bec542701edba1628ad34f3b2ea47c556/csrc/multi_tensor_lamb.cu#L231-L249

@leezu leezu removed the pr-awaiting-review PR is waiting for code review label Nov 27, 2019
ptrendx pushed a commit to ptrendx/mxnet that referenced this pull request Dec 10, 2019
* initial commit lamb optimizer

* fixing base lamb optimizer

* adding API doc for Lamb Phase 1 and 2
eric-haibin-lin pushed a commit that referenced this pull request Dec 10, 2019
* initial commit lamb optimizer

* fixing base lamb optimizer

* adding API doc for Lamb Phase 1 and 2
eric-haibin-lin pushed a commit to eric-haibin-lin/mxnet that referenced this pull request Dec 14, 2019
* initial commit lamb optimizer

* fixing base lamb optimizer

* adding API doc for Lamb Phase 1 and 2
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants