diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index ea0e8f76461da..ba8bf05f49bb3 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -136,6 +136,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue to keep downscaling the batch size in case there hasn't been even a single successful optimal batch size with `mode="power"` ([#14372](https://github.com/Lightning-AI/lightning/pull/14372)) +- Fixed `WandbLogger` `save_dir` is not set after creation ([#14326](https://github.com/Lightning-AI/lightning/pull/14326)) + + ## [1.7.4] - 2022-08-31 @@ -162,9 +165,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue to avoid the impact of sanity check on `reload_dataloaders_every_n_epochs` for validation ([#13964](https://github.com/Lightning-AI/lightning/pull/13964)) -- Fixed `Trainer.estimated_stepping_batches` when maximum number of epochs is not set ([#14317](https://github.com/Lightning-AI/lightning/pull/14317)) - - ## [1.7.2] - 2022-08-17 ### Added diff --git a/src/pytorch_lightning/loggers/wandb.py b/src/pytorch_lightning/loggers/wandb.py index baf4bc9092774..3198e46b1a586 100644 --- a/src/pytorch_lightning/loggers/wandb.py +++ b/src/pytorch_lightning/loggers/wandb.py @@ -223,7 +223,7 @@ def __init__(self, *args, **kwarg): Args: name: Display name for the run. - save_dir: Path where data is saved (wandb dir by default). + save_dir: Path where data is saved. offline: Run offline (data can be streamed later to wandb servers). id: Sets the version, mainly used to resume a previous run. version: Same as id. @@ -255,7 +255,7 @@ def __init__(self, *args, **kwarg): def __init__( self, name: Optional[str] = None, - save_dir: Optional[str] = None, + save_dir: str = ".", offline: bool = False, id: Optional[str] = None, anonymous: Optional[bool] = None, @@ -300,7 +300,7 @@ def __init__( name=name, project=project, id=version or id, - dir=save_dir, + dir=save_dir or kwargs.pop("dir"), resume="allow", anonymous=("allow" if anonymous else None), ) diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index 648e1a8f38ec8..b408046c9e5d2 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -58,9 +58,15 @@ def test_wandb_logger_init(wandb, monkeypatch): wandb.init.reset_mock() WandbLogger(project="test_project").experiment wandb.init.assert_called_once_with( - name=None, dir=None, id=None, project="test_project", resume="allow", anonymous=None + name=None, dir=".", id=None, project="test_project", resume="allow", anonymous=None ) + # test wandb.init set save_dir correctly after created + wandb.run = None + wandb.init.reset_mock() + logger = WandbLogger(project="test_project") + assert logger.save_dir is not None + # test wandb.init and setting logger experiment externally wandb.run = None run = wandb.init()