Skip to content

Commit 64b51e1

Browse files
carmoccarohitgr7
authored andcommitted
Surface Neptune installation problems to the user (#14715)
1 parent 24c9de8 commit 64b51e1

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

src/pytorch_lightning/loggers/neptune.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,8 @@ def __init__(
231231
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
232232
**neptune_run_kwargs: Any,
233233
):
234-
if neptune is None:
235-
raise ModuleNotFoundError(
236-
"You want to use the `Neptune` logger which is not installed yet, install it with"
237-
" `pip install neptune-client`."
238-
)
234+
if not _NEPTUNE_AVAILABLE:
235+
raise ModuleNotFoundError(str(_NEPTUNE_AVAILABLE))
239236
# verify if user passed proper init arguments
240237
self._verify_input_arguments(api_key, project, name, run, neptune_run_kwargs)
241238
super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func)

tests/tests_pytorch/loggers/test_all.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
mock.patch("pytorch_lightning.loggers.mlflow.mlflow"),
4444
mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient"),
4545
mock.patch("pytorch_lightning.loggers.neptune.neptune", new_callable=create_neptune_mock),
46+
mock.patch("pytorch_lightning.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True),
4647
mock.patch("pytorch_lightning.loggers.wandb.wandb"),
4748
mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock),
4849
)
@@ -290,7 +291,9 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
290291
logger.experiment.log_metric.assert_called_once_with(ANY, "tmp-test", 1.0, ANY, 0)
291292

292293
# Neptune
293-
with mock.patch("pytorch_lightning.loggers.neptune.neptune"):
294+
with mock.patch("pytorch_lightning.loggers.neptune.neptune"), mock.patch(
295+
"pytorch_lightning.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True
296+
):
294297
logger = _instantiate_logger(NeptuneLogger, api_key="test", project="project", save_dir=tmpdir, prefix=prefix)
295298
assert logger.experiment.__getitem__.call_count == 2
296299
logger.log_metrics({"test": 1.0}, step=0)

tests/tests_pytorch/loggers/test_neptune.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pickle
1616
import unittest
1717
from collections import namedtuple
18+
from unittest import mock
1819
from unittest.mock import call, MagicMock, patch
1920

2021
import pytest
@@ -78,6 +79,10 @@ def tmpdir_unittest_fixture(request, tmpdir):
7879

7980
@patch("pytorch_lightning.loggers.neptune.neptune", new_callable=create_neptune_mock)
8081
class TestNeptuneLogger(unittest.TestCase):
82+
def run(self, *args, **kwargs):
83+
with mock.patch("pytorch_lightning.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True):
84+
super().run(*args, **kwargs)
85+
8186
def test_neptune_online(self, neptune):
8287
logger = NeptuneLogger(api_key="test", project="project")
8388
created_run_mock = logger.run

0 commit comments

Comments
 (0)