diff --git a/pytorch_lightning/utilities/model_utils.py b/pytorch_lightning/utilities/model_utils.py index 876f546312902..993d9e11e1491 100644 --- a/pytorch_lightning/utilities/model_utils.py +++ b/pytorch_lightning/utilities/model_utils.py @@ -23,9 +23,7 @@ def is_overridden(method_name: str, model: Union[LightningModule, LightningDataM # TODO - refector this function to accept model_name, instance, parent so it makes more sense super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule - # assert model, 'no model passes' - - if not hasattr(model, method_name): + if not hasattr(model, method_name) or not hasattr(super_object, method_name): # in case of calling deprecated method return False diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 66a1f6ac783f2..886e0db4e7854 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -11,14 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock +import inspect import pytest import torch +from unittest.mock import MagicMock from pytorch_lightning import Trainer from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator -from tests.base import EvalModelTemplate +from tests.base import EvalModelTemplate, BoringModel @pytest.mark.parametrize('max_steps', [1, 2, 3]) @@ -142,3 +143,216 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx): else: assert trainer.batch_idx == batch_idx_ assert trainer.global_step == (batch_idx_ + 1) * max_epochs + + +def test_trainer_model_hook_system(tmpdir): + """Test the hooks system.""" + + class HookedModel(BoringModel): + def __init__(self): + super().__init__() + self.called = [] + + def on_after_backward(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_after_backward() + + def on_before_zero_grad(self, optimizer): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_before_zero_grad(optimizer) + + def on_epoch_start(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_epoch_start() + + def on_epoch_end(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_epoch_end() + + def on_fit_start(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_fit_start() + + def on_fit_end(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_fit_end() + + def on_hpc_load(self, checkpoint): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_hpc_load(checkpoint) + + def on_hpc_save(self, checkpoint): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_hpc_save(checkpoint) + + def on_load_checkpoint(self, checkpoint): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_load_checkpoint(checkpoint) + + def on_save_checkpoint(self, checkpoint): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_save_checkpoint(checkpoint) + + def on_pretrain_routine_start(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_pretrain_routine_start() + + def on_pretrain_routine_end(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_pretrain_routine_end() + + def on_train_start(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_train_start() + + def on_train_end(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_train_end() + + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_train_batch_start(batch, batch_idx, dataloader_idx) + + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx) + + def on_train_epoch_start(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_train_epoch_start() + + def on_train_epoch_end(self, outputs): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_train_epoch_end(outputs) + + def on_validation_start(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_validation_start() + + def on_validation_end(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_validation_end() + + def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_validation_batch_start(batch, batch_idx, dataloader_idx) + + def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx) + + def on_validation_epoch_start(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_validation_epoch_start() + + def on_validation_epoch_end(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_validation_epoch_end() + + def on_test_start(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_test_start() + + def on_test_end(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_test_end() + + def on_test_batch_start(self, batch, batch_idx, dataloader_idx): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_test_batch_start(batch, batch_idx, dataloader_idx) + + def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_test_batch_end(outputs, batch, batch_idx, dataloader_idx) + + def on_test_epoch_start(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_test_epoch_start() + + def on_test_epoch_end(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_test_epoch_end() + + def on_validation_model_eval(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_validation_model_eval() + + def on_validation_model_train(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_validation_model_train() + + def on_test_model_eval(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_test_model_eval() + + def on_test_model_train(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_test_model_train() + + model = HookedModel() + + assert model.called == [] + + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=1, + limit_train_batches=2, + limit_test_batches=1, + progress_bar_refresh_rate=0, + ) + + assert model.called == [] + + trainer.fit(model) + + assert model.called == [ + 'on_fit_start', + 'on_pretrain_routine_start', + 'on_pretrain_routine_end', + 'on_validation_model_eval', + 'on_validation_epoch_start', + 'on_validation_batch_start', + 'on_validation_batch_end', + 'on_validation_epoch_end', + 'on_validation_model_train', + 'on_train_start', + 'on_epoch_start', + 'on_train_epoch_start', + 'on_train_batch_start', + 'on_after_backward', + 'on_before_zero_grad', + 'on_train_batch_end', + 'on_train_batch_start', + 'on_after_backward', + 'on_before_zero_grad', + 'on_train_batch_end', + 'on_validation_model_eval', + 'on_validation_epoch_start', + 'on_validation_batch_start', + 'on_validation_batch_end', + 'on_validation_epoch_end', + 'on_validation_model_train', + 'on_save_checkpoint', + 'on_epoch_end', + 'on_train_epoch_end', + 'on_train_end', + 'on_fit_end', + ] + + model2 = HookedModel() + trainer.test(model2) + + assert model2.called == [ + 'on_fit_start', + 'on_pretrain_routine_start', + 'on_pretrain_routine_end', + 'on_test_model_eval', + 'on_test_epoch_start', + 'on_test_batch_start', + 'on_test_batch_end', + 'on_test_epoch_end', + 'on_test_model_train', + 'on_fit_end', + ]