-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Fix ModelCheckpoint default paths #413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
44d23bc
70f2fbb
d8db6e5
5a3a88a
8d87a65
b669e3b
658bdf7
2044f5e
f4c3985
4929469
13777e1
0ed3f1f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,12 @@ | ||
| import os | ||
| import pickle | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from pytorch_lightning import Trainer | ||
| from pytorch_lightning.testing import LightningTestModel | ||
| from pytorch_lightning.logging import LightningLoggerBase, rank_zero_only | ||
| from . import testing_utils | ||
|
|
||
| RANDOM_FILE_PATHS = list(np.random.randint(12000, 19000, 1000)) | ||
|
|
@@ -69,117 +71,119 @@ def test_testtube_pickle(): | |
| testing_utils.clear_save_dir() | ||
|
|
||
|
|
||
| # def test_mlflow_logger(): | ||
| # """ | ||
| # verify that basic functionality of mlflow logger works | ||
| # """ | ||
| # reset_seed() | ||
| # | ||
| # try: | ||
| # from pytorch_lightning.logging import MLFlowLogger | ||
| # except ModuleNotFoundError: | ||
| # return | ||
| # | ||
| # hparams = get_hparams() | ||
| # model = LightningTestModel(hparams) | ||
| # | ||
| # root_dir = os.path.dirname(os.path.realpath(__file__)) | ||
| # mlflow_dir = os.path.join(root_dir, "mlruns") | ||
| # import pdb | ||
| # pdb.set_trace() | ||
| # | ||
| # logger = MLFlowLogger("test", f"file://{mlflow_dir}") | ||
| # logger.log_hyperparams(hparams) | ||
| # logger.save() | ||
| # | ||
| # trainer_options = dict( | ||
| # max_nb_epochs=1, | ||
| # train_percent_check=0.01, | ||
| # logger=logger | ||
| # ) | ||
| # | ||
| # trainer = Trainer(**trainer_options) | ||
| # result = trainer.fit(model) | ||
| # | ||
| # print('result finished') | ||
| # assert result == 1, "Training failed" | ||
| # | ||
| # shutil.move(mlflow_dir, mlflow_dir + f'_{n}') | ||
|
|
||
|
|
||
| # def test_mlflow_pickle(): | ||
| # """ | ||
| # verify that pickling trainer with mlflow logger works | ||
| # """ | ||
| # reset_seed() | ||
| # | ||
| # try: | ||
| # from pytorch_lightning.logging import MLFlowLogger | ||
| # except ModuleNotFoundError: | ||
| # return | ||
| # | ||
| # hparams = get_hparams() | ||
| # model = LightningTestModel(hparams) | ||
| # | ||
| # root_dir = os.path.dirname(os.path.realpath(__file__)) | ||
| # mlflow_dir = os.path.join(root_dir, "mlruns") | ||
| # | ||
| # logger = MLFlowLogger("test", f"file://{mlflow_dir}") | ||
| # logger.log_hyperparams(hparams) | ||
| # logger.save() | ||
| # | ||
| # trainer_options = dict( | ||
| # max_nb_epochs=1, | ||
| # logger=logger | ||
| # ) | ||
| # | ||
| # trainer = Trainer(**trainer_options) | ||
| # pkl_bytes = pickle.dumps(trainer) | ||
| # trainer2 = pickle.loads(pkl_bytes) | ||
| # trainer2.logger.log_metrics({"acc": 1.0}) | ||
| # | ||
| # n = RANDOM_FILE_PATHS.pop() | ||
| # shutil.move(mlflow_dir, mlflow_dir + f'_{n}') | ||
|
|
||
|
|
||
| # def test_custom_logger(): | ||
| # | ||
| # class CustomLogger(LightningLoggerBase): | ||
| # def __init__(self): | ||
| # super().__init__() | ||
| # self.hparams_logged = None | ||
| # self.metrics_logged = None | ||
| # self.finalized = False | ||
| # | ||
| # @rank_zero_only | ||
| # def log_hyperparams(self, params): | ||
| # self.hparams_logged = params | ||
| # | ||
| # @rank_zero_only | ||
| # def log_metrics(self, metrics, step_num): | ||
| # self.metrics_logged = metrics | ||
| # | ||
| # @rank_zero_only | ||
| # def finalize(self, status): | ||
| # self.finalized_status = status | ||
| # | ||
| # hparams = get_hparams() | ||
| # model = LightningTestModel(hparams) | ||
| # | ||
| # logger = CustomLogger() | ||
| # | ||
| # trainer_options = dict( | ||
| # max_nb_epochs=1, | ||
| # train_percent_check=0.01, | ||
| # logger=logger | ||
| # ) | ||
| # | ||
| # trainer = Trainer(**trainer_options) | ||
| # result = trainer.fit(model) | ||
| # assert result == 1, "Training failed" | ||
| # assert logger.hparams_logged == hparams | ||
| # assert logger.metrics_logged != {} | ||
| # assert logger.finalized_status == "success" | ||
| def test_mlflow_logger(): | ||
| """ | ||
| verify that basic functionality of mlflow logger works | ||
| """ | ||
| reset_seed() | ||
|
|
||
| try: | ||
| from pytorch_lightning.logging import MLFlowLogger | ||
| except ModuleNotFoundError: | ||
| return | ||
|
|
||
| hparams = testing_utils.get_hparams() | ||
| model = LightningTestModel(hparams) | ||
|
|
||
| root_dir = os.path.dirname(os.path.realpath(__file__)) | ||
| mlflow_dir = os.path.join(root_dir, "mlruns") | ||
|
|
||
| logger = MLFlowLogger("test", f"file://{mlflow_dir}") | ||
|
|
||
| trainer_options = dict( | ||
| max_nb_epochs=1, | ||
| train_percent_check=0.01, | ||
| logger=logger | ||
| ) | ||
|
|
||
| trainer = Trainer(**trainer_options) | ||
| result = trainer.fit(model) | ||
|
|
||
| print('result finished') | ||
| assert result == 1, "Training failed" | ||
|
|
||
| testing_utils.clear_save_dir() | ||
|
|
||
|
|
||
| def test_mlflow_pickle(): | ||
| """ | ||
| verify that pickling trainer with mlflow logger works | ||
| """ | ||
| reset_seed() | ||
|
|
||
| try: | ||
| from pytorch_lightning.logging import MLFlowLogger | ||
| except ModuleNotFoundError: | ||
| return | ||
|
|
||
| hparams = testing_utils.get_hparams() | ||
| model = LightningTestModel(hparams) | ||
|
|
||
| root_dir = os.path.dirname(os.path.realpath(__file__)) | ||
| mlflow_dir = os.path.join(root_dir, "mlruns") | ||
|
|
||
| logger = MLFlowLogger("test", f"file://{mlflow_dir}") | ||
|
|
||
| trainer_options = dict( | ||
| max_nb_epochs=1, | ||
| logger=logger | ||
| ) | ||
|
|
||
| trainer = Trainer(**trainer_options) | ||
| pkl_bytes = pickle.dumps(trainer) | ||
| trainer2 = pickle.loads(pkl_bytes) | ||
| trainer2.logger.log_metrics({"acc": 1.0}) | ||
|
|
||
| testing_utils.clear_save_dir() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add an assert to check that the output is as you expect
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test is designed to detect a very particular failure mode when an exception is thrown on the pickle dump/load. No need to assert. All an assert here would test is whether the standard pickle module works correctly. The test isn't really the focus of this PR. For context, this test was in master for a while, then a tricky interaction of some new (and mostly unrelated) features caused problems when using the default checkpoint saver, default save path, and anything other than the test tube logger. These tests got commented out as a short term fix. I'm just uncommenting them now, since this PR resolves the underlying problem. I definitely think there are some improvements that could be made to this and other tests, but lets deal with them in a separate issue / PR. |
||
|
|
||
|
|
||
| def test_custom_logger(tmpdir): | ||
|
|
||
| class CustomLogger(LightningLoggerBase): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.hparams_logged = None | ||
| self.metrics_logged = None | ||
| self.finalized = False | ||
|
|
||
| @rank_zero_only | ||
| def log_hyperparams(self, params): | ||
| self.hparams_logged = params | ||
|
|
||
| @rank_zero_only | ||
| def log_metrics(self, metrics, step_num): | ||
| self.metrics_logged = metrics | ||
|
|
||
| @rank_zero_only | ||
| def finalize(self, status): | ||
| self.finalized_status = status | ||
|
|
||
| @property | ||
| def name(self): | ||
| return "name" | ||
|
|
||
| @property | ||
| def version(self): | ||
| return "1" | ||
|
|
||
| hparams = testing_utils.get_hparams() | ||
| model = LightningTestModel(hparams) | ||
|
|
||
| logger = CustomLogger() | ||
|
|
||
| trainer_options = dict( | ||
| max_nb_epochs=1, | ||
| train_percent_check=0.01, | ||
| logger=logger, | ||
| default_save_path=tmpdir | ||
| ) | ||
|
|
||
| trainer = Trainer(**trainer_options) | ||
| result = trainer.fit(model) | ||
| assert result == 1, "Training failed" | ||
| assert logger.hparams_logged == hparams | ||
| assert logger.metrics_logged != {} | ||
| assert logger.finalized_status == "success" | ||
|
|
||
|
|
||
| def reset_seed(): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rather:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMHO the current version is easier to read.