Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorchClassifier: use a new optimizer for the cloned classifier #1580

Merged
merged 4 commits into from
May 16, 2022
Merged

PyTorchClassifier: use a new optimizer for the cloned classifier #1580

merged 4 commits into from
May 16, 2022

Conversation

chao1995
Copy link
Contributor

@chao1995 chao1995 commented Mar 10, 2022

Description

Fixes #1579

Type of change

Please check all relevant options.

  • Improvement (non-breaking)
  • Bug fix (non-breaking)
  • New feature (non-breaking)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Test Configuration:

  • OS: macOS
  • Python version: 3.8.10
  • ART version or commit number: 1.9.1
  • PyTorch version: 1.10.2

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@beat-buesser beat-buesser self-assigned this Mar 10, 2022
@codecov-commenter
Copy link

codecov-commenter commented Mar 10, 2022

Codecov Report

Merging #1580 (d9969d3) into dev_1.10.2 (8f8039b) will decrease coverage by 0.06%.
The diff coverage is 100.00%.

Impacted file tree graph

@@              Coverage Diff               @@
##           dev_1.10.2    #1580      +/-   ##
==============================================
- Coverage       88.10%   88.04%   -0.07%     
==============================================
  Files             259      259              
  Lines           21349    21398      +49     
  Branches         3789     3800      +11     
==============================================
+ Hits            18809    18839      +30     
- Misses           1597     1608      +11     
- Partials          943      951       +8     
Impacted Files Coverage Δ
art/estimators/classification/pytorch.py 88.32% <100.00%> (+0.02%) ⬆️
art/estimators/certification/abstain.py 90.90% <0.00%> (-9.10%) ⬇️
...s/certification/randomized_smoothing/tensorflow.py 91.37% <0.00%> (-6.35%) ⬇️
art/estimators/poison_mitigation/strip/strip.py 94.44% <0.00%> (-5.56%) ⬇️
art/defences/preprocessor/mp3_compression.py 84.05% <0.00%> (-5.01%) ⬇️
...tors/certification/randomized_smoothing/pytorch.py 92.10% <0.00%> (-4.13%) ⬇️
...mators/certification/randomized_smoothing/numpy.py 85.00% <0.00%> (-2.10%) ⬇️
art/attacks/evasion/boundary.py 92.77% <0.00%> (-1.21%) ⬇️
...ation/randomized_smoothing/randomized_smoothing.py 97.33% <0.00%> (-0.11%) ⬇️
...ion/imperceptible_asr/imperceptible_asr_pytorch.py 86.79% <0.00%> (ø)

@beat-buesser beat-buesser changed the base branch from main to dev_1.10.0 March 10, 2022 12:11
Copy link
Collaborator

@abigailgold abigailgold left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see my comment on not loading the optimizer's state and changing the reset method to also create a new optimizer.


# create a new optimizer that binds to cloned model's parameters
new_optimizer = type(self._optimizer)(model.parameters())
new_optimizer.load_state_dict(self._optimizer.state_dict())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think this line should be removed. Since we are going to refit the model, the optimizer should also be reset.
Moreover, it's a good idea to add optimizer resetting to the PyTorchClassifier.reset() method as well...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thanks for the review. This line is intended to inherit any user-provided parameters like learning rate, but I'll check if there is any unintended side effects and whether there is a more explicit way to inherit the params (params_group?) without loading the whole state_dict.
For the resetting, wondering whether zero_grad() used in the training already serves the purpose, will check it out too. It may take me some time to get back on this though, apologies in advance.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think param_groups is probably the right way to go here. There may be other elements in the state besides the gradients (e.g., adaptive learning rate, weight decay, etc.) which would not be reset by zero_grad().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @abigailgold, thanks, removed load_state_dict and instead create the new optimizer with optimizer.defaults, which includes all the parameters provided by the user, including learning rate and weight decay etc. I suspect the defaults stay unchanged during training, but the parameters in param_groups may change (e.g. for adaptive learning rate).

raise ValueError("An optimizer is needed to train the model, but none for provided.")

# create a new optimizer that binds to the cloned model's parameters
new_optimizer = type(self._optimizer)(model.parameters(), lr=self._optimizer.defaults["lr"]) # type: ignore
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check what lr is from self._optimizer.defaults["lr"]. We are expecting the actual lr used here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, defaults contained the parameter passed in.

>>> import torch
>>> from torch import nn
>>> model = nn.Linear(100, 10)
>>> opt = torch.optim.Adam(model.parameters(), lr=0.1)
>>> opt.defaults
{'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}
>>>
>>> opt2 = torch.optim.Adam(model.parameters())
>>> opt2.defaults
{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}

@chao1995 chao1995 changed the base branch from dev_1.10.0 to dev_1.11.0 April 28, 2022 06:06
ignored mypy error because lr is a required parameter for SGD

Signed-off-by: chao1995 <huangchao0825@gmail.com>
@beat-buesser beat-buesser added bug Something isn't working improvement Improve implementation labels May 10, 2022
@beat-buesser beat-buesser added this to Pull request open in ART 1.10.2 via automation May 10, 2022
@beat-buesser beat-buesser added this to the ART 1.10.2 milestone May 10, 2022
@beat-buesser beat-buesser changed the base branch from dev_1.11.0 to dev_1.10.2 May 14, 2022 22:28
Signed-off-by: Beat Buesser <beat.buesser@ie.ibm.com>
Copy link
Collaborator

@beat-buesser beat-buesser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @chao1995 Thank you very much for you pull request fixing the cloning of PyTorch classifiers and congratulations to your first contribution to ART!

@beat-buesser beat-buesser merged commit af9bf9c into Trusted-AI:dev_1.10.2 May 16, 2022
ART 1.10.2 automation moved this from Pull request open to Pull request done May 16, 2022
@chao1995
Copy link
Contributor Author

Thank you @beat-buesser!

@chao1995 chao1995 deleted the new-optimizer-for-clone branch May 17, 2022 09:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working improvement Improve implementation
Projects
No open projects
ART 1.10.2
  
Pull request done
4 participants