Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions ml-agents/mlagents/trainers/curriculum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import math

from .exception import CurriculumError
from .exception import CurriculumConfigError, CurriculumLoadingError

import logging

Expand All @@ -23,14 +23,8 @@ def __init__(self, location, default_reset_parameters):
# The name of the brain should be the basename of the file without the
# extension.
self._brain_name = os.path.basename(location).split(".")[0]
self.data = Curriculum.load_curriculum_file(location)

try:
with open(location) as data_file:
self.data = json.load(data_file)
except IOError:
raise CurriculumError("The file {0} could not be found.".format(location))
except UnicodeDecodeError:
raise CurriculumError("There was an error decoding {}".format(location))
self.smoothing_value = 0
for key in [
"parameters",
Expand All @@ -40,7 +34,7 @@ def __init__(self, location, default_reset_parameters):
"signal_smoothing",
]:
if key not in self.data:
raise CurriculumError(
raise CurriculumConfigError(
"{0} does not contain a " "{1} field.".format(location, key)
)
self.smoothing_value = 0
Expand All @@ -51,12 +45,12 @@ def __init__(self, location, default_reset_parameters):
parameters = self.data["parameters"]
for key in parameters:
if key not in default_reset_parameters:
raise CurriculumError(
raise CurriculumConfigError(
"The parameter {0} in Curriculum {1} is not present in "
"the Environment".format(key, location)
)
if len(parameters[key]) != self.max_lesson_num + 1:
raise CurriculumError(
raise CurriculumConfigError(
"The parameter {0} in Curriculum {1} must have {2} values "
"but {3} were found".format(
key, location, self.max_lesson_num + 1, len(parameters[key])
Expand Down Expand Up @@ -117,3 +111,27 @@ def get_config(self, lesson=None):
for key in parameters:
config[key] = parameters[key][lesson]
return config

@staticmethod
def load_curriculum_file(location):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept these in the same file, but I can move them to trainer_util.py if that's cleaner. It's currently the only place we load json.

try:
with open(location) as data_file:
return Curriculum._load_curriculum(data_file)
except IOError:
raise CurriculumLoadingError(
"The file {0} could not be found.".format(location)
)
except UnicodeDecodeError:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be handled in _load_curriculum instead, but I couldn't reproduce it to confirm :/

raise CurriculumLoadingError(
"There was an error decoding {}".format(location)
)

@staticmethod
def _load_curriculum(fp):
try:
return json.load(fp)
except json.decoder.JSONDecodeError as e:
raise CurriculumLoadingError(
"Error parsing JSON file. Please check for formatting errors. "
"A tool such as https://jsonlint.com/ can be helpful with this."
) from e
16 changes: 16 additions & 0 deletions ml-agents/mlagents/trainers/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,22 @@ class CurriculumError(TrainerError):
pass


class CurriculumLoadingError(CurriculumError):
"""
Any error related to loading the Curriculum config file.
"""

pass


class CurriculumConfigError(CurriculumError):
"""
Any error related to processing the Curriculum config file.
"""

pass


class MetaCurriculumError(TrainerError):
"""
Any error related to the configuration of a metacurriculum.
Expand Down
24 changes: 4 additions & 20 deletions ml-agents/mlagents/trainers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
import glob
import shutil
import numpy as np
import yaml
from typing import Any, Callable, Dict, Optional, List, NamedTuple

from typing import Any, Callable, Optional, List, NamedTuple


from mlagents.trainers.trainer_controller import TrainerController
from mlagents.trainers.exception import TrainerError
from mlagents.trainers.meta_curriculum import MetaCurriculumError, MetaCurriculum
from mlagents.trainers.trainer_util import initialize_trainers
from mlagents.trainers.trainer_util import initialize_trainers, load_config
from mlagents.envs.environment import UnityEnvironment
from mlagents.envs.sampler_class import SamplerManager
from mlagents.envs.exception import UnityEnvironmentException, SamplerException
from mlagents.envs.exception import SamplerException
from mlagents.envs.base_unity_environment import BaseUnityEnvironment
from mlagents.envs.subprocess_env_manager import SubprocessEnvManager

Expand Down Expand Up @@ -323,22 +323,6 @@ def prepare_for_docker_run(docker_target_name, env_path):
return env_path


def load_config(trainer_config_path: str) -> Dict[str, Any]:
try:
with open(trainer_config_path) as data_file:
trainer_config = yaml.safe_load(data_file)
return trainer_config
except IOError:
raise UnityEnvironmentException(
"Parameter file could not be found " "at {}.".format(trainer_config_path)
)
except UnicodeDecodeError:
raise UnityEnvironmentException(
"There was an error decoding "
"Trainer Config from this path : {}".format(trainer_config_path)
)


def create_environment_factory(
env_path: str,
docker_target_name: Optional[str],
Expand Down
33 changes: 31 additions & 2 deletions ml-agents/mlagents/trainers/tests/test_curriculum.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import io
import json
import pytest
from unittest.mock import patch, mock_open

from mlagents.trainers.exception import CurriculumError
from mlagents.trainers.exception import CurriculumConfigError, CurriculumLoadingError
from mlagents.trainers.curriculum import Curriculum


Expand Down Expand Up @@ -60,7 +62,7 @@ def test_init_curriculum_happy_path(mock_file, location, default_reset_parameter
def test_init_curriculum_bad_curriculum_raises_error(
mock_file, location, default_reset_parameters
):
with pytest.raises(CurriculumError):
with pytest.raises(CurriculumConfigError):
Curriculum(location, default_reset_parameters)


Expand Down Expand Up @@ -93,3 +95,30 @@ def test_get_config(mock_file):
curriculum.lesson_num = 2
assert curriculum.get_config() == {"param1": 0.3, "param2": 20, "param3": 0.7}
assert curriculum.get_config(0) == {"param1": 0.7, "param2": 100, "param3": 0.2}


# Test json loading and error handling. These examples don't need to valid config files.


def test_curriculum_load_good():
expected = {"x": 1}
value = json.dumps(expected)
fp = io.StringIO(value)
assert expected == Curriculum._load_curriculum(fp)


def test_curriculum_load_missing_file():
with pytest.raises(CurriculumLoadingError):
Curriculum.load_curriculum_file("notAValidFile.json")


def test_curriculum_load_invalid_json():
# This isn't valid json because of the trailing comma
contents = """
{
"x": [1, 2, 3,]
}
"""
fp = io.StringIO(contents)
with pytest.raises(CurriculumLoadingError):
Curriculum._load_curriculum(fp)
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/tests/test_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@pytest.fixture
def dummy_config():
return yaml.load(
return yaml.safe_load(
"""
trainer: sac
batch_size: 32
Expand Down
29 changes: 29 additions & 0 deletions ml-agents/mlagents/trainers/tests/test_trainer_util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import pytest
import yaml
import os
import io
from unittest.mock import patch

import mlagents.trainers.trainer_util as trainer_util
from mlagents.trainers.trainer_util import load_config, _load_config
from mlagents.trainers.trainer_metrics import TrainerMetrics
from mlagents.trainers.ppo.trainer import PPOTrainer
from mlagents.trainers.bc.offline_trainer import OfflineBCTrainer
Expand Down Expand Up @@ -313,3 +315,30 @@ def test_initialize_invalid_trainer_raises_exception(BrainParametersMock):
load_model=load_model,
seed=seed,
)


def test_load_config_missing_file():
with pytest.raises(UnityEnvironmentException):
load_config("thisFileDefinitelyDoesNotExist.yaml")


def test_load_config_valid_yaml():
file_contents = """
this:
- is fine
"""
fp = io.StringIO(file_contents)
res = _load_config(fp)
assert res == {"this": ["is fine"]}


def test_load_config_invalid_yaml():
file_contents = """
you:
- will
- not
- parse
"""
with pytest.raises(UnityEnvironmentException):
fp = io.StringIO(file_contents)
_load_config(fp)
31 changes: 30 additions & 1 deletion ml-agents/mlagents/trainers/trainer_util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict
import yaml
from typing import Any, Dict, TextIO

from mlagents.trainers.meta_curriculum import MetaCurriculum
from mlagents.envs.exception import UnityEnvironmentException
Expand Down Expand Up @@ -108,3 +109,31 @@ def initialize_trainers(
"brain {}".format(brain_name)
)
return trainers


def load_config(config_path: str) -> Dict[str, Any]:
try:
with open(config_path) as data_file:
return _load_config(data_file)
except IOError:
raise UnityEnvironmentException(
f"Config file could not be found at {config_path}."
)
except UnicodeDecodeError:
raise UnityEnvironmentException(
f"There was an error decoding Config file from {config_path}. "
f"Make sure your file is save using UTF-8"
)


def _load_config(fp: TextIO) -> Dict[str, Any]:
"""
Load the yaml config from the file-like object.
"""
try:
return yaml.safe_load(fp)
except yaml.parser.ParserError as e:
raise UnityEnvironmentException(
"Error parsing yaml file. Please check for formatting errors. "
"A tool such as http://www.yamllint.com/ can be helpful with this."
) from e