diff --git a/ml-agents/mlagents/trainers/curriculum.py b/ml-agents/mlagents/trainers/curriculum.py index 6962c55ff6..75ddb04a13 100644 --- a/ml-agents/mlagents/trainers/curriculum.py +++ b/ml-agents/mlagents/trainers/curriculum.py @@ -2,7 +2,7 @@ import json import math -from .exception import CurriculumError +from .exception import CurriculumConfigError, CurriculumLoadingError import logging @@ -23,14 +23,8 @@ def __init__(self, location, default_reset_parameters): # The name of the brain should be the basename of the file without the # extension. self._brain_name = os.path.basename(location).split(".")[0] + self.data = Curriculum.load_curriculum_file(location) - try: - with open(location) as data_file: - self.data = json.load(data_file) - except IOError: - raise CurriculumError("The file {0} could not be found.".format(location)) - except UnicodeDecodeError: - raise CurriculumError("There was an error decoding {}".format(location)) self.smoothing_value = 0 for key in [ "parameters", @@ -40,7 +34,7 @@ def __init__(self, location, default_reset_parameters): "signal_smoothing", ]: if key not in self.data: - raise CurriculumError( + raise CurriculumConfigError( "{0} does not contain a " "{1} field.".format(location, key) ) self.smoothing_value = 0 @@ -51,12 +45,12 @@ def __init__(self, location, default_reset_parameters): parameters = self.data["parameters"] for key in parameters: if key not in default_reset_parameters: - raise CurriculumError( + raise CurriculumConfigError( "The parameter {0} in Curriculum {1} is not present in " "the Environment".format(key, location) ) if len(parameters[key]) != self.max_lesson_num + 1: - raise CurriculumError( + raise CurriculumConfigError( "The parameter {0} in Curriculum {1} must have {2} values " "but {3} were found".format( key, location, self.max_lesson_num + 1, len(parameters[key]) @@ -117,3 +111,27 @@ def get_config(self, lesson=None): for key in parameters: config[key] = parameters[key][lesson] return config + + @staticmethod + def load_curriculum_file(location): + try: + with open(location) as data_file: + return Curriculum._load_curriculum(data_file) + except IOError: + raise CurriculumLoadingError( + "The file {0} could not be found.".format(location) + ) + except UnicodeDecodeError: + raise CurriculumLoadingError( + "There was an error decoding {}".format(location) + ) + + @staticmethod + def _load_curriculum(fp): + try: + return json.load(fp) + except json.decoder.JSONDecodeError as e: + raise CurriculumLoadingError( + "Error parsing JSON file. Please check for formatting errors. " + "A tool such as https://jsonlint.com/ can be helpful with this." + ) from e diff --git a/ml-agents/mlagents/trainers/exception.py b/ml-agents/mlagents/trainers/exception.py index d9b9921081..8dcb9dca50 100644 --- a/ml-agents/mlagents/trainers/exception.py +++ b/ml-agents/mlagents/trainers/exception.py @@ -19,6 +19,22 @@ class CurriculumError(TrainerError): pass +class CurriculumLoadingError(CurriculumError): + """ + Any error related to loading the Curriculum config file. + """ + + pass + + +class CurriculumConfigError(CurriculumError): + """ + Any error related to processing the Curriculum config file. + """ + + pass + + class MetaCurriculumError(TrainerError): """ Any error related to the configuration of a metacurriculum. diff --git a/ml-agents/mlagents/trainers/learn.py b/ml-agents/mlagents/trainers/learn.py index 82c923be82..3376462779 100644 --- a/ml-agents/mlagents/trainers/learn.py +++ b/ml-agents/mlagents/trainers/learn.py @@ -8,17 +8,17 @@ import glob import shutil import numpy as np -import yaml -from typing import Any, Callable, Dict, Optional, List, NamedTuple + +from typing import Any, Callable, Optional, List, NamedTuple from mlagents.trainers.trainer_controller import TrainerController from mlagents.trainers.exception import TrainerError from mlagents.trainers.meta_curriculum import MetaCurriculumError, MetaCurriculum -from mlagents.trainers.trainer_util import initialize_trainers +from mlagents.trainers.trainer_util import initialize_trainers, load_config from mlagents.envs.environment import UnityEnvironment from mlagents.envs.sampler_class import SamplerManager -from mlagents.envs.exception import UnityEnvironmentException, SamplerException +from mlagents.envs.exception import SamplerException from mlagents.envs.base_unity_environment import BaseUnityEnvironment from mlagents.envs.subprocess_env_manager import SubprocessEnvManager @@ -323,22 +323,6 @@ def prepare_for_docker_run(docker_target_name, env_path): return env_path -def load_config(trainer_config_path: str) -> Dict[str, Any]: - try: - with open(trainer_config_path) as data_file: - trainer_config = yaml.safe_load(data_file) - return trainer_config - except IOError: - raise UnityEnvironmentException( - "Parameter file could not be found " "at {}.".format(trainer_config_path) - ) - except UnicodeDecodeError: - raise UnityEnvironmentException( - "There was an error decoding " - "Trainer Config from this path : {}".format(trainer_config_path) - ) - - def create_environment_factory( env_path: str, docker_target_name: Optional[str], diff --git a/ml-agents/mlagents/trainers/tests/test_curriculum.py b/ml-agents/mlagents/trainers/tests/test_curriculum.py index 7759710334..84f0c2dd88 100644 --- a/ml-agents/mlagents/trainers/tests/test_curriculum.py +++ b/ml-agents/mlagents/trainers/tests/test_curriculum.py @@ -1,7 +1,9 @@ +import io +import json import pytest from unittest.mock import patch, mock_open -from mlagents.trainers.exception import CurriculumError +from mlagents.trainers.exception import CurriculumConfigError, CurriculumLoadingError from mlagents.trainers.curriculum import Curriculum @@ -60,7 +62,7 @@ def test_init_curriculum_happy_path(mock_file, location, default_reset_parameter def test_init_curriculum_bad_curriculum_raises_error( mock_file, location, default_reset_parameters ): - with pytest.raises(CurriculumError): + with pytest.raises(CurriculumConfigError): Curriculum(location, default_reset_parameters) @@ -93,3 +95,30 @@ def test_get_config(mock_file): curriculum.lesson_num = 2 assert curriculum.get_config() == {"param1": 0.3, "param2": 20, "param3": 0.7} assert curriculum.get_config(0) == {"param1": 0.7, "param2": 100, "param3": 0.2} + + +# Test json loading and error handling. These examples don't need to valid config files. + + +def test_curriculum_load_good(): + expected = {"x": 1} + value = json.dumps(expected) + fp = io.StringIO(value) + assert expected == Curriculum._load_curriculum(fp) + + +def test_curriculum_load_missing_file(): + with pytest.raises(CurriculumLoadingError): + Curriculum.load_curriculum_file("notAValidFile.json") + + +def test_curriculum_load_invalid_json(): + # This isn't valid json because of the trailing comma + contents = """ +{ + "x": [1, 2, 3,] +} +""" + fp = io.StringIO(contents) + with pytest.raises(CurriculumLoadingError): + Curriculum._load_curriculum(fp) diff --git a/ml-agents/mlagents/trainers/tests/test_sac.py b/ml-agents/mlagents/trainers/tests/test_sac.py index 1c7af068fa..bdaa1cf10d 100644 --- a/ml-agents/mlagents/trainers/tests/test_sac.py +++ b/ml-agents/mlagents/trainers/tests/test_sac.py @@ -15,7 +15,7 @@ @pytest.fixture def dummy_config(): - return yaml.load( + return yaml.safe_load( """ trainer: sac batch_size: 32 diff --git a/ml-agents/mlagents/trainers/tests/test_trainer_util.py b/ml-agents/mlagents/trainers/tests/test_trainer_util.py index fd8f3d230a..61b3d910d5 100644 --- a/ml-agents/mlagents/trainers/tests/test_trainer_util.py +++ b/ml-agents/mlagents/trainers/tests/test_trainer_util.py @@ -1,9 +1,11 @@ import pytest import yaml import os +import io from unittest.mock import patch import mlagents.trainers.trainer_util as trainer_util +from mlagents.trainers.trainer_util import load_config, _load_config from mlagents.trainers.trainer_metrics import TrainerMetrics from mlagents.trainers.ppo.trainer import PPOTrainer from mlagents.trainers.bc.offline_trainer import OfflineBCTrainer @@ -313,3 +315,30 @@ def test_initialize_invalid_trainer_raises_exception(BrainParametersMock): load_model=load_model, seed=seed, ) + + +def test_load_config_missing_file(): + with pytest.raises(UnityEnvironmentException): + load_config("thisFileDefinitelyDoesNotExist.yaml") + + +def test_load_config_valid_yaml(): + file_contents = """ +this: + - is fine + """ + fp = io.StringIO(file_contents) + res = _load_config(fp) + assert res == {"this": ["is fine"]} + + +def test_load_config_invalid_yaml(): + file_contents = """ +you: + - will +- not + - parse + """ + with pytest.raises(UnityEnvironmentException): + fp = io.StringIO(file_contents) + _load_config(fp) diff --git a/ml-agents/mlagents/trainers/trainer_util.py b/ml-agents/mlagents/trainers/trainer_util.py index 133896848b..dff7c6b416 100644 --- a/ml-agents/mlagents/trainers/trainer_util.py +++ b/ml-agents/mlagents/trainers/trainer_util.py @@ -1,4 +1,5 @@ -from typing import Any, Dict +import yaml +from typing import Any, Dict, TextIO from mlagents.trainers.meta_curriculum import MetaCurriculum from mlagents.envs.exception import UnityEnvironmentException @@ -108,3 +109,31 @@ def initialize_trainers( "brain {}".format(brain_name) ) return trainers + + +def load_config(config_path: str) -> Dict[str, Any]: + try: + with open(config_path) as data_file: + return _load_config(data_file) + except IOError: + raise UnityEnvironmentException( + f"Config file could not be found at {config_path}." + ) + except UnicodeDecodeError: + raise UnityEnvironmentException( + f"There was an error decoding Config file from {config_path}. " + f"Make sure your file is save using UTF-8" + ) + + +def _load_config(fp: TextIO) -> Dict[str, Any]: + """ + Load the yaml config from the file-like object. + """ + try: + return yaml.safe_load(fp) + except yaml.parser.ParserError as e: + raise UnityEnvironmentException( + "Error parsing yaml file. Please check for formatting errors. " + "A tool such as http://www.yamllint.com/ can be helpful with this." + ) from e