From 530c574cda562f10ee9bc25bc535d9612da4c9a7 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Thu, 19 Sep 2019 13:21:29 -0700 Subject: [PATCH 1/3] WIP cleanup loading --- ml-agents/mlagents/trainers/learn.py | 22 ++-------- ml-agents/mlagents/trainers/tests/test_sac.py | 2 +- .../trainers/tests/test_trainer_util.py | 42 +++++++++++++++++++ ml-agents/mlagents/trainers/trainer_util.py | 26 +++++++++++- 4 files changed, 71 insertions(+), 21 deletions(-) diff --git a/ml-agents/mlagents/trainers/learn.py b/ml-agents/mlagents/trainers/learn.py index c49bcaa2dd..3b8563ce6e 100644 --- a/ml-agents/mlagents/trainers/learn.py +++ b/ml-agents/mlagents/trainers/learn.py @@ -7,7 +7,7 @@ import glob import shutil import numpy as np -import yaml + from docopt import docopt from typing import Any, Callable, Dict, Optional @@ -15,10 +15,10 @@ from mlagents.trainers.trainer_controller import TrainerController from mlagents.trainers.exception import TrainerError from mlagents.trainers.meta_curriculum import MetaCurriculumError, MetaCurriculum -from mlagents.trainers.trainer_util import initialize_trainers +from mlagents.trainers.trainer_util import initialize_trainers, load_config from mlagents.envs.environment import UnityEnvironment from mlagents.envs.sampler_class import SamplerManager -from mlagents.envs.exception import UnityEnvironmentException, SamplerException +from mlagents.envs.exception import SamplerException from mlagents.envs.base_unity_environment import BaseUnityEnvironment from mlagents.envs.subprocess_env_manager import SubprocessEnvManager @@ -200,22 +200,6 @@ def prepare_for_docker_run(docker_target_name, env_path): return env_path -def load_config(trainer_config_path: str) -> Dict[str, Any]: - try: - with open(trainer_config_path) as data_file: - trainer_config = yaml.safe_load(data_file) - return trainer_config - except IOError: - raise UnityEnvironmentException( - "Parameter file could not be found " "at {}.".format(trainer_config_path) - ) - except UnicodeDecodeError: - raise UnityEnvironmentException( - "There was an error decoding " - "Trainer Config from this path : {}".format(trainer_config_path) - ) - - def create_environment_factory( env_path: str, docker_target_name: str, diff --git a/ml-agents/mlagents/trainers/tests/test_sac.py b/ml-agents/mlagents/trainers/tests/test_sac.py index 1c7af068fa..bdaa1cf10d 100644 --- a/ml-agents/mlagents/trainers/tests/test_sac.py +++ b/ml-agents/mlagents/trainers/tests/test_sac.py @@ -15,7 +15,7 @@ @pytest.fixture def dummy_config(): - return yaml.load( + return yaml.safe_load( """ trainer: sac batch_size: 32 diff --git a/ml-agents/mlagents/trainers/tests/test_trainer_util.py b/ml-agents/mlagents/trainers/tests/test_trainer_util.py index fd8f3d230a..3efdf7b66e 100644 --- a/ml-agents/mlagents/trainers/tests/test_trainer_util.py +++ b/ml-agents/mlagents/trainers/tests/test_trainer_util.py @@ -1,9 +1,12 @@ import pytest import yaml import os +import io +import tempfile from unittest.mock import patch import mlagents.trainers.trainer_util as trainer_util +from mlagents.trainers.trainer_util import load_config, _load_config from mlagents.trainers.trainer_metrics import TrainerMetrics from mlagents.trainers.ppo.trainer import PPOTrainer from mlagents.trainers.bc.offline_trainer import OfflineBCTrainer @@ -313,3 +316,42 @@ def test_initialize_invalid_trainer_raises_exception(BrainParametersMock): load_model=load_model, seed=seed, ) + + +def test_load_config_missing_file(): + with pytest.raises(UnityEnvironmentException): + load_config("thisFileDefinitelyDoesNotExist.yaml") + + +def test_load_config_valid_yaml(): + file_contents = """ +this: + - is fine + """ + fp = io.StringIO(file_contents) + res = _load_config(fp) + assert res == {"this": ["is fine"]} + + +def test_load_config_invalid_yaml(): + file_contents = """ +you: + - will +- not + - parse + """ + with pytest.raises(UnityEnvironmentException): + fp = io.StringIO(file_contents) + _load_config(fp) + + +def test_load_config_unicode_yaml(): + file_contents = """ +thís: + - 😡 + """ + fp = io.StringIO(file_contents) + res = _load_config(fp) + assert res == {"thís": ["😡"]} + + with tempfile.mktemp("unit") diff --git a/ml-agents/mlagents/trainers/trainer_util.py b/ml-agents/mlagents/trainers/trainer_util.py index 133896848b..3491cd2883 100644 --- a/ml-agents/mlagents/trainers/trainer_util.py +++ b/ml-agents/mlagents/trainers/trainer_util.py @@ -1,4 +1,5 @@ -from typing import Any, Dict +import yaml +from typing import Any, Dict, TextIO from mlagents.trainers.meta_curriculum import MetaCurriculum from mlagents.envs.exception import UnityEnvironmentException @@ -108,3 +109,26 @@ def initialize_trainers( "brain {}".format(brain_name) ) return trainers + + +def load_config(trainer_config_path: str) -> Dict[str, Any]: + try: + with open(trainer_config_path) as data_file: + return _load_config(data_file) + except IOError: + raise UnityEnvironmentException( + "Parameter file could not be found " "at {}.".format(trainer_config_path) + ) + except UnicodeDecodeError: + raise UnityEnvironmentException( + "There was an error decoding " + "Trainer Config from this path : {}".format(trainer_config_path) + ) + + +def _load_config(fp: TextIO) -> Dict[str, Any]: + try: + return yaml.safe_load(fp) + except yaml.parser.ParserError: + # TODO better message + raise UnityEnvironmentException("Error parsing yaml file. ") From ee84640efbf5c8cd4a7c9828eaacb0684275b892 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Fri, 20 Sep 2019 15:03:31 -0700 Subject: [PATCH 2/3] better exceptions for parser errors - refer to online lint tools --- ml-agents/mlagents/trainers/curriculum.py | 40 ++++++++++++++----- ml-agents/mlagents/trainers/exception.py | 16 ++++++++ .../trainers/tests/test_curriculum.py | 33 ++++++++++++++- .../trainers/tests/test_trainer_util.py | 13 ------ ml-agents/mlagents/trainers/trainer_util.py | 17 +++++--- 5 files changed, 87 insertions(+), 32 deletions(-) diff --git a/ml-agents/mlagents/trainers/curriculum.py b/ml-agents/mlagents/trainers/curriculum.py index 6962c55ff6..75ddb04a13 100644 --- a/ml-agents/mlagents/trainers/curriculum.py +++ b/ml-agents/mlagents/trainers/curriculum.py @@ -2,7 +2,7 @@ import json import math -from .exception import CurriculumError +from .exception import CurriculumConfigError, CurriculumLoadingError import logging @@ -23,14 +23,8 @@ def __init__(self, location, default_reset_parameters): # The name of the brain should be the basename of the file without the # extension. self._brain_name = os.path.basename(location).split(".")[0] + self.data = Curriculum.load_curriculum_file(location) - try: - with open(location) as data_file: - self.data = json.load(data_file) - except IOError: - raise CurriculumError("The file {0} could not be found.".format(location)) - except UnicodeDecodeError: - raise CurriculumError("There was an error decoding {}".format(location)) self.smoothing_value = 0 for key in [ "parameters", @@ -40,7 +34,7 @@ def __init__(self, location, default_reset_parameters): "signal_smoothing", ]: if key not in self.data: - raise CurriculumError( + raise CurriculumConfigError( "{0} does not contain a " "{1} field.".format(location, key) ) self.smoothing_value = 0 @@ -51,12 +45,12 @@ def __init__(self, location, default_reset_parameters): parameters = self.data["parameters"] for key in parameters: if key not in default_reset_parameters: - raise CurriculumError( + raise CurriculumConfigError( "The parameter {0} in Curriculum {1} is not present in " "the Environment".format(key, location) ) if len(parameters[key]) != self.max_lesson_num + 1: - raise CurriculumError( + raise CurriculumConfigError( "The parameter {0} in Curriculum {1} must have {2} values " "but {3} were found".format( key, location, self.max_lesson_num + 1, len(parameters[key]) @@ -117,3 +111,27 @@ def get_config(self, lesson=None): for key in parameters: config[key] = parameters[key][lesson] return config + + @staticmethod + def load_curriculum_file(location): + try: + with open(location) as data_file: + return Curriculum._load_curriculum(data_file) + except IOError: + raise CurriculumLoadingError( + "The file {0} could not be found.".format(location) + ) + except UnicodeDecodeError: + raise CurriculumLoadingError( + "There was an error decoding {}".format(location) + ) + + @staticmethod + def _load_curriculum(fp): + try: + return json.load(fp) + except json.decoder.JSONDecodeError as e: + raise CurriculumLoadingError( + "Error parsing JSON file. Please check for formatting errors. " + "A tool such as https://jsonlint.com/ can be helpful with this." + ) from e diff --git a/ml-agents/mlagents/trainers/exception.py b/ml-agents/mlagents/trainers/exception.py index d9b9921081..8dcb9dca50 100644 --- a/ml-agents/mlagents/trainers/exception.py +++ b/ml-agents/mlagents/trainers/exception.py @@ -19,6 +19,22 @@ class CurriculumError(TrainerError): pass +class CurriculumLoadingError(CurriculumError): + """ + Any error related to loading the Curriculum config file. + """ + + pass + + +class CurriculumConfigError(CurriculumError): + """ + Any error related to processing the Curriculum config file. + """ + + pass + + class MetaCurriculumError(TrainerError): """ Any error related to the configuration of a metacurriculum. diff --git a/ml-agents/mlagents/trainers/tests/test_curriculum.py b/ml-agents/mlagents/trainers/tests/test_curriculum.py index 7759710334..84f0c2dd88 100644 --- a/ml-agents/mlagents/trainers/tests/test_curriculum.py +++ b/ml-agents/mlagents/trainers/tests/test_curriculum.py @@ -1,7 +1,9 @@ +import io +import json import pytest from unittest.mock import patch, mock_open -from mlagents.trainers.exception import CurriculumError +from mlagents.trainers.exception import CurriculumConfigError, CurriculumLoadingError from mlagents.trainers.curriculum import Curriculum @@ -60,7 +62,7 @@ def test_init_curriculum_happy_path(mock_file, location, default_reset_parameter def test_init_curriculum_bad_curriculum_raises_error( mock_file, location, default_reset_parameters ): - with pytest.raises(CurriculumError): + with pytest.raises(CurriculumConfigError): Curriculum(location, default_reset_parameters) @@ -93,3 +95,30 @@ def test_get_config(mock_file): curriculum.lesson_num = 2 assert curriculum.get_config() == {"param1": 0.3, "param2": 20, "param3": 0.7} assert curriculum.get_config(0) == {"param1": 0.7, "param2": 100, "param3": 0.2} + + +# Test json loading and error handling. These examples don't need to valid config files. + + +def test_curriculum_load_good(): + expected = {"x": 1} + value = json.dumps(expected) + fp = io.StringIO(value) + assert expected == Curriculum._load_curriculum(fp) + + +def test_curriculum_load_missing_file(): + with pytest.raises(CurriculumLoadingError): + Curriculum.load_curriculum_file("notAValidFile.json") + + +def test_curriculum_load_invalid_json(): + # This isn't valid json because of the trailing comma + contents = """ +{ + "x": [1, 2, 3,] +} +""" + fp = io.StringIO(contents) + with pytest.raises(CurriculumLoadingError): + Curriculum._load_curriculum(fp) diff --git a/ml-agents/mlagents/trainers/tests/test_trainer_util.py b/ml-agents/mlagents/trainers/tests/test_trainer_util.py index 3efdf7b66e..61b3d910d5 100644 --- a/ml-agents/mlagents/trainers/tests/test_trainer_util.py +++ b/ml-agents/mlagents/trainers/tests/test_trainer_util.py @@ -2,7 +2,6 @@ import yaml import os import io -import tempfile from unittest.mock import patch import mlagents.trainers.trainer_util as trainer_util @@ -343,15 +342,3 @@ def test_load_config_invalid_yaml(): with pytest.raises(UnityEnvironmentException): fp = io.StringIO(file_contents) _load_config(fp) - - -def test_load_config_unicode_yaml(): - file_contents = """ -thís: - - 😡 - """ - fp = io.StringIO(file_contents) - res = _load_config(fp) - assert res == {"thís": ["😡"]} - - with tempfile.mktemp("unit") diff --git a/ml-agents/mlagents/trainers/trainer_util.py b/ml-agents/mlagents/trainers/trainer_util.py index 3491cd2883..bb8380a28f 100644 --- a/ml-agents/mlagents/trainers/trainer_util.py +++ b/ml-agents/mlagents/trainers/trainer_util.py @@ -117,18 +117,23 @@ def load_config(trainer_config_path: str) -> Dict[str, Any]: return _load_config(data_file) except IOError: raise UnityEnvironmentException( - "Parameter file could not be found " "at {}.".format(trainer_config_path) + f"Config file could not be found at {trainer_config_path}." ) except UnicodeDecodeError: raise UnityEnvironmentException( - "There was an error decoding " - "Trainer Config from this path : {}".format(trainer_config_path) + f"There was an error decoding Config file from {trainer_config_path}. " + f"Make sure your file is save using UTF-8" ) def _load_config(fp: TextIO) -> Dict[str, Any]: + """ + Load the yaml config from the file-like object. + """ try: return yaml.safe_load(fp) - except yaml.parser.ParserError: - # TODO better message - raise UnityEnvironmentException("Error parsing yaml file. ") + except yaml.parser.ParserError as e: + raise UnityEnvironmentException( + "Error parsing yaml file. Please check for formatting errors. " + "A tool such as http://www.yamllint.com/ can be helpful with this." + ) from e From 138fdc692a01b9944a99e570948173cd611e7b2b Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Mon, 23 Sep 2019 14:55:13 -0700 Subject: [PATCH 3/3] feedback - rename variable --- ml-agents/mlagents/trainers/trainer_util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ml-agents/mlagents/trainers/trainer_util.py b/ml-agents/mlagents/trainers/trainer_util.py index bb8380a28f..dff7c6b416 100644 --- a/ml-agents/mlagents/trainers/trainer_util.py +++ b/ml-agents/mlagents/trainers/trainer_util.py @@ -111,17 +111,17 @@ def initialize_trainers( return trainers -def load_config(trainer_config_path: str) -> Dict[str, Any]: +def load_config(config_path: str) -> Dict[str, Any]: try: - with open(trainer_config_path) as data_file: + with open(config_path) as data_file: return _load_config(data_file) except IOError: raise UnityEnvironmentException( - f"Config file could not be found at {trainer_config_path}." + f"Config file could not be found at {config_path}." ) except UnicodeDecodeError: raise UnityEnvironmentException( - f"There was an error decoding Config file from {trainer_config_path}. " + f"There was an error decoding Config file from {config_path}. " f"Make sure your file is save using UTF-8" )