Skip to content

Latest commit

 

History

History
99 lines (66 loc) · 4.31 KB

early_stopping.rst

File metadata and controls

99 lines (66 loc) · 4.31 KB
.. testsetup:: *

    from pytorch_lightning.callbacks.early_stopping import EarlyStopping

Early Stopping

Stopping an Epoch Early

You can stop and skip the rest of the current epoch early by overriding :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_train_batch_start` to return -1 when some condition is met.

If you do this repeatedly, for every epoch you had originally requested, then this will stop your entire training.

EarlyStopping Callback

The :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` callback can be used to monitor a metric and stop the training when no improvement is observed.

To enable it:

from pytorch_lightning.callbacks.early_stopping import EarlyStopping


class LitModel(LightningModule):
    def validation_step(self, batch, batch_idx):
        loss = ...
        self.log("val_loss", loss)


model = LitModel()
trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min")])
trainer.fit(model)

You can customize the callbacks behaviour by changing its parameters.

.. testcode::

    early_stop_callback = EarlyStopping(monitor="val_accuracy", min_delta=0.00, patience=3, verbose=False, mode="max")
    trainer = Trainer(callbacks=[early_stop_callback])


Additional parameters that stop training at extreme points:

  • stopping_threshold: Stops training immediately once the monitored quantity reaches this threshold. It is useful when we know that going beyond a certain optimal value does not further benefit us.
  • divergence_threshold: Stops training as soon as the monitored quantity becomes worse than this threshold. When reaching a value this bad, we believes the model cannot recover anymore and it is better to stop early and run with different initial conditions.
  • check_finite: When turned on, it stops training if the monitored metric becomes NaN or infinite.
  • check_on_train_epoch_end: When turned on, it checks the metric at the end of a training epoch. Use this only when you are monitoring any metric logged within training-specific hooks on epoch-level.

In case you need early stopping in a different part of training, subclass :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` and change where it is called:

.. testcode::

    class MyEarlyStopping(EarlyStopping):
        def on_validation_end(self, trainer, pl_module):
            # override this to disable early stopping at the end of val loop
            pass

        def on_train_end(self, trainer, pl_module):
            # instead, do it at the end of training loop
            self._run_early_stopping_check(trainer)

Note

The :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` callback runs at the end of every validation epoch by default. However, the frequency of validation can be modified by setting various parameters in the :class:`~pytorch_lightning.trainer.trainer.Trainer`, for example :paramref:`~pytorch_lightning.trainer.trainer.Trainer.check_val_every_n_epoch` and :paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval`. It must be noted that the patience parameter counts the number of validation checks with no improvement, and not the number of training epochs. Therefore, with parameters check_val_every_n_epoch=10 and patience=3, the trainer will perform at least 40 training epochs before being stopped.