Skip to content

EarlyStopping + min_epochs bug after upgrade to 1.8.x #15708

@cschell

Description

@cschell

Bug description

After the upgrade to 1.8.x I noticed that some of my runs started to behave strange. I noticed that this happened if the EarlyStopper kicked in, before min_epochs was reached: somehow the time for each epoch increases significantly, and the progressbar of the training step vanishes.

Here you see a figure from one of these runs (wall time (x axis) vs. epochs (y axis)) . You can clearly see where the EarlyStopping kicked in at epoch 59, since after that the time for each epoch started to increase drastically.

Bildschirm­foto 2022-11-17 um 10 28 56

The example below provokes the same strange behaviour as soon as the dummy_metric reaches the threshold before the min_epochs are reached.

How to reproduce the bug

import os

import torch
from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.dummy_metric = 0

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        self.dummy_metric += 1
        self.log("dummy_metric", self.dummy_metric)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)





def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()

    early_stopper = EarlyStopping(
      mode="min",
      min_delta=10,
      monitor="dummy_metric",
      patience=1,
      divergence_threshold=10,
      verbose=True,
      strict=True,)

    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_val_batches=1,
        check_val_every_n_epoch=1,
        limit_test_batches=1,
        log_every_n_steps=1,
        num_sanity_val_steps=0,
        max_epochs=1000,
        min_epochs=100,
        enable_model_summary=False,
        callbacks=[early_stopper],
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)


if __name__ == "__main__":
    run()

Error messages and logs

No response

Environment


* CUDA:
        - GPU:
                - NVIDIA GeForce RTX 2080 Ti
                - NVIDIA GeForce RTX 2070 SUPER
        - available:         True
        - version:           11.7
* Packages:
        - numpy:             1.23.4
        - pyTorch_debug:     False
        - pyTorch_version:   1.13.0+cu117
        - pytorch-lightning: 1.8.1
        - tqdm:              4.64.1
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.9.12
        - version:           #58-Ubuntu SMP Thu Oct 13 08:03:55 UTC 2022


More info

No response

cc @carmocca @awaelchli @justusschock @Borda

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions