There are cases when you might want to do something different at different parts of the training/validation loop. To enable a hook, simply override the method in your LightningModule and the trainer will call it at the correct time.
Contributing If there's a hook you'd like to add, simply:
- Fork PyTorchLightning.
- Add the hook to :class:`pytorch_lightning.core.hooks.ModelHooks`.
- Add it in the correct place in :mod:`pytorch_lightning.trainer` where it should be called.
- :meth:`~pytorch_lightning.core.lightning.LightningModule.prepare_data`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.setup`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.init_ddp_connection`
- :meth:`~pytorch_lightning.trainer.optimizers.TrainerOptimizersMixin.init_optimizers`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_apex`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.train_dataloader`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.test_dataloader`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.val_dataloader`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.summarize`
- :meth:`~pytorch_lightning.trainer.training_io.TrainerIOMixin.restore_weights`
Warning
prepare_data is only called from global_rank=0. Don't assign state (self.something), use setup for that
- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_epoch_start`
- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_batch_start`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.tbptt_split_batch`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step_end` (optional)
- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_before_zero_grad`
- :meth:`~pytorch_lightning.core.hooks.ModelHooks.backward`
- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_after_backward`
optimizer.step()
- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_batch_end`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.training_epoch_end`
- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_epoch_end`
model.zero_grad()
model.eval()
torch.set_grad_enabled(False)
- :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step_end`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`
model.train()
torch.set_grad_enabled(True)
- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_post_performance_check`
model.zero_grad()
model.eval()
torch.set_grad_enabled(False)
- :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step_end`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.test_epoch_end`
model.train()
torch.set_grad_enabled(True)
- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_post_performance_check`
.. automodule:: pytorch_lightning.core.hooks :noindex: