Skip to content
10 changes: 10 additions & 0 deletions pytorch_lightning/callbacks/pt_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,16 @@ def __init__(self, filepath, monitor='val_loss', verbose=0,
save_best_only=True, save_weights_only=False,
mode='auto', period=1, prefix=''):
super(ModelCheckpoint, self).__init__()
if (
save_best_only and
os.path.isdir(filepath) and
len(os.listdir(filepath)) > 0
):
warnings.warn(
f"Checkpoint directory {filepath} exists and is not empty with save_best_only=True."
"All files in this directory will be deleted when a checkpoint is saved!"
)

self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
Expand Down
7 changes: 6 additions & 1 deletion pytorch_lightning/logging/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,12 @@ def rank(self, value):
"""Set the process rank"""
self._rank = value

@property
def name(self):
"""Return the experiment name"""
raise NotImplementedError("Sub-classes must provide a name property")

@property
def version(self):
"""Return the experiment version"""
return None
raise NotImplementedError("Sub-classes must provide a version property")
8 changes: 8 additions & 0 deletions pytorch_lightning/logging/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,11 @@ def finalize(self, status="FINISHED"):
if status == 'success':
status = 'FINISHED'
self.experiment.set_terminated(self.run_id, status)

@property
def name(self):
return self.experiment_name

@property
def version(self):
return self._run_id
11 changes: 9 additions & 2 deletions pytorch_lightning/logging/test_tube_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(
):
super().__init__()
self.save_dir = save_dir
self.name = name
self._name = name
self.description = description
self.debug = debug
self._version = version
Expand All @@ -29,7 +29,7 @@ def experiment(self):

self._experiment = Experiment(
save_dir=self.save_dir,
name=self.name,
name=self._name,
debug=self.debug,
version=self.version,
description=self.description,
Expand Down Expand Up @@ -80,6 +80,13 @@ def rank(self, value):
if self._experiment is not None:
self.experiment.rank = value

@property
def name(self):
if self._experiment is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather:

name = self._name if self._experiment is None else self.experiment.name
return name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO the current version is easier to read.

return self._name
else:
return self.experiment.name

@property
def version(self):
if self._experiment is None:
Expand Down
15 changes: 9 additions & 6 deletions pytorch_lightning/trainer/callback_config_mixin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.logging import TestTubeLogger

Expand All @@ -12,14 +14,15 @@ def configure_checkpoint_callback(self):
"""
if self.checkpoint_callback is True:
# init a default one
if isinstance(self.logger, TestTubeLogger):
ckpt_path = '{}/{}/version_{}/{}'.format(
if self.logger is not None:
ckpt_path = os.path.join(
self.default_save_path,
self.logger.experiment.name,
self.logger.experiment.version,
'checkpoints')
self.logger.name,
f'version_{self.logger.version}',
"checkpoints"
)
else:
ckpt_path = self.default_save_path
ckpt_path = os.path.join(self.default_save_path, "checkpoints")

self.checkpoint_callback = ModelCheckpoint(
filepath=ckpt_path
Expand Down
226 changes: 115 additions & 111 deletions tests/test_y_logging.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import pickle

import numpy as np
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.testing import LightningTestModel
from pytorch_lightning.logging import LightningLoggerBase, rank_zero_only
from . import testing_utils

RANDOM_FILE_PATHS = list(np.random.randint(12000, 19000, 1000))
Expand Down Expand Up @@ -69,117 +71,119 @@ def test_testtube_pickle():
testing_utils.clear_save_dir()


# def test_mlflow_logger():
# """
# verify that basic functionality of mlflow logger works
# """
# reset_seed()
#
# try:
# from pytorch_lightning.logging import MLFlowLogger
# except ModuleNotFoundError:
# return
#
# hparams = get_hparams()
# model = LightningTestModel(hparams)
#
# root_dir = os.path.dirname(os.path.realpath(__file__))
# mlflow_dir = os.path.join(root_dir, "mlruns")
# import pdb
# pdb.set_trace()
#
# logger = MLFlowLogger("test", f"file://{mlflow_dir}")
# logger.log_hyperparams(hparams)
# logger.save()
#
# trainer_options = dict(
# max_nb_epochs=1,
# train_percent_check=0.01,
# logger=logger
# )
#
# trainer = Trainer(**trainer_options)
# result = trainer.fit(model)
#
# print('result finished')
# assert result == 1, "Training failed"
#
# shutil.move(mlflow_dir, mlflow_dir + f'_{n}')


# def test_mlflow_pickle():
# """
# verify that pickling trainer with mlflow logger works
# """
# reset_seed()
#
# try:
# from pytorch_lightning.logging import MLFlowLogger
# except ModuleNotFoundError:
# return
#
# hparams = get_hparams()
# model = LightningTestModel(hparams)
#
# root_dir = os.path.dirname(os.path.realpath(__file__))
# mlflow_dir = os.path.join(root_dir, "mlruns")
#
# logger = MLFlowLogger("test", f"file://{mlflow_dir}")
# logger.log_hyperparams(hparams)
# logger.save()
#
# trainer_options = dict(
# max_nb_epochs=1,
# logger=logger
# )
#
# trainer = Trainer(**trainer_options)
# pkl_bytes = pickle.dumps(trainer)
# trainer2 = pickle.loads(pkl_bytes)
# trainer2.logger.log_metrics({"acc": 1.0})
#
# n = RANDOM_FILE_PATHS.pop()
# shutil.move(mlflow_dir, mlflow_dir + f'_{n}')


# def test_custom_logger():
#
# class CustomLogger(LightningLoggerBase):
# def __init__(self):
# super().__init__()
# self.hparams_logged = None
# self.metrics_logged = None
# self.finalized = False
#
# @rank_zero_only
# def log_hyperparams(self, params):
# self.hparams_logged = params
#
# @rank_zero_only
# def log_metrics(self, metrics, step_num):
# self.metrics_logged = metrics
#
# @rank_zero_only
# def finalize(self, status):
# self.finalized_status = status
#
# hparams = get_hparams()
# model = LightningTestModel(hparams)
#
# logger = CustomLogger()
#
# trainer_options = dict(
# max_nb_epochs=1,
# train_percent_check=0.01,
# logger=logger
# )
#
# trainer = Trainer(**trainer_options)
# result = trainer.fit(model)
# assert result == 1, "Training failed"
# assert logger.hparams_logged == hparams
# assert logger.metrics_logged != {}
# assert logger.finalized_status == "success"
def test_mlflow_logger():
"""
verify that basic functionality of mlflow logger works
"""
reset_seed()

try:
from pytorch_lightning.logging import MLFlowLogger
except ModuleNotFoundError:
return

hparams = testing_utils.get_hparams()
model = LightningTestModel(hparams)

root_dir = os.path.dirname(os.path.realpath(__file__))
mlflow_dir = os.path.join(root_dir, "mlruns")

logger = MLFlowLogger("test", f"file://{mlflow_dir}")

trainer_options = dict(
max_nb_epochs=1,
train_percent_check=0.01,
logger=logger
)

trainer = Trainer(**trainer_options)
result = trainer.fit(model)

print('result finished')
assert result == 1, "Training failed"

testing_utils.clear_save_dir()


def test_mlflow_pickle():
"""
verify that pickling trainer with mlflow logger works
"""
reset_seed()

try:
from pytorch_lightning.logging import MLFlowLogger
except ModuleNotFoundError:
return

hparams = testing_utils.get_hparams()
model = LightningTestModel(hparams)

root_dir = os.path.dirname(os.path.realpath(__file__))
mlflow_dir = os.path.join(root_dir, "mlruns")

logger = MLFlowLogger("test", f"file://{mlflow_dir}")

trainer_options = dict(
max_nb_epochs=1,
logger=logger
)

trainer = Trainer(**trainer_options)
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({"acc": 1.0})

testing_utils.clear_save_dir()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add an assert to check that the output is as you expect

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is designed to detect a very particular failure mode when an exception is thrown on the pickle dump/load. No need to assert. All an assert here would test is whether the standard pickle module works correctly.

The test isn't really the focus of this PR. For context, this test was in master for a while, then a tricky interaction of some new (and mostly unrelated) features caused problems when using the default checkpoint saver, default save path, and anything other than the test tube logger. These tests got commented out as a short term fix. I'm just uncommenting them now, since this PR resolves the underlying problem.

I definitely think there are some improvements that could be made to this and other tests, but lets deal with them in a separate issue / PR.



def test_custom_logger(tmpdir):

class CustomLogger(LightningLoggerBase):
def __init__(self):
super().__init__()
self.hparams_logged = None
self.metrics_logged = None
self.finalized = False

@rank_zero_only
def log_hyperparams(self, params):
self.hparams_logged = params

@rank_zero_only
def log_metrics(self, metrics, step_num):
self.metrics_logged = metrics

@rank_zero_only
def finalize(self, status):
self.finalized_status = status

@property
def name(self):
return "name"

@property
def version(self):
return "1"

hparams = testing_utils.get_hparams()
model = LightningTestModel(hparams)

logger = CustomLogger()

trainer_options = dict(
max_nb_epochs=1,
train_percent_check=0.01,
logger=logger,
default_save_path=tmpdir
)

trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result == 1, "Training failed"
assert logger.hparams_logged == hparams
assert logger.metrics_logged != {}
assert logger.finalized_status == "success"


def reset_seed():
Expand Down