From f746a2d264a0e2f72ac4c303c0f4fc75a1d4d5d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 15 Sep 2022 12:43:00 +0200 Subject: [PATCH 1/2] Surface Neptune installation problems to the user --- src/pytorch_lightning/loggers/neptune.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/pytorch_lightning/loggers/neptune.py b/src/pytorch_lightning/loggers/neptune.py index 0c1ab35cf58ee..26e225b9a74a1 100644 --- a/src/pytorch_lightning/loggers/neptune.py +++ b/src/pytorch_lightning/loggers/neptune.py @@ -37,10 +37,7 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_only _NEPTUNE_AVAILABLE = RequirementCache("neptune-client") -_NEPTUNE_GREATER_EQUAL_0_9 = RequirementCache("neptune-client>=0.9.0") - - -if _NEPTUNE_AVAILABLE and _NEPTUNE_GREATER_EQUAL_0_9: +if _NEPTUNE_AVAILABLE: try: from neptune import new as neptune from neptune.new.exceptions import NeptuneLegacyProjectException, NeptuneOfflineModeFetchException @@ -272,13 +269,10 @@ def __init__( agg_default_func: Optional[Callable[[Sequence[float]], float]] = None, **neptune_run_kwargs: Any, ): + if not _NEPTUNE_AVAILABLE: + raise ModuleNotFoundError(str(_NEPTUNE_AVAILABLE)) # verify if user passed proper init arguments self._verify_input_arguments(api_key, project, name, run, neptune_run_kwargs) - if neptune is None: - raise ModuleNotFoundError( - "You want to use the `Neptune` logger which is not installed yet, install it with" - " `pip install neptune-client`." - ) super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func) self._log_model_checkpoints = log_model_checkpoints From 1fe642a3199140872d4f66ee2b011385ba69078d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 15 Sep 2022 16:02:26 +0200 Subject: [PATCH 2/2] Fix some mocks, blocked by deprecated API tests --- tests/tests_pytorch/loggers/test_all.py | 5 ++++- tests/tests_pytorch/loggers/test_neptune.py | 8 +++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 279a1aeab7e69..8d79442e68b73 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -43,6 +43,7 @@ mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient"), mock.patch("pytorch_lightning.loggers.neptune.neptune", new_callable=create_neptune_mock), + mock.patch("pytorch_lightning.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True), mock.patch("pytorch_lightning.loggers.wandb.wandb"), mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock), ) @@ -290,7 +291,9 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch): logger.experiment.log_metric.assert_called_once_with(ANY, "tmp-test", 1.0, ANY, 0) # Neptune - with mock.patch("pytorch_lightning.loggers.neptune.neptune"): + with mock.patch("pytorch_lightning.loggers.neptune.neptune"), mock.patch( + "pytorch_lightning.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True + ): logger = _instantiate_logger(NeptuneLogger, api_key="test", project="project", save_dir=tmpdir, prefix=prefix) assert logger.experiment.__getitem__.call_count == 2 logger.log_metrics({"test": 1.0}, step=0) diff --git a/tests/tests_pytorch/loggers/test_neptune.py b/tests/tests_pytorch/loggers/test_neptune.py index de3017a33a472..856d82babec1a 100644 --- a/tests/tests_pytorch/loggers/test_neptune.py +++ b/tests/tests_pytorch/loggers/test_neptune.py @@ -15,6 +15,7 @@ import pickle import unittest from collections import namedtuple +from unittest import mock from unittest.mock import call, MagicMock, patch import pytest @@ -78,6 +79,10 @@ def tmpdir_unittest_fixture(request, tmpdir): @patch("pytorch_lightning.loggers.neptune.neptune", new_callable=create_neptune_mock) class TestNeptuneLogger(unittest.TestCase): + def run(self, *args, **kwargs): + with mock.patch("pytorch_lightning.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True): + super().run(*args, **kwargs) + def test_neptune_online(self, neptune): logger = NeptuneLogger(api_key="test", project="project") created_run_mock = logger.run @@ -354,10 +359,11 @@ def test_legacy_kwargs(self): for legacy_kwarg in legacy_neptune_kwargs: self._assert_legacy_usage(NeptuneLogger, **{legacy_kwarg: None}) + @patch("pytorch_lightning.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True) @patch("pytorch_lightning.loggers.neptune.warnings") @patch("pytorch_lightning.loggers.neptune.NeptuneFile") @patch("pytorch_lightning.loggers.neptune.neptune") - def test_legacy_functions(self, neptune, neptune_file_mock, warnings_mock): + def test_legacy_functions(self, _, neptune, neptune_file_mock, warnings_mock): logger = NeptuneLogger(api_key="test", project="project") # test deprecated functions which will be shut down in pytorch-lightning 1.7.0