Skip to content

Commit

Permalink
update lr_finder to check for attribute if not running fast_dev_run (#…
Browse files Browse the repository at this point in the history
…5990)

* ref lr_finder a bit

* chlog

* Apply suggestions from code review

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
4 people committed Feb 17, 2021
1 parent c0ee1f1 commit 99da0d9
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 50 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -196,6 +196,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Refactored `EpochResultStore` ([#5522](https://github.com/PyTorchLightning/pytorch-lightning/pull/5522))


- Update `lr_finder` to check for attribute if not running `fast_dev_run` ([#5990](https://github.com/PyTorchLightning/pytorch-lightning/pull/5990))


- LightningOptimizer manual optimizer is more flexible and expose `toggle_model` ([#5771](https://github.com/PyTorchLightning/pytorch-lightning/pull/5771))


Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/tuner/batch_size_scaling.py
Expand Up @@ -63,10 +63,10 @@ def scale_batch_size(
It is expected that the user has provided a model or datamodule that has a hyperparameter
with that name. We will look for this attribute name in the following places
- `model`
- `model.hparams`
- `model.datamodule`
- `trainer.datamodule` (the datamodule passed to the tune method)
- ``model``
- ``model.hparams``
- ``model.datamodule``
- ``trainer.datamodule`` (the datamodule passed to the tune method)
**fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader
or datamodule.
Expand Down
50 changes: 25 additions & 25 deletions pytorch_lightning/tuner/lr_finder.py
Expand Up @@ -60,21 +60,6 @@ def _determine_lr_attr_name(trainer, model: LightningModule) -> str:
)


def _run_lr_finder_internally(trainer, model: LightningModule):
""" Call lr finder internally during Trainer.fit() """
lr_attr_name = _determine_lr_attr_name(trainer, model)
lr_finder = lr_find(trainer, model)
if lr_finder is None:
return

lr = lr_finder.suggestion()

# TODO: log lr.results to self.logger
lightning_setattr(model, lr_attr_name, lr)

log.info(f'Learning rate set to {lr}')


def lr_find(
trainer,
model: LightningModule,
Expand All @@ -86,16 +71,17 @@ def lr_find(
mode: str = 'exponential',
early_stop_threshold: float = 4.0,
datamodule: Optional[LightningDataModule] = None,
update_attr: bool = False,
):
r"""
`lr_find` enables the user to do a range test of good initial learning rates,
``lr_find`` enables the user to do a range test of good initial learning rates,
to reduce the amount of guesswork in picking a good starting learning rate.
Args:
model: Model to do range testing for
train_dataloader: A PyTorch
`DataLoader` with training samples. If the model has
``DataLoader`` with training samples. If the model has
a predefined train_dataloader method, this will be skipped.
min_lr: minimum learning rate to investigate
Expand All @@ -104,19 +90,21 @@ def lr_find(
num_training: number of learning rates to test
mode: search strategy, either 'linear' or 'exponential'. If set to
'linear' the learning rate will be searched by linearly increasing
after each batch. If set to 'exponential', will increase learning
rate exponentially.
mode: Search strategy to update learning rate after each batch:
- ``'exponential'`` (default): Will increase the learning rate exponentially.
- ``'linear'``: Will increase the learning rate linearly.
early_stop_threshold: threshold for stopping the search. If the
loss at any point is larger than early_stop_threshold*best_loss
then the search is stopped. To disable, set to None.
datamodule: An optional `LightningDataModule` which holds the training
and validation dataloader(s). Note that the `train_dataloader` and
`val_dataloaders` parameters cannot be used at the same time as
this parameter, or a `MisconfigurationException` will be raised.
datamodule: An optional ``LightningDataModule`` which holds the training
and validation dataloader(s). Note that the ``train_dataloader`` and
``val_dataloaders`` parameters cannot be used at the same time as
this parameter, or a ``MisconfigurationException`` will be raised.
update_attr: Whether to update the learning rate attribute or not.
Example::
Expand Down Expand Up @@ -144,6 +132,10 @@ def lr_find(
rank_zero_warn('Skipping learning rate finder since fast_dev_run is enabled.', UserWarning)
return

# Determine lr attr
if update_attr:
lr_attr_name = _determine_lr_attr_name(trainer, model)

save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt')

__lr_finder_dump_params(trainer, model)
Expand Down Expand Up @@ -200,6 +192,14 @@ def lr_find(
if trainer.progress_bar_callback:
trainer.progress_bar_callback.enable()

# Update lr attr if required
if update_attr:
lr = lr_finder.suggestion()

# TODO: log lr.results to self.logger
lightning_setattr(model, lr_attr_name, lr)
log.info(f'Learning rate set to {lr}')

return lr_finder


Expand Down
19 changes: 9 additions & 10 deletions pytorch_lightning/tuner/tuning.py
Expand Up @@ -20,7 +20,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size
from pytorch_lightning.tuner.lr_finder import _run_lr_finder_internally, lr_find
from pytorch_lightning.tuner.lr_finder import lr_find


class Tuner:
Expand Down Expand Up @@ -53,7 +53,7 @@ def tune(self, model, train_dataloader, val_dataloaders, datamodule):

# Run learning rate finder:
if self.trainer.auto_lr_find:
self.internal_find_lr(model)
self.lr_find(model, update_attr=True)

def scale_batch_size(
self,
Expand Down Expand Up @@ -92,10 +92,10 @@ def scale_batch_size(
It is expected that the user has provided a model or datamodule that has a hyperparameter
with that name. We will look for this attribute name in the following places
- `model`
- `model.hparams`
- `model.datamodule`
- `trainer.datamodule` (the datamodule passed to the tune method)
- ``model``
- ``model.hparams``
- ``model.datamodule``
- ``trainer.datamodule`` (the datamodule passed to the tune method)
**fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader
or datamodule.
Expand All @@ -122,7 +122,8 @@ def lr_find(
num_training: int = 100,
mode: str = 'exponential',
early_stop_threshold: float = 4.0,
datamodule: Optional[LightningDataModule] = None
datamodule: Optional[LightningDataModule] = None,
update_attr: bool = False,
):
return lr_find(
self.trainer,
Expand All @@ -135,10 +136,8 @@ def lr_find(
mode,
early_stop_threshold,
datamodule,
update_attr,
)

def internal_find_lr(self, model: LightningModule):
return _run_lr_finder_internally(self.trainer, model)

def pick_multiple_gpus(self, num_gpus: int):
return pick_multiple_gpus(num_gpus)
3 changes: 2 additions & 1 deletion tests/helpers/simple_models.py
Expand Up @@ -22,8 +22,9 @@
class ClassificationModel(LightningModule):

def __init__(self, lr=0.01):
self.lr = lr
super().__init__()

self.lr = lr
for i in range(3):
setattr(self, f"layer_{i}", nn.Linear(32, 32))
setattr(self, f"layer_{i}a", torch.nn.ReLU())
Expand Down
17 changes: 7 additions & 10 deletions tests/trainer/test_lr_finder.py
Expand Up @@ -21,7 +21,8 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.helpers import BoringModel
from tests.helpers.datamodules import TrialMNISTDataModule
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.simple_models import ClassificationModel


def test_error_on_more_than_1_optimizer(tmpdir):
Expand Down Expand Up @@ -180,12 +181,10 @@ def test_datamodule_parameter(tmpdir):
""" Test that the datamodule parameter works """

# trial datamodule
dm = TrialMNISTDataModule(tmpdir)
dm = ClassifDataModule()
model = ClassificationModel()

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

before_lr = hparams.get('learning_rate')
before_lr = model.lr
# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir,
Expand All @@ -194,7 +193,7 @@ def test_datamodule_parameter(tmpdir):

lrfinder = trainer.tuner.lr_find(model, datamodule=dm)
after_lr = lrfinder.suggestion()
model.learning_rate = after_lr
model.lr = after_lr

assert before_lr != after_lr, \
'Learning rate was not altered after running learning rate finder'
Expand Down Expand Up @@ -271,8 +270,6 @@ def test_suggestion_with_non_finite_values(tmpdir):

def test_lr_finder_fails_fast_on_bad_config(tmpdir):
""" Test that tune fails if the model does not have a lr BEFORE running lr find """
# note: this did not raise an exception before #5638 because lr_find is skipped
# during fast_dev_run and the lr attribute check was done after lr_find
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, auto_lr_find=True)
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, auto_lr_find=True)
with pytest.raises(MisconfigurationException, match='should have one of these fields'):
trainer.tune(BoringModel())

0 comments on commit 99da0d9

Please sign in to comment.