Skip to content

Commit

Permalink
add test for model hooks (#4010)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Oct 20, 2020
1 parent 9edef40 commit 3777988
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 5 deletions.
4 changes: 1 addition & 3 deletions pytorch_lightning/utilities/model_utils.py
Expand Up @@ -23,9 +23,7 @@ def is_overridden(method_name: str, model: Union[LightningModule, LightningDataM
# TODO - refector this function to accept model_name, instance, parent so it makes more sense
super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule

# assert model, 'no model passes'

if not hasattr(model, method_name):
if not hasattr(model, method_name) or not hasattr(super_object, method_name):
# in case of calling deprecated method
return False

Expand Down
218 changes: 216 additions & 2 deletions tests/models/test_hooks.py
Expand Up @@ -11,14 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock
import inspect

import pytest
import torch
from unittest.mock import MagicMock

from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator
from tests.base import EvalModelTemplate
from tests.base import EvalModelTemplate, BoringModel


@pytest.mark.parametrize('max_steps', [1, 2, 3])
Expand Down Expand Up @@ -142,3 +143,216 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
else:
assert trainer.batch_idx == batch_idx_
assert trainer.global_step == (batch_idx_ + 1) * max_epochs


def test_trainer_model_hook_system(tmpdir):
"""Test the hooks system."""

class HookedModel(BoringModel):
def __init__(self):
super().__init__()
self.called = []

def on_after_backward(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_after_backward()

def on_before_zero_grad(self, optimizer):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_before_zero_grad(optimizer)

def on_epoch_start(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_epoch_start()

def on_epoch_end(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_epoch_end()

def on_fit_start(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_fit_start()

def on_fit_end(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_fit_end()

def on_hpc_load(self, checkpoint):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_hpc_load(checkpoint)

def on_hpc_save(self, checkpoint):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_hpc_save(checkpoint)

def on_load_checkpoint(self, checkpoint):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_load_checkpoint(checkpoint)

def on_save_checkpoint(self, checkpoint):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_save_checkpoint(checkpoint)

def on_pretrain_routine_start(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_pretrain_routine_start()

def on_pretrain_routine_end(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_pretrain_routine_end()

def on_train_start(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_train_start()

def on_train_end(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_train_end()

def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_train_batch_start(batch, batch_idx, dataloader_idx)

def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)

def on_train_epoch_start(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_train_epoch_start()

def on_train_epoch_end(self, outputs):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_train_epoch_end(outputs)

def on_validation_start(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_validation_start()

def on_validation_end(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_validation_end()

def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_validation_batch_start(batch, batch_idx, dataloader_idx)

def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx)

def on_validation_epoch_start(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_validation_epoch_start()

def on_validation_epoch_end(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_validation_epoch_end()

def on_test_start(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_start()

def on_test_end(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_end()

def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_batch_start(batch, batch_idx, dataloader_idx)

def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_batch_end(outputs, batch, batch_idx, dataloader_idx)

def on_test_epoch_start(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_epoch_start()

def on_test_epoch_end(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_epoch_end()

def on_validation_model_eval(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_validation_model_eval()

def on_validation_model_train(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_validation_model_train()

def on_test_model_eval(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_model_eval()

def on_test_model_train(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_model_train()

model = HookedModel()

assert model.called == []

# fit model
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_val_batches=1,
limit_train_batches=2,
limit_test_batches=1,
progress_bar_refresh_rate=0,
)

assert model.called == []

trainer.fit(model)

assert model.called == [
'on_fit_start',
'on_pretrain_routine_start',
'on_pretrain_routine_end',
'on_validation_model_eval',
'on_validation_epoch_start',
'on_validation_batch_start',
'on_validation_batch_end',
'on_validation_epoch_end',
'on_validation_model_train',
'on_train_start',
'on_epoch_start',
'on_train_epoch_start',
'on_train_batch_start',
'on_after_backward',
'on_before_zero_grad',
'on_train_batch_end',
'on_train_batch_start',
'on_after_backward',
'on_before_zero_grad',
'on_train_batch_end',
'on_validation_model_eval',
'on_validation_epoch_start',
'on_validation_batch_start',
'on_validation_batch_end',
'on_validation_epoch_end',
'on_validation_model_train',
'on_save_checkpoint',
'on_epoch_end',
'on_train_epoch_end',
'on_train_end',
'on_fit_end',
]

model2 = HookedModel()
trainer.test(model2)

assert model2.called == [
'on_fit_start',
'on_pretrain_routine_start',
'on_pretrain_routine_end',
'on_test_model_eval',
'on_test_epoch_start',
'on_test_batch_start',
'on_test_batch_end',
'on_test_epoch_end',
'on_test_model_train',
'on_fit_end',
]

0 comments on commit 3777988

Please sign in to comment.