Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep track of the best model's path saved by ModelCheckpoint #1799

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 14 additions & 2 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

class ModelCheckpoint(Callback):
r"""
Save the model after every epoch.
Save the model after every epoch if it improves.

After training finishes, use :attr:`best_model` to retrieve the path to the
best checkpoint file.

Args:
filepath: path to save the model file.
Expand Down Expand Up @@ -80,6 +83,13 @@ class ModelCheckpoint(Callback):
... filepath='my/path/sample-mnist_{epoch:02d}-{val_loss:.2f}'
... )

# retrieve the best checkpoint after training
>>> checkpoint_callback = ModelCheckpoint(filepath='my/path/')
>>> trainer = Trainer(checkpoint_callback=checkpoint_callback)
>>> # model = ...
kepler marked this conversation as resolved.
Show resolved Hide resolved
>>> # trainer.fit(model)
>>> print(checkpoint_callback.best_model)

"""

def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', verbose: bool = False,
Expand Down Expand Up @@ -112,6 +122,7 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
# {filename: monitor}
self.kth_best_model = ''
self.best = 0
self.best_model = ''
self.save_function = None

torch_inf = torch.tensor(np.Inf)
Expand Down Expand Up @@ -265,7 +276,8 @@ def _do_check_save(self, filepath, current, epoch):
self.kth_value = self.best_k_models[self.kth_best_model]

_op = min if self.mode == 'min' else max
self.best = _op(self.best_k_models.values())
self.best_model = _op(self.best_k_models, key=self.best_k_models.get)
self.best = self.best_k_models[self.best_model]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here the best means best_score? otherwise, it is a bit confusing to have best and best_model
pls add types to the init 🐰

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, self.best is the best "monitored quantity". But that's existing naming. Changing it will break backwards compatibility. If that's OK, I can surely rename it. Otherwise, best_model could be best_model_path. I used best_model to keep consistency with the existing kth_best_model attribute.

As for type hints, there's no changes to the arguments. Where specifically would I add them?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@PyTorchLightning/core-contributors are we fine to rename best >> best_score?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would prefer it (best_score), but if we do that we need to provide our users with a utility to upgrade their checkpoints (basically switch best to best_score) and temporarily have a warning which catches old checkpoints at load time and alerts users of the utility to convert. it's a very simple utility but we should provide it for our users if we move forward with this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeremyjordan you mean to keep compatible with past saved checkpoint which contains best? That shall be simple, in loading, we will map best >> best_score... and yes we shall wrap the deprecated best with some warning :]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I renamed best to best_model_score and best_model to best_model_path, since having best_score and best_model would still be a bit confusing (IMHO). To keep consistency, I also renamed kth_best_model to kth_best_model_path.

I added properties for best and kth_best_model that log a deprecation warning and return the correct value.

When loading a checkpoint, if it's in an old format, the value for best is simply assigned to best_model_score. In my opinion, adding a warning in this part will not really help the user, as there's not much they can do.

Let me know of any further changes.


if self.verbose > 0:
log.info(
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def dump_checkpoint(self, weights_only: bool = False):
if not weights_only:
if self.checkpoint_callback:
checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best
checkpoint['checkpoint_callback_best_model'] = self.checkpoint_callback.best_model

if self.early_stop_callback:
checkpoint['early_stop_callback_wait'] = self.early_stop_callback.wait
Expand Down Expand Up @@ -398,10 +399,11 @@ def restore_training_state(self, checkpoint):
' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.'
)

if self.checkpoint_callback is not None and self.checkpoint_callback is not False:
if self.checkpoint_callback:
self.checkpoint_callback.best = checkpoint['checkpoint_callback_best']
self.checkpoint_callback.best_model = checkpoint['checkpoint_callback_best_model']

if self.early_stop_callback is not None and self.early_stop_callback is not False:
if self.early_stop_callback:
self.early_stop_callback.wait = checkpoint['early_stop_callback_wait']
self.early_stop_callback.patience = checkpoint['early_stop_callback_patience']

Expand Down