Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2/n] add Stateful functionality support for Callbacks #12232

Merged
merged 10 commits into from Mar 19, 2022
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -149,6 +149,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added optional `storage_options` argument to `Trainer.save_checkpoint()` to pass to custom `CheckpointIO` implementations ([#11891](https://github.com/PyTorchLightning/pytorch-lightning/pull/11891))


- Added `Callback.state_dict()` and `Callback.load_state_dict()` methods ([#12232](https://github.com/PyTorchLightning/pytorch-lightning/pull/12232))


### Changed

- Drop PyTorch 1.7 support ([#12191](https://github.com/PyTorchLightning/pytorch-lightning/pull/12191))
Expand Down
Expand Up @@ -223,6 +223,7 @@ def restore_callbacks(self) -> None:
return

self.trainer._call_callbacks_on_load_checkpoint(self._loaded_checkpoint)
self.trainer._call_callbacks_load_state_dict(self._loaded_checkpoint)

def restore_loops(self) -> None:
"""Restores the loop progress from the pre-loaded checkpoint.
Expand Down Expand Up @@ -344,7 +345,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:

if not weights_only:
# dump callbacks
checkpoint["callbacks"] = self.trainer._call_callbacks_on_save_checkpoint(checkpoint)
checkpoint["callbacks"] = self.trainer._call_callbacks_state_dict()

optimizer_states = []
for i, optimizer in enumerate(self.trainer.optimizers):
Expand Down Expand Up @@ -386,6 +387,12 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
checkpoint[datamodule.__class__.__qualname__] = datamodule_state_dict

# on_save_checkpoint hooks
if not weights_only:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
# if state is returned from callback's on_save_checkpoint
# it overrides the returned state from callback's state_dict
# support for returning state in on_save_checkpoint
# will be removed in v1.8
self.trainer._call_callbacks_on_save_checkpoint(checkpoint)
model.on_save_checkpoint(checkpoint)
if datamodule is not None:
datamodule.on_save_checkpoint(checkpoint)
Expand Down
36 changes: 29 additions & 7 deletions pytorch_lightning/trainer/trainer.py
Expand Up @@ -1661,19 +1661,28 @@ def _on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader
else:
callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx)

def _call_callbacks_on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]:
"""Called when saving a model checkpoint.
def _call_callbacks_state_dict(self) -> Dict[str, dict]:
"""Called when saving a model checkpoint, calls and returns every callback's `state_dict`, keyed by
`Callback.state_key`."""
callback_state_dicts = {}
for callback in self.callbacks:
state_dict = callback.state_dict()
if state_dict:
callback_state_dicts[callback.state_key] = state_dict
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
return callback_state_dicts

def _call_callbacks_on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
"""Called when saving a model checkpoint, calls every callback's `on_save_checkpoint` hook.

Calls every callback's `on_save_checkpoint` hook. We have a dedicated function for this rather than using
`_call_callback_hooks` because we have special logic for returning callback_states.
Will be removed in v1.8: If state is returned, we insert the callback state into
tchaton marked this conversation as resolved.
Show resolved Hide resolved
jjenniferdai marked this conversation as resolved.
Show resolved Hide resolved
``checkpoint["callbacks"][Callback.state_key]``. It overrides ``state_dict`` if already present.
"""
callback_states = {}
for callback in self.callbacks:
# TODO: Add profiling for on_save_checkpoint hook
state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
if state:
callback_states[callback.state_key] = state
return callback_states
# TODO: Add deprecation warning if state is returned (see reference PR #11887)
checkpoint["callbacks"][callback.state_key] = state
ananthsub marked this conversation as resolved.
Show resolved Hide resolved

def _call_callbacks_on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
"""Called when loading a model checkpoint.
Expand Down Expand Up @@ -1703,6 +1712,19 @@ def _call_callbacks_on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None
# TODO: Add profiling for on_load_checkpoint hook
callback.on_load_checkpoint(self, self.lightning_module, state)

def _call_callbacks_load_state_dict(self, checkpoint: Dict[str, Any]) -> None:
"""Called when loading a model checkpoint, calls every callback's `load_state_dict`."""
callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks")

if callback_states is None:
return

for callback in self.callbacks:
state = callback_states.get(callback.state_key, callback_states.get(callback._legacy_state_key))
if state:
state = deepcopy(state)
callback.load_state_dict(state)
ananthsub marked this conversation as resolved.
Show resolved Hide resolved

def _call_strategy_hook(
self,
hook_name: str,
Expand Down
95 changes: 90 additions & 5 deletions tests/callbacks/test_callbacks.py
Expand Up @@ -115,15 +115,16 @@ def __init__(self, state):
def state_key(self):
return type(self)

def on_save_checkpoint(self, *args):
def state_dict(self):
return {"state": self.state}

def on_load_checkpoint(self, trainer, pl_module, callback_state):
self.state = callback_state["state"]
def load_state_dict(self, state_dict) -> None:
self.state = state_dict["state"]


def test_resume_callback_state_saved_by_type(tmpdir):
"""Test that a legacy checkpoint that didn't use a state key before can still be loaded."""
def test_resume_callback_state_saved_by_type_stateful(tmpdir):
"""Test that a legacy checkpoint that didn't use a state key before can still be loaded, using
state_dict/load_state_dict."""
model = BoringModel()
callback = OldStatefulCallback(state=111)
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback])
Expand All @@ -137,6 +138,44 @@ def test_resume_callback_state_saved_by_type(tmpdir):
assert callback.state == 111


class OldStatefulCallbackHooks(Callback):
def __init__(self, state):
self.state = state

@property
def state_key(self):
return type(self)

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
return {"state": self.state}

def on_load_checkpoint(self, trainer, pl_module, callback_state):
self.state = callback_state["state"]


def test_resume_callback_state_saved_by_type_hooks(tmpdir):
"""Test that a legacy checkpoint that didn't use a state key before can still be loaded, using deprecated
on_save/load_checkpoint signatures."""
# TODO: remove old on_save/load_checkpoint signature support in v1.8
# in favor of Stateful and new on_save/load_checkpoint signatures
# on_save_checkpoint() -> dict, on_load_checkpoint(callback_state)
# will become
# on_save_checkpoint() -> None and on_load_checkpoint(checkpoint)
model = BoringModel()
callback = OldStatefulCallbackHooks(state=111)
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback])
# TODO: catch deprecated call after deprecations introduced (see reference PR #11887)
trainer.fit(model)
ckpt_path = Path(trainer.checkpoint_callback.best_model_path)
assert ckpt_path.exists()

callback = OldStatefulCallbackHooks(state=222)
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback])
# TODO: catch deprecated call after deprecations introduced (see reference PR #11887)
trainer.fit(model, ckpt_path=ckpt_path)
assert callback.state == 111


def test_resume_incomplete_callbacks_list_warning(tmpdir):
model = BoringModel()
callback0 = ModelCheckpoint(monitor="epoch")
Expand Down Expand Up @@ -164,3 +203,49 @@ def test_resume_incomplete_callbacks_list_warning(tmpdir):
)
with no_warning_call(UserWarning, match="Please add the following callbacks:"):
trainer.fit(model, ckpt_path=ckpt_path)


class AllStatefulCallback(Callback):
def __init__(self, state):
self.state = state

@property
def state_key(self):
return type(self)

def state_dict(self):
return {"new_state": self.state}

def load_state_dict(self, state_dict):
assert state_dict == {"old_state_precedence": 10}
self.state = state_dict["old_state_precedence"]

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
return {"old_state_precedence": 10}

def on_load_checkpoint(self, trainer, pl_module, callback_state):
assert callback_state == {"old_state_precedence": 10}
self.old_state_precedence = callback_state["old_state_precedence"]


def test_resume_callback_state_all(tmpdir):
"""Test on_save/load_checkpoint state precedence over state_dict/load_state_dict until v1.8 removal."""
jjenniferdai marked this conversation as resolved.
Show resolved Hide resolved
# TODO: remove old on_save/load_checkpoint signature support in v1.8
# in favor of Stateful and new on_save/load_checkpoint signatures
# on_save_checkpoint() -> dict, on_load_checkpoint(callback_state)
# will become
# on_save_checkpoint() -> None and on_load_checkpoint(checkpoint)
model = BoringModel()
callback = AllStatefulCallback(state=111)
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback])
# TODO: catch deprecated call after deprecations introduced (see reference PR #11887)
trainer.fit(model)
ckpt_path = Path(trainer.checkpoint_callback.best_model_path)
assert ckpt_path.exists()

callback = AllStatefulCallback(state=222)
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback])
# TODO: catch deprecated call after deprecations introduced (see reference PR #11887)
trainer.fit(model, ckpt_path=ckpt_path)
assert callback.state == 10
assert callback.old_state_precedence == 10
3 changes: 3 additions & 0 deletions tests/models/test_hooks.py
Expand Up @@ -553,6 +553,7 @@ def training_step(self, batch, batch_idx):
dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)),
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
# `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_train_epoch_end`
dict(name="Callback.state_dict"),
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)),
dict(name="on_save_checkpoint", args=(saved_ckpt,)),
dict(name="on_train_epoch_end"),
Expand Down Expand Up @@ -626,6 +627,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
dict(name="setup", kwargs=dict(stage="fit")),
dict(name="on_load_checkpoint", args=(loaded_ckpt,)),
dict(name="Callback.on_load_checkpoint", args=(trainer, model, {"foo": True})),
dict(name="Callback.load_state_dict", args=({"foo": True},)),
dict(name="configure_sharded_model"),
dict(name="Callback.on_configure_sharded_model", args=(trainer, model)),
dict(name="configure_optimizers"),
Expand All @@ -647,6 +649,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
*model._train_batch(trainer, model, steps_after_reload, current_batch=1, current_epoch=1),
dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)),
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
dict(name="Callback.state_dict"),
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)),
dict(name="on_save_checkpoint", args=(saved_ckpt,)),
dict(name="on_train_epoch_end"),
Expand Down