Skip to content

Commit

Permalink
Add missing val/test hooks in LightningModule (#5467)
Browse files Browse the repository at this point in the history
* add missing val/test hooks

* chlog

* None

Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>
  • Loading branch information
2 people authored and Borda committed Jan 13, 2021
1 parent f989558 commit 00de705
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 15 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -47,6 +47,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `IoU` class interface ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704))


- Added missing val/test hooks in `LightningModule` ([#5467](https://github.com/PyTorchLightning/pytorch-lightning/pull/5467))


### Changed

- Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))
Expand Down
41 changes: 32 additions & 9 deletions pytorch_lightning/core/hooks.py
Expand Up @@ -17,14 +17,15 @@
from typing import Any, Dict, List, Optional, Union

import torch
from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn


class ModelHooks:
"""Hooks to be used in LightningModule."""
def setup(self, stage: str):
def setup(self, stage: str) -> None:
"""
Called at the beginning of fit and test.
This is a good hook when you need to build models dynamically or adjust something about them.
Expand Down Expand Up @@ -52,29 +53,29 @@ def setup(stage):
"""

def teardown(self, stage: str):
def teardown(self, stage: str) -> None:
"""
Called at the end of fit and test.
Args:
stage: either 'fit' or 'test'
"""

def on_fit_start(self):
def on_fit_start(self) -> None:
"""
Called at the very beginning of fit.
If on DDP it is called on every process
"""

def on_fit_end(self):
def on_fit_end(self) -> None:
"""
Called at the very end of fit.
If on DDP it is called on every process
"""

def on_train_start(self) -> None:
"""
Called at the beginning of training before sanity check.
Called at the beginning of training after sanity check.
"""
# do something at the start of training

Expand All @@ -84,6 +85,18 @@ def on_train_end(self) -> None:
"""
# do something at the end of training

def on_validation_start(self) -> None:
"""
Called at the beginning of validation.
"""
# do something at the start of validation

def on_validation_end(self) -> None:
"""
Called at the end of validation.
"""
# do something at the end of validation

def on_pretrain_routine_start(self) -> None:
"""
Called at the beginning of the pretrain routine (between fit and train start).
Expand All @@ -108,9 +121,7 @@ def on_pretrain_routine_end(self) -> None:
"""
# do something at the end of the pretrain routine

def on_train_batch_start(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
"""
Called in the training loop before anything happens for that batch.
Expand Down Expand Up @@ -253,6 +264,18 @@ def on_test_epoch_end(self) -> None:
"""
# do something when the epoch ends

def on_test_start(self) -> None:
"""
Called at the beginning of testing.
"""
# do something at the start of testing

def on_test_end(self) -> None:
"""
Called at the end of testing.
"""
# do something at the end of testing

def on_before_zero_grad(self, optimizer: Optimizer) -> None:
"""
Called after optimizer.step() and before optimizer.zero_grad().
Expand Down
25 changes: 19 additions & 6 deletions tests/models/test_hooks.py
Expand Up @@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from unittest.mock import MagicMock

import pytest
import torch
from unittest.mock import MagicMock

from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator
from pytorch_lightning.trainer.states import TrainerState
from tests.base import EvalModelTemplate, BoringModel
from tests.base import BoringModel, EvalModelTemplate


@pytest.mark.parametrize('max_steps', [1, 2, 3])
Expand Down Expand Up @@ -254,10 +254,6 @@ def on_test_start(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_start()

def on_test_end(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_end()

def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_batch_start(batch, batch_idx, dataloader_idx)
Expand Down Expand Up @@ -290,6 +286,14 @@ def on_test_model_train(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_model_train()

def on_test_end(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_end()

def teardown(self, stage: str):
self.called.append(inspect.currentframe().f_code.co_name)
super().teardown(stage)

model = HookedModel()

assert model.called == []
Expand All @@ -313,10 +317,12 @@ def on_test_model_train(self):
'on_pretrain_routine_start',
'on_pretrain_routine_end',
'on_validation_model_eval',
'on_validation_start',
'on_validation_epoch_start',
'on_validation_batch_start',
'on_validation_batch_end',
'on_validation_epoch_end',
'on_validation_end',
'on_validation_model_train',
'on_train_start',
'on_epoch_start',
Expand All @@ -330,16 +336,19 @@ def on_test_model_train(self):
'on_before_zero_grad',
'on_train_batch_end',
'on_validation_model_eval',
'on_validation_start',
'on_validation_epoch_start',
'on_validation_batch_start',
'on_validation_batch_end',
'on_validation_epoch_end',
'on_save_checkpoint',
'on_validation_end',
'on_validation_model_train',
'on_epoch_end',
'on_train_epoch_end',
'on_train_end',
'on_fit_end',
'teardown',
]

assert model.called == expected
Expand All @@ -352,12 +361,16 @@ def on_test_model_train(self):
'on_pretrain_routine_start',
'on_pretrain_routine_end',
'on_test_model_eval',
'on_test_start',
'on_test_epoch_start',
'on_test_batch_start',
'on_test_batch_end',
'on_test_epoch_end',
'on_test_end',
'on_test_model_train',
'on_fit_end',
'teardown', # for 'fit'
'teardown', # for 'test'
]

assert model2.called == expected

0 comments on commit 00de705

Please sign in to comment.