Skip to content

Commit

Permalink
Force ModelCheckpoint callback to run last (#5731)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Feb 3, 2021
1 parent 630a88a commit 9555043
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -113,6 +113,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Made `LightningModule.global_rank`, `LightningModule.local_rank` and `LightningModule.logger` read-only properties ([#5730](https://github.com/PyTorchLightning/pytorch-lightning/pull/5730))


- Forced `ModelCheckpoint` callbacks to run after all others to guarantee all states are saved to the checkpoint ([#5731](https://github.com/PyTorchLightning/pytorch-lightning/pull/5731))


- Refactored Accelerators and Plugins
* Added base classes for plugins ([#5715](https://github.com/PyTorchLightning/pytorch-lightning/pull/5715))
* Added parallel plugins for DP, DDP, DDPSpawn, DDP2 and Horovod ([#5714](https://github.com/PyTorchLightning/pytorch-lightning/pull/5714))
Expand Down
4 changes: 3 additions & 1 deletion docs/source/common/trainer.rst
Expand Up @@ -515,7 +515,9 @@ callbacks
|
Add a list of :class:`~pytorch_lightning.callbacks.Callback`.
Add a list of :class:`~pytorch_lightning.callbacks.Callback`. Callbacks run sequentially in the order defined here
with the exception of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks which run
after all others to ensure all states are saved to the checkpoints.
.. code-block:: python
Expand Down
24 changes: 22 additions & 2 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Union
from typing import List, Union

from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.utilities import rank_zero_warn
Expand Down Expand Up @@ -46,13 +46,16 @@ def on_trainer_init(
self.trainer.callbacks = callbacks or []

# configure checkpoint callback
# it is important that this is the last callback to run
# pass through the required args to figure out defaults
self.configure_checkpoint_callbacks(checkpoint_callback)

# init progress bar
self.trainer._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position)

# push all checkpoint callbacks to the end
# it is important that these are the last callbacks to run
self.trainer.callbacks = self._reorder_callbacks(self.trainer.callbacks)

def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpoint, bool]):
if isinstance(checkpoint_callback, ModelCheckpoint):
# TODO: deprecated, remove this block in v1.3.0
Expand Down Expand Up @@ -104,3 +107,20 @@ def attach_model_logging_functions(self, model):
for callback in self.trainer.callbacks:
callback.log = model.log
callback.log_dict = model.log_dict

@staticmethod
def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
"""
Moves all ModelCheckpoint callbacks to the end of the list. The sequential order within the group of
checkpoint callbacks is preserved, as well as the order of all other callbacks.
Args:
callbacks: A list of callbacks.
Return:
A new list in which the last elements are ModelCheckpoints if there were any present in the
input.
"""
checkpoints = [c for c in callbacks if isinstance(c, ModelCheckpoint)]
not_checkpoints = [c for c in callbacks if not isinstance(c, ModelCheckpoint)]
return not_checkpoints + checkpoints
4 changes: 2 additions & 2 deletions tests/models/test_restore.py
Expand Up @@ -141,7 +141,7 @@ def test_callbacks_references_resume_from_checkpoint(tmpdir):
# initial training
checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True)
trainer = Trainer(**args, callbacks=[checkpoint])
assert checkpoint is trainer.callbacks[0] is trainer.checkpoint_callback
assert checkpoint is trainer.callbacks[-1] is trainer.checkpoint_callback
trainer.fit(model)

# resumed training
Expand All @@ -150,7 +150,7 @@ def test_callbacks_references_resume_from_checkpoint(tmpdir):
# precedence over the one in the last.ckpt file
trainer = Trainer(**args, callbacks=[new_checkpoint], resume_from_checkpoint=str(tmpdir / "last.ckpt"))
assert checkpoint is not new_checkpoint
assert new_checkpoint is trainer.callbacks[0] is trainer.checkpoint_callback
assert new_checkpoint is trainer.callbacks[-1] is trainer.checkpoint_callback
trainer.fit(model)


Expand Down
55 changes: 55 additions & 0 deletions tests/trainer/connectors/test_callback_connector.py
@@ -0,0 +1,55 @@
from unittest.mock import Mock

import torch

from pytorch_lightning import Trainer, Callback
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, ProgressBar
from tests.base import BoringModel


def test_checkpoint_callbacks_are_last(tmpdir):
""" Test that checkpoint callbacks always get moved to the end of the list, with preserved order. """
checkpoint1 = ModelCheckpoint(tmpdir)
checkpoint2 = ModelCheckpoint(tmpdir)
lr_monitor = LearningRateMonitor()
progress_bar = ProgressBar()

model = Mock()
model.configure_callbacks.return_value = []
trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, checkpoint2])
assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2]


class StatefulCallback0(Callback):

def on_save_checkpoint(self, trainer, pl_module):
return {"content0": 0}


class StatefulCallback1(Callback):

def on_save_checkpoint(self, trainer, pl_module):
return {"content1": 1}


def test_all_callback_states_saved_before_checkpoint_callback(tmpdir):
""" Test that all callback states get saved even if the ModelCheckpoint is not given as last. """

callback0 = StatefulCallback0()
callback1 = StatefulCallback1()
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename="all_states")
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
limit_val_batches=1,
callbacks=[callback0, checkpoint_callback, callback1]
)
trainer.fit(model)

ckpt = torch.load(str(tmpdir / "all_states.ckpt"))
state0 = ckpt["callbacks"][type(callback0)]
state1 = ckpt["callbacks"][type(callback1)]
assert "content0" in state0 and state0["content0"] == 0
assert "content1" in state1 and state1["content1"] == 1
assert type(checkpoint_callback) in ckpt["callbacks"]

0 comments on commit 9555043

Please sign in to comment.