Skip to content

Commit

Permalink
Deprecate callback hooks on_pretrain_routine_{start,end} (#11794)
Browse files Browse the repository at this point in the history
Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
  • Loading branch information
krishnakalyan3 and rohitgr7 committed Feb 24, 2022
1 parent 00211c1 commit 29d5afb
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 33 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -418,6 +418,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `pytorch_lightning.utilities.warnings.LightningDeprecationWarning` in favor of `pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning`


- Deprecated `on_pretrain_routine_start` and `on_pretrain_routine_end` callback hooks in favor of `on_fit_start` ([#11794](https://github.com/PyTorchLightning/pytorch-lightning/pull/11794))


- Deprecated `agg_key_funcs` and `agg_default_func` parameters from `LightningLoggerBase` ([#11871](https://github.com/PyTorchLightning/pytorch-lightning/pull/11871))


Expand Down
12 changes: 0 additions & 12 deletions docs/source/extensions/callbacks.rst
Expand Up @@ -375,18 +375,6 @@ on_train_end
.. automethod:: pytorch_lightning.callbacks.Callback.on_train_end
:noindex:

on_pretrain_routine_start
~~~~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.callbacks.Callback.on_pretrain_routine_start
:noindex:

on_pretrain_routine_end
~~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.callbacks.Callback.on_pretrain_routine_end
:noindex:

on_validation_start
~~~~~~~~~~~~~~~~~~~

Expand Down
16 changes: 14 additions & 2 deletions pytorch_lightning/callbacks/base.py
Expand Up @@ -248,10 +248,22 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
"""Called when the train ends."""

def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the pretrain routine begins."""
r"""
.. deprecated:: v1.6
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use ``on_fit_start`` instead.
Called when the pretrain routine begins.
"""

def on_pretrain_routine_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the pretrain routine ends."""
r"""
.. deprecated:: v1.6
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use ``on_fit_start`` instead.
Called when the pretrain routine ends.
"""

def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the validation loop begins."""
Expand Down
10 changes: 4 additions & 6 deletions pytorch_lightning/callbacks/model_checkpoint.py
Expand Up @@ -248,19 +248,17 @@ def state_key(self) -> str:
)

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
self.__resolve_ckpt_dir(trainer)
if trainer.is_global_zero and stage == "fit":
self.__warn_if_dir_not_empty(self.dirpath)

# NOTE: setting these attributes needs to happen as early as possible BEFORE reloading callback states,
# because the attributes are part of the state_key which needs to be fully defined before reloading.
if self._save_on_train_epoch_end is None:
# if the user runs validation multiple times per training epoch or multiple training epochs without
# validation, then we run after validation instead of on train epoch end
self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1

def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""When pretrain routine starts we build the ckpt dir on the fly."""
self.__resolve_ckpt_dir(trainer)
if trainer.is_global_zero:
self.__warn_if_dir_not_empty(self.dirpath)

def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._last_time_checked = time.monotonic()

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/model_summary.py
Expand Up @@ -49,7 +49,7 @@ class ModelSummary(Callback):
def __init__(self, max_depth: int = 1) -> None:
self._max_depth: int = max_depth

def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self._max_depth:
return None

Expand Down
16 changes: 8 additions & 8 deletions pytorch_lightning/core/hooks.py
Expand Up @@ -182,19 +182,19 @@ def on_predict_model_eval(self) -> None:
self.trainer.model.eval()

def on_epoch_start(self) -> None:
r"""
.. deprecated:: v1.6 This hook was deprecated in v1.6 and will be removed in v1.8. Use
``on_<train/validation/test>_epoch_start`` instead.
"""Called when either of train/val/test epoch begins.
Called when either of train/val/test epoch begins.
.. deprecated:: v1.6
:meth:`on_epoch_start` has been deprecated in v1.6 and will be removed in v1.8.
Use ``on_<train/validation/test>_epoch_start`` instead.
"""

def on_epoch_end(self) -> None:
r"""
.. deprecated:: v1.6 This hook was deprecated in v1.6 and will be removed in v1.8. Use
``on_<train/validation/test>_epoch_end`` instead.
"""Called when either of train/val/test epoch ends.
Called when either of train/val/test epoch ends.
.. deprecated:: v1.6
:meth:`on_epoch_end` has been deprecated in v1.6 and will be removed in v1.8.
Use ``on_<train/validation/test>_epoch_end`` instead.
"""

def on_train_epoch_start(self) -> None:
Expand Down
7 changes: 6 additions & 1 deletion pytorch_lightning/trainer/configuration_validator.py
Expand Up @@ -348,7 +348,6 @@ def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None:
f"The `Callback.{hook}` hook was deprecated in v1.6 and"
f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead."
)

for hook, alternative_hook in (
["on_epoch_start", "on_<train/validation/test>_epoch_start"],
["on_epoch_end", "on_<train/validation/test>_epoch_end"],
Expand All @@ -358,3 +357,9 @@ def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None:
f"The `Callback.{hook}` hook was deprecated in v1.6 and"
f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead."
)
for hook in ("on_pretrain_routine_start", "on_pretrain_routine_end"):
if is_overridden(method_name=hook, instance=callback):
rank_zero_deprecation(
f"The `Callback.{hook}` hook has been deprecated in v1.6 and"
f" will be removed in v1.8. Please use `Callback.on_fit_start` instead."
)
38 changes: 37 additions & 1 deletion tests/deprecated_api/test_remove_1-8.py
Expand Up @@ -249,7 +249,7 @@ def test_v1_8_0_deprecate_trainer_callback_hook_mixin():
)
model = BoringModel()
# need to attach model to trainer for testing of `on_pretrain_routine_start`
trainer.fit(model)
trainer.strategy.connect(model)
for method_name in methods_with_self:
fn = getattr(trainer, method_name, None)
with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"):
Expand Down Expand Up @@ -572,3 +572,39 @@ def agg_and_log_metrics(self, metrics, step):
Trainer(logger=[logger, logger3])
# Should have no deprecation warning
Trainer(logger=[logger2, logger3])


def test_v1_8_0_callback_on_pretrain_routine_start_end(tmpdir):
class TestCallback(Callback):
def on_pretrain_routine_start(self, trainer, pl_module):
print("on_pretrain_routine_start called.")

model = BoringModel()

trainer = Trainer(
callbacks=[TestCallback()],
fast_dev_run=True,
enable_progress_bar=False,
default_root_dir=tmpdir,
)
with pytest.deprecated_call(
match="The `Callback.on_pretrain_routine_start` hook has been deprecated in v1.6" " and will be removed in v1.8"
):
trainer.fit(model)

class TestCallback(Callback):
def on_pretrain_routine_end(self, trainer, pl_module):
print("on_pretrain_routine_end called.")

model = BoringModel()

trainer = Trainer(
callbacks=[TestCallback()],
fast_dev_run=True,
enable_progress_bar=False,
default_root_dir=tmpdir,
)
with pytest.deprecated_call(
match="The `Callback.on_pretrain_routine_end` hook has been deprecated in v1.6" " and will be removed in v1.8"
):
trainer.fit(model)
7 changes: 5 additions & 2 deletions tests/models/test_restore.py
Expand Up @@ -325,12 +325,15 @@ def get_trainer_args():

# initial training
trainer = Trainer(**get_trainer_args())
trainer.fit(model, datamodule=dm)
with pytest.deprecated_call(match="`Callback.on_pretrain_routine_end` hook has been deprecated in v1.6"):
trainer.fit(model, datamodule=dm)

callbacks_before_resume = deepcopy(trainer.callbacks)

# resumed training
trainer = Trainer(**get_trainer_args())
trainer.fit(model, datamodule=dm, ckpt_path=str(tmpdir / "last.ckpt"))
with pytest.deprecated_call(match="`Callback.on_pretrain_routine_end` hook has been deprecated in v1.6"):
trainer.fit(model, datamodule=dm, ckpt_path=str(tmpdir / "last.ckpt"))

assert len(callbacks_before_resume) == len(callback_capture.callbacks)

Expand Down

0 comments on commit 29d5afb

Please sign in to comment.