Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LightningCLI(save_config_overwrite=False|True) #8059

Merged
merged 4 commits into from Jun 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -71,6 +71,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `save_config_filename` init argument to `LightningCLI` to ease resolving name conflicts ([#7741](https://github.com/PyTorchLightning/pytorch-lightning/pull/7741))


- Added `save_config_overwrite` init argument to `LightningCLI` to ease overwriting existing config files ([#8059](https://github.com/PyTorchLightning/pytorch-lightning/pull/8059))


- Added reset dataloader hooks to Training Plugins and Accelerators ([#7861](https://github.com/PyTorchLightning/pytorch-lightning/pull/7861))


Expand Down
22 changes: 16 additions & 6 deletions pytorch_lightning/utilities/cli.py
Expand Up @@ -89,20 +89,24 @@ def __init__(
parser: LightningArgumentParser,
config: Union[Namespace, Dict[str, Any]],
config_filename: str,
overwrite: bool = False,
) -> None:
self.parser = parser
self.config = config
self.config_filename = config_filename
self.overwrite = overwrite
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
log_dir = trainer.log_dir or trainer.default_root_dir
config_path = os.path.join(log_dir, self.config_filename)
if os.path.isfile(config_path):
if not self.overwrite and os.path.isfile(config_path):
raise RuntimeError(
f'{self.__class__.__name__} expected {config_path} to not exist. '
'Aborting to avoid overwriting results of a previous run.'
f'{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting'
' results of a previous run. You can delete the previous config file,'
' set `LightningCLI(save_config_callback=None)` to disable config saving,'
' or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file.'
)
self.parser.save(self.config, config_path, skip_none=False)
self.parser.save(self.config, config_path, skip_none=False, overwrite=self.overwrite)


class LightningCLI:
Expand All @@ -112,8 +116,9 @@ def __init__(
self,
model_class: Type[LightningModule],
datamodule_class: Type[LightningDataModule] = None,
save_config_callback: Type[SaveConfigCallback] = SaveConfigCallback,
save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback,
save_config_filename: str = 'config.yaml',
save_config_overwrite: bool = False,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
trainer_class: Type[Trainer] = Trainer,
trainer_defaults: Dict[str, Any] = None,
seed_everything_default: int = None,
Expand Down Expand Up @@ -150,6 +155,8 @@ def __init__(
model_class: :class:`~pytorch_lightning.core.lightning.LightningModule` class to train on.
datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class.
save_config_callback: A callback class to save the training config.
save_config_filename: Filename for the config file.
save_config_overwrite: Whether to overwrite an existing config file.
trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class.
trainer_defaults: Set to override Trainer defaults or add persistent callbacks.
seed_everything_default: Default value for the :func:`~pytorch_lightning.utilities.seed.seed_everything`
Expand All @@ -173,6 +180,7 @@ def __init__(
self.datamodule_class = datamodule_class
self.save_config_callback = save_config_callback
self.save_config_filename = save_config_filename
self.save_config_overwrite = save_config_overwrite
self.trainer_class = trainer_class
self.trainer_defaults = {} if trainer_defaults is None else trainer_defaults
self.seed_everything_default = seed_everything_default
Expand Down Expand Up @@ -246,7 +254,9 @@ def instantiate_trainer(self) -> None:
else:
self.config_init['trainer']['callbacks'].append(self.trainer_defaults['callbacks'])
if self.save_config_callback and not self.config_init['trainer']['fast_dev_run']:
config_callback = self.save_config_callback(self.parser, self.config, self.save_config_filename)
config_callback = self.save_config_callback(
self.parser, self.config, self.save_config_filename, overwrite=self.save_config_overwrite
)
self.config_init['trainer']['callbacks'].append(config_callback)
self.trainer = self.trainer_class(**self.config_init['trainer'])

Expand Down
11 changes: 11 additions & 0 deletions tests/utilities/test_cli.py
Expand Up @@ -603,3 +603,14 @@ def add_arguments_to_parser(self, parser):

assert cli.model.batch_size == 8
assert cli.model.num_classes == 5


def test_cli_config_overwrite(tmpdir):
trainer_defaults = {'default_root_dir': str(tmpdir), 'logger': False, 'max_steps': 1, 'max_epochs': 1}

with mock.patch('sys.argv', ['any.py']):
LightningCLI(BoringModel, trainer_defaults=trainer_defaults)
with mock.patch('sys.argv', ['any.py']), pytest.raises(RuntimeError, match='Aborting to avoid overwriting'):
LightningCLI(BoringModel, trainer_defaults=trainer_defaults)
with mock.patch('sys.argv', ['any.py']):
LightningCLI(BoringModel, save_config_overwrite=True, trainer_defaults=trainer_defaults)