Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate callback hooks on_pretrain_routine_start and on_pretrain_routine_end #11794

Merged
merged 43 commits into from Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
32d7c2b
init commit
krishnakalyan3 Feb 7, 2022
5d33ba2
feedback based changes
krishnakalyan3 Feb 7, 2022
1142c92
adress PR comments
krishnakalyan3 Feb 9, 2022
d3d743e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 9, 2022
9014b11
trainer changes
krishnakalyan3 Feb 9, 2022
e6d28f5
Merge branch 'deprecate-pretrain_routine' of github.com:krishnakalyan…
krishnakalyan3 Feb 9, 2022
f0282db
update changelog
krishnakalyan3 Feb 9, 2022
efa6ef6
validations
krishnakalyan3 Feb 9, 2022
ad1751b
alternative hooks update
krishnakalyan3 Feb 9, 2022
c1c2eba
revert changes
krishnakalyan3 Feb 9, 2022
3e480b7
revert hooks
krishnakalyan3 Feb 9, 2022
f6c2f05
revert hook
krishnakalyan3 Feb 9, 2022
8779cc5
revert changes for hooks
krishnakalyan3 Feb 9, 2022
e7f3ad5
remove from logging
krishnakalyan3 Feb 9, 2022
ca74752
remove comments
krishnakalyan3 Feb 9, 2022
72c5f2a
init commit
krishnakalyan3 Feb 7, 2022
9d8ae14
feedback based changes
krishnakalyan3 Feb 7, 2022
53887ee
adress PR comments
krishnakalyan3 Feb 9, 2022
fe5efa2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 9, 2022
e2172fc
trainer changes
krishnakalyan3 Feb 9, 2022
71b14aa
rebase and commit
krishnakalyan3 Feb 10, 2022
437710b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 10, 2022
53535f8
rebased validator
krishnakalyan3 Feb 10, 2022
7e9451b
Merge branch 'deprecate-pretrain_routine' of github.com:krishnakalyan…
krishnakalyan3 Feb 10, 2022
e741632
rebase again
krishnakalyan3 Feb 10, 2022
0311369
fix ci error by importing optional
krishnakalyan3 Feb 15, 2022
d779a25
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 15, 2022
47b3e82
remove bc breaking changes
krishnakalyan3 Feb 15, 2022
1be6952
Merge branch 'deprecate-pretrain_routine' of github.com:krishnakalyan…
krishnakalyan3 Feb 15, 2022
216ac95
changes according to suggestions
krishnakalyan3 Feb 15, 2022
1b2763a
update unit tests
krishnakalyan3 Feb 16, 2022
a163793
unit tests updated
krishnakalyan3 Feb 16, 2022
73e6c68
update test restore
krishnakalyan3 Feb 16, 2022
b7f5252
remove to fix unit test
krishnakalyan3 Feb 16, 2022
f843d87
Merge branch 'master' of github.com:krishnakalyan3/pytorch-lightning
krishnakalyan3 Feb 19, 2022
3c40f9c
update merge
krishnakalyan3 Feb 19, 2022
fe14dec
fix the deprecations
rohitgr7 Feb 24, 2022
86574e2
fix the deprecations
rohitgr7 Feb 24, 2022
88dc0ab
fix the deprecations
rohitgr7 Feb 24, 2022
92586bc
add deprecation test
rohitgr7 Feb 24, 2022
1175c8a
Merge branch 'master' into deprecate-pretrain_routine
rohitgr7 Feb 24, 2022
75da9fa
Apply suggestions from code review
rohitgr7 Feb 24, 2022
41f5489
Update pytorch_lightning/callbacks/base.py
rohitgr7 Feb 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -415,6 +415,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `pytorch_lightning.utilities.warnings.LightningDeprecationWarning` in favor of `pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning`


- Deprecated `on_pretrain_routine_start` and `on_pretrain_routine_end` callback hooks in favor of `on_fit_start` ([#11794](https://github.com/PyTorchLightning/pytorch-lightning/pull/11794))


- Deprecated `agg_key_funcs` and `agg_default_func` parameters from `LightningLoggerBase` ([#11871](https://github.com/PyTorchLightning/pytorch-lightning/pull/11871))


Expand Down
12 changes: 0 additions & 12 deletions docs/source/extensions/callbacks.rst
Expand Up @@ -375,18 +375,6 @@ on_train_end
.. automethod:: pytorch_lightning.callbacks.Callback.on_train_end
:noindex:

on_pretrain_routine_start
~~~~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.callbacks.Callback.on_pretrain_routine_start
:noindex:

on_pretrain_routine_end
~~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.callbacks.Callback.on_pretrain_routine_end
:noindex:

on_validation_start
~~~~~~~~~~~~~~~~~~~

Expand Down
16 changes: 14 additions & 2 deletions pytorch_lightning/callbacks/base.py
Expand Up @@ -248,10 +248,22 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
"""Called when the train ends."""

def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the pretrain routine begins."""
r"""
.. deprecated:: v1.6
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
``on_fit_start`` instead.
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

Called when the pretrain routine begins.
"""

def on_pretrain_routine_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the pretrain routine ends."""
r"""
.. deprecated:: v1.6
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
``on_fit_start`` instead.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@daniellepintz On fit end is not correct. That happens at the very end of the call

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry, I was confused.

rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

Called when the pretrain routine ends.
"""

def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the validation loop begins."""
Expand Down
10 changes: 4 additions & 6 deletions pytorch_lightning/callbacks/model_checkpoint.py
Expand Up @@ -248,19 +248,17 @@ def state_key(self) -> str:
)

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
self.__resolve_ckpt_dir(trainer)
if trainer.is_global_zero and stage == "fit":
self.__warn_if_dir_not_empty(self.dirpath)

# NOTE: setting these attributes needs to happen as early as possible BEFORE reloading callback states,
# because the attributes are part of the state_key which needs to be fully defined before reloading.
if self._save_on_train_epoch_end is None:
# if the user runs validation multiple times per training epoch or multiple training epochs without
# validation, then we run after validation instead of on train epoch end
self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1

def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""When pretrain routine starts we build the ckpt dir on the fly."""
self.__resolve_ckpt_dir(trainer)
if trainer.is_global_zero:
self.__warn_if_dir_not_empty(self.dirpath)

def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._last_time_checked = time.monotonic()

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/model_summary.py
Expand Up @@ -49,7 +49,7 @@ class ModelSummary(Callback):
def __init__(self, max_depth: int = 1) -> None:
self._max_depth: int = max_depth

def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self._max_depth:
return None

Expand Down
16 changes: 8 additions & 8 deletions pytorch_lightning/core/hooks.py
Expand Up @@ -182,19 +182,19 @@ def on_predict_model_eval(self) -> None:
self.trainer.model.eval()

def on_epoch_start(self) -> None:
r"""
.. deprecated:: v1.6 This hook was deprecated in v1.6 and will be removed in v1.8. Use
``on_<train/validation/test>_epoch_start`` instead.
"""Called when either of train/val/test epoch begins.

Called when either of train/val/test epoch begins.
.. deprecated:: v1.6
:meth:`on_epoch_start` has been deprecated in v1.6 and will be removed in v1.8.
Use ``on_<train/validation/test>_epoch_start`` instead.
"""

def on_epoch_end(self) -> None:
r"""
.. deprecated:: v1.6 This hook was deprecated in v1.6 and will be removed in v1.8. Use
``on_<train/validation/test>_epoch_end`` instead.
"""Called when either of train/val/test epoch ends.

Called when either of train/val/test epoch ends.
.. deprecated:: v1.6
:meth:`on_epoch_end` has been deprecated in v1.6 and will be removed in v1.8.
Use ``on_<train/validation/test>_epoch_end`` instead.
"""

def on_train_epoch_start(self) -> None:
Expand Down
7 changes: 6 additions & 1 deletion pytorch_lightning/trainer/configuration_validator.py
Expand Up @@ -348,7 +348,6 @@ def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None:
f"The `Callback.{hook}` hook was deprecated in v1.6 and"
f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead."
)

for hook, alternative_hook in (
["on_epoch_start", "on_<train/validation/test>_epoch_start"],
["on_epoch_end", "on_<train/validation/test>_epoch_end"],
Expand All @@ -358,3 +357,9 @@ def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None:
f"The `Callback.{hook}` hook was deprecated in v1.6 and"
f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead."
)
for hook in ("on_pretrain_routine_start", "on_pretrain_routine_end"):
if is_overridden(method_name=hook, instance=callback):
rank_zero_deprecation(
f"The `Callback.{hook}` hook has been deprecated in v1.6 and"
f" will be removed in v1.8. Please use `Callback.on_fit_start` instead."
)
38 changes: 37 additions & 1 deletion tests/deprecated_api/test_remove_1-8.py
Expand Up @@ -249,7 +249,7 @@ def test_v1_8_0_deprecate_trainer_callback_hook_mixin():
)
model = BoringModel()
# need to attach model to trainer for testing of `on_pretrain_routine_start`
trainer.fit(model)
trainer.strategy.connect(model)
for method_name in methods_with_self:
fn = getattr(trainer, method_name, None)
with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"):
Expand Down Expand Up @@ -572,3 +572,39 @@ def agg_and_log_metrics(self, metrics, step):
Trainer(logger=[logger, logger3])
# Should have no deprecation warning
Trainer(logger=[logger2, logger3])


def test_v1_8_0_callback_on_pretrain_routune(tmpdir):
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
class TestCallback(Callback):
def on_pretrain_routine_start(self, trainer, pl_module):
print("on_pretrain_routune_start called.")
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

model = BoringModel()

trainer = Trainer(
callbacks=[TestCallback()],
fast_dev_run=True,
enable_progress_bar=False,
default_root_dir=tmpdir,
)
with pytest.deprecated_call(
match="The `Callback.on_pretrain_routine_start` hook has been deprecated in v1.6" " and will be removed in v1.8"
):
trainer.fit(model)

class TestCallback(Callback):
def on_pretrain_routine_end(self, trainer, pl_module):
print("on_pretrain_routune_end called.")
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

model = BoringModel()

trainer = Trainer(
callbacks=[TestCallback()],
fast_dev_run=True,
enable_progress_bar=False,
default_root_dir=tmpdir,
)
with pytest.deprecated_call(
match="The `Callback.on_pretrain_routine_end` hook has been deprecated in v1.6" " and will be removed in v1.8"
):
trainer.fit(model)
7 changes: 5 additions & 2 deletions tests/models/test_restore.py
Expand Up @@ -325,12 +325,15 @@ def get_trainer_args():

# initial training
trainer = Trainer(**get_trainer_args())
trainer.fit(model, datamodule=dm)
with pytest.deprecated_call(match="`Callback.on_pretrain_routine_end` hook has been deprecated in v1.6"):
trainer.fit(model, datamodule=dm)

callbacks_before_resume = deepcopy(trainer.callbacks)

# resumed training
trainer = Trainer(**get_trainer_args())
trainer.fit(model, datamodule=dm, ckpt_path=str(tmpdir / "last.ckpt"))
with pytest.deprecated_call(match="`Callback.on_pretrain_routine_end` hook has been deprecated in v1.6"):
trainer.fit(model, datamodule=dm, ckpt_path=str(tmpdir / "last.ckpt"))

assert len(callbacks_before_resume) == len(callback_capture.callbacks)

Expand Down