Skip to content

Commit

Permalink
LightningCLI changes for jsonargparse>=4.0.0 (#10426)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: thomas chaton <thomas@grid.ai>
  • Loading branch information
3 people committed Nov 19, 2021
1 parent ff8ac6e commit 5d748e5
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 17 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520))


- Changes in `LightningCLI` required for the new major release of jsonargparse v4.0.0 ([#10426](https://github.com/PyTorchLightning/pytorch-lightning/pull/10426))


- Renamed `refresh_rate_per_second` parameter to `referesh_rate` for `RichProgressBar` signature ([#10497](https://github.com/PyTorchLightning/pytorch-lightning/pull/10497))


Expand Down
27 changes: 14 additions & 13 deletions pytorch_lightning/utilities/cli.py
Expand Up @@ -14,7 +14,6 @@
import inspect
import os
import sys
from argparse import Namespace
from types import MethodType, ModuleType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from unittest import mock
Expand All @@ -32,13 +31,12 @@
from pytorch_lightning.utilities.types import LRSchedulerType, LRSchedulerTypeTuple

if _JSONARGPARSE_AVAILABLE:
from jsonargparse import ActionConfigFile, ArgumentParser, class_from_function, set_config_read_mode
from jsonargparse.actions import _ActionSubCommands
from jsonargparse import ActionConfigFile, ArgumentParser, class_from_function, Namespace, set_config_read_mode
from jsonargparse.optionals import import_docstring_parse

set_config_read_mode(fsspec_enabled=True)
else:
ArgumentParser = object
ArgumentParser = Namespace = object


class _Registry(dict):
Expand Down Expand Up @@ -100,7 +98,7 @@ class LightningArgumentParser(ArgumentParser):
# use class attribute because `parse_args` is only called on the main parser
_choices: Dict[str, Tuple[Tuple[Type, ...], bool]] = {}

def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize argument parser that supports configuration file input.
For full details of accepted arguments see `ArgumentParser.__init__
Expand All @@ -109,9 +107,9 @@ def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> Non
if not _JSONARGPARSE_AVAILABLE:
raise ModuleNotFoundError(
"`jsonargparse` is not installed but it is required for the CLI."
" Install it with `pip install jsonargparse[signatures]`."
" Install it with `pip install -U jsonargparse[signatures]`."
)
super().__init__(*args, parse_as_dict=parse_as_dict, **kwargs)
super().__init__(*args, **kwargs)
self.add_argument(
"--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format."
)
Expand Down Expand Up @@ -363,7 +361,7 @@ class SaveConfigCallback(Callback):
def __init__(
self,
parser: LightningArgumentParser,
config: Union[Namespace, Dict[str, Any]],
config: Namespace,
config_filename: str,
overwrite: bool = False,
multifile: bool = False,
Expand Down Expand Up @@ -671,8 +669,7 @@ def _parser(self, subcommand: Optional[str]) -> LightningArgumentParser:
if subcommand is None:
return self.parser
# return the subcommand parser for the subcommand passed
action_subcommands = [a for a in self.parser._actions if isinstance(a, _ActionSubCommands)]
action_subcommand = action_subcommands[0]
action_subcommand = self.parser._subcommands_action
return action_subcommand._name_parser_map[subcommand]

def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None:
Expand Down Expand Up @@ -772,12 +769,16 @@ def _prepare_subcommand_kwargs(self, subcommand: str) -> Dict[str, Any]:
return fn_kwargs


def _global_add_class_path(class_type: Type, init_args: Dict[str, Any] = None) -> Dict[str, Any]:
def _global_add_class_path(
class_type: Type, init_args: Optional[Union[Namespace, Dict[str, Any]]] = None
) -> Dict[str, Any]:
if isinstance(init_args, Namespace):
init_args = init_args.as_dict()
return {"class_path": class_type.__module__ + "." + class_type.__name__, "init_args": init_args or {}}


def _add_class_path_generator(class_type: Type) -> Callable[[Dict[str, Any]], Dict[str, Any]]:
def add_class_path(init_args: Dict[str, Any]) -> Dict[str, Any]:
def _add_class_path_generator(class_type: Type) -> Callable[[Namespace], Dict[str, Any]]:
def add_class_path(init_args: Namespace) -> Dict[str, Any]:
return _global_add_class_path(class_type, init_args)

return add_class_path
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/imports.py
Expand Up @@ -85,7 +85,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version:
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
_HYDRA_AVAILABLE = _module_available("hydra")
_HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental")
_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse")
_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse") and _compare_version("jsonargparse", operator.ge, "4.0.0")
_KINETO_AVAILABLE = _TORCH_GREATER_EQUAL_1_8_1 and torch.profiler.kineto_available()
_NEPTUNE_AVAILABLE = _module_available("neptune")
_NEPTUNE_GREATER_EQUAL_0_9 = _NEPTUNE_AVAILABLE and _compare_version("neptune", operator.ge, "0.9.0")
Expand Down
2 changes: 1 addition & 1 deletion requirements/extra.txt
Expand Up @@ -5,6 +5,6 @@ horovod>=0.21.2 # no need to install with [pytorch] as pytorch is already insta
torchtext>=0.8.*
omegaconf>=2.0.5
hydra-core>=1.0.5
jsonargparse[signatures]>=3.19.3
jsonargparse[signatures]>=4.0.0
gcsfs>=2021.5.0
rich>=10.2.2
4 changes: 2 additions & 2 deletions tests/utilities/test_cli.py
Expand Up @@ -348,7 +348,7 @@ def test_lightning_cli_args(tmpdir):
loaded_config = yaml.safe_load(f.read())

loaded_config = loaded_config["fit"]
cli_config = cli.config["fit"]
cli_config = cli.config["fit"].as_dict()

assert cli_config["seed_everything"] == 1234
assert "model" not in loaded_config and "model" not in cli_config # no arguments to include
Expand Down Expand Up @@ -404,7 +404,7 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir):
loaded_config = yaml.safe_load(f.read())

loaded_config = loaded_config["fit"]
cli_config = cli.config["fit"]
cli_config = cli.config["fit"].as_dict()

assert loaded_config["model"] == cli_config["model"]
assert loaded_config["data"] == cli_config["data"]
Expand Down

0 comments on commit 5d748e5

Please sign in to comment.