From 9aaef1e2f9a5ac0940442533c71d19753ea6bdca Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Thu, 25 Jun 2020 15:07:23 -0700 Subject: [PATCH 1/2] Add test for settings export --- .../mlagents/trainers/tests/test_settings.py | 118 +++++++++++++++++- 1 file changed, 117 insertions(+), 1 deletion(-) diff --git a/ml-agents/mlagents/trainers/tests/test_settings.py b/ml-agents/mlagents/trainers/tests/test_settings.py index 6cdcd9440a..b9ea071181 100644 --- a/ml-agents/mlagents/trainers/tests/test_settings.py +++ b/ml-agents/mlagents/trainers/tests/test_settings.py @@ -1,7 +1,8 @@ import attr import pytest +import yaml -from typing import Dict +from typing import Dict, List, Optional from mlagents.trainers.settings import ( RunOptions, @@ -32,6 +33,32 @@ def check_if_different(testobj1: object, testobj2: object) -> None: check_if_different(val, attr.asdict(testobj2, recurse=False)[key]) +def check_dict_is_at_least( + testdict1: Dict, testdict2: Dict, exceptions: Optional[List[str]] = None +) -> None: + """ + Check if everything present in the 1st dict is the same in the second dict. + Excludes things that the second dict has but is not present in the heirarchy of the + 1st dict. Used to compare an underspecified config dict structure (e.g. as + would be provded by a user) with a complete one (e.g. as exported by RunOptions). + """ + for key, val in testdict1.items(): + if exceptions is not None and key in exceptions: + continue + assert key in testdict2 + if isinstance(val, dict): + check_dict_is_at_least(val, testdict2[key]) + elif isinstance(val, list): + assert isinstance(testdict2[key], list) + for _el0, _el1 in zip(val, testdict2[key]): + if isinstance(_el0, dict): + check_dict_is_at_least(_el0, _el1) + else: + assert val == testdict2[key] + else: # If not a dict, don't recurse into it + assert val == testdict2[key] + + def test_is_new_instance(): """ Verify that every instance of RunOptions() and its subclasses @@ -289,3 +316,92 @@ def test_env_parameter_structure(): EnvironmentParameterSettings.structure( invalid_curriculum_dict, Dict[str, EnvironmentParameterSettings] ) + + +def test_exportable_settings(): + """ + Test that structuring and unstructuring a RunOptions object results in the same + configuration representation. + """ + # Try to enable as many features as possible in this test YAML to hit all the + # edge cases. Set as much as possible as non-default values to ensure no flukes. + # TODO: Add back in environment_parameters + test_yaml = """ + behaviors: + 3DBall: + trainer_type: sac + hyperparameters: + learning_rate: 0.0004 + learning_rate_schedule: constant + batch_size: 64 + buffer_size: 200000 + buffer_init_steps: 100 + tau: 0.006 + steps_per_update: 10.0 + save_replay_buffer: true + init_entcoef: 0.5 + reward_signal_steps_per_update: 10.0 + network_settings: + normalize: false + hidden_units: 256 + num_layers: 3 + vis_encode_type: nature_cnn + memory: + memory_size: 1288 + sequence_length: 12 + reward_signals: + extrinsic: + gamma: 0.999 + strength: 1.0 + curiosity: + gamma: 0.999 + strength: 1.0 + keep_checkpoints: 5 + max_steps: 500000 + time_horizon: 1000 + summary_freq: 12000 + checkpoint_interval: 1 + threaded: true + env_settings: + env_path: test_env_path + env_args: + - test_env_args1 + - test_env_args2 + base_port: 12345 + num_envs: 8 + seed: 12345 + engine_settings: + width: 12345 + height: 12345 + quality_level: 12345 + time_scale: 12345 + target_frame_rate: 12345 + capture_frame_rate: 12345 + no_graphics: true + checkpoint_settings: + run_id: test_run_id + initialize_from: test_directory + load_model: false + resume: true + force: true + train_model: false + inference: false + debug: true + """ + loaded_yaml = yaml.safe_load(test_yaml) + run_options = RunOptions.from_dict(yaml.safe_load(test_yaml)) + dict_export = run_options.as_dict() + check_dict_is_at_least(loaded_yaml, dict_export) + + # Re-import and verify has same elements + run_options2 = RunOptions.from_dict(dict_export) + second_export = run_options2.as_dict() + + check_dict_is_at_least( + dict_export, second_export, exceptions=["environment_parameters"] + ) + # Should be able to use equality instead of back-and-forth once environment_parameters + # is working + check_dict_is_at_least( + second_export, dict_export, exceptions=["environment_parameters"] + ) From 60bd5ec43dbd06d452c861d81d6da56ae0e0a914 Mon Sep 17 00:00:00 2001 From: Ervin T Date: Thu, 25 Jun 2020 16:29:48 -0700 Subject: [PATCH 2/2] Update ml-agents/mlagents/trainers/tests/test_settings.py Co-authored-by: Vincent-Pierre BERGES --- ml-agents/mlagents/trainers/tests/test_settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-agents/mlagents/trainers/tests/test_settings.py b/ml-agents/mlagents/trainers/tests/test_settings.py index b9ea071181..33271a1c14 100644 --- a/ml-agents/mlagents/trainers/tests/test_settings.py +++ b/ml-agents/mlagents/trainers/tests/test_settings.py @@ -40,7 +40,7 @@ def check_dict_is_at_least( Check if everything present in the 1st dict is the same in the second dict. Excludes things that the second dict has but is not present in the heirarchy of the 1st dict. Used to compare an underspecified config dict structure (e.g. as - would be provded by a user) with a complete one (e.g. as exported by RunOptions). + would be provided by a user) with a complete one (e.g. as exported by RunOptions). """ for key, val in testdict1.items(): if exceptions is not None and key in exceptions: