Skip to content

Commit

Permalink
[CLI] Shorthand notation to instantiate models (#9588)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Sep 22, 2021
1 parent 8f1c855 commit 3f7872d
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 19 deletions.
59 changes: 48 additions & 11 deletions docs/source/common/lightning_cli.rst
Expand Up @@ -415,22 +415,59 @@ as described above:
$ python ... --trainer.callbacks=CustomCallback ...
This callback will be included in the generated config:
.. note::
.. code-block:: yaml
This shorthand notation is only supported in the shell and not inside a configuration file. The configuration file
generated by calling the previous command with ``--print_config`` will have the ``class_path`` notation.
.. code-block:: yaml
trainer:
callbacks:
- class_path: your_class_path.CustomCallback
init_args:
...
trainer:
callbacks:
- class_path: your_class_path.CustomCallback
init_args:
...
Multiple models and/or datasets
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
In the previous examples :class:`~pytorch_lightning.utilities.cli.LightningCLI` works only for a single model and
datamodule class. However, there are many cases in which the objective is to easily be able to run many experiments for
multiple models and datasets. For these cases the tool can be configured such that a model and/or a datamodule is
multiple models and datasets.
The model argument can be left unset if a model has been registered first, this is particularly interesting for library
authors who want to provide their users a range of models to choose from:
.. code-block:: python
import flash.image
from pytorch_lightning.utilities.cli import MODEL_REGISTRY
@MODEL_REGISTRY
class MyModel(LightningModule):
...
# register all `LightningModule` subclasses from a package
MODEL_REGISTRY.register_classes(flash.image, LightningModule)
# print(MODEL_REGISTRY)
# >>> Registered objects: ['MyModel', 'ImageClassifier', 'ObjectDetector', 'StyleTransfer', ...]
cli = LightningCLI()
.. code-block:: bash
$ python trainer.py fit --model=MyModel --model.feat_dim=64
.. note::
This shorthand notation is only supported in the shell and not inside a configuration file. The configuration file
generated by calling the previous command with ``--print_config`` will have the ``class_path`` notation described
below.
Additionally, the tool can be configured such that a model and/or a datamodule is
specified by an import path and init arguments. For example, with a tool implemented as:
.. code-block:: python
Expand Down Expand Up @@ -750,7 +787,7 @@ A corresponding example of the config file would be:
.. note::
This short-hand notation is only supported in the shell and not inside a configuration file. The configuration file
This shorthand notation is only supported in the shell and not inside a configuration file. The configuration file
generated by calling the previous command with ``--print_config`` will have the ``class_path`` notation.
Furthermore, you can register your own optimizers and/or learning rate schedulers as follows:
Expand Down Expand Up @@ -894,8 +931,8 @@ Notes related to reproducibility
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The topic of reproducibility is complex and it is impossible to guarantee reproducibility by just providing a class that
people can use in unexpected ways. Nevertheless :class:`~pytorch_lightning.utilities.cli.LightningCLI` tries to give a
framework and recommendations to make reproducibility simpler.
people can use in unexpected ways. Nevertheless, the :class:`~pytorch_lightning.utilities.cli.LightningCLI` tries to
give a framework and recommendations to make reproducibility simpler.
When an experiment is run, it is good practice to use a stable version of the source code, either being a released
package or at least a commit of some version controlled repository. For each run of a CLI the config file is
Expand Down
25 changes: 18 additions & 7 deletions pytorch_lightning/utilities/cli.py
Expand Up @@ -89,6 +89,8 @@ def __str__(self) -> str:
CALLBACK_REGISTRY = _Registry()
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.callbacks.Callback)

MODEL_REGISTRY = _Registry()


class LightningArgumentParser(ArgumentParser):
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning."""
Expand Down Expand Up @@ -147,7 +149,7 @@ def add_lightning_class_args(
if issubclass(lightning_class, Callback):
self.callback_keys.append(nested_key)
if subclass_mode:
return self.add_subclass_arguments(lightning_class, nested_key, required=True)
return self.add_subclass_arguments(lightning_class, nested_key, fail_untyped=False, required=True)
return self.add_class_arguments(
lightning_class, nested_key, fail_untyped=False, instantiate=not issubclass(lightning_class, Trainer)
)
Expand Down Expand Up @@ -385,7 +387,7 @@ class LightningCLI:

def __init__(
self,
model_class: Union[Type[LightningModule], Callable[..., LightningModule]],
model_class: Optional[Union[Type[LightningModule], Callable[..., LightningModule]]] = None,
datamodule_class: Optional[Union[Type[LightningDataModule], Callable[..., LightningDataModule]]] = None,
save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback,
save_config_filename: str = "config.yaml",
Expand Down Expand Up @@ -413,8 +415,9 @@ def __init__(
.. warning:: ``LightningCLI`` is in beta and subject to change.
Args:
model_class: :class:`~pytorch_lightning.core.lightning.LightningModule` class to train on or a callable
which returns a :class:`~pytorch_lightning.core.lightning.LightningModule` instance when called.
model_class: An optional :class:`~pytorch_lightning.core.lightning.LightningModule` class to train on or a
callable which returns a :class:`~pytorch_lightning.core.lightning.LightningModule` instance when
called. If ``None``, you can pass a registered model with ``--model=MyModel``.
datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class or a
callable which returns a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` instance when
called.
Expand All @@ -439,17 +442,20 @@ def __init__(
run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer`
method. If set to ``False``, the trainer and model classes will be instantiated only.
"""
self.model_class = model_class
self.datamodule_class = datamodule_class
self.save_config_callback = save_config_callback
self.save_config_filename = save_config_filename
self.save_config_overwrite = save_config_overwrite
self.trainer_class = trainer_class
self.trainer_defaults = trainer_defaults or {}
self.seed_everything_default = seed_everything_default
self.subclass_mode_model = subclass_mode_model
self.subclass_mode_data = subclass_mode_data

self.model_class = model_class
# used to differentiate between the original value and the processed value
self._model_class = model_class or LightningModule
self.subclass_mode_model = (model_class is None) or subclass_mode_model

main_kwargs, subparser_kwargs = self._setup_parser_kwargs(
parser_kwargs or {}, # type: ignore # github.com/python/mypy/issues/6463
{"description": description, "env_prefix": env_prefix, "default_env": env_parse},
Expand Down Expand Up @@ -509,7 +515,12 @@ def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
parser.set_choices("trainer.callbacks", CALLBACK_REGISTRY.classes, is_list=True)
trainer_defaults = {"trainer." + k: v for k, v in self.trainer_defaults.items() if k != "callbacks"}
parser.set_defaults(trainer_defaults)
parser.add_lightning_class_args(self.model_class, "model", subclass_mode=self.subclass_mode_model)

parser.add_lightning_class_args(self._model_class, "model", subclass_mode=self.subclass_mode_model)
if self.model_class is None and MODEL_REGISTRY:
# did not pass a model and there are models registered
parser.set_choices("model", MODEL_REGISTRY.classes)

if self.datamodule_class is not None:
parser.add_lightning_class_args(self.datamodule_class, "data", subclass_mode=self.subclass_mode_data)

Expand Down
32 changes: 31 additions & 1 deletion tests/utilities/test_cli.py
Expand Up @@ -22,6 +22,7 @@
from io import StringIO
from typing import List, Optional, Union
from unittest import mock
from unittest.mock import ANY

import pytest
import torch
Expand All @@ -39,6 +40,7 @@
LightningArgumentParser,
LightningCLI,
LR_SCHEDULER_REGISTRY,
MODEL_REGISTRY,
OPTIMIZER_REGISTRY,
SaveConfigCallback,
)
Expand Down Expand Up @@ -888,6 +890,32 @@ def test_registries(tmpdir):
assert isinstance(CustomCallback(), CustomCallback)


@MODEL_REGISTRY
class TestModel(BoringModel):
def __init__(self, foo, bar=5):
super().__init__()
self.foo = foo
self.bar = bar


MODEL_REGISTRY(cls=BoringModel)


def test_lightning_cli_model_choices():
with mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), mock.patch(
"pytorch_lightning.Trainer._fit_impl"
) as run:
cli = LightningCLI(trainer_defaults={"fast_dev_run": 1})
assert isinstance(cli.model, BoringModel)
run.assert_called_once_with(cli.model, ANY, ANY, ANY)

with mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]):
cli = LightningCLI(run=False)
assert isinstance(cli.model, TestModel)
assert cli.model.foo == 123
assert cli.model.bar == 5


@pytest.mark.parametrize("use_class_path_callbacks", [False, True])
def test_registries_resolution(use_class_path_callbacks):
"""This test validates registries are used when simplified command line are being used."""
Expand All @@ -899,6 +927,7 @@ def test_registries_resolution(use_class_path_callbacks):
"--trainer.callbacks=LearningRateMonitor",
"--trainer.callbacks.logging_interval=epoch",
"--trainer.callbacks.log_momentum=True",
"--model=BoringModel",
"--trainer.callbacks=ModelCheckpoint",
"--trainer.callbacks.monitor=loss",
"--lr_scheduler",
Expand All @@ -916,8 +945,9 @@ def test_registries_resolution(use_class_path_callbacks):
extras = [Callback, Callback]

with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(BoringModel, run=False)
cli = LightningCLI(run=False)

assert isinstance(cli.model, BoringModel)
optimizers, lr_scheduler = cli.model.configure_optimizers()
assert isinstance(optimizers[0], torch.optim.Adam)
assert optimizers[0].param_groups[0]["lr"] == 0.0001
Expand Down

0 comments on commit 3f7872d

Please sign in to comment.