Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Gluon trainer updates: add learning_rate and lr_scheduler properties and add setter for learning rate #7659

Closed
wants to merge 39 commits into from

Conversation

astonzhang
Copy link
Member

No description provided.

Zhang added 21 commits July 21, 2017 11:35
@@ -113,6 +113,36 @@ def _init_kvstore(self):

self._kv_initialized = True


@property
def learning_rate(self):
Copy link
Contributor

@piiswrong piiswrong Aug 29, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

document this as

Properties
----------

in init doc

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also report learning_rate when using lr_sheduler

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved



@property
def lr_scheduler(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't expose this for now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved


@property
def learning_rate(self):
return self._optimizer.lr
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kind of implementation couples the two classes together (i.e. Trainer must know the structure of optimizer). Instead add accessors in optimizer and make reading learning rate the concern of optimizer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

"learning rate only when the LRScheduler of"
"the optimizer is undefined.")
else:
self._optimizer.lr = lr
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the accessor comment. Try make setting learning rate the concern of optimizer's.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

@astonzhang
Copy link
Member Author

Dudes, thank you for your comments. Code is updated. Let me know if you spot any further issues.

@@ -191,6 +204,24 @@ def update(self, index, weight, grad, state):
"""
raise NotImplementedError()

def set_learning_rate(self, lr):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@learning_rate.setter
def set_learning_rate

and then you can do

optimizer.learning_rate = 0.5

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

raise UserWarning("Optimizer has to be defined before its learning"
"rate is mutated.")
else:
self._optimizer.set_learning_rate(lr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update according to the comment below

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

----------
Learning_rate: float
The learning rate of the optimizer or the learning rate of the
LRScheduler of the optimizer if the LRScheduler is defined.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

explain it can be accessed with trainer.learning_rate and set with trainer.learning_rate = xxx

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

@astonzhang astonzhang changed the title Gluon trainer updates: add learning_rate and lr_scheduler properties and add setter for learning rate Gluon trainer updates: add learning_rate and lr_scheduler properties and add setter for learning rate, and patch for issues of hard to pass L1loss unittest Aug 31, 2017
@szha
Copy link
Member

szha commented Aug 31, 2017

All tests passed except R GPU, @thirdwing is fixing this in #7686.

@@ -70,15 +70,14 @@ def test_ce_loss():
label = mx.nd.array(np.random.randint(0, nclass, size=(N,)), dtype='int32')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Submit test fix in a separate PR and revert the change on file property
100644 → 100755

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

@astonzhang astonzhang changed the title Gluon trainer updates: add learning_rate and lr_scheduler properties and add setter for learning rate, and patch for issues of hard to pass L1loss unittest Gluon trainer updates: add learning_rate and lr_scheduler properties and add setter for learning rate Sep 1, 2017
@astonzhang
Copy link
Member Author

The unit test patch is in a separate pr as suggested by @piiswrong
#7693

@astonzhang
Copy link
Member Author

Closing this PR due to new PR at #7760

@astonzhang astonzhang closed this Sep 6, 2017
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants