Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2 / 3] improvements to saving and loading callback state #7187

Merged
merged 67 commits into from
Aug 24, 2021
Merged
Show file tree
Hide file tree
Changes from 60 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
89131c2
class name as key
awaelchli Apr 8, 2021
63fb983
string state identifier
awaelchli Apr 8, 2021
7dc218a
add legacy state loading
awaelchli Apr 8, 2021
04b588b
update test
awaelchli Apr 8, 2021
bb11e28
update tests
awaelchli Apr 8, 2021
271360c
Merge branch 'master' into bugfix/callback-state
awaelchli Apr 15, 2021
20b66f0
Merge branch 'master' into bugfix/callback-state
awaelchli Apr 16, 2021
f585a28
Merge branch 'master' into bugfix/callback-state
awaelchli Apr 16, 2021
e1d518b
Merge branch 'master' into bugfix/callback-state
awaelchli Apr 17, 2021
880066b
Merge branch 'master' into bugfix/callback-state
awaelchli Apr 21, 2021
0259ecb
flake8
awaelchli Apr 21, 2021
d56e5e4
add test
awaelchli Apr 21, 2021
24a2cc8
Merge branch 'master' into bugfix/callback-state
awaelchli Apr 22, 2021
81b5e36
Merge branch 'master' into bugfix/callback-state
carmocca Apr 22, 2021
72ba440
Apply suggestions from code review
awaelchli Apr 22, 2021
79d8568
improve test
awaelchli Apr 22, 2021
d9ea8c1
flake
awaelchli Apr 22, 2021
98f7fe6
Merge branch 'master' into bugfix/callback-state
awaelchli Apr 22, 2021
68f571c
Merge branch 'master' into bugfix/callback-state
awaelchli Jul 26, 2021
0851f0d
fix merge
awaelchli Jul 26, 2021
82d5658
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 26, 2021
334fd4a
use qualname
awaelchli Jul 26, 2021
090b169
Merge remote-tracking branch 'origin/bugfix/callback-state' into bugf…
awaelchli Jul 26, 2021
f144fd1
rename state_id
awaelchli Jul 26, 2021
6154986
fix diff
awaelchli Jul 26, 2021
2c0c707
unique identifiers
awaelchli Apr 23, 2021
9f9a76d
update tests
awaelchli Apr 23, 2021
31a7737
unused import
awaelchli Apr 23, 2021
8eec798
rename state_id
awaelchli Jul 26, 2021
291c9fe
rename state_id
awaelchli Jul 26, 2021
2308c5d
rename state_id
awaelchli Jul 26, 2021
4cda723
fix merge error
awaelchli Jul 26, 2021
a92110e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 26, 2021
40e8827
remove print statements
awaelchli Jul 26, 2021
73ee9b9
Merge remote-tracking branch 'origin/feature/callback-state/part3' in…
awaelchli Jul 26, 2021
9472d1f
update fx validator
awaelchli Jul 26, 2021
3c1c8c0
Merge branch 'master' into feature/callback-state/part3
awaelchli Jul 28, 2021
7579e92
formatting, updates from master
awaelchli Jul 28, 2021
2fa8ff9
fix test
awaelchli Jul 28, 2021
ad9ba3e
remove redundant change to logger connector
awaelchli Jul 28, 2021
c136662
use helper function for get members
awaelchli Jul 28, 2021
7f9aa47
update with master
awaelchli Jul 28, 2021
c4a0f15
rename state_identifier -> state id
awaelchli Jul 28, 2021
f741bcd
add changelog
awaelchli Jul 28, 2021
ace9c1d
add docs for persisting state
awaelchli Jul 28, 2021
6574796
update test
awaelchli Jul 29, 2021
87ce420
repr string representation
awaelchli Jul 31, 2021
eac1921
Revert "repr string representation"
awaelchli Jul 31, 2021
7ff6cc0
Merge branch 'master' into feature/callback-state/part3
awaelchli Aug 12, 2021
9023a1f
add mode
awaelchli Aug 12, 2021
76a81e3
update mode=min in test
awaelchli Aug 12, 2021
5bb4a48
repr everywhere
awaelchli Aug 12, 2021
8314c62
adapt tests to repr
awaelchli Aug 12, 2021
6e7c4b3
adjust test
awaelchli Aug 12, 2021
8ea9dea
add every_n_train_steps, every_n_epochs, train_time_interval
awaelchli Aug 13, 2021
62864ec
update docs
awaelchli Aug 13, 2021
15a6492
add save_on_train_epoch_end
awaelchli Aug 13, 2021
49f376b
update docs with tip
awaelchli Aug 13, 2021
77e7027
black formatting with line break
awaelchli Aug 13, 2021
333b1b3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2021
f0cc8b4
Apply suggestions from code review
awaelchli Aug 16, 2021
0498577
fix syntax error
awaelchli Aug 16, 2021
0f56fea
Merge branch 'master' into feature/callback-state/part3
awaelchli Aug 16, 2021
182ea63
Merge branch 'master' into feature/callback-state/part3
Borda Aug 19, 2021
9260725
rename state_id -> state_key
awaelchli Aug 19, 2021
ebb2a73
fix blacken-docs precommit complaints
awaelchli Aug 19, 2021
184f412
Merge branch 'master' into feature/callback-state/part3
awaelchli Aug 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366))


- Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187))


### Changed

- Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770))
Expand Down
59 changes: 56 additions & 3 deletions docs/source/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,69 @@ Lightning has a few built-in callbacks.

----------

.. _Persisting Callback State:

Persisting State
----------------

Some callbacks require internal state in order to function properly. You can optionally
choose to persist your callback's state as part of model checkpoint files using the callback hooks
:meth:`~pytorch_lightning.callbacks.Callback.on_save_checkpoint` and :meth:`~pytorch_lightning.callbacks.Callback.on_load_checkpoint`.
However, you must follow two constraints:
Note that the returned state must be able to be pickled.

When your callback is meant to be used only as a singleton callback then implementing the above two hooks is enough
to persist state effectively. However, if passing multiple instances of the callback to the Trainer is supported, then
the callback must define a :attr:`~pytorch_lightning.callbacks.Callback.state_id` property in order for Lightning
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
to be able to distinguish the different states when loading the callback state. This concept is best illustrated by
the following example.

.. testcode::

class Counter(Callback):
def __init__(self, what="epochs", verbose=True):
self.what = what
self.verbose = verbose
self.state = {"epochs": 0, "batches": 0}

@property
def state_id(self):
# note: we do not include `verbose` here on purpose
return self._generate_state_id(what=self.what)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

def on_train_epoch_end(self, *args, **kwargs):
if self.what == "epochs":
self.state["epochs"] += 1

def on_train_batch_end(self, *args, **kwargs):
if self.what == "batches":
self.state["batches"] += 1

def on_load_checkpoint(self, trainer, pl_module, callback_state):
self.state.update(callback_state)

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
return self.state.copy()


# two callbacks of the same type are being used
trainer = Trainer(callbacks=[Counter(what="epochs"), Counter(what="batches")])

A Lightning checkpoint from this Trainer with the two stateful callbacks will include the following information:

.. code-block:: python

{
"state_dict": ...,
"callbacks": {
"Counter{'what': 'batches'}": {"batches": 32, "epochs": 0},
"Counter{'what': 'epochs'}": {"batches": 0, "epochs": 2},
...
}
}

1. Your returned state must be able to be pickled.
2. You can only use one instance of that class in the Trainer callbacks list. We don't support persisting state for multiple callbacks of the same class.
The implementation of a :attr:`~pytorch_lightning.callbacks.Callback.state_id` is essential here. If it were missing,
Lightning would not be able to disambiguate the state for these two callbacks, and :attr:`~pytorch_lightning.callbacks.Callback.state_id`
by default only defines the class name as the key, e.g., here ``Counter``.


Best Practices
Expand Down
10 changes: 10 additions & 0 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ def _legacy_state_id(self) -> Type["Callback"]:
"""State identifier for checkpoints saved prior to version 1.5.0."""
return type(self)

def _generate_state_id(self, **kwargs: Any) -> str:
"""
Formats a set of key-value pairs into a state id string with the callback class name prefixed.
Useful for defining a :attr:`state_id`.

Args:
**kwargs: A set of key-value pairs. Must be serializable to :class:`str`.
"""
return f"{self.__class__.__qualname__}{repr(kwargs)}"

def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called before configure sharded model"""

Expand Down
11 changes: 11 additions & 0 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ class EarlyStopping(Callback):
>>> from pytorch_lightning.callbacks import EarlyStopping
>>> early_stopping = EarlyStopping('val_loss')
>>> trainer = Trainer(callbacks=[early_stopping])

.. tip:: Saving and restoring multiple early stopping callbacks at the same time is supported under variation in the
following arguments:

*monitor, mode*

Read more: :ref:`Persisting Callback State`
"""
mode_dict = {"min": torch.lt, "max": torch.gt}

Expand Down Expand Up @@ -120,6 +127,10 @@ def __init__(
)
self.monitor = monitor or "early_stop_on"

@property
def state_id(self) -> str:
return self._generate_state_id(monitor=self.monitor, mode=self.mode)

def _validate_condition_metric(self, logs):
monitor_val = logs.get(self.monitor)

Expand Down
17 changes: 17 additions & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,12 @@ class ModelCheckpoint(Callback):
trainer.fit(model)
checkpoint_callback.best_model_path

.. tip:: Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the
following arguments:

*monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval, save_on_train_epoch_end*

Read more: :ref:`Persisting Callback State`
"""

CHECKPOINT_JOIN_CHAR = "-"
Expand Down Expand Up @@ -248,6 +254,17 @@ def __init__(
self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval, period)
self.__validate_init_configuration()

@property
def state_id(self) -> str:
return self._generate_state_id(
monitor=self.monitor,
mode=self.mode,
every_n_train_steps=self._every_n_train_steps,
every_n_epochs=self._every_n_epochs,
train_time_interval=self._train_time_interval,
save_on_train_epoch_end=self._save_on_train_epoch_end,
)

def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""
When pretrain routine starts we build the ckpt dir on the fly
Expand Down
10 changes: 9 additions & 1 deletion tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@
_logger = logging.getLogger(__name__)


def test_early_stopping_state_id():
early_stopping = EarlyStopping(monitor="val_loss")
assert early_stopping.state_id == "EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}"


class EarlyStoppingTestRestore(EarlyStopping):
# this class has to be defined outside the test function, otherwise we get pickle error
def __init__(self, expected_state, *args, **kwargs):
Expand Down Expand Up @@ -77,7 +82,10 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
# the checkpoint saves "epoch + 1"
early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1]
assert 4 == len(early_stop_callback.saved_states)
assert checkpoint["callbacks"]["EarlyStoppingTestRestore"] == early_stop_callback_state
assert (
checkpoint["callbacks"]["EarlyStoppingTestRestore{'monitor': 'train_loss', 'mode': 'min'}"]
== early_stop_callback_state
)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

# ensure state is reloaded properly (assertion in the callback)
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor="train_loss")
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from functools import partial

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import Callback, LambdaCallback
from tests.helpers.boring_model import BoringModel
from tests.models.test_hooks import get_members


def test_lambda_call(tmpdir):
Expand All @@ -32,7 +32,7 @@ def on_train_epoch_start(self):
def call(hook, *_, **__):
checker.add(hook)

hooks = {m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)}
hooks = get_members(Callback)
hooks_args = {h: partial(call, h) for h in hooks}
hooks_args["on_save_checkpoint"] = lambda *_: [checker.add("on_save_checkpoint")]

Expand Down
34 changes: 30 additions & 4 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@
from tests.helpers.runif import RunIf


def test_model_checkpoint_state_id():
early_stopping = ModelCheckpoint(monitor="val_loss")
assert (
early_stopping.state_id
== "ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
" 'train_time_interval': None, 'save_on_train_epoch_end': None}"
)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved


class LogInTwoMethods(BoringModel):
def training_step(self, batch, batch_idx):
out = super().training_step(batch, batch_idx)
Expand Down Expand Up @@ -148,7 +157,10 @@ def on_validation_epoch_end(self):
assert chk["epoch"] == epoch + 1
assert chk["global_step"] == limit_train_batches * (epoch + 1)

mc_specific_data = chk["callbacks"]["ModelCheckpoint"]
mc_specific_data = chk["callbacks"][
f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
" 'train_time_interval': None, 'save_on_train_epoch_end': True}"
]
assert mc_specific_data["dirpath"] == checkpoint.dirpath
assert mc_specific_data["monitor"] == monitor
assert mc_specific_data["current_score"] == score
Expand Down Expand Up @@ -259,7 +271,10 @@ def _make_assertions(epoch, ix, version=""):
expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num)
assert chk["global_step"] == expected_global_step

mc_specific_data = chk["callbacks"]["ModelCheckpoint"]
mc_specific_data = chk["callbacks"][
f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
" 'train_time_interval': None, 'save_on_train_epoch_end': False}"
]
assert mc_specific_data["dirpath"] == checkpoint.dirpath
assert mc_specific_data["monitor"] == monitor
assert mc_specific_data["current_score"] == score
Expand Down Expand Up @@ -857,7 +872,12 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):

assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"]
assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"]
assert ckpt_last["callbacks"]["ModelCheckpoint"] == ckpt_last_epoch["callbacks"]["ModelCheckpoint"]

ckpt_id = (
"ModelCheckpoint{'monitor': 'early_stop_on', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
" 'train_time_interval': None, 'save_on_train_epoch_end': True}"
)
assert ckpt_last["callbacks"][ckpt_id] == ckpt_last_epoch["callbacks"][ckpt_id]

# it is easier to load the model objects than to iterate over the raw dict of tensors
model_last_epoch = LogInTwoMethods.load_from_checkpoint(path_last_epoch)
Expand Down Expand Up @@ -1095,7 +1115,13 @@ def training_step(self, *args):
trainer.fit(TestModel())
assert model_checkpoint.current_score == 0.3
ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()]
ckpts = [ckpt["callbacks"]["ModelCheckpoint"] for ckpt in ckpts]
ckpts = [
ckpt["callbacks"][
"ModelCheckpoint{'monitor': 'foo', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
" 'train_time_interval': None, 'save_on_train_epoch_end': True}"
]
for ckpt in ckpts
]
assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3]


Expand Down
41 changes: 34 additions & 7 deletions tests/trainer/connectors/test_callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,28 +61,55 @@ def on_save_checkpoint(self, *args):


class StatefulCallback1(Callback):
def __init__(self, unique=None, other=None):
self._unique = unique
self._other = other

@property
def state_id(self):
return self._generate_state_id(unique=self._unique)

def on_save_checkpoint(self, *args):
return {"content1": 1}
return {"content1": self._unique}


def test_all_callback_states_saved_before_checkpoint_callback(tmpdir):
"""Test that all callback states get saved even if the ModelCheckpoint is not given as last."""
"""
Test that all callback states get saved even if the ModelCheckpoint is not given as last
and when there are multiple callbacks of the same type.
"""

callback0 = StatefulCallback0()
callback1 = StatefulCallback1()
callback1 = StatefulCallback1(unique="one")
callback2 = StatefulCallback1(unique="two", other=2)
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename="all_states")
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir, max_steps=1, limit_val_batches=1, callbacks=[callback0, checkpoint_callback, callback1]
default_root_dir=tmpdir,
max_steps=1,
limit_val_batches=1,
callbacks=[
callback0,
# checkpoint callback does not have to be at the end
checkpoint_callback,
# callback2 and callback3 have the same type
callback1,
callback2,
],
)
trainer.fit(model)

ckpt = torch.load(str(tmpdir / "all_states.ckpt"))
state0 = ckpt["callbacks"]["StatefulCallback0"]
state1 = ckpt["callbacks"]["StatefulCallback1"]
state1 = ckpt["callbacks"]["StatefulCallback1{'unique': 'one'}"]
state2 = ckpt["callbacks"]["StatefulCallback1{'unique': 'two'}"]
assert "content0" in state0 and state0["content0"] == 0
assert "content1" in state1 and state1["content1"] == 1
assert "ModelCheckpoint" in ckpt["callbacks"]
assert "content1" in state1 and state1["content1"] == "one"
assert "content1" in state2 and state2["content1"] == "two"
assert (
"ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
" 'train_time_interval': None, 'save_on_train_epoch_end': True}" in ckpt["callbacks"]
)


def test_attach_model_callbacks():
Expand Down