Skip to content

Commit

Permalink
Add function to remove checkpoint to allow override for extended clas…
Browse files Browse the repository at this point in the history
…ses (#16067)
  • Loading branch information
SeanNaren authored and awaelchli committed Dec 19, 2022
1 parent aa3dd54 commit bba2bce
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[
previous, self.last_model_path = self.last_model_path, filepath
self._save_checkpoint(trainer, filepath)
if previous and previous != filepath:
trainer.strategy.remove_checkpoint(previous)
self._remove_checkpoint(trainer, previous)

def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
assert self.monitor
Expand All @@ -668,7 +668,7 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate
previous, self.best_model_path = self.best_model_path, filepath
self._save_checkpoint(trainer, filepath)
if self.save_top_k == 1 and previous and previous != filepath:
trainer.strategy.remove_checkpoint(previous)
self._remove_checkpoint(trainer, previous)

def _update_best_and_save(
self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]
Expand Down Expand Up @@ -710,7 +710,7 @@ def _update_best_and_save(
self._save_checkpoint(trainer, filepath)

if del_filepath is not None and filepath != del_filepath:
trainer.strategy.remove_checkpoint(del_filepath)
self._remove_checkpoint(trainer, del_filepath)

def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
"""Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML
Expand All @@ -727,3 +727,7 @@ def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool:
state to diverge between ranks."""
exists = self._fs.exists(filepath)
return trainer.strategy.broadcast(exists)

def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
"""Calls the strategy to remove the checkpoint file."""
trainer.strategy.remove_checkpoint(filepath)

0 comments on commit bba2bce

Please sign in to comment.