diff --git a/pytorch_lightning/callbacks/pt_callbacks.py b/pytorch_lightning/callbacks/pt_callbacks.py index 2c1553d493974..388fca69490e7 100644 --- a/pytorch_lightning/callbacks/pt_callbacks.py +++ b/pytorch_lightning/callbacks/pt_callbacks.py @@ -179,6 +179,16 @@ def __init__(self, filepath, monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=False, mode='auto', period=1, prefix=''): super(ModelCheckpoint, self).__init__() + if ( + save_best_only and + os.path.isdir(filepath) and + len(os.listdir(filepath)) > 0 + ): + warnings.warn( + f"Checkpoint directory {filepath} exists and is not empty with save_best_only=True." + "All files in this directory will be deleted when a checkpoint is saved!" + ) + self.monitor = monitor self.verbose = verbose self.filepath = filepath diff --git a/pytorch_lightning/logging/base.py b/pytorch_lightning/logging/base.py index 351b945e856f7..a72e072be6cc9 100644 --- a/pytorch_lightning/logging/base.py +++ b/pytorch_lightning/logging/base.py @@ -65,7 +65,12 @@ def rank(self, value): """Set the process rank""" self._rank = value + @property + def name(self): + """Return the experiment name""" + raise NotImplementedError("Sub-classes must provide a name property") + @property def version(self): """Return the experiment version""" - return None + raise NotImplementedError("Sub-classes must provide a version property") diff --git a/pytorch_lightning/logging/mlflow_logger.py b/pytorch_lightning/logging/mlflow_logger.py index 6c0d460a7b185..e9b2abd5bcea9 100644 --- a/pytorch_lightning/logging/mlflow_logger.py +++ b/pytorch_lightning/logging/mlflow_logger.py @@ -60,3 +60,11 @@ def finalize(self, status="FINISHED"): if status == 'success': status = 'FINISHED' self.experiment.set_terminated(self.run_id, status) + + @property + def name(self): + return self.experiment_name + + @property + def version(self): + return self._run_id diff --git a/pytorch_lightning/logging/test_tube_logger.py b/pytorch_lightning/logging/test_tube_logger.py index 0d74307bded30..23e8806f0fc30 100644 --- a/pytorch_lightning/logging/test_tube_logger.py +++ b/pytorch_lightning/logging/test_tube_logger.py @@ -15,7 +15,7 @@ def __init__( ): super().__init__() self.save_dir = save_dir - self.name = name + self._name = name self.description = description self.debug = debug self._version = version @@ -29,7 +29,7 @@ def experiment(self): self._experiment = Experiment( save_dir=self.save_dir, - name=self.name, + name=self._name, debug=self.debug, version=self.version, description=self.description, @@ -80,6 +80,13 @@ def rank(self, value): if self._experiment is not None: self.experiment.rank = value + @property + def name(self): + if self._experiment is None: + return self._name + else: + return self.experiment.name + @property def version(self): if self._experiment is None: diff --git a/pytorch_lightning/trainer/callback_config_mixin.py b/pytorch_lightning/trainer/callback_config_mixin.py index 1b36bd469156b..58e6ec05c011e 100644 --- a/pytorch_lightning/trainer/callback_config_mixin.py +++ b/pytorch_lightning/trainer/callback_config_mixin.py @@ -1,3 +1,5 @@ +import os + from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.logging import TestTubeLogger @@ -12,14 +14,15 @@ def configure_checkpoint_callback(self): """ if self.checkpoint_callback is True: # init a default one - if isinstance(self.logger, TestTubeLogger): - ckpt_path = '{}/{}/version_{}/{}'.format( + if self.logger is not None: + ckpt_path = os.path.join( self.default_save_path, - self.logger.experiment.name, - self.logger.experiment.version, - 'checkpoints') + self.logger.name, + f'version_{self.logger.version}', + "checkpoints" + ) else: - ckpt_path = self.default_save_path + ckpt_path = os.path.join(self.default_save_path, "checkpoints") self.checkpoint_callback = ModelCheckpoint( filepath=ckpt_path diff --git a/tests/test_y_logging.py b/tests/test_y_logging.py index 39b7edfc80f6a..c44a9e2c8b987 100644 --- a/tests/test_y_logging.py +++ b/tests/test_y_logging.py @@ -1,3 +1,4 @@ +import os import pickle import numpy as np @@ -5,6 +6,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.testing import LightningTestModel +from pytorch_lightning.logging import LightningLoggerBase, rank_zero_only from . import testing_utils RANDOM_FILE_PATHS = list(np.random.randint(12000, 19000, 1000)) @@ -69,117 +71,119 @@ def test_testtube_pickle(): testing_utils.clear_save_dir() -# def test_mlflow_logger(): -# """ -# verify that basic functionality of mlflow logger works -# """ -# reset_seed() -# -# try: -# from pytorch_lightning.logging import MLFlowLogger -# except ModuleNotFoundError: -# return -# -# hparams = get_hparams() -# model = LightningTestModel(hparams) -# -# root_dir = os.path.dirname(os.path.realpath(__file__)) -# mlflow_dir = os.path.join(root_dir, "mlruns") -# import pdb -# pdb.set_trace() -# -# logger = MLFlowLogger("test", f"file://{mlflow_dir}") -# logger.log_hyperparams(hparams) -# logger.save() -# -# trainer_options = dict( -# max_nb_epochs=1, -# train_percent_check=0.01, -# logger=logger -# ) -# -# trainer = Trainer(**trainer_options) -# result = trainer.fit(model) -# -# print('result finished') -# assert result == 1, "Training failed" -# -# shutil.move(mlflow_dir, mlflow_dir + f'_{n}') - - -# def test_mlflow_pickle(): -# """ -# verify that pickling trainer with mlflow logger works -# """ -# reset_seed() -# -# try: -# from pytorch_lightning.logging import MLFlowLogger -# except ModuleNotFoundError: -# return -# -# hparams = get_hparams() -# model = LightningTestModel(hparams) -# -# root_dir = os.path.dirname(os.path.realpath(__file__)) -# mlflow_dir = os.path.join(root_dir, "mlruns") -# -# logger = MLFlowLogger("test", f"file://{mlflow_dir}") -# logger.log_hyperparams(hparams) -# logger.save() -# -# trainer_options = dict( -# max_nb_epochs=1, -# logger=logger -# ) -# -# trainer = Trainer(**trainer_options) -# pkl_bytes = pickle.dumps(trainer) -# trainer2 = pickle.loads(pkl_bytes) -# trainer2.logger.log_metrics({"acc": 1.0}) -# -# n = RANDOM_FILE_PATHS.pop() -# shutil.move(mlflow_dir, mlflow_dir + f'_{n}') - - -# def test_custom_logger(): -# -# class CustomLogger(LightningLoggerBase): -# def __init__(self): -# super().__init__() -# self.hparams_logged = None -# self.metrics_logged = None -# self.finalized = False -# -# @rank_zero_only -# def log_hyperparams(self, params): -# self.hparams_logged = params -# -# @rank_zero_only -# def log_metrics(self, metrics, step_num): -# self.metrics_logged = metrics -# -# @rank_zero_only -# def finalize(self, status): -# self.finalized_status = status -# -# hparams = get_hparams() -# model = LightningTestModel(hparams) -# -# logger = CustomLogger() -# -# trainer_options = dict( -# max_nb_epochs=1, -# train_percent_check=0.01, -# logger=logger -# ) -# -# trainer = Trainer(**trainer_options) -# result = trainer.fit(model) -# assert result == 1, "Training failed" -# assert logger.hparams_logged == hparams -# assert logger.metrics_logged != {} -# assert logger.finalized_status == "success" +def test_mlflow_logger(): + """ + verify that basic functionality of mlflow logger works + """ + reset_seed() + + try: + from pytorch_lightning.logging import MLFlowLogger + except ModuleNotFoundError: + return + + hparams = testing_utils.get_hparams() + model = LightningTestModel(hparams) + + root_dir = os.path.dirname(os.path.realpath(__file__)) + mlflow_dir = os.path.join(root_dir, "mlruns") + + logger = MLFlowLogger("test", f"file://{mlflow_dir}") + + trainer_options = dict( + max_nb_epochs=1, + train_percent_check=0.01, + logger=logger + ) + + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + print('result finished') + assert result == 1, "Training failed" + + testing_utils.clear_save_dir() + + +def test_mlflow_pickle(): + """ + verify that pickling trainer with mlflow logger works + """ + reset_seed() + + try: + from pytorch_lightning.logging import MLFlowLogger + except ModuleNotFoundError: + return + + hparams = testing_utils.get_hparams() + model = LightningTestModel(hparams) + + root_dir = os.path.dirname(os.path.realpath(__file__)) + mlflow_dir = os.path.join(root_dir, "mlruns") + + logger = MLFlowLogger("test", f"file://{mlflow_dir}") + + trainer_options = dict( + max_nb_epochs=1, + logger=logger + ) + + trainer = Trainer(**trainer_options) + pkl_bytes = pickle.dumps(trainer) + trainer2 = pickle.loads(pkl_bytes) + trainer2.logger.log_metrics({"acc": 1.0}) + + testing_utils.clear_save_dir() + + +def test_custom_logger(tmpdir): + + class CustomLogger(LightningLoggerBase): + def __init__(self): + super().__init__() + self.hparams_logged = None + self.metrics_logged = None + self.finalized = False + + @rank_zero_only + def log_hyperparams(self, params): + self.hparams_logged = params + + @rank_zero_only + def log_metrics(self, metrics, step_num): + self.metrics_logged = metrics + + @rank_zero_only + def finalize(self, status): + self.finalized_status = status + + @property + def name(self): + return "name" + + @property + def version(self): + return "1" + + hparams = testing_utils.get_hparams() + model = LightningTestModel(hparams) + + logger = CustomLogger() + + trainer_options = dict( + max_nb_epochs=1, + train_percent_check=0.01, + logger=logger, + default_save_path=tmpdir + ) + + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + assert result == 1, "Training failed" + assert logger.hparams_logged == hparams + assert logger.metrics_logged != {} + assert logger.finalized_status == "success" def reset_seed():