|
10 | 10 |
|
11 | 11 | from pytorch_lightning import _logger as log
|
12 | 12 | from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
|
13 |
| -from pytorch_lightning.utilities.io import load as pl_load |
| 13 | +from pytorch_lightning.utilities.cloud_io import load as pl_load |
14 | 14 |
|
15 | 15 | PRIMITIVE_TYPES = (bool, int, float, str)
|
16 | 16 | ALLOWED_CONFIG_TYPES = (AttributeDict, dict, Namespace)
|
17 | 17 | try:
|
18 | 18 | from omegaconf import Container
|
19 | 19 | except ImportError:
|
20 |
| - pass |
| 20 | + Container = None |
21 | 21 | else:
|
22 | 22 | ALLOWED_CONFIG_TYPES = ALLOWED_CONFIG_TYPES + (Container, )
|
23 | 23 |
|
@@ -332,11 +332,25 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
|
332 | 332 |
|
333 | 333 |
|
334 | 334 | def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
|
| 335 | + """ |
| 336 | + Args: |
| 337 | + config_yaml: path to new YAML file |
| 338 | + hparams: parameters to be saved |
| 339 | + """ |
335 | 340 | if not os.path.isdir(os.path.dirname(config_yaml)):
|
336 | 341 | raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.')
|
337 | 342 |
|
| 343 | + if Container is not None and isinstance(hparams, Container): |
| 344 | + from omegaconf import OmegaConf |
| 345 | + OmegaConf.save(hparams, config_yaml, resolve=True) |
| 346 | + return |
| 347 | + |
| 348 | + # saving the standard way |
338 | 349 | if isinstance(hparams, Namespace):
|
339 | 350 | hparams = vars(hparams)
|
| 351 | + elif isinstance(hparams, AttributeDict): |
| 352 | + hparams = dict(hparams) |
| 353 | + assert isinstance(hparams, dict) |
340 | 354 |
|
341 | 355 | with open(config_yaml, 'w', newline='') as fp:
|
342 | 356 | yaml.dump(hparams, fp)
|
|
0 commit comments