Skip to content

Commit

Permalink
load_from_checkpoint support for LightningCLI when using dependency…
Browse files Browse the repository at this point in the history
… injection (#18105)
  • Loading branch information
mauvilsa committed Feb 23, 2024
1 parent a6273d1 commit 623ec58
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 9 deletions.
10 changes: 10 additions & 0 deletions docs/source-pytorch/cli/lightning_cli_advanced_3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ Since the init parameters of the model have as a type hint a class, in the confi
decoder: Instance of a module for decoding
"""
super().__init__()
self.save_hyperparameters()
self.encoder = encoder
self.decoder = decoder

Expand All @@ -216,6 +217,13 @@ If the CLI is implemented as ``LightningCLI(MyMainModel)`` the configuration wou
It is also possible to combine ``subclass_mode_model=True`` and submodules, thereby having two levels of ``class_path``.

.. tip::

By having ``self.save_hyperparameters()`` it becomes possible to load the model from a checkpoint. Simply do
``ModelClass.load_from_checkpoint("path/to/checkpoint.ckpt")``. In the case of using ``subclass_mode_model=True``,
then load it like ``LightningModule.load_from_checkpoint("path/to/checkpoint.ckpt")``. ``save_hyperparameters`` is
optional and can be safely removed if there is no need to load from a checkpoint.


Fixed optimizer and scheduler
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -279,6 +287,7 @@ An example of a model that uses two optimizers is the following:
class MyModel(LightningModule):
def __init__(self, optimizer1: OptimizerCallable, optimizer2: OptimizerCallable):
super().__init__()
self.save_hyperparameters()
self.optimizer1 = optimizer1
self.optimizer2 = optimizer2
Expand Down Expand Up @@ -318,6 +327,7 @@ that uses dependency injection for an optimizer and a learning scheduler is:
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
):
super().__init__()
self.save_hyperparameters()
self.optimizer = optimizer
self.scheduler = scheduler
Expand Down
2 changes: 1 addition & 1 deletion requirements/pytorch/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
matplotlib>3.1, <3.9.0
omegaconf >=2.0.5, <2.4.0
hydra-core >=1.0.5, <1.4.0
jsonargparse[signatures] >=4.26.1, <4.28.0
jsonargparse[signatures] >=4.27.5, <4.28.0
rich >=12.3.0, <13.6.0
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
bitsandbytes ==0.41.0 # strict
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- The `ModelSummary` and `RichModelSummary` callbacks now display the training mode of each layer in the column "Mode" ([#19468](https://github.com/Lightning-AI/lightning/pull/19468))

- Added `load_from_checkpoint` support for `LightningCLI` when using dependency injection ([#18105](https://github.com/Lightning-AI/lightning/pull/18105))

-

-
Expand Down Expand Up @@ -64,6 +66,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added shortcut name `strategy='deepspeed_stage_1_offload'` to the strategy registry ([#19075](https://github.com/Lightning-AI/lightning/pull/19075))
- Added support for non-strict state-dict loading in Trainer via the new `LightningModule.strict_loading = True | False` attribute ([#19404](https://github.com/Lightning-AI/lightning/pull/19404))


### Changed

- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))
Expand Down
56 changes: 54 additions & 2 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import sys
from functools import partial, update_wrapper
from types import MethodType
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union

import torch
import yaml
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.rank_zero import _warn
from torch.optim import Optimizer
Expand All @@ -27,11 +29,12 @@
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER
from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, seed_everything
from lightning.pytorch.core.mixins.hparams_mixin import _given_hyperparameters_context
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_warn

_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.26.1")
_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.27.5")

if _JSONARGPARSE_SIGNATURES_AVAILABLE:
import docstring_parser
Expand All @@ -50,6 +53,8 @@
locals()["ArgumentParser"] = object
locals()["Namespace"] = object

ModuleType = TypeVar("ModuleType")


class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -381,6 +386,7 @@ def __init__(

self._set_seed()

self._add_instantiators()
self.before_instantiate_classes()
self.instantiate_classes()

Expand Down Expand Up @@ -527,6 +533,22 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
else:
self.config = parser.parse_args(args)

def _add_instantiators(self) -> None:
self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False))
if "subcommand" in self.config:
self.config_dump = self.config_dump[self.config.subcommand]

self.parser.add_instantiator(
_InstantiatorFn(cli=self, key="model"),
_get_module_type(self._model_class),
subclasses=self.subclass_mode_model,
)
self.parser.add_instantiator(
_InstantiatorFn(cli=self, key="data"),
_get_module_type(self._datamodule_class),
subclasses=self.subclass_mode_data,
)

def before_instantiate_classes(self) -> None:
"""Implement to run some code before instantiating the classes."""

Expand Down Expand Up @@ -755,3 +777,33 @@ def _get_short_description(component: object) -> Optional[str]:
return docstring.short_description
except (ValueError, docstring_parser.ParseError) as ex:
rank_zero_warn(f"Failed parsing docstring for {component}: {ex}")


def _get_module_type(value: Union[Callable, type]) -> type:
if callable(value) and not isinstance(value, type):
return inspect.signature(value).return_annotation
return value


class _InstantiatorFn:
def __init__(self, cli: LightningCLI, key: str) -> None:
self.cli = cli
self.key = key

def __call__(self, class_type: Type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType:
with _given_hyperparameters_context(
hparams=self.cli.config_dump.get(self.key, {}),
instantiator="lightning.pytorch.cli.instantiate_module",
):
return class_type(*args, **kwargs)


def instantiate_module(class_type: Type[ModuleType], config: Dict[str, Any]) -> ModuleType:
parser = ArgumentParser(exit_on_error=False)
if "class_path" in config:
parser.add_subclass_arguments(class_type, "module")
else:
parser.add_class_arguments(class_type, "module")
cfg = parser.parse_object({"module": config})
init = parser.instantiate_classes(cfg)
return init.module
23 changes: 20 additions & 3 deletions src/lightning/pytorch/core/mixins/hparams_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import inspect
import types
from argparse import Namespace
from typing import Any, List, MutableMapping, Optional, Sequence, Union
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Iterator, List, MutableMapping, Optional, Sequence, Union

from lightning.fabric.utilities.data import AttributeDict
from lightning.pytorch.utilities.parsing import save_hyperparameters
Expand All @@ -24,6 +26,20 @@
_ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)


_given_hyperparameters: ContextVar = ContextVar("_given_hyperparameters", default=None)


@contextmanager
def _given_hyperparameters_context(hparams: dict, instantiator: str) -> Iterator[None]:
hparams = hparams.copy()
hparams["_instantiator"] = instantiator
token = _given_hyperparameters.set(hparams)
try:
yield
finally:
_given_hyperparameters.reset(token)


class HyperparametersMixin:
__jit_unused_properties__: List[str] = ["hparams", "hparams_initial"]

Expand Down Expand Up @@ -105,12 +121,13 @@ class ``__init__`` to be ignored
"""
self._log_hyperparams = logger
given_hparams = _given_hyperparameters.get()
# the frame needs to be created in this file.
if not frame:
if given_hparams is None and not frame:
current_frame = inspect.currentframe()
if current_frame:
frame = current_frame.f_back
save_hyperparameters(self, *args, ignore=ignore, frame=frame)
save_hyperparameters(self, *args, ignore=ignore, frame=frame, given_hparams=given_hparams)

def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None:
hp = self._to_hparams_dict(hp)
Expand Down
9 changes: 8 additions & 1 deletion src/lightning/pytorch/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,18 @@ def _load_state(
_cls_kwargs.update(cls_kwargs_loaded)
_cls_kwargs.update(cls_kwargs_new)

instantiator = None
instantiator_path = _cls_kwargs.pop("_instantiator", None)
if instantiator_path is not None:
# import custom instantiator
module_path, name = instantiator_path.rsplit(".", 1)
instantiator = getattr(__import__(module_path, fromlist=[name]), name)

if not cls_spec.varkw:
# filter kwargs according to class init unless it allows any argument via kwargs
_cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name}

obj = cls(**_cls_kwargs)
obj = instantiator(cls, _cls_kwargs) if instantiator else cls(**_cls_kwargs)

if isinstance(obj, pl.LightningDataModule):
if obj.__class__.__qualname__ in checkpoint:
Expand Down
10 changes: 8 additions & 2 deletions src/lightning/pytorch/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,11 @@ def collect_init_args(


def save_hyperparameters(
obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None
obj: Any,
*args: Any,
ignore: Optional[Union[Sequence[str], str]] = None,
frame: Optional[types.FrameType] = None,
given_hparams: Optional[Dict[str, Any]] = None,
) -> None:
"""See :meth:`~lightning.pytorch.LightningModule.save_hyperparameters`"""

Expand All @@ -156,7 +160,9 @@ def save_hyperparameters(
if not isinstance(frame, types.FrameType):
raise AttributeError("There is no `frame` available while being required.")

if is_dataclass(obj):
if given_hparams is not None:
init_args = given_hparams
elif is_dataclass(obj):
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
else:
init_args = {}
Expand Down
78 changes: 78 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,84 @@ def configure_optimizers(self):
assert init[1]["lr_scheduler"].gamma == 0.3


class TestModelSaveHparams(BoringModel):
def __init__(
self,
optimizer: OptimizerCallable = torch.optim.Adam,
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
activation: torch.nn.Module = lazy_instance(torch.nn.LeakyReLU, negative_slope=0.05),
):
super().__init__()
self.save_hyperparameters()
self.optimizer = optimizer
self.scheduler = scheduler
self.activation = activation

def configure_optimizers(self):
optimizer = self.optimizer(self.parameters())
scheduler = self.scheduler(optimizer)
return {"optimizer": optimizer, "lr_scheduler": scheduler}


def test_lightning_cli_load_from_checkpoint_dependency_injection(cleandir):
with mock.patch("sys.argv", ["any.py", "--trainer.max_epochs=1"]):
cli = LightningCLI(TestModelSaveHparams, run=False, auto_configure_optimizers=False)
cli.trainer.fit(cli.model)

hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
assert hparams_path.is_file()
hparams = yaml.safe_load(hparams_path.read_text())
expected = {
"_instantiator": "lightning.pytorch.cli.instantiate_module",
"optimizer": "torch.optim.Adam",
"scheduler": "torch.optim.lr_scheduler.ConstantLR",
"activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}},
}
assert hparams == expected

checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None)
assert checkpoint_path.is_file()
ckpt = torch.load(checkpoint_path)
assert ckpt["hyper_parameters"] == expected

model = TestModelSaveHparams.load_from_checkpoint(checkpoint_path)
assert isinstance(model, TestModelSaveHparams)
assert isinstance(model.activation, torch.nn.LeakyReLU)
assert model.activation.negative_slope == 0.05
optimizer, lr_scheduler = model.configure_optimizers().values()
assert isinstance(optimizer, torch.optim.Adam)
assert isinstance(lr_scheduler, torch.optim.lr_scheduler.ConstantLR)


def test_lightning_cli_load_from_checkpoint_dependency_injection_subclass_mode(cleandir):
with mock.patch("sys.argv", ["any.py", "--trainer.max_epochs=1", "--model=TestModelSaveHparams"]):
cli = LightningCLI(TestModelSaveHparams, run=False, auto_configure_optimizers=False, subclass_mode_model=True)
cli.trainer.fit(cli.model)

expected = {
"_instantiator": "lightning.pytorch.cli.instantiate_module",
"class_path": f"{__name__}.TestModelSaveHparams",
"init_args": {
"optimizer": "torch.optim.Adam",
"scheduler": "torch.optim.lr_scheduler.ConstantLR",
"activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}},
},
}

checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None)
assert checkpoint_path.is_file()
ckpt = torch.load(checkpoint_path)
assert ckpt["hyper_parameters"] == expected

model = LightningModule.load_from_checkpoint(checkpoint_path)
assert isinstance(model, TestModelSaveHparams)
assert isinstance(model.activation, torch.nn.LeakyReLU)
assert model.activation.negative_slope == 0.05
optimizer, lr_scheduler = model.configure_optimizers().values()
assert isinstance(optimizer, torch.optim.Adam)
assert isinstance(lr_scheduler, torch.optim.lr_scheduler.ConstantLR)


@pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn])
def test_lightning_cli_trainer_fn(fn):
class TestCLI(LightningCLI):
Expand Down

0 comments on commit 623ec58

Please sign in to comment.