diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py index ed326b4c8f384..f83c390d271c3 100644 --- a/src/lightning/pytorch/tuner/lr_finder.py +++ b/src/lightning/pytorch/tuner/lr_finder.py @@ -393,7 +393,11 @@ def on_train_batch_end( if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0: return - if outputs is None: + # _AutomaticOptimization.run turns None STEP_OUTPUT into an empty dict + if not outputs: + # need to add an element, because we also added one element to lrs in on_train_batch_start + # so add nan, because they are not considered when computing the suggestion + self.losses.append(float("nan")) return if self.progress_bar: diff --git a/src/lightning/pytorch/tuner/tuning.py b/src/lightning/pytorch/tuner/tuning.py index 64065b9576faa..53b7b45210ef9 100644 --- a/src/lightning/pytorch/tuner/tuning.py +++ b/src/lightning/pytorch/tuner/tuning.py @@ -113,7 +113,7 @@ def lr_find( max_lr: float = 1, num_training: int = 100, mode: str = "exponential", - early_stop_threshold: float = 4.0, + early_stop_threshold: Optional[float] = 4.0, update_attr: bool = True, attr_name: str = "", ) -> Optional["pl.tuner.lr_finder._LRFinder"]: diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index bc8e529def2cb..87a7412396b91 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import math import os from copy import deepcopy +from typing import Any from unittest import mock import pytest @@ -26,6 +28,7 @@ from lightning.pytorch.tuner.lr_finder import _LRFinder from lightning.pytorch.tuner.tuning import Tuner from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.types import STEP_OUTPUT from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel @@ -229,7 +232,7 @@ def __init__(self): lr_finder = tuner.lr_find(model, early_stop_threshold=None) assert lr_finder.suggestion() != 1e-3 - assert len(lr_finder.results["lr"]) == 100 + assert len(lr_finder.results["lr"]) == len(lr_finder.results["loss"]) == 100 assert lr_finder._total_batch_idx == 199 @@ -503,3 +506,35 @@ def configure_optimizers(self): assert trainer.num_val_batches[0] == len(trainer.val_dataloaders) assert trainer.num_val_batches[0] != num_lr_tuner_training_steps + + +def test_lr_finder_training_step_none_output(tmpdir): + # add some nans into the skipped steps (first 10) but also into the steps used to compute the lr + none_steps = [5, 12, 17] + + class CustomBoringModel(BoringModel): + def __init__(self): + super().__init__() + self.lr = 0.123 + + def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: + if self.trainer.global_step in none_steps: + return None + + return super().training_step(batch, batch_idx) + + seed_everything(1) + model = CustomBoringModel() + + trainer = Trainer(default_root_dir=tmpdir) + + tuner = Tuner(trainer) + # restrict number of steps for faster test execution + # and disable early stopping to easily check expected number of lrs and losses + lr_finder = tuner.lr_find(model=model, update_attr=True, num_training=20, early_stop_threshold=None) + assert len(lr_finder.results["lr"]) == len(lr_finder.results["loss"]) == 20 + assert torch.isnan(torch.tensor(lr_finder.results["loss"])[none_steps]).all() + + suggested_lr = lr_finder.suggestion() + assert math.isfinite(suggested_lr) + assert math.isclose(model.lr, suggested_lr)