Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CLI] Shorthand notation to instantiate models #9588

Merged
merged 6 commits into from Sep 22, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
59 changes: 48 additions & 11 deletions docs/source/common/lightning_cli.rst
Expand Up @@ -415,22 +415,59 @@ as described above:

$ python ... --trainer.callbacks=CustomCallback ...

This callback will be included in the generated config:
.. note::

.. code-block:: yaml
This shorthand notation is only supported in the shell and not inside a configuration file. The configuration file
generated by calling the previous command with ``--print_config`` will have the ``class_path`` notation.

.. code-block:: yaml

trainer:
callbacks:
- class_path: your_class_path.CustomCallback
init_args:
...

trainer:
callbacks:
- class_path: your_class_path.CustomCallback
init_args:
...

Multiple models and/or datasets
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In the previous examples :class:`~pytorch_lightning.utilities.cli.LightningCLI` works only for a single model and
datamodule class. However, there are many cases in which the objective is to easily be able to run many experiments for
multiple models and datasets. For these cases the tool can be configured such that a model and/or a datamodule is
multiple models and datasets.

The model argument can be left unset if a model has been registered first, this is particularly interesting for library
authors who want to provide their users a range of models to choose from:

.. code-block:: python

import flash.image
from pytorch_lightning.utilities.cli import MODEL_REGISTRY


@MODEL_REGISTRY
class SimCLR(LightningModule):
...


# register all `LightningModule` subclasses from a package
MODEL_REGISTRY.register_classes(flash.image, LightningModule)
# print(MODEL_REGISTRY)
# >>> Registered objects: ['SimCLR', 'ImageClassifier', 'ObjectDetector', 'StyleTransfer', ...]

cli = LightningCLI()

.. code-block:: bash

$ python trainer.py fit --model=SimCLR --model.feat_dim=64

.. note::

This shorthand notation is only supported in the shell and not inside a configuration file. The configuration file
generated by calling the previous command with ``--print_config`` will have the ``class_path`` notation described
below.

Additionally, the tool can be configured such that a model and/or a datamodule is
specified by an import path and init arguments. For example, with a tool implemented as:

.. code-block:: python
Expand Down Expand Up @@ -750,7 +787,7 @@ A corresponding example of the config file would be:

.. note::

This short-hand notation is only supported in the shell and not inside a configuration file. The configuration file
This shorthand notation is only supported in the shell and not inside a configuration file. The configuration file
generated by calling the previous command with ``--print_config`` will have the ``class_path`` notation.

Furthermore, you can register your own optimizers and/or learning rate schedulers as follows:
Expand Down Expand Up @@ -894,8 +931,8 @@ Notes related to reproducibility
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The topic of reproducibility is complex and it is impossible to guarantee reproducibility by just providing a class that
people can use in unexpected ways. Nevertheless :class:`~pytorch_lightning.utilities.cli.LightningCLI` tries to give a
framework and recommendations to make reproducibility simpler.
people can use in unexpected ways. Nevertheless, the :class:`~pytorch_lightning.utilities.cli.LightningCLI` tries to
give a framework and recommendations to make reproducibility simpler.

When an experiment is run, it is good practice to use a stable version of the source code, either being a released
package or at least a commit of some version controlled repository. For each run of a CLI the config file is
Expand Down
31 changes: 23 additions & 8 deletions pytorch_lightning/utilities/cli.py
Expand Up @@ -42,7 +42,7 @@


class _Registry(dict):
def __call__(self, cls: Type, key: Optional[str] = None, override: bool = False) -> None:
def __call__(self, cls: Type, key: Optional[str] = None, override: bool = False) -> Type:
"""Registers a class mapped to a name.

Args:
Expand All @@ -58,6 +58,7 @@ def __call__(self, cls: Type, key: Optional[str] = None, override: bool = False)
if key in self and not override:
raise MisconfigurationException(f"'{key}' is already present in the registry. HINT: Use `override=True`.")
self[key] = cls
return cls

def register_classes(self, module: ModuleType, base_cls: Type, override: bool = False) -> None:
"""This function is an utility to register all classes from a module."""
Expand Down Expand Up @@ -88,6 +89,8 @@ def __str__(self) -> str:
CALLBACK_REGISTRY = _Registry()
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.callbacks.Callback)

MODEL_REGISTRY = _Registry()


class LightningArgumentParser(ArgumentParser):
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning."""
Expand Down Expand Up @@ -146,7 +149,7 @@ def add_lightning_class_args(
if issubclass(lightning_class, Callback):
self.callback_keys.append(nested_key)
if subclass_mode:
return self.add_subclass_arguments(lightning_class, nested_key, required=True)
return self.add_subclass_arguments(lightning_class, nested_key, fail_untyped=False, required=True)
return self.add_class_arguments(
lightning_class, nested_key, fail_untyped=False, instantiate=not issubclass(lightning_class, Trainer)
)
Expand Down Expand Up @@ -384,7 +387,7 @@ class LightningCLI:

def __init__(
self,
model_class: Union[Type[LightningModule], Callable[..., LightningModule]],
model_class: Optional[Union[Type[LightningModule], Callable[..., LightningModule]]] = None,
datamodule_class: Optional[Union[Type[LightningDataModule], Callable[..., LightningDataModule]]] = None,
save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback,
save_config_filename: str = "config.yaml",
Expand Down Expand Up @@ -412,8 +415,9 @@ def __init__(
.. warning:: ``LightningCLI`` is in beta and subject to change.

Args:
model_class: :class:`~pytorch_lightning.core.lightning.LightningModule` class to train on or a callable
which returns a :class:`~pytorch_lightning.core.lightning.LightningModule` instance when called.
model_class: An optional :class:`~pytorch_lightning.core.lightning.LightningModule` class to train on or a
callable which returns a :class:`~pytorch_lightning.core.lightning.LightningModule` instance when
called. If ``None``, you can pass a registered model with ``--model=AModel``.
datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class or a
callable which returns a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` instance when
called.
Expand All @@ -438,17 +442,23 @@ def __init__(
run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer`
method. If set to ``False``, the trainer and model classes will be instantiated only.
"""
self.model_class = model_class
self.datamodule_class = datamodule_class
self.save_config_callback = save_config_callback
self.save_config_filename = save_config_filename
self.save_config_overwrite = save_config_overwrite
self.trainer_class = trainer_class
self.trainer_defaults = trainer_defaults or {}
self.seed_everything_default = seed_everything_default
self.subclass_mode_model = subclass_mode_model
self.subclass_mode_data = subclass_mode_data

self.model_class = model_class
self._model_class = model_class
self.subclass_mode_model = subclass_mode_model
if model_class is None:
# used to differentiate between the original value and the processed value
self._model_class = LightningModule
self.subclass_mode_model = True

main_kwargs, subparser_kwargs = self._setup_parser_kwargs(
parser_kwargs or {}, # type: ignore # github.com/python/mypy/issues/6463
{"description": description, "env_prefix": env_prefix, "default_env": env_parse},
Expand Down Expand Up @@ -508,7 +518,12 @@ def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
parser.set_choices("trainer.callbacks", CALLBACK_REGISTRY.classes, is_list=True)
trainer_defaults = {"trainer." + k: v for k, v in self.trainer_defaults.items() if k != "callbacks"}
parser.set_defaults(trainer_defaults)
parser.add_lightning_class_args(self.model_class, "model", subclass_mode=self.subclass_mode_model)

parser.add_lightning_class_args(self._model_class, "model", subclass_mode=self.subclass_mode_model)
if self.model_class is None and MODEL_REGISTRY:
# did not pass a model and there are models registered
parser.set_choices("model", MODEL_REGISTRY.classes)

if self.datamodule_class is not None:
parser.add_lightning_class_args(self.datamodule_class, "data", subclass_mode=self.subclass_mode_data)

Expand Down
35 changes: 34 additions & 1 deletion tests/utilities/test_cli.py
Expand Up @@ -22,6 +22,7 @@
from io import StringIO
from typing import List, Optional, Union
from unittest import mock
from unittest.mock import ANY

import pytest
import torch
Expand All @@ -39,6 +40,7 @@
LightningArgumentParser,
LightningCLI,
LR_SCHEDULER_REGISTRY,
MODEL_REGISTRY,
OPTIMIZER_REGISTRY,
SaveConfigCallback,
)
Expand Down Expand Up @@ -883,6 +885,35 @@ def test_registries(tmpdir):
OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer)
OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer, override=True)

# test `_Registry.__call__` returns the class
assert isinstance(CustomCallback(), CustomCallback)


@MODEL_REGISTRY
class TestModel(BoringModel):
def __init__(self, foo, bar=5):
super().__init__()
self.foo = foo
self.bar = bar


MODEL_REGISTRY(cls=BoringModel)


def test_lightning_cli_model_choices():
with mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), mock.patch(
"pytorch_lightning.Trainer._fit_impl"
) as run:
cli = LightningCLI(trainer_defaults={"fast_dev_run": 1})
assert isinstance(cli.model, BoringModel)
run.assert_called_once_with(cli.model, ANY, ANY, ANY)

with mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]):
cli = LightningCLI(run=False)
assert isinstance(cli.model, TestModel)
assert cli.model.foo == 123
assert cli.model.bar == 5


@pytest.mark.parametrize("use_class_path_callbacks", [False, True])
def test_registries_resolution(use_class_path_callbacks):
Expand All @@ -895,6 +926,7 @@ def test_registries_resolution(use_class_path_callbacks):
"--trainer.callbacks=LearningRateMonitor",
"--trainer.callbacks.logging_interval=epoch",
"--trainer.callbacks.log_momentum=True",
"--model=BoringModel",
"--trainer.callbacks=ModelCheckpoint",
"--trainer.callbacks.monitor=loss",
"--lr_scheduler",
Expand All @@ -912,8 +944,9 @@ def test_registries_resolution(use_class_path_callbacks):
extras = [Callback, Callback]

with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(BoringModel, run=False)
cli = LightningCLI(run=False)

assert isinstance(cli.model, BoringModel)
optimizers, lr_scheduler = cli.model.configure_optimizers()
assert isinstance(optimizers[0], torch.optim.Adam)
assert optimizers[0].param_groups[0]["lr"] == 0.0001
Expand Down