Skip to content

Commit

Permalink
Skip tests that cause CLI argparse errors on Python 3.11.9 (#19756)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Apr 11, 2024
1 parent 76b691d commit 316cc71
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 5 deletions.
3 changes: 2 additions & 1 deletion tests/tests_pytorch/loggers/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def mlflow_mock(monkeypatch):
mlflow.tracking = mlflow_tracking
mlflow.entities = mlflow_entities

(monkeypatch.setattr("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", True),)
monkeypatch.setattr("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", True)
monkeypatch.setattr("lightning.pytorch.loggers.mlflow._MLFLOW_SYNCHRONOUS_AVAILABLE", True)
return mlflow


Expand Down
3 changes: 0 additions & 3 deletions tests/tests_pytorch/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loggers.mlflow import (
_MLFLOW_AVAILABLE,
_MLFLOW_SYNCHRONOUS_AVAILABLE,
MLFlowLogger,
_get_resolve_tags,
)
Expand Down Expand Up @@ -269,8 +268,6 @@ def test_mlflow_logger_experiment_calls(mlflow_mock, tmp_path):
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_experiment_calls_with_synchronous(mlflow_mock, tmp_path, synchronous):
"""Test that the logger calls methods on the mlflow experiment with the specified synchronous flag."""
if not _MLFLOW_SYNCHRONOUS_AVAILABLE:
pytest.skip("this test requires mlflow>=2.8.0")

time = mlflow_mock.entities.time
metric = mlflow_mock.entities.Metric
Expand Down
3 changes: 3 additions & 0 deletions tests/tests_pytorch/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning_utilities.test.warning import no_warning_call

from tests_pytorch.test_cli import _xfail_python_ge_3_11_9


def test_wandb_project_name(wandb_mock):
with mock.patch.dict(os.environ, {}):
Expand Down Expand Up @@ -548,6 +550,7 @@ def test_wandb_logger_download_artifact(wandb_mock, tmp_path):
wandb_mock.Api().artifact.assert_called_once_with("test_artifact", type="model")


@_xfail_python_ge_3_11_9
@pytest.mark.parametrize(("log_model", "expected"), [("True", True), ("False", False), ("all", "all")])
def test_wandb_logger_cli_integration(log_model, expected, wandb_mock, monkeypatch, tmp_path):
"""Test that the WandbLogger can be used with the LightningCLI."""
Expand Down
27 changes: 26 additions & 1 deletion tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
from lightning_utilities import compare_version
from lightning_utilities.test.warning import no_warning_call
from packaging.version import Version
from tensorboard.backend.event_processing import event_accumulator
from tensorboard.plugins.hparams.plugin_data_pb2 import HParamsPluginData
from torch.optim import SGD
Expand All @@ -64,6 +65,14 @@ def lazy_instance(*args, **kwargs):
return None


_xfail_python_ge_3_11_9 = pytest.mark.xfail(
# https://github.com/omni-us/jsonargparse/issues/484
Version(f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}") >= Version("3.11.9"),
strict=False,
reason="jsonargparse + Python 3.11.9 compatibility issue",
)


@contextmanager
def mock_subclasses(baseclass, *subclasses):
"""Mocks baseclass so that it only has the given child subclasses."""
Expand Down Expand Up @@ -347,6 +356,7 @@ def test_save_to_log_dir_false_error():
)


@_xfail_python_ge_3_11_9
def test_lightning_cli_logger_save_config(cleandir):
class LoggerSaveConfigCallback(SaveConfigCallback):
def __init__(self, *args, **kwargs) -> None:
Expand Down Expand Up @@ -736,6 +746,7 @@ def add_arguments_to_parser(self, parser):
assert cli.trainer.lr_scheduler_configs[0].scheduler.step_size == 50


@_xfail_python_ge_3_11_9
@pytest.mark.parametrize("use_generic_base_class", [False, True])
def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(use_generic_base_class):
class MyLightningCLI(LightningCLI):
Expand Down Expand Up @@ -782,7 +793,7 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict):
assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR)


@pytest.mark.skipif(compare_version("jsonargparse", operator.lt, "4.21.3"), reason="vulnerability with failing imports")
@_xfail_python_ge_3_11_9
def test_lightning_cli_optimizers_and_lr_scheduler_with_callable_type():
class TestModel(BoringModel):
def __init__(
Expand Down Expand Up @@ -1031,6 +1042,7 @@ def __init__(self, foo, bar=5):
self.bar = bar


@_xfail_python_ge_3_11_9
def test_lightning_cli_model_short_arguments():
with mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), mock.patch(
"lightning.pytorch.Trainer._fit_impl"
Expand All @@ -1055,6 +1067,7 @@ def __init__(self, foo, bar=5):
self.bar = bar


@_xfail_python_ge_3_11_9
def test_lightning_cli_datamodule_short_arguments():
# with set model
with mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), mock.patch(
Expand Down Expand Up @@ -1100,6 +1113,7 @@ def test_lightning_cli_datamodule_short_arguments():
assert cli.parser.groups["data"].group_class is BoringDataModule


@_xfail_python_ge_3_11_9
@pytest.mark.parametrize("use_class_path_callbacks", [False, True])
def test_callbacks_append(use_class_path_callbacks):
"""This test validates registries are used when simplified command line are being used."""
Expand Down Expand Up @@ -1143,6 +1157,7 @@ def test_callbacks_append(use_class_path_callbacks):
assert all(t in callback_types for t in expected)


@_xfail_python_ge_3_11_9
def test_optimizers_and_lr_schedulers_reload(cleandir):
base = ["any.py", "--trainer.max_epochs=1"]
input = base + [
Expand Down Expand Up @@ -1174,6 +1189,7 @@ def test_optimizers_and_lr_schedulers_reload(cleandir):
LightningCLI(BoringModel, run=False)


@_xfail_python_ge_3_11_9
def test_optimizers_and_lr_schedulers_add_arguments_to_parser_implemented_reload(cleandir):
class TestLightningCLI(LightningCLI):
def __init__(self, *args):
Expand Down Expand Up @@ -1427,6 +1443,7 @@ def test_cli_help_message():
assert "Implements Adam" in shorthand_help.getvalue()


@_xfail_python_ge_3_11_9
def test_cli_reducelronplateau():
with mock.patch(
"sys.argv", ["any.py", "--optimizer=Adam", "--lr_scheduler=ReduceLROnPlateau", "--lr_scheduler.monitor=foo"]
Expand All @@ -1437,6 +1454,7 @@ def test_cli_reducelronplateau():
assert config["lr_scheduler"]["scheduler"].monitor == "foo"


@_xfail_python_ge_3_11_9
def test_cli_configureoptimizers_can_be_overridden():
class MyCLI(LightningCLI):
def __init__(self):
Expand Down Expand Up @@ -1481,6 +1499,7 @@ def __init__(self, activation: torch.nn.Module = lazy_instance(torch.nn.LeakyReL
assert cli.model.activation is not model.activation


@_xfail_python_ge_3_11_9
def test_ddpstrategy_instantiation_and_find_unused_parameters(mps_count_0):
strategy_default = lazy_instance(DDPStrategy, find_unused_parameters=True)
with mock.patch("sys.argv", ["any.py", "--trainer.strategy.process_group_backend=group"]):
Expand All @@ -1496,6 +1515,7 @@ def test_ddpstrategy_instantiation_and_find_unused_parameters(mps_count_0):
assert strategy_default is not cli.config_init.trainer.strategy


@_xfail_python_ge_3_11_9
def test_cli_logger_shorthand():
with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(TestModel, run=False, trainer_defaults={"logger": False})
Expand Down Expand Up @@ -1526,6 +1546,7 @@ def _test_logger_init_args(logger_name, init, unresolved=None):
assert data["dict_kwargs"] == unresolved


@_xfail_python_ge_3_11_9
def test_comet_logger_init_args():
_test_logger_init_args(
"CometLogger",
Expand All @@ -1541,6 +1562,7 @@ def test_comet_logger_init_args():
strict=False,
reason="TypeError on Windows when parsing",
)
@_xfail_python_ge_3_11_9
def test_neptune_logger_init_args():
_test_logger_init_args(
"NeptuneLogger",
Expand All @@ -1549,6 +1571,7 @@ def test_neptune_logger_init_args():
)


@_xfail_python_ge_3_11_9
def test_tensorboard_logger_init_args():
_test_logger_init_args(
"TensorBoardLogger",
Expand All @@ -1560,6 +1583,7 @@ def test_tensorboard_logger_init_args():
)


@_xfail_python_ge_3_11_9
def test_wandb_logger_init_args():
_test_logger_init_args(
"WandbLogger",
Expand Down Expand Up @@ -1644,6 +1668,7 @@ def __init__(self, a_func: Callable = torch.nn.Softmax):
assert "a_func: torch.nn.Softmax" in out.getvalue()


@_xfail_python_ge_3_11_9
def test_pytorch_profiler_init_args():
from lightning.pytorch.profilers import Profiler, PyTorchProfiler

Expand Down

0 comments on commit 316cc71

Please sign in to comment.