diff --git a/CHANGELOG.md b/CHANGELOG.md index 26761ab9c195d..e74f9787ed0ab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Early stopping checks `on_validation_end` ([#1458](https://github.com/PyTorchLightning/pytorch-lightning/pull/1458)) +- Attribute `best_model_path` to `ModelCheckpoint` for storing and later retrieving the path to the best saved model file ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799)) + ### Changed - Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729)) @@ -26,10 +28,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Allow passing model hyperparameters as complete kwarg list ([#1896](https://github.com/PyTorchLightning/pytorch-lightning/pull/1896)) +- Renamed `ModelCheckpoint`'s attributes `best` to `best_model_score` and `kth_best_model` to `kth_best_model_path` ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799)) + - Re-Enable Logger's `ImportError`s ([#1938](https://github.com/PyTorchLightning/pytorch-lightning/pull/1938)) ### Deprecated +- Deprecated `ModelCheckpoint`'s attributes `best` and `kth_best_model` ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799)) + - Dropped official support/testing for older PyTorch versions <1.3 ([#1917](https://github.com/PyTorchLightning/pytorch-lightning/pull/1917)) ### Removed diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index dcce9d23d9054..9336fe309889a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -20,7 +20,10 @@ class ModelCheckpoint(Callback): r""" - Save the model after every epoch. + Save the model after every epoch if it improves. + + After training finishes, use :attr:`best_model_path` to retrieve the path to the + best checkpoint file and :attr:`best_model_score` to retrieve its score. Args: filepath: path to save the model file. @@ -81,6 +84,13 @@ class ModelCheckpoint(Callback): ... filepath='my/path/sample-mnist_{epoch:02d}-{val_loss:.2f}' ... ) + # retrieve the best checkpoint after training + checkpoint_callback = ModelCheckpoint(filepath='my/path/') + trainer = Trainer(checkpoint_callback=checkpoint_callback) + model = ... + trainer.fit(model) + checkpoint_callback.best_model_path + """ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', verbose: bool = False, @@ -112,8 +122,9 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve self.prefix = prefix self.best_k_models = {} # {filename: monitor} - self.kth_best_model = '' - self.best = 0 + self.kth_best_model_path = '' + self.best_model_score = 0 + self.best_model_path = '' self.save_function = None torch_inf = torch.tensor(np.Inf) @@ -131,6 +142,18 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve self.kth_value, self.mode = mode_dict[mode] + @property + def best(self): + rank_zero_warn("Attribute `best` has been renamed to `best_model_score` since v0.8.0" + " and will be removed in v0.10.0", DeprecationWarning) + return self.best_model_score + + @property + def kth_best_model(self): + rank_zero_warn("Attribute `kth_best_model` has been renamed to `kth_best_model_path` since v0.8.0" + " and will be removed in v0.10.0", DeprecationWarning) + return self.kth_best_model_path + def _del_model(self, filepath): if os.path.isfile(filepath): os.remove(filepath) @@ -162,7 +185,7 @@ def check_monitor_top_k(self, current): "max": torch.gt, }[self.mode] - return monitor_op(current, self.best_k_models[self.kth_best_model]) + return monitor_op(current, self.best_k_models[self.kth_best_model_path]) def format_checkpoint_name(self, epoch, metrics, ver=None): """Generate a filename according to the defined template. @@ -258,25 +281,26 @@ def _do_check_save(self, filepath, current, epoch): del_list = [] if len(self.best_k_models) == self.save_top_k and self.save_top_k > 0: - delpath = self.kth_best_model - self.best_k_models.pop(self.kth_best_model) + delpath = self.kth_best_model_path + self.best_k_models.pop(self.kth_best_model_path) del_list.append(delpath) self.best_k_models[filepath] = current if len(self.best_k_models) == self.save_top_k: # monitor dict has reached k elements _op = max if self.mode == 'min' else min - self.kth_best_model = _op(self.best_k_models, - key=self.best_k_models.get) - self.kth_value = self.best_k_models[self.kth_best_model] + self.kth_best_model_path = _op(self.best_k_models, + key=self.best_k_models.get) + self.kth_value = self.best_k_models[self.kth_best_model_path] _op = min if self.mode == 'min' else max - self.best = _op(self.best_k_models.values()) + self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) + self.best_model_score = self.best_k_models[self.best_model_path] if self.verbose > 0: log.info( f'\nEpoch {epoch:05d}: {self.monitor} reached' - f' {current:0.5f} (best {self.best:0.5f}), saving model to' + f' {current:0.5f} (best {self.best_model_score:0.5f}), saving model to' f' {filepath} as top {self.save_top_k}') self._save_model(filepath) diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index e55fb105051e6..fd0385cde4b34 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -330,7 +330,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: if not weights_only: if self.checkpoint_callback: - checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best + checkpoint['checkpoint_callback_best_model_score'] = self.checkpoint_callback.best_model_score + checkpoint['checkpoint_callback_best_model_path'] = self.checkpoint_callback.best_model_path if self.early_stop_callback: checkpoint['early_stop_callback_wait'] = self.early_stop_callback.wait @@ -401,10 +402,19 @@ def restore_training_state(self, checkpoint): ' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.' ) - if self.checkpoint_callback is not None and self.checkpoint_callback is not False: - self.checkpoint_callback.best = checkpoint['checkpoint_callback_best'] - - if self.early_stop_callback is not None and self.early_stop_callback is not False: + if self.checkpoint_callback: + if 'checkpoint_callback_best_model_score' in checkpoint: + self.checkpoint_callback.best_model_score = checkpoint['checkpoint_callback_best_model_score'] + else: + # Old naming until version 0.7.6 + rank_zero_warn( + 'Loading a checkpoint created with an old version of Lightning; ' + 'this will not be supported in the future.' + ) + self.checkpoint_callback.best_model_score = checkpoint['checkpoint_callback_best'] + self.checkpoint_callback.best_model_path = checkpoint['checkpoint_callback_best_model_path'] + + if self.early_stop_callback: self.early_stop_callback.wait = checkpoint['early_stop_callback_wait'] self.early_stop_callback.patience = checkpoint['early_stop_callback_patience']