diff --git a/environment.yml b/environment.yml index d6f885f00bc72..f26e93031770e 100644 --- a/environment.yml +++ b/environment.yml @@ -50,5 +50,5 @@ dependencies: - test-tube>=0.7.5 - mlflow>=1.0.0 - comet_ml>=3.1.12 - - wandb>=0.8.21 + - wandb>=0.10.22 - neptune-client>=0.10.0 diff --git a/pyproject.toml b/pyproject.toml index 50b67dec5a04b..eb9b025e36811 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,6 @@ module = [ "pytorch_lightning.demos.mnist_datamodule", "pytorch_lightning.loggers.comet", "pytorch_lightning.loggers.neptune", - "pytorch_lightning.loggers.wandb", "pytorch_lightning.profilers.advanced", "pytorch_lightning.profilers.base", "pytorch_lightning.profilers.pytorch", diff --git a/requirements/pytorch/extra.txt b/requirements/pytorch/extra.txt index 7d13068d814a8..440a80594d5d9 100644 --- a/requirements/pytorch/extra.txt +++ b/requirements/pytorch/extra.txt @@ -3,6 +3,6 @@ matplotlib>3.1, <3.5.3 torchtext>=0.10.*, <=0.12.0 omegaconf>=2.0.5, <2.3.0 hydra-core>=1.0.5, <1.3.0 -jsonargparse[signatures]>=4.10.0, <=4.10.0 +jsonargparse[signatures]>=4.10.2, <=4.10.2 gcsfs>=2021.5.0, <2022.6.0 rich>=10.14.0, !=10.15.0.a, <13.0.0 diff --git a/requirements/pytorch/loggers.txt b/requirements/pytorch/loggers.txt index 2abcb4b2df31f..a857ab5660d54 100644 --- a/requirements/pytorch/loggers.txt +++ b/requirements/pytorch/loggers.txt @@ -4,4 +4,4 @@ neptune-client>=0.10.0, <0.16.4 comet-ml>=3.1.12, <3.31.6 mlflow>=1.0.0, <1.27.0 test_tube>=0.7.5, <=0.7.5 -wandb>=0.8.21, <0.12.20 +wandb>=0.10.22, <0.12.20 diff --git a/src/pytorch_lightning/loggers/wandb.py b/src/pytorch_lightning/loggers/wandb.py index 88439cd9435db..53fbd2b1097f8 100644 --- a/src/pytorch_lightning/loggers/wandb.py +++ b/src/pytorch_lightning/loggers/wandb.py @@ -32,10 +32,11 @@ try: import wandb + from wandb.sdk.lib import RunDisabled from wandb.wandb_run import Run except ModuleNotFoundError: # needed for test mocks, these tests shall be updated - wandb, Run = None, None + wandb, Run, RunDisabled = None, None, None # type: ignore class WandbLogger(Logger): @@ -251,18 +252,18 @@ def __init__( self, name: Optional[str] = None, save_dir: Optional[str] = None, - offline: Optional[bool] = False, + offline: bool = False, id: Optional[str] = None, anonymous: Optional[bool] = None, version: Optional[str] = None, project: Optional[str] = None, log_model: Union[str, bool] = False, - experiment=None, - prefix: Optional[str] = "", + experiment: Union[Run, RunDisabled, None] = None, + prefix: str = "", agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, agg_default_func: Optional[Callable[[Sequence[float]], float]] = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: if wandb is None: raise ModuleNotFoundError( "You want to use `wandb` logger which is not installed yet," @@ -288,17 +289,16 @@ def __init__( self._log_model = log_model self._prefix = prefix self._experiment = experiment - self._logged_model_time = {} - self._checkpoint_callback = None + self._logged_model_time: Dict[str, float] = {} + self._checkpoint_callback: Optional["ReferenceType[Checkpoint]"] = None # set wandb init arguments - anonymous_lut = {True: "allow", False: None} - self._wandb_init = dict( + self._wandb_init: Dict[str, Any] = dict( name=name or project, project=project, id=version or id, dir=save_dir, resume="allow", - anonymous=anonymous_lut.get(anonymous, anonymous), + anonymous=("allow" if anonymous else None), ) self._wandb_init.update(**kwargs) # extract parameters @@ -310,7 +310,7 @@ def __init__( wandb.require("service") _ = self.experiment - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() # args needed to reload correct experiment if self._experiment is not None: @@ -322,7 +322,7 @@ def __getstate__(self): state["_experiment"] = None return state - @property + @property # type: ignore[misc] @rank_zero_experiment def experiment(self) -> Run: r""" @@ -357,13 +357,14 @@ def experiment(self) -> Run: self._experiment = wandb.init(**self._wandb_init) # define default x-axis - if getattr(self._experiment, "define_metric", None): + if isinstance(self._experiment, Run) and getattr(self._experiment, "define_metric", None): self._experiment.define_metric("trainer/global_step") self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True) + assert isinstance(self._experiment, Run) return self._experiment - def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100, log_graph: bool = True): + def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100, log_graph: bool = True) -> None: self.experiment.watch(model, log=log, log_freq=log_freq, log_graph=log_graph) @rank_zero_only @@ -379,7 +380,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) if step is not None: - self.experiment.log({**metrics, "trainer/global_step": step}) + self.experiment.log(dict(metrics, **{"trainer/global_step": step})) else: self.experiment.log(metrics) @@ -417,7 +418,7 @@ def log_text( self.log_table(key, columns, data, dataframe, step) @rank_zero_only - def log_image(self, key: str, images: List[Any], step: Optional[int] = None, **kwargs: str) -> None: + def log_image(self, key: str, images: List[Any], step: Optional[int] = None, **kwargs: Any) -> None: """Log images (tensors, numpy arrays, PIL Images or file paths). Optional kwargs are lists passed to each image (ex: caption, masks, boxes). diff --git a/src/pytorch_lightning/utilities/cli.py b/src/pytorch_lightning/utilities/cli.py index f9d3375a6c6d8..a66cd6c0899cd 100644 --- a/src/pytorch_lightning/utilities/cli.py +++ b/src/pytorch_lightning/utilities/cli.py @@ -31,7 +31,7 @@ from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_deprecation, rank_zero_warn -_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.10.0") +_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.10.2") if _JSONARGPARSE_SIGNATURES_AVAILABLE: import docstring_parser diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 96d1016cc612b..d613296abccf5 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -45,6 +45,7 @@ mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient"), mock.patch("pytorch_lightning.loggers.neptune.neptune", new_callable=create_neptune_mock), mock.patch("pytorch_lightning.loggers.wandb.wandb"), + mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock), ) ALL_LOGGER_CLASSES = ( CometLogger, @@ -363,7 +364,9 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch): logger.experiment.add_scalar.assert_called_once_with("tmp-test", 1.0, 0) # WandB - with mock.patch("pytorch_lightning.loggers.wandb.wandb") as wandb: + with mock.patch("pytorch_lightning.loggers.wandb.wandb") as wandb, mock.patch( + "pytorch_lightning.loggers.wandb.Run", new=mock.Mock + ): logger = _instantiate_logger(WandbLogger, save_dir=tmpdir, prefix=prefix) wandb.run = None wandb.init().step = 0 diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index f62ebff9e719a..48162e6d9d2e2 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -24,6 +24,7 @@ from tests_pytorch.helpers.utils import no_warning_call +@mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock) @mock.patch("pytorch_lightning.loggers.wandb.wandb") def test_wandb_logger_init(wandb, monkeypatch): """Verify that basic functionality of wandb logger works. @@ -111,20 +112,21 @@ class Experiment: def name(self): return "the_run_name" - wandb.run = None - wandb.init.return_value = Experiment() - logger = WandbLogger(id="the_id", offline=True) + with mock.patch("pytorch_lightning.loggers.wandb.Run", new=Experiment): + wandb.run = None + wandb.init.return_value = Experiment() + logger = WandbLogger(id="the_id", offline=True) - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, logger=logger) - # Access the experiment to ensure it's created - assert trainer.logger.experiment, "missing experiment" - assert trainer.log_dir == logger.save_dir - pkl_bytes = pickle.dumps(trainer) - trainer2 = pickle.loads(pkl_bytes) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, logger=logger) + # Access the experiment to ensure it's created + assert trainer.logger.experiment, "missing experiment" + assert trainer.log_dir == logger.save_dir + pkl_bytes = pickle.dumps(trainer) + trainer2 = pickle.loads(pkl_bytes) - assert os.environ["WANDB_MODE"] == "dryrun" - assert trainer2.logger.__class__.__name__ == WandbLogger.__name__ - assert trainer2.logger.experiment, "missing experiment" + assert os.environ["WANDB_MODE"] == "dryrun" + assert trainer2.logger.__class__.__name__ == WandbLogger.__name__ + assert trainer2.logger.experiment, "missing experiment" wandb.init.assert_called() assert "id" in wandb.init.call_args[1] @@ -133,6 +135,7 @@ def name(self): del os.environ["WANDB_MODE"] +@mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock) @mock.patch("pytorch_lightning.loggers.wandb.wandb") def test_wandb_logger_dirs_creation(wandb, monkeypatch, tmpdir): """Test that the logger creates the folders and files in the right place.""" @@ -169,6 +172,7 @@ def test_wandb_logger_dirs_creation(wandb, monkeypatch, tmpdir): assert trainer.log_dir == logger.save_dir +@mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock) @mock.patch("pytorch_lightning.loggers.wandb.wandb") def test_wandb_log_model(wandb, monkeypatch, tmpdir): """Test that the logger creates the folders and files in the right place.""" @@ -234,6 +238,7 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir): ) +@mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock) @mock.patch("pytorch_lightning.loggers.wandb.wandb") def test_wandb_log_media(wandb, tmpdir): """Test that the logger creates the folders and files in the right place."""