Skip to content

Commit

Permalink
Add an additional attribute to ModelCheckpoint to keep track of the b…
Browse files Browse the repository at this point in the history
…est model's path

Currently, only the best metric value is directly tracked. This new attribute will help in uses cases where the trained model needs to be used or tracked right after training.
  • Loading branch information
kepler committed May 19, 2020
1 parent ac76dfc commit ed3e6f3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
4 changes: 3 additions & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
# {filename: monitor}
self.kth_best_model = ''
self.best = 0
self.best_model = ''
self.save_function = None

torch_inf = torch.tensor(np.Inf)
Expand Down Expand Up @@ -265,7 +266,8 @@ def _do_check_save(self, filepath, current, epoch):
self.kth_value = self.best_k_models[self.kth_best_model]

_op = min if self.mode == 'min' else max
self.best = _op(self.best_k_models.values())
self.best_model = _op(self.best_k_models, key=self.best_k_models.get)
self.best = self.best_k_models[self.best_model]

if self.verbose > 0:
log.info(
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def dump_checkpoint(self, weights_only: bool = False):
if not weights_only:
if self.checkpoint_callback:
checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best
checkpoint['checkpoint_callback_best_model'] = self.checkpoint_callback.best_model

if self.early_stop_callback:
checkpoint['early_stop_callback_wait'] = self.early_stop_callback.wait
Expand Down Expand Up @@ -398,10 +399,11 @@ def restore_training_state(self, checkpoint):
' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.'
)

if self.checkpoint_callback is not None and self.checkpoint_callback is not False:
if self.checkpoint_callback:
self.checkpoint_callback.best = checkpoint['checkpoint_callback_best']
self.checkpoint_callback.best_model = checkpoint['checkpoint_callback_best_model']

if self.early_stop_callback is not None and self.early_stop_callback is not False:
if self.early_stop_callback:
self.early_stop_callback.wait = checkpoint['early_stop_callback_wait']
self.early_stop_callback.patience = checkpoint['early_stop_callback_patience']

Expand Down

0 comments on commit ed3e6f3

Please sign in to comment.