Skip to content

Commit

Permalink
Merge pull request #1580 from chao1995/new-optimizer-for-clone
Browse files Browse the repository at this point in the history
PyTorchClassifier: use a new optimizer for the cloned classifier
  • Loading branch information
beat-buesser committed May 16, 2022
2 parents 8549e0e + d9969d3 commit af9bf9c
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions art/estimators/classification/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,13 +497,21 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg

def clone_for_refitting(self) -> "PyTorchClassifier": # lgtm [py/inheritance/incorrect-overridden-signature]
"""
Create a copy of the classifier that can be refit from scratch. Will inherit same architecture, optimizer and
initialization as cloned model, but without weights.
Create a copy of the classifier that can be refit from scratch. Will inherit same architecture, same type of
optimizer and initialization as the original classifier, but without weights.
:return: new estimator
"""
model = copy.deepcopy(self.model)
clone = type(self)(model, self._loss, self.input_shape, self.nb_classes, optimizer=self._optimizer)

if self._optimizer is None: # pragma: no cover
raise ValueError("An optimizer is needed to train the model, but none is provided.")

# create a new optimizer that binds to the cloned model's parameters and uses original optimizer's defaults
new_optimizer = type(self._optimizer)(model.parameters(), **self._optimizer.defaults) # type: ignore

clone = type(self)(model, self._loss, self.input_shape, self.nb_classes, optimizer=new_optimizer)

# reset weights
clone.reset()
params = self.get_params()
Expand Down

0 comments on commit af9bf9c

Please sign in to comment.