Skip to content

Commit

Permalink
Update tests to support pytest 5.x
Browse files Browse the repository at this point in the history
Our tests were using pytest fixtures by actually calling the fixture
methods, but in newer 5.x versions of pytest this causes test failures.
The recommended method for using fixtures is dependency injection.

This change updates the relevant test fixtures to either not use
`pytest.fixture` or to use dependency injection to pass the fixture.
The version range requirements in `test_requirements.txt` were also
updated accordingly.
  • Loading branch information
Jonathan Harper committed Dec 17, 2019
1 parent ebad90f commit 4b660ad
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 65 deletions.
1 change: 0 additions & 1 deletion ml-agents/mlagents/trainers/tests/test_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from mlagents.trainers.learn import parse_command_line


@pytest.fixture
def basic_options(extra_args=None):
extra_args = extra_args or {}
args = ["basic_path"]
Expand Down
1 change: 0 additions & 1 deletion ml-agents/mlagents/trainers/tests/test_rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from mlagents.trainers.buffer import AgentBuffer


@pytest.fixture
def dummy_config():
return yaml.safe_load(
"""
Expand Down
4 changes: 2 additions & 2 deletions ml-agents/mlagents/trainers/tests/test_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def test_sac_model_cc_vector_rnn():
sess.run(run_list, feed_dict=feed_dict)


def test_sac_save_load_buffer(tmpdir):
def test_sac_save_load_buffer(tmpdir, dummy_config):
env, mock_brain, _ = mb.setup_mock_env_and_brains(
mock.Mock(),
False,
Expand All @@ -335,7 +335,7 @@ def test_sac_save_load_buffer(tmpdir):
vector_obs_space=VECTOR_OBS_SPACE,
discrete_action_space=DISCRETE_ACTION_SPACE,
)
trainer_params = dummy_config()
trainer_params = dummy_config
trainer_params["summary_path"] = str(tmpdir)
trainer_params["model_path"] = str(tmpdir)
trainer_params["save_replay_buffer"] = True
Expand Down
67 changes: 21 additions & 46 deletions ml-agents/mlagents/trainers/tests/test_trainer_controller.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,12 @@
from unittest.mock import MagicMock, Mock, patch

from mlagents.tf_utils import tf

import yaml
import pytest

from mlagents.tf_utils import tf
from mlagents.trainers.trainer_controller import TrainerController
from mlagents.trainers.subprocess_env_manager import EnvironmentStep
from mlagents.trainers.sampler_class import SamplerManager


@pytest.fixture
def dummy_config():
return yaml.safe_load(
"""
default:
trainer: ppo
batch_size: 32
beta: 5.0e-3
buffer_size: 512
epsilon: 0.2
gamma: 0.99
hidden_units: 128
lambd: 0.95
learning_rate: 3.0e-4
max_steps: 5.0e4
normalize: true
num_epoch: 5
num_layers: 2
time_horizon: 64
sequence_length: 64
summary_freq: 1000
use_recurrent: false
memory_size: 8
use_curiosity: false
curiosity_strength: 0.0
curiosity_enc_size: 1
"""
)


@pytest.fixture
def basic_trainer_controller():
return TrainerController(
Expand Down Expand Up @@ -76,14 +43,15 @@ def test_initialization_seed(numpy_random_seed, tensorflow_set_seed):
tensorflow_set_seed.assert_called_with(seed)


def trainer_controller_with_start_learning_mocks():
@pytest.fixture
def trainer_controller_with_start_learning_mocks(basic_trainer_controller):
trainer_mock = MagicMock()
trainer_mock.get_step = 0
trainer_mock.get_max_steps = 5
trainer_mock.parameters = {"some": "parameter"}
trainer_mock.write_tensorboard_text = MagicMock()

tc = basic_trainer_controller()
tc = basic_trainer_controller
tc.initialize_trainers = MagicMock()
tc.trainers = {"testbrain": trainer_mock}
tc.advance = MagicMock()
Expand All @@ -103,8 +71,10 @@ def take_step_sideeffect(env):


@patch.object(tf, "reset_default_graph")
def test_start_learning_trains_forever_if_no_train_model(tf_reset_graph):
tc, trainer_mock = trainer_controller_with_start_learning_mocks()
def test_start_learning_trains_forever_if_no_train_model(
tf_reset_graph, trainer_controller_with_start_learning_mocks
):
tc, trainer_mock = trainer_controller_with_start_learning_mocks
tc.train_model = False

tf_reset_graph.return_value = None
Expand All @@ -123,8 +93,10 @@ def test_start_learning_trains_forever_if_no_train_model(tf_reset_graph):


@patch.object(tf, "reset_default_graph")
def test_start_learning_trains_until_max_steps_then_saves(tf_reset_graph):
tc, trainer_mock = trainer_controller_with_start_learning_mocks()
def test_start_learning_trains_until_max_steps_then_saves(
tf_reset_graph, trainer_controller_with_start_learning_mocks
):
tc, trainer_mock = trainer_controller_with_start_learning_mocks
tf_reset_graph.return_value = None

brain_info_mock = MagicMock()
Expand All @@ -140,21 +112,24 @@ def test_start_learning_trains_until_max_steps_then_saves(tf_reset_graph):
tc._save_model.assert_called_once()


def trainer_controller_with_take_step_mocks():
@pytest.fixture
def trainer_controller_with_take_step_mocks(basic_trainer_controller):
trainer_mock = MagicMock()
trainer_mock.get_step = 0
trainer_mock.get_max_steps = 5
trainer_mock.parameters = {"some": "parameter"}
trainer_mock.write_tensorboard_text = MagicMock()

tc = basic_trainer_controller()
tc = basic_trainer_controller
tc.trainers = {"testbrain": trainer_mock}

return tc, trainer_mock


def test_take_step_adds_experiences_to_trainer_and_trains():
tc, trainer_mock = trainer_controller_with_take_step_mocks()
def test_take_step_adds_experiences_to_trainer_and_trains(
trainer_controller_with_take_step_mocks
):
tc, trainer_mock = trainer_controller_with_take_step_mocks

brain_name = "testbrain"
action_info_dict = {brain_name: MagicMock()}
Expand Down Expand Up @@ -184,8 +159,8 @@ def test_take_step_adds_experiences_to_trainer_and_trains():
trainer_mock.increment_step.assert_called_once()


def test_take_step_if_not_training():
tc, trainer_mock = trainer_controller_with_take_step_mocks()
def test_take_step_if_not_training(trainer_controller_with_take_step_mocks):
tc, trainer_mock = trainer_controller_with_take_step_mocks
tc.train_model = False

brain_name = "testbrain"
Expand Down
30 changes: 16 additions & 14 deletions ml-agents/mlagents/trainers/tests/test_trainer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def dummy_config():


@pytest.fixture
def dummy_config_with_override():
base = dummy_config()
def dummy_config_with_override(dummy_config):
base = dummy_config
base["testbrain"] = {}
base["testbrain"]["normalize"] = False
return base
Expand Down Expand Up @@ -83,7 +83,9 @@ def dummy_bad_config():


@patch("mlagents.trainers.brain.BrainParameters")
def test_initialize_trainer_parameters_override_defaults(BrainParametersMock):
def test_initialize_trainer_parameters_override_defaults(
BrainParametersMock, dummy_config_with_override
):
summaries_dir = "test_dir"
run_id = "testrun"
model_path = "model_dir"
Expand All @@ -93,7 +95,7 @@ def test_initialize_trainer_parameters_override_defaults(BrainParametersMock):
seed = 11
expected_reward_buff_cap = 1

base_config = dummy_config_with_override()
base_config = dummy_config_with_override
expected_config = base_config["default"]
expected_config["summary_path"] = summaries_dir + f"/{run_id}_testbrain"
expected_config["model_path"] = model_path + "/testbrain"
Expand Down Expand Up @@ -146,7 +148,7 @@ def mock_constructor(


@patch("mlagents.trainers.brain.BrainParameters")
def test_initialize_ppo_trainer(BrainParametersMock):
def test_initialize_ppo_trainer(BrainParametersMock, dummy_config):
brain_params_mock = BrainParametersMock()
BrainParametersMock.return_value.brain_name = "testbrain"
external_brains = {"testbrain": BrainParametersMock()}
Expand All @@ -159,7 +161,7 @@ def test_initialize_ppo_trainer(BrainParametersMock):
seed = 11
expected_reward_buff_cap = 1

base_config = dummy_config()
base_config = dummy_config
expected_config = base_config["default"]
expected_config["summary_path"] = summaries_dir + f"/{run_id}_testbrain"
expected_config["model_path"] = model_path + "/testbrain"
Expand Down Expand Up @@ -205,15 +207,17 @@ def mock_constructor(


@patch("mlagents.trainers.brain.BrainParameters")
def test_initialize_invalid_trainer_raises_exception(BrainParametersMock):
def test_initialize_invalid_trainer_raises_exception(
BrainParametersMock, dummy_bad_config
):
summaries_dir = "test_dir"
run_id = "testrun"
model_path = "model_dir"
keep_checkpoints = 1
train_model = True
load_model = False
seed = 11
bad_config = dummy_bad_config()
bad_config = dummy_bad_config
BrainParametersMock.return_value.brain_name = "testbrain"
external_brains = {"testbrain": BrainParametersMock()}

Expand All @@ -233,13 +237,12 @@ def test_initialize_invalid_trainer_raises_exception(BrainParametersMock):
trainers[brain_name] = trainer_factory.generate(brain_parameters)


def test_handles_no_default_section():
def test_handles_no_default_section(dummy_config):
"""
Make sure the trainer setup handles a missing "default" in the config.
"""
brain_name = "testbrain"
config = dummy_config()
no_default_config = {brain_name: config["default"]}
no_default_config = {brain_name: dummy_config["default"]}
brain_parameters = BrainParameters(
brain_name=brain_name,
vector_observation_space_size=1,
Expand All @@ -262,14 +265,13 @@ def test_handles_no_default_section():
trainer_factory.generate(brain_parameters)


def test_raise_if_no_config_for_brain():
def test_raise_if_no_config_for_brain(dummy_config):
"""
Make sure the trainer setup raises a friendlier exception if both "default" and the brain name
are missing from the config.
"""
brain_name = "testbrain"
config = dummy_config()
bad_config = {"some_other_brain": config["default"]}
bad_config = {"some_other_brain": dummy_config["default"]}
brain_parameters = BrainParameters(
brain_name=brain_name,
vector_observation_space_size=1,
Expand Down
2 changes: 1 addition & 1 deletion test_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Test-only dependencies should go here, not in setup.py
pytest>=3.2.2,<4.0.0
pytest>=3.2.2,<6.0.0
pytest-cov==2.6.1

0 comments on commit 4b660ad

Please sign in to comment.