Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/lightning/pytorch/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,11 @@ def on_train_batch_end(
if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return

if outputs is None:
# _AutomaticOptimization.run turns None STEP_OUTPUT into an empty dict
if not outputs:
# need to add an element, because we also added one element to lrs in on_train_batch_start
# so add nan, because they are not considered when computing the suggestion
self.losses.append(float("nan"))
return

if self.progress_bar:
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def lr_find(
max_lr: float = 1,
num_training: int = 100,
mode: str = "exponential",
early_stop_threshold: float = 4.0,
early_stop_threshold: Optional[float] = 4.0,
update_attr: bool = True,
attr_name: str = "",
) -> Optional["pl.tuner.lr_finder._LRFinder"]:
Expand Down
37 changes: 36 additions & 1 deletion tests/tests_pytorch/tuner/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import os
from copy import deepcopy
from typing import Any
from unittest import mock

import pytest
Expand All @@ -26,6 +28,7 @@
from lightning.pytorch.tuner.lr_finder import _LRFinder
from lightning.pytorch.tuner.tuning import Tuner
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.types import STEP_OUTPUT
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.helpers.simple_models import ClassificationModel
Expand Down Expand Up @@ -229,7 +232,7 @@ def __init__(self):
lr_finder = tuner.lr_find(model, early_stop_threshold=None)

assert lr_finder.suggestion() != 1e-3
assert len(lr_finder.results["lr"]) == 100
assert len(lr_finder.results["lr"]) == len(lr_finder.results["loss"]) == 100
assert lr_finder._total_batch_idx == 199


Expand Down Expand Up @@ -503,3 +506,35 @@ def configure_optimizers(self):

assert trainer.num_val_batches[0] == len(trainer.val_dataloaders)
assert trainer.num_val_batches[0] != num_lr_tuner_training_steps


def test_lr_finder_training_step_none_output(tmpdir):
# add some nans into the skipped steps (first 10) but also into the steps used to compute the lr
none_steps = [5, 12, 17]

class CustomBoringModel(BoringModel):
def __init__(self):
super().__init__()
self.lr = 0.123

def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
if self.trainer.global_step in none_steps:
return None

return super().training_step(batch, batch_idx)

seed_everything(1)
model = CustomBoringModel()

trainer = Trainer(default_root_dir=tmpdir)

tuner = Tuner(trainer)
# restrict number of steps for faster test execution
# and disable early stopping to easily check expected number of lrs and losses
lr_finder = tuner.lr_find(model=model, update_attr=True, num_training=20, early_stop_threshold=None)
assert len(lr_finder.results["lr"]) == len(lr_finder.results["loss"]) == 20
assert torch.isnan(torch.tensor(lr_finder.results["loss"])[none_steps]).all()

suggested_lr = lr_finder.suggestion()
assert math.isfinite(suggested_lr)
assert math.isclose(model.lr, suggested_lr)