diff --git a/ml-agents/mlagents/trainers/tests/test_settings.py b/ml-agents/mlagents/trainers/tests/test_settings.py index 6cdcd9440a..33271a1c14 100644 --- a/ml-agents/mlagents/trainers/tests/test_settings.py +++ b/ml-agents/mlagents/trainers/tests/test_settings.py @@ -1,7 +1,8 @@ import attr import pytest +import yaml -from typing import Dict +from typing import Dict, List, Optional from mlagents.trainers.settings import ( RunOptions, @@ -32,6 +33,32 @@ def check_if_different(testobj1: object, testobj2: object) -> None: check_if_different(val, attr.asdict(testobj2, recurse=False)[key]) +def check_dict_is_at_least( + testdict1: Dict, testdict2: Dict, exceptions: Optional[List[str]] = None +) -> None: + """ + Check if everything present in the 1st dict is the same in the second dict. + Excludes things that the second dict has but is not present in the heirarchy of the + 1st dict. Used to compare an underspecified config dict structure (e.g. as + would be provided by a user) with a complete one (e.g. as exported by RunOptions). + """ + for key, val in testdict1.items(): + if exceptions is not None and key in exceptions: + continue + assert key in testdict2 + if isinstance(val, dict): + check_dict_is_at_least(val, testdict2[key]) + elif isinstance(val, list): + assert isinstance(testdict2[key], list) + for _el0, _el1 in zip(val, testdict2[key]): + if isinstance(_el0, dict): + check_dict_is_at_least(_el0, _el1) + else: + assert val == testdict2[key] + else: # If not a dict, don't recurse into it + assert val == testdict2[key] + + def test_is_new_instance(): """ Verify that every instance of RunOptions() and its subclasses @@ -289,3 +316,92 @@ def test_env_parameter_structure(): EnvironmentParameterSettings.structure( invalid_curriculum_dict, Dict[str, EnvironmentParameterSettings] ) + + +def test_exportable_settings(): + """ + Test that structuring and unstructuring a RunOptions object results in the same + configuration representation. + """ + # Try to enable as many features as possible in this test YAML to hit all the + # edge cases. Set as much as possible as non-default values to ensure no flukes. + # TODO: Add back in environment_parameters + test_yaml = """ + behaviors: + 3DBall: + trainer_type: sac + hyperparameters: + learning_rate: 0.0004 + learning_rate_schedule: constant + batch_size: 64 + buffer_size: 200000 + buffer_init_steps: 100 + tau: 0.006 + steps_per_update: 10.0 + save_replay_buffer: true + init_entcoef: 0.5 + reward_signal_steps_per_update: 10.0 + network_settings: + normalize: false + hidden_units: 256 + num_layers: 3 + vis_encode_type: nature_cnn + memory: + memory_size: 1288 + sequence_length: 12 + reward_signals: + extrinsic: + gamma: 0.999 + strength: 1.0 + curiosity: + gamma: 0.999 + strength: 1.0 + keep_checkpoints: 5 + max_steps: 500000 + time_horizon: 1000 + summary_freq: 12000 + checkpoint_interval: 1 + threaded: true + env_settings: + env_path: test_env_path + env_args: + - test_env_args1 + - test_env_args2 + base_port: 12345 + num_envs: 8 + seed: 12345 + engine_settings: + width: 12345 + height: 12345 + quality_level: 12345 + time_scale: 12345 + target_frame_rate: 12345 + capture_frame_rate: 12345 + no_graphics: true + checkpoint_settings: + run_id: test_run_id + initialize_from: test_directory + load_model: false + resume: true + force: true + train_model: false + inference: false + debug: true + """ + loaded_yaml = yaml.safe_load(test_yaml) + run_options = RunOptions.from_dict(yaml.safe_load(test_yaml)) + dict_export = run_options.as_dict() + check_dict_is_at_least(loaded_yaml, dict_export) + + # Re-import and verify has same elements + run_options2 = RunOptions.from_dict(dict_export) + second_export = run_options2.as_dict() + + check_dict_is_at_least( + dict_export, second_export, exceptions=["environment_parameters"] + ) + # Should be able to use equality instead of back-and-forth once environment_parameters + # is working + check_dict_is_at_least( + second_export, dict_export, exceptions=["environment_parameters"] + )