Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CLI] Add support for ReduceLROnPlateau #10860

Merged
merged 4 commits into from
Dec 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719))


-
- Added support for `--lr_scheduler=ReduceLROnPlateau` to the `LightningCLI` ([#10860](https://github.com/PyTorchLightning/pytorch-lightning/issues/10860))


- Added `LightningCLI.configure_optimizers` to override the `configure_optimizers` return value ([#10860](https://github.com/PyTorchLightning/pytorch-lightning/issues/10860))

### Changed

Expand Down
85 changes: 41 additions & 44 deletions docs/source/common/lightning_cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -822,42 +822,20 @@ Furthermore, you can register your own optimizers and/or learning rate scheduler

$ python trainer.py fit --optimizer=CustomAdam --optimizer.lr=0.01 --lr_scheduler=CustomCosineAnnealingLR

If you need to customize the key names or link arguments together, you can choose from all available optimizers and
learning rate schedulers by accessing the registries.

.. code-block::

class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(
OPTIMIZER_REGISTRY.classes,
nested_key="gen_optimizer",
link_to="model.optimizer1_init"
)
parser.add_optimizer_args(
OPTIMIZER_REGISTRY.classes,
nested_key="gen_discriminator",
link_to="model.optimizer2_init"
)
The :class:`torch.optim.lr_scheduler.ReduceLROnPlateau` scheduler requires an additional monitor argument:

.. code-block:: bash

$ python trainer.py fit \
--gen_optimizer=Adam \
--gen_optimizer.lr=0.01 \
--gen_discriminator=AdamW \
--gen_discriminator.lr=0.0001
$ python trainer.py fit --optimizer=Adam --lr_scheduler=ReduceLROnPlateau --lr_scheduler.monitor=metric_to_track

You can also use pass the class path directly, for example, if the optimizer hasn't been registered to the
``OPTIMIZER_REGISTRY``:
If you need to customize the learning rate scheduler configuration, you can do so by overriding
:meth:`~pytorch_lightning.utilities.cli.LightningCLI.configure_optimizers`:

.. code-block:: bash
.. testcode::

$ python trainer.py fit \
--gen_optimizer.class_path=torch.optim.Adam \
--gen_optimizer.init_args.lr=0.01 \
--gen_discriminator.class_path=torch.optim.AdamW \
--gen_discriminator.init_args.lr=0.0001
class MyLightningCLI(LightningCLI):
def configure_optimizers(lightning_module, optimizer, lr_scheduler=None):
return ...

If you will not be changing the class, you can manually add the arguments for specific optimizers and/or
learning rate schedulers by subclassing the CLI. This has the advantage of providing the proper help message for those
Expand Down Expand Up @@ -892,45 +870,64 @@ Where the arguments can be passed directly through command line without specifyi
$ python trainer.py fit --optimizer.lr=0.01 --lr_scheduler.gamma=0.2

The automatic implementation of :code:`configure_optimizers` can be disabled by linking the configuration group. An
example can be :code:`ReduceLROnPlateau` which requires to specify a monitor. This would be:
example can be when one wants to add support for multiple optimizers:

.. testcode::
.. code-block:: python

from pytorch_lightning.utilities.cli import instantiate_class


class MyModel(LightningModule):
def __init__(self, optimizer_init: dict, lr_scheduler_init: dict):
def __init__(self, optimizer1_init: dict, optimizer2_init: dict):
super().__init__()
self.optimizer_init = optimizer_init
self.lr_scheduler_init = lr_scheduler_init
self.optimizer1_init = optimizer1_init
self.optimizer2_init = optimizer2_init

def configure_optimizers(self):
optimizer = instantiate_class(self.parameters(), self.optimizer_init)
scheduler = instantiate_class(optimizer, self.lr_scheduler_init)
return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}
optimizer1 = instantiate_class(self.parameters(), self.optimizer1_init)
optimizer2 = instantiate_class(self.parameters(), self.optimizer2_init)
return [optimizer1, optimizer2]


class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(
torch.optim.Adam,
link_to="model.optimizer_init",
OPTIMIZER_REGISTRY.classes, nested_key="gen_optimizer", link_to="model.optimizer1_init"
)
parser.add_lr_scheduler_args(
torch.optim.lr_scheduler.ReduceLROnPlateau,
link_to="model.lr_scheduler_init",
parser.add_optimizer_args(
OPTIMIZER_REGISTRY.classes, nested_key="gen_discriminator", link_to="model.optimizer2_init"
)


cli = MyLightningCLI(MyModel)

The value given to :code:`optimizer_init` will always be a dictionary including :code:`class_path` and
The value given to :code:`optimizer*_init` will always be a dictionary including :code:`class_path` and
:code:`init_args` entries. The function :func:`~pytorch_lightning.utilities.cli.instantiate_class`
takes care of importing the class defined in :code:`class_path` and instantiating it using some positional arguments,
in this case :code:`self.parameters()`, and the :code:`init_args`.
Any number of optimizers and learning rate schedulers can be added when using :code:`link_to`.

With shorthand notation:

.. code-block:: bash

$ python trainer.py fit \
--gen_optimizer=Adam \
--gen_optimizer.lr=0.01 \
--gen_discriminator=AdamW \
--gen_discriminator.lr=0.0001

You can also pass the class path directly, for example, if the optimizer hasn't been registered to the
``OPTIMIZER_REGISTRY``:

.. code-block:: bash

$ python trainer.py fit \
--gen_optimizer.class_path=torch.optim.Adam \
--gen_optimizer.init_args.lr=0.01 \
--gen_discriminator.class_path=torch.optim.AdamW \
--gen_discriminator.init_args.lr=0.0001


Notes related to reproducibility
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
60 changes: 43 additions & 17 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import inspect
import os
import sys
from functools import partial, update_wrapper
from types import MethodType, ModuleType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from unittest import mock
Expand All @@ -28,7 +29,7 @@
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import LRSchedulerType, LRSchedulerTypeTuple
from pytorch_lightning.utilities.types import LRSchedulerType, LRSchedulerTypeTuple, LRSchedulerTypeUnion

if _JSONARGPARSE_AVAILABLE:
from jsonargparse import ActionConfigFile, ArgumentParser, class_from_function, Namespace, set_config_read_mode
Expand Down Expand Up @@ -84,6 +85,15 @@ def __str__(self) -> str:
LR_SCHEDULER_REGISTRY = _Registry()
LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler)


class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None:
super().__init__(optimizer, *args, **kwargs)
self.monitor = monitor


LR_SCHEDULER_REGISTRY(cls=ReduceLROnPlateau)

CALLBACK_REGISTRY = _Registry()
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.callbacks.Callback)

Expand Down Expand Up @@ -173,7 +183,7 @@ def add_optimizer_args(
"""Adds arguments from an optimizer class to a nested key of the parser.

Args:
optimizer_class: Any subclass of torch.optim.Optimizer.
optimizer_class: Any subclass of :class:`torch.optim.Optimizer`.
nested_key: Name of the nested namespace to store arguments.
link_to: Dot notation of a parser key to set arguments or AUTOMATIC.
"""
Expand Down Expand Up @@ -690,12 +700,31 @@ def _parser(self, subcommand: Optional[str]) -> LightningArgumentParser:
action_subcommand = self.parser._subcommands_action
return action_subcommand._name_parser_map[subcommand]

def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None:
"""Adds to the model an automatically generated ``configure_optimizers`` method.
@staticmethod
def configure_optimizers(
carmocca marked this conversation as resolved.
Show resolved Hide resolved
lightning_module: LightningModule, optimizer: Optimizer, lr_scheduler: Optional[LRSchedulerTypeUnion] = None
) -> Any:
"""Override to customize the :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_optimizers`
method.

If a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC', then a
`configure_optimizers` method is automatically implemented in the model class.
Args:
lightning_module: A reference to the model.
optimizer: The optimizer.
lr_scheduler: The learning rate scheduler (if used).
"""
if lr_scheduler is None:
return optimizer
if isinstance(lr_scheduler, ReduceLROnPlateau):
return {
"optimizer": optimizer,
"lr_scheduler": {"scheduler": lr_scheduler, "monitor": lr_scheduler.monitor},
}
return [optimizer], [lr_scheduler]

def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None:
"""Overrides the model's :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_optimizers`
method if a single optimizer and optionally a scheduler argument groups are added to the parser as
'AUTOMATIC'."""
parser = self._parser(subcommand)

def get_automatic(
Expand Down Expand Up @@ -739,21 +768,18 @@ def get_automatic(
if not isinstance(lr_scheduler_class, tuple):
lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init)

def configure_optimizers(
self: LightningModule,
) -> Union[Optimizer, Tuple[List[Optimizer], List[LRSchedulerType]]]:
optimizer = instantiate_class(self.parameters(), optimizer_init)
if not lr_scheduler_init:
return optimizer
lr_scheduler = instantiate_class(optimizer, lr_scheduler_init)
return [optimizer], [lr_scheduler]

if is_overridden("configure_optimizers", self.model):
warnings._warn(
f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by "
f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model`."
f"`{self.__class__.__name__}.configure_optimizers`."
)
self.model.configure_optimizers = MethodType(configure_optimizers, self.model)

optimizer = instantiate_class(self.model.parameters(), optimizer_init)
lr_scheduler = instantiate_class(optimizer, lr_scheduler_init) if lr_scheduler_init else None
fn = partial(self.configure_optimizers, optimizer=optimizer, lr_scheduler=lr_scheduler)
update_wrapper(fn, self.model.configure_optimizers) # necessary for `is_overridden`
# override the existing method
self.model.configure_optimizers = MethodType(fn, self.model)

def _get(self, config: Dict[str, Any], key: str, default: Optional[Any] = None) -> Any:
"""Utility to get a config value which might be inside a subcommand."""
Expand Down
46 changes: 42 additions & 4 deletions tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import torch
import yaml
from packaging import version
from torch.optim import SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR

from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
Expand Down Expand Up @@ -626,10 +628,7 @@ class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(torch.optim.Adam)

match = (
"BoringModel.configure_optimizers` will be overridden by "
"`MyLightningCLI.add_configure_optimizers_method_to_model`"
)
match = "BoringModel.configure_optimizers` will be overridden by " "`MyLightningCLI.configure_optimizers`"
argv = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.fast_dev_run=1"] if run else []
with mock.patch("sys.argv", ["any.py"] + argv), pytest.warns(UserWarning, match=match):
cli = MyLightningCLI(BoringModel, run=run)
Expand Down Expand Up @@ -878,6 +877,7 @@ def test_registries():
assert "CosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names
assert "CosineAnnealingWarmRestarts" in LR_SCHEDULER_REGISTRY.names
assert "CustomCosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names
assert "ReduceLROnPlateau" in LR_SCHEDULER_REGISTRY.names

assert "EarlyStopping" in CALLBACK_REGISTRY.names
assert "CustomCallback" in CALLBACK_REGISTRY.names
Expand Down Expand Up @@ -1384,3 +1384,41 @@ def test_cli_help_message():
assert shorthand_help.getvalue() == classpath_help.getvalue()
# make sure it's not empty
assert "Implements Adam" in shorthand_help.getvalue()


def test_cli_reducelronplateau():
with mock.patch(
"sys.argv", ["any.py", "--optimizer=Adam", "--lr_scheduler=ReduceLROnPlateau", "--lr_scheduler.monitor=foo"]
):
cli = LightningCLI(BoringModel, run=False)
config = cli.model.configure_optimizers()
assert isinstance(config["lr_scheduler"]["scheduler"], ReduceLROnPlateau)
assert config["lr_scheduler"]["scheduler"].monitor == "foo"


def test_cli_configureoptimizers_can_be_overridden():
class MyCLI(LightningCLI):
def __init__(self):
super().__init__(BoringModel, run=False)

@staticmethod
def configure_optimizers(self, optimizer, lr_scheduler=None):
assert isinstance(self, BoringModel)
assert lr_scheduler is None
return 123

with mock.patch("sys.argv", ["any.py", "--optimizer=Adam"]):
cli = MyCLI()
assert cli.model.configure_optimizers() == 123

# with no optimization config, we don't override
with mock.patch("sys.argv", ["any.py"]):
cli = MyCLI()
[optimizer], [scheduler] = cli.model.configure_optimizers()
assert isinstance(optimizer, SGD)
assert isinstance(scheduler, StepLR)
with mock.patch("sys.argv", ["any.py", "--lr_scheduler=StepLR"]):
cli = MyCLI()
[optimizer], [scheduler] = cli.model.configure_optimizers()
assert isinstance(optimizer, SGD)
assert isinstance(scheduler, StepLR)