Skip to content

Commit

Permalink
Add LightningCLI(save_config_overwrite=False|True) (#8059)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Jun 21, 2021
1 parent d1efae2 commit d9bf975
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 6 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `save_config_filename` init argument to `LightningCLI` to ease resolving name conflicts ([#7741](https://github.com/PyTorchLightning/pytorch-lightning/pull/7741))


- Added `save_config_overwrite` init argument to `LightningCLI` to ease overwriting existing config files ([#8059](https://github.com/PyTorchLightning/pytorch-lightning/pull/8059))


- Added reset dataloader hooks to Training Plugins and Accelerators ([#7861](https://github.com/PyTorchLightning/pytorch-lightning/pull/7861))


Expand Down
22 changes: 16 additions & 6 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,20 +89,24 @@ def __init__(
parser: LightningArgumentParser,
config: Union[Namespace, Dict[str, Any]],
config_filename: str,
overwrite: bool = False,
) -> None:
self.parser = parser
self.config = config
self.config_filename = config_filename
self.overwrite = overwrite

def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
log_dir = trainer.log_dir or trainer.default_root_dir
config_path = os.path.join(log_dir, self.config_filename)
if os.path.isfile(config_path):
if not self.overwrite and os.path.isfile(config_path):
raise RuntimeError(
f'{self.__class__.__name__} expected {config_path} to not exist. '
'Aborting to avoid overwriting results of a previous run.'
f'{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting'
' results of a previous run. You can delete the previous config file,'
' set `LightningCLI(save_config_callback=None)` to disable config saving,'
' or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file.'
)
self.parser.save(self.config, config_path, skip_none=False)
self.parser.save(self.config, config_path, skip_none=False, overwrite=self.overwrite)


class LightningCLI:
Expand All @@ -112,8 +116,9 @@ def __init__(
self,
model_class: Type[LightningModule],
datamodule_class: Type[LightningDataModule] = None,
save_config_callback: Type[SaveConfigCallback] = SaveConfigCallback,
save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback,
save_config_filename: str = 'config.yaml',
save_config_overwrite: bool = False,
trainer_class: Type[Trainer] = Trainer,
trainer_defaults: Dict[str, Any] = None,
seed_everything_default: int = None,
Expand Down Expand Up @@ -150,6 +155,8 @@ def __init__(
model_class: :class:`~pytorch_lightning.core.lightning.LightningModule` class to train on.
datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class.
save_config_callback: A callback class to save the training config.
save_config_filename: Filename for the config file.
save_config_overwrite: Whether to overwrite an existing config file.
trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class.
trainer_defaults: Set to override Trainer defaults or add persistent callbacks.
seed_everything_default: Default value for the :func:`~pytorch_lightning.utilities.seed.seed_everything`
Expand All @@ -173,6 +180,7 @@ def __init__(
self.datamodule_class = datamodule_class
self.save_config_callback = save_config_callback
self.save_config_filename = save_config_filename
self.save_config_overwrite = save_config_overwrite
self.trainer_class = trainer_class
self.trainer_defaults = {} if trainer_defaults is None else trainer_defaults
self.seed_everything_default = seed_everything_default
Expand Down Expand Up @@ -246,7 +254,9 @@ def instantiate_trainer(self) -> None:
else:
self.config_init['trainer']['callbacks'].append(self.trainer_defaults['callbacks'])
if self.save_config_callback and not self.config_init['trainer']['fast_dev_run']:
config_callback = self.save_config_callback(self.parser, self.config, self.save_config_filename)
config_callback = self.save_config_callback(
self.parser, self.config, self.save_config_filename, overwrite=self.save_config_overwrite
)
self.config_init['trainer']['callbacks'].append(config_callback)
self.trainer = self.trainer_class(**self.config_init['trainer'])

Expand Down
11 changes: 11 additions & 0 deletions tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,3 +603,14 @@ def add_arguments_to_parser(self, parser):

assert cli.model.batch_size == 8
assert cli.model.num_classes == 5


def test_cli_config_overwrite(tmpdir):
trainer_defaults = {'default_root_dir': str(tmpdir), 'logger': False, 'max_steps': 1, 'max_epochs': 1}

with mock.patch('sys.argv', ['any.py']):
LightningCLI(BoringModel, trainer_defaults=trainer_defaults)
with mock.patch('sys.argv', ['any.py']), pytest.raises(RuntimeError, match='Aborting to avoid overwriting'):
LightningCLI(BoringModel, trainer_defaults=trainer_defaults)
with mock.patch('sys.argv', ['any.py']):
LightningCLI(BoringModel, save_config_overwrite=True, trainer_defaults=trainer_defaults)

0 comments on commit d9bf975

Please sign in to comment.