-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
Milestone
Description
Bug description
After the upgrade to 1.8.x I noticed that some of my runs started to behave strange. I noticed that this happened if the EarlyStopper kicked in, before min_epochs was reached: somehow the time for each epoch increases significantly, and the progressbar of the training step vanishes.
Here you see a figure from one of these runs (wall time (x axis) vs. epochs (y axis)) . You can clearly see where the EarlyStopping kicked in at epoch 59, since after that the time for each epoch started to increase drastically.
The example below provokes the same strange behaviour as soon as the dummy_metric reaches the threshold before the min_epochs are reached.
How to reproduce the bug
import os
import torch
from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
self.dummy_metric = 0
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
self.dummy_metric += 1
self.log("dummy_metric", self.dummy_metric)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
test_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
early_stopper = EarlyStopping(
mode="min",
min_delta=10,
monitor="dummy_metric",
patience=1,
divergence_threshold=10,
verbose=True,
strict=True,)
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_val_batches=1,
check_val_every_n_epoch=1,
limit_test_batches=1,
log_every_n_steps=1,
num_sanity_val_steps=0,
max_epochs=1000,
min_epochs=100,
enable_model_summary=False,
callbacks=[early_stopper],
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
trainer.test(model, dataloaders=test_data)
if __name__ == "__main__":
run()Error messages and logs
No response
Environment
* CUDA:
- GPU:
- NVIDIA GeForce RTX 2080 Ti
- NVIDIA GeForce RTX 2070 SUPER
- available: True
- version: 11.7
* Packages:
- numpy: 1.23.4
- pyTorch_debug: False
- pyTorch_version: 1.13.0+cu117
- pytorch-lightning: 1.8.1
- tqdm: 4.64.1
* System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.9.12
- version: #58-Ubuntu SMP Thu Oct 13 08:03:55 UTC 2022
More info
No response
