Skip to content

Commit

Permalink
bump: transmission to use neptune only, drop neptune-client (#19265)
Browse files Browse the repository at this point in the history
* bump: min version `neptune>=1.0.0`
* Apply suggestions from code review

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
3 people committed Feb 16, 2024
1 parent 5998dd1 commit 6497e36
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 34 deletions.
2 changes: 1 addition & 1 deletion requirements/pytorch/loggers.info
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# all supported loggers. this list is here as a reference, but they are not installed in CI

neptune
neptune >=1.0.0
comet-ml >=3.31.0
mlflow >=1.0.0
wandb >=0.12.10
Expand Down
49 changes: 16 additions & 33 deletions src/lightning/pytorch/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,12 @@

log = logging.getLogger(__name__)

# neptune is available with two names on PyPI : `neptune` and `neptune-client`
# Neptune is available with two names on PyPI : `neptune` and `neptune-client`
# `neptune` was introduced as a name transition of neptune-client and the long-term target is to get
# rid of Neptune-client package completely someday. It was introduced as a part of breaking-changes with a release
# of neptune-client==1.0. neptune-client>=1.0 is just an alias of neptune package and have some breaking-changes
# in compare to neptune-client<1.0.0.
_NEPTUNE_AVAILABLE = RequirementCache("neptune>=1.0")
_NEPTUNE_CLIENT_AVAILABLE = RequirementCache("neptune-client")
_INTEGRATION_VERSION_KEY = "source_code/integrations/pytorch-lightning"


Expand Down Expand Up @@ -224,8 +227,9 @@ def __init__(
prefix: str = "training",
**neptune_run_kwargs: Any,
):
if not _NEPTUNE_AVAILABLE and not _NEPTUNE_CLIENT_AVAILABLE:
if not _NEPTUNE_AVAILABLE:
raise ModuleNotFoundError(str(_NEPTUNE_AVAILABLE))

# verify if user passed proper init arguments
self._verify_input_arguments(api_key, project, name, run, neptune_run_kwargs)
super().__init__()
Expand Down Expand Up @@ -254,10 +258,7 @@ def __init__(
root_obj[_INTEGRATION_VERSION_KEY] = pl.__version__

def _retrieve_run_data(self) -> None:
if _NEPTUNE_AVAILABLE:
from neptune.handler import Handler
else:
from neptune.new.handler import Handler
from neptune.handler import Handler

assert self._run_instance is not None
root_obj = self._run_instance
Expand Down Expand Up @@ -310,12 +311,9 @@ def _verify_input_arguments(
run: Optional[Union["Run", "Handler"]],
neptune_run_kwargs: dict,
) -> None:
if _NEPTUNE_AVAILABLE:
from neptune import Run
from neptune.handler import Handler
else:
from neptune.new import Run
from neptune.new.handler import Handler
from neptune import Run
from neptune.handler import Handler

# check if user passed the client `Run`/`Handler` object
if run is not None and not isinstance(run, (Run, Handler)):
raise ValueError("Run parameter expected to be of type `neptune.Run`, or `neptune.handler.Handler`.")
Expand All @@ -335,10 +333,7 @@ def __getstate__(self) -> Dict[str, Any]:
return state

def __setstate__(self, state: Dict[str, Any]) -> None:
if _NEPTUNE_AVAILABLE:
import neptune
else:
import neptune.new as neptune
import neptune

self.__dict__ = state
self._run_instance = neptune.init_run(**self._neptune_init_args)
Expand Down Expand Up @@ -376,10 +371,7 @@ def training_step(self, batch, batch_idx):
@property
@rank_zero_experiment
def run(self) -> "Run":
if _NEPTUNE_AVAILABLE:
import neptune
else:
import neptune.new as neptune
import neptune

if not self._run_instance:
self._run_instance = neptune.init_run(**self._neptune_init_args)
Expand Down Expand Up @@ -426,10 +418,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
neptune_logger.log_hyperparams(PARAMS)
"""
if _NEPTUNE_AVAILABLE:
from neptune.utils import stringify_unsupported
else:
from neptune.new.utils import stringify_unsupported
from neptune.utils import stringify_unsupported

params = _convert_params(params)
params = _sanitize_callable_params(params)
Expand Down Expand Up @@ -485,10 +474,7 @@ def save_dir(self) -> Optional[str]:

@rank_zero_only
def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1) -> None:
if _NEPTUNE_AVAILABLE:
from neptune.types import File
else:
from neptune.new.types import File
from neptune.types import File

model_str = str(ModelSummary(model=model, max_depth=max_depth))
self.run[self._construct_path_with_prefix("model/summary")] = File.from_content(
Expand All @@ -507,10 +493,7 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
if not self._log_model_checkpoints:
return

if _NEPTUNE_AVAILABLE:
from neptune.types import File
else:
from neptune.new.types import File
from neptune.types import File

file_names = set()
checkpoints_namespace = self._construct_path_with_prefix("model/checkpoints")
Expand Down

0 comments on commit 6497e36

Please sign in to comment.