Skip to content

Commit

Permalink
LinghtningCLI now will not allow setting a class instance as a default.
Browse files Browse the repository at this point in the history
  • Loading branch information
mauvilsa committed Oct 19, 2023
1 parent c68ff64 commit 6caf334
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion requirements/pytorch/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
matplotlib>3.1, <3.9.0
omegaconf >=2.0.5, <2.4.0
hydra-core >=1.0.5, <1.4.0
jsonargparse[signatures] >=4.18.0, <4.26.0
jsonargparse[signatures] >=4.26.0, <4.27.0
rich >=12.3.0, <13.6.0
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

-
- `LightningCLI` no longer allows setting a normal class instance as default. A `lazy_instance` can be used instead ([#?????](https://github.com/Lightning-AI/lightning/pull/?????))


### Deprecated
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_warn

_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.18.0")
_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.26.0")

if _JSONARGPARSE_SIGNATURES_AVAILABLE:
import docstring_parser
Expand Down
9 changes: 7 additions & 2 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@
else:
from argparse import Namespace

def lazy_instance(*args, **kwargs):
return None


@contextmanager
def mock_subclasses(baseclass, *subclasses):
Expand Down Expand Up @@ -176,7 +179,9 @@ def on_fit_start(self):
self.trainer.ran_asserts = True

with mock.patch("sys.argv", ["any.py", "fit", f"--trainer.callbacks={json.dumps(callbacks)}"]):
cli = LightningCLI(TestModel, trainer_defaults={"fast_dev_run": True, "logger": CSVLogger(".")})
cli = LightningCLI(
TestModel, trainer_defaults={"fast_dev_run": True, "logger": lazy_instance(CSVLogger, save_dir=".")}
)

assert cli.trainer.ran_asserts

Expand Down Expand Up @@ -592,7 +597,7 @@ def on_fit_start(self):

# mps not yet supported by distributed
@RunIf(skip_windows=True, mps=False)
@pytest.mark.parametrize("logger", [False, TensorBoardLogger(".")])
@pytest.mark.parametrize("logger", [False, lazy_instance(TensorBoardLogger, save_dir=".")])
@pytest.mark.parametrize("strategy", ["ddp_spawn", "ddp"])
def test_cli_distributed_save_config_callback(cleandir, logger, strategy):
from torch.multiprocessing import ProcessRaisedException
Expand Down

0 comments on commit 6caf334

Please sign in to comment.