Skip to content

Commit

Permalink
test: save hparams to yaml (#2198)
Browse files Browse the repository at this point in the history
* save hparams to yaml

* import

* resolves

* req

* Update requirements/base.txt

Co-authored-by: Omry Yadan <omry@fb.com>

Co-authored-by: Omry Yadan <omry@fb.com>
  • Loading branch information
Borda and omry committed Jun 16, 2020
1 parent f94b919 commit e289e45
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 18 deletions.
18 changes: 16 additions & 2 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
from pytorch_lightning.utilities.io import load as pl_load
from pytorch_lightning.utilities.cloud_io import load as pl_load

PRIMITIVE_TYPES = (bool, int, float, str)
ALLOWED_CONFIG_TYPES = (AttributeDict, dict, Namespace)
try:
from omegaconf import Container
except ImportError:
pass
Container = None
else:
ALLOWED_CONFIG_TYPES = ALLOWED_CONFIG_TYPES + (Container, )

Expand Down Expand Up @@ -332,11 +332,25 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:


def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
"""
Args:
config_yaml: path to new YAML file
hparams: parameters to be saved
"""
if not os.path.isdir(os.path.dirname(config_yaml)):
raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.')

if Container is not None and isinstance(hparams, Container):
from omegaconf import OmegaConf
OmegaConf.save(hparams, config_yaml, resolve=True)
return

# saving the standard way
if isinstance(hparams, Namespace):
hparams = vars(hparams)
elif isinstance(hparams, AttributeDict):
hparams = dict(hparams)
assert isinstance(hparams, dict)

with open(config_yaml, 'w', newline='') as fp:
yaml.dump(hparams, fp)
Expand Down
14 changes: 1 addition & 13 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.utilities import rank_zero_only

try:
from omegaconf import Container
except ImportError:
Container = None


class TensorBoardLogger(LightningLoggerBase):
r"""
Expand Down Expand Up @@ -156,14 +151,7 @@ def save(self) -> None:
hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE)

# save the metatags file
if Container is not None:
if isinstance(self.hparams, Container):
from omegaconf import OmegaConf
OmegaConf.save(self.hparams, hparams_file, resolve=True)
else:
save_hparams_to_yaml(hparams_file, self.hparams)
else:
save_hparams_to_yaml(hparams_file, self.hparams)
save_hparams_to_yaml(hparams_file, self.hparams)

@rank_zero_only
def finalize(self, status: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
LightningDataParallel,
)
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.io import load as pl_load
from pytorch_lightning.utilities.cloud_io import load as pl_load

try:
import torch_xla
Expand Down
File renamed without changes.
3 changes: 2 additions & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ tqdm>=4.41.0
torch>=1.3
tensorboard>=1.14
future>=0.17.1 # required for builtins in setup.py
pyyaml>=3.13
# pyyaml>=3.13
PyYAML>=5.1 # OmegaConf requirement
19 changes: 19 additions & 0 deletions tests/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from omegaconf import OmegaConf, Container

from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml
from pytorch_lightning.utilities import AttributeDict
from tests.base import EvalModelTemplate

Expand Down Expand Up @@ -407,3 +408,21 @@ def test_hparams_pickle(tmpdir):
assert ad == pickle.loads(pkl)
pkl = cloudpickle.dumps(ad)
assert ad == pickle.loads(pkl)


def test_hparams_save_yaml(tmpdir):
hparams = dict(batch_size=32, learning_rate=0.001, data_root='./any/path/here',
nasted=dict(any_num=123, anystr='abcd'))
path_yaml = os.path.join(tmpdir, 'testing-hparams.yaml')

save_hparams_to_yaml(path_yaml, hparams)
assert load_hparams_from_yaml(path_yaml) == hparams

save_hparams_to_yaml(path_yaml, Namespace(**hparams))
assert load_hparams_from_yaml(path_yaml) == hparams

save_hparams_to_yaml(path_yaml, AttributeDict(hparams))
assert load_hparams_from_yaml(path_yaml) == hparams

save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
assert load_hparams_from_yaml(path_yaml) == hparams
2 changes: 1 addition & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.io import load as pl_load
from pytorch_lightning.utilities.cloud_io import load as pl_load
from tests.base import EvalModelTemplate


Expand Down

0 comments on commit e289e45

Please sign in to comment.