Skip to content

Commit e289e45

Browse files
Bordaomry
andauthored
test: save hparams to yaml (Lightning-AI#2198)
* save hparams to yaml * import * resolves * req * Update requirements/base.txt Co-authored-by: Omry Yadan <omry@fb.com> Co-authored-by: Omry Yadan <omry@fb.com>
1 parent f94b919 commit e289e45

File tree

7 files changed

+40
-18
lines changed

7 files changed

+40
-18
lines changed

pytorch_lightning/core/saving.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010

1111
from pytorch_lightning import _logger as log
1212
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
13-
from pytorch_lightning.utilities.io import load as pl_load
13+
from pytorch_lightning.utilities.cloud_io import load as pl_load
1414

1515
PRIMITIVE_TYPES = (bool, int, float, str)
1616
ALLOWED_CONFIG_TYPES = (AttributeDict, dict, Namespace)
1717
try:
1818
from omegaconf import Container
1919
except ImportError:
20-
pass
20+
Container = None
2121
else:
2222
ALLOWED_CONFIG_TYPES = ALLOWED_CONFIG_TYPES + (Container, )
2323

@@ -332,11 +332,25 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
332332

333333

334334
def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
335+
"""
336+
Args:
337+
config_yaml: path to new YAML file
338+
hparams: parameters to be saved
339+
"""
335340
if not os.path.isdir(os.path.dirname(config_yaml)):
336341
raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.')
337342

343+
if Container is not None and isinstance(hparams, Container):
344+
from omegaconf import OmegaConf
345+
OmegaConf.save(hparams, config_yaml, resolve=True)
346+
return
347+
348+
# saving the standard way
338349
if isinstance(hparams, Namespace):
339350
hparams = vars(hparams)
351+
elif isinstance(hparams, AttributeDict):
352+
hparams = dict(hparams)
353+
assert isinstance(hparams, dict)
340354

341355
with open(config_yaml, 'w', newline='') as fp:
342356
yaml.dump(hparams, fp)

pytorch_lightning/loggers/tensorboard.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,6 @@
1717
from pytorch_lightning.loggers.base import LightningLoggerBase
1818
from pytorch_lightning.utilities import rank_zero_only
1919

20-
try:
21-
from omegaconf import Container
22-
except ImportError:
23-
Container = None
24-
2520

2621
class TensorBoardLogger(LightningLoggerBase):
2722
r"""
@@ -156,14 +151,7 @@ def save(self) -> None:
156151
hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE)
157152

158153
# save the metatags file
159-
if Container is not None:
160-
if isinstance(self.hparams, Container):
161-
from omegaconf import OmegaConf
162-
OmegaConf.save(self.hparams, hparams_file, resolve=True)
163-
else:
164-
save_hparams_to_yaml(hparams_file, self.hparams)
165-
else:
166-
save_hparams_to_yaml(hparams_file, self.hparams)
154+
save_hparams_to_yaml(hparams_file, self.hparams)
167155

168156
@rank_zero_only
169157
def finalize(self, status: str) -> None:

pytorch_lightning/trainer/training_io.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
LightningDataParallel,
102102
)
103103
from pytorch_lightning.utilities import rank_zero_warn
104-
from pytorch_lightning.utilities.io import load as pl_load
104+
from pytorch_lightning.utilities.cloud_io import load as pl_load
105105

106106
try:
107107
import torch_xla
File renamed without changes.

requirements/base.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ tqdm>=4.41.0
55
torch>=1.3
66
tensorboard>=1.14
77
future>=0.17.1 # required for builtins in setup.py
8-
pyyaml>=3.13
8+
# pyyaml>=3.13
9+
PyYAML>=5.1 # OmegaConf requirement

tests/models/test_hparams.py

+19
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from omegaconf import OmegaConf, Container
99

1010
from pytorch_lightning import Trainer, LightningModule
11+
from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml
1112
from pytorch_lightning.utilities import AttributeDict
1213
from tests.base import EvalModelTemplate
1314

@@ -407,3 +408,21 @@ def test_hparams_pickle(tmpdir):
407408
assert ad == pickle.loads(pkl)
408409
pkl = cloudpickle.dumps(ad)
409410
assert ad == pickle.loads(pkl)
411+
412+
413+
def test_hparams_save_yaml(tmpdir):
414+
hparams = dict(batch_size=32, learning_rate=0.001, data_root='./any/path/here',
415+
nasted=dict(any_num=123, anystr='abcd'))
416+
path_yaml = os.path.join(tmpdir, 'testing-hparams.yaml')
417+
418+
save_hparams_to_yaml(path_yaml, hparams)
419+
assert load_hparams_from_yaml(path_yaml) == hparams
420+
421+
save_hparams_to_yaml(path_yaml, Namespace(**hparams))
422+
assert load_hparams_from_yaml(path_yaml) == hparams
423+
424+
save_hparams_to_yaml(path_yaml, AttributeDict(hparams))
425+
assert load_hparams_from_yaml(path_yaml) == hparams
426+
427+
save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
428+
assert load_hparams_from_yaml(path_yaml) == hparams

tests/trainer/test_trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pytorch_lightning.loggers import TensorBoardLogger
2020
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
2121
from pytorch_lightning.utilities.exceptions import MisconfigurationException
22-
from pytorch_lightning.utilities.io import load as pl_load
22+
from pytorch_lightning.utilities.cloud_io import load as pl_load
2323
from tests.base import EvalModelTemplate
2424

2525

0 commit comments

Comments
 (0)