-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorchClassifier: use a new optimizer for the cloned classifier #1580
PyTorchClassifier: use a new optimizer for the cloned classifier #1580
Conversation
Codecov Report
@@ Coverage Diff @@
## dev_1.10.2 #1580 +/- ##
==============================================
- Coverage 88.10% 88.04% -0.07%
==============================================
Files 259 259
Lines 21349 21398 +49
Branches 3789 3800 +11
==============================================
+ Hits 18809 18839 +30
- Misses 1597 1608 +11
- Partials 943 951 +8
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see my comment on not loading the optimizer's state and changing the reset method to also create a new optimizer.
|
||
# create a new optimizer that binds to cloned model's parameters | ||
new_optimizer = type(self._optimizer)(model.parameters()) | ||
new_optimizer.load_state_dict(self._optimizer.state_dict()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I think this line should be removed. Since we are going to refit the model, the optimizer should also be reset.
Moreover, it's a good idea to add optimizer resetting to the PyTorchClassifier.reset() method as well...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, thanks for the review. This line is intended to inherit any user-provided parameters like learning rate, but I'll check if there is any unintended side effects and whether there is a more explicit way to inherit the params (params_group?) without loading the whole state_dict
.
For the resetting, wondering whether zero_grad()
used in the training already serves the purpose, will check it out too. It may take me some time to get back on this though, apologies in advance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think param_groups is probably the right way to go here. There may be other elements in the state besides the gradients (e.g., adaptive learning rate, weight decay, etc.) which would not be reset by zero_grad().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @abigailgold, thanks, removed load_state_dict
and instead create the new optimizer with optimizer.defaults
, which includes all the parameters provided by the user, including learning rate and weight decay etc. I suspect the defaults stay unchanged during training, but the parameters in param_groups may change (e.g. for adaptive learning rate).
raise ValueError("An optimizer is needed to train the model, but none for provided.") | ||
|
||
# create a new optimizer that binds to the cloned model's parameters | ||
new_optimizer = type(self._optimizer)(model.parameters(), lr=self._optimizer.defaults["lr"]) # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check what lr is from self._optimizer.defaults["lr"]
. We are expecting the actual lr
used here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, defaults contained the parameter passed in.
>>> import torch
>>> from torch import nn
>>> model = nn.Linear(100, 10)
>>> opt = torch.optim.Adam(model.parameters(), lr=0.1)
>>> opt.defaults
{'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}
>>>
>>> opt2 = torch.optim.Adam(model.parameters())
>>> opt2.defaults
{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}
dask/distributed#6013 Signed-off-by: Beat Buesser <beat.buesser@ie.ibm.com>
ignored mypy error because lr is a required parameter for SGD Signed-off-by: chao1995 <huangchao0825@gmail.com>
Signed-off-by: Beat Buesser <beat.buesser@ie.ibm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @chao1995 Thank you very much for you pull request fixing the cloning of PyTorch classifiers and congratulations to your first contribution to ART!
Thank you @beat-buesser! |
Description
Fixes #1579
Type of change
Please check all relevant options.
Test Configuration:
Checklist