Skip to content

Commit

Permalink
Keep track of the best model's path saved by ModelCheckpoint (#1799)
Browse files Browse the repository at this point in the history
* Add an additional attribute to ModelCheckpoint to keep track of the best model's path

Currently, only the best metric value is directly tracked. This new attribute will help in uses cases where the trained model needs to be used or tracked right after training.

* Add small description and usage example to docs

* Fix PEP8 issues

* Fix doctest example

* Fix expected output in doctest

* Apply suggestions from code review

* Show example as code block instead of doctest

* Apply suggestions from code review

* Update CHANGELOG.md

* Rename `ModelCheckpoint.best` to `ModelCheckpoint.best_model_score`

Also rename `ModelCheckpoint.best_model` (added in this PR) to `ModelCheckpoint.best_model_path`, for consistency, and `kth_best_model` to `kth_best_model_path`.

* Update pytorch_lightning/trainer/training_io.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Add warning when loading checkpoint from an old version

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
2 people authored and justusschock committed Jun 29, 2020
1 parent e3bcafa commit c33daaf
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 16 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Early stopping checks `on_validation_end` ([#1458](https://github.com/PyTorchLightning/pytorch-lightning/pull/1458))

- Attribute `best_model_path` to `ModelCheckpoint` for storing and later retrieving the path to the best saved model file ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799))

### Changed

- Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729))
Expand All @@ -26,10 +28,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Allow passing model hyperparameters as complete kwarg list ([#1896](https://github.com/PyTorchLightning/pytorch-lightning/pull/1896))

- Renamed `ModelCheckpoint`'s attributes `best` to `best_model_score` and `kth_best_model` to `kth_best_model_path` ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799))

- Re-Enable Logger's `ImportError`s ([#1938](https://github.com/PyTorchLightning/pytorch-lightning/pull/1938))

### Deprecated

- Deprecated `ModelCheckpoint`'s attributes `best` and `kth_best_model` ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799))

- Dropped official support/testing for older PyTorch versions <1.3 ([#1917](https://github.com/PyTorchLightning/pytorch-lightning/pull/1917))

### Removed
Expand Down
46 changes: 35 additions & 11 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

class ModelCheckpoint(Callback):
r"""
Save the model after every epoch.
Save the model after every epoch if it improves.
After training finishes, use :attr:`best_model_path` to retrieve the path to the
best checkpoint file and :attr:`best_model_score` to retrieve its score.
Args:
filepath: path to save the model file.
Expand Down Expand Up @@ -81,6 +84,13 @@ class ModelCheckpoint(Callback):
... filepath='my/path/sample-mnist_{epoch:02d}-{val_loss:.2f}'
... )
# retrieve the best checkpoint after training
checkpoint_callback = ModelCheckpoint(filepath='my/path/')
trainer = Trainer(checkpoint_callback=checkpoint_callback)
model = ...
trainer.fit(model)
checkpoint_callback.best_model_path
"""

def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', verbose: bool = False,
Expand Down Expand Up @@ -112,8 +122,9 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
self.prefix = prefix
self.best_k_models = {}
# {filename: monitor}
self.kth_best_model = ''
self.best = 0
self.kth_best_model_path = ''
self.best_model_score = 0
self.best_model_path = ''
self.save_function = None

torch_inf = torch.tensor(np.Inf)
Expand All @@ -131,6 +142,18 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve

self.kth_value, self.mode = mode_dict[mode]

@property
def best(self):
rank_zero_warn("Attribute `best` has been renamed to `best_model_score` since v0.8.0"
" and will be removed in v0.10.0", DeprecationWarning)
return self.best_model_score

@property
def kth_best_model(self):
rank_zero_warn("Attribute `kth_best_model` has been renamed to `kth_best_model_path` since v0.8.0"
" and will be removed in v0.10.0", DeprecationWarning)
return self.kth_best_model_path

def _del_model(self, filepath):
if os.path.isfile(filepath):
os.remove(filepath)
Expand Down Expand Up @@ -162,7 +185,7 @@ def check_monitor_top_k(self, current):
"max": torch.gt,
}[self.mode]

return monitor_op(current, self.best_k_models[self.kth_best_model])
return monitor_op(current, self.best_k_models[self.kth_best_model_path])

def format_checkpoint_name(self, epoch, metrics, ver=None):
"""Generate a filename according to the defined template.
Expand Down Expand Up @@ -258,25 +281,26 @@ def _do_check_save(self, filepath, current, epoch):

del_list = []
if len(self.best_k_models) == self.save_top_k and self.save_top_k > 0:
delpath = self.kth_best_model
self.best_k_models.pop(self.kth_best_model)
delpath = self.kth_best_model_path
self.best_k_models.pop(self.kth_best_model_path)
del_list.append(delpath)

self.best_k_models[filepath] = current
if len(self.best_k_models) == self.save_top_k:
# monitor dict has reached k elements
_op = max if self.mode == 'min' else min
self.kth_best_model = _op(self.best_k_models,
key=self.best_k_models.get)
self.kth_value = self.best_k_models[self.kth_best_model]
self.kth_best_model_path = _op(self.best_k_models,
key=self.best_k_models.get)
self.kth_value = self.best_k_models[self.kth_best_model_path]

_op = min if self.mode == 'min' else max
self.best = _op(self.best_k_models.values())
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get)
self.best_model_score = self.best_k_models[self.best_model_path]

if self.verbose > 0:
log.info(
f'\nEpoch {epoch:05d}: {self.monitor} reached'
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
f' {current:0.5f} (best {self.best_model_score:0.5f}), saving model to'
f' {filepath} as top {self.save_top_k}')
self._save_model(filepath)

Expand Down
20 changes: 15 additions & 5 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:

if not weights_only:
if self.checkpoint_callback:
checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best
checkpoint['checkpoint_callback_best_model_score'] = self.checkpoint_callback.best_model_score
checkpoint['checkpoint_callback_best_model_path'] = self.checkpoint_callback.best_model_path

if self.early_stop_callback:
checkpoint['early_stop_callback_wait'] = self.early_stop_callback.wait
Expand Down Expand Up @@ -401,10 +402,19 @@ def restore_training_state(self, checkpoint):
' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.'
)

if self.checkpoint_callback is not None and self.checkpoint_callback is not False:
self.checkpoint_callback.best = checkpoint['checkpoint_callback_best']

if self.early_stop_callback is not None and self.early_stop_callback is not False:
if self.checkpoint_callback:
if 'checkpoint_callback_best_model_score' in checkpoint:
self.checkpoint_callback.best_model_score = checkpoint['checkpoint_callback_best_model_score']
else:
# Old naming until version 0.7.6
rank_zero_warn(
'Loading a checkpoint created with an old version of Lightning; '
'this will not be supported in the future.'
)
self.checkpoint_callback.best_model_score = checkpoint['checkpoint_callback_best']
self.checkpoint_callback.best_model_path = checkpoint['checkpoint_callback_best_model_path']

if self.early_stop_callback:
self.early_stop_callback.wait = checkpoint['early_stop_callback_wait']
self.early_stop_callback.patience = checkpoint['early_stop_callback_patience']

Expand Down

0 comments on commit c33daaf

Please sign in to comment.