Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing val/test hooks in LightningModule #5467

Merged
merged 8 commits into from Jan 13, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -47,6 +47,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `IoU` class interface ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704))


- Added missing val/test hooks in `LightningModule` ([#5467](https://github.com/PyTorchLightning/pytorch-lightning/pull/5467))


### Changed

- Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))
Expand Down
41 changes: 32 additions & 9 deletions pytorch_lightning/core/hooks.py
Expand Up @@ -17,14 +17,15 @@
from typing import Any, Dict, List, Optional, Union

import torch
from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn


class ModelHooks:
"""Hooks to be used in LightningModule."""
def setup(self, stage: str):
def setup(self, stage: str) -> None:
"""
Called at the beginning of fit and test.
This is a good hook when you need to build models dynamically or adjust something about them.
Expand Down Expand Up @@ -52,29 +53,29 @@ def setup(stage):

"""

def teardown(self, stage: str):
def teardown(self, stage: str) -> None:
"""
Called at the end of fit and test.

Args:
stage: either 'fit' or 'test'
"""

def on_fit_start(self):
def on_fit_start(self) -> None:
"""
Called at the very beginning of fit.
If on DDP it is called on every process
"""

def on_fit_end(self):
def on_fit_end(self) -> None:
"""
Called at the very end of fit.
If on DDP it is called on every process
"""

def on_train_start(self) -> None:
"""
Called at the beginning of training before sanity check.
Called at the beginning of training after sanity check.
"""
# do something at the start of training

Expand All @@ -84,6 +85,18 @@ def on_train_end(self) -> None:
"""
# do something at the end of training

def on_validation_start(self) -> None:
"""
Called at the beginning of validation.
"""
# do something at the start of validation

def on_validation_end(self) -> None:
"""
Called at the end of validation.
"""
# do something at the end of validation

def on_pretrain_routine_start(self) -> None:
"""
Called at the beginning of the pretrain routine (between fit and train start).
Expand All @@ -108,9 +121,7 @@ def on_pretrain_routine_end(self) -> None:
"""
# do something at the end of the pretrain routine

def on_train_batch_start(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
"""
Called in the training loop before anything happens for that batch.

Expand Down Expand Up @@ -253,6 +264,18 @@ def on_test_epoch_end(self) -> None:
"""
# do something when the epoch ends

def on_test_start(self) -> None:
"""
Called at the beginning of testing.
"""
# do something at the start of testing

def on_test_end(self) -> None:
"""
Called at the end of testing.
"""
# do something at the end of testing

def on_before_zero_grad(self, optimizer: Optimizer) -> None:
"""
Called after optimizer.step() and before optimizer.zero_grad().
Expand Down
25 changes: 19 additions & 6 deletions tests/models/test_hooks.py
Expand Up @@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from unittest.mock import MagicMock

import pytest
import torch
from unittest.mock import MagicMock

from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator
from pytorch_lightning.trainer.states import TrainerState
from tests.base import EvalModelTemplate, BoringModel
from tests.base import BoringModel, EvalModelTemplate


@pytest.mark.parametrize('max_steps', [1, 2, 3])
Expand Down Expand Up @@ -254,10 +254,6 @@ def on_test_start(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_start()

def on_test_end(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_end()

def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_batch_start(batch, batch_idx, dataloader_idx)
Expand Down Expand Up @@ -290,6 +286,14 @@ def on_test_model_train(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_model_train()

def on_test_end(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_end()

def teardown(self, stage: str):
self.called.append(inspect.currentframe().f_code.co_name)
super().teardown(stage)

model = HookedModel()

assert model.called == []
Expand All @@ -313,10 +317,12 @@ def on_test_model_train(self):
'on_pretrain_routine_start',
'on_pretrain_routine_end',
'on_validation_model_eval',
'on_validation_start',
'on_validation_epoch_start',
'on_validation_batch_start',
'on_validation_batch_end',
'on_validation_epoch_end',
'on_validation_end',
'on_validation_model_train',
'on_train_start',
'on_epoch_start',
Expand All @@ -330,16 +336,19 @@ def on_test_model_train(self):
'on_before_zero_grad',
'on_train_batch_end',
'on_validation_model_eval',
'on_validation_start',
'on_validation_epoch_start',
'on_validation_batch_start',
'on_validation_batch_end',
'on_validation_epoch_end',
'on_save_checkpoint',
'on_validation_end',
'on_validation_model_train',
'on_epoch_end',
'on_train_epoch_end',
'on_train_end',
'on_fit_end',
'teardown',
]

assert model.called == expected
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall this be updated the way it's done with callback using mock?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would be great (in a follow up pr). let me know if you need help

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure 馃憤

Expand All @@ -352,12 +361,16 @@ def on_test_model_train(self):
'on_pretrain_routine_start',
'on_pretrain_routine_end',
'on_test_model_eval',
'on_test_start',
'on_test_epoch_start',
'on_test_batch_start',
'on_test_batch_end',
'on_test_epoch_end',
'on_test_end',
'on_test_model_train',
'on_fit_end',
'teardown', # for 'fit'
'teardown', # for 'test'
]

assert model2.called == expected