Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Aug 5, 2020
1 parent 2c66a4e commit 259af6a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
return {}

with open(config_yaml) as fp:
tags = yaml.load(fp, Loader=yaml.SafeLoader)
tags = yaml.load(fp)

return tags

Expand Down
12 changes: 9 additions & 3 deletions tests/loggers/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,11 @@ def test_file_logger_log_metrics(tmpdir, step_idx):
logger.log_metrics(metrics, step_idx)
logger.save()

path_yaml = os.path.join(logger.log_dir, ExperimentWriter.NAME_HPARAMS_FILE)
params = load_hparams_from_yaml(path_yaml)
assert all([n in params for n in metrics])
path_csv = os.path.join(logger.log_dir, ExperimentWriter.NAME_METRICS_FILE)
with open(path_csv, 'r') as fp:
lines = fp.readlines()
assert len(lines) == 2
assert all([n in lines[0] for n in metrics])


def test_file_logger_log_hyperparams(tmpdir):
Expand All @@ -89,3 +91,7 @@ def test_file_logger_log_hyperparams(tmpdir):
}
logger.log_hyperparams(hparams)
logger.save()

path_yaml = os.path.join(logger.log_dir, ExperimentWriter.NAME_HPARAMS_FILE)
params = load_hparams_from_yaml(path_yaml)
assert all([n in params for n in hparams])

0 comments on commit 259af6a

Please sign in to comment.