Skip to content

Commit

Permalink
Add more extensive tests for BC trainer (#2506)
Browse files Browse the repository at this point in the history
* Add more extensive tests for BC trainer
* Break up tests for BC trainer
  • Loading branch information
Ervin T committed Sep 9, 2019
1 parent d2ceb9f commit cc6c22f
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 9 deletions.
11 changes: 6 additions & 5 deletions ml-agents/mlagents/trainers/tests/mock_brain.py
Expand Up @@ -91,12 +91,13 @@ def setup_mock_unityenvironment(mock_env, mock_brain, mock_braininfo):
:Mock mock_brain: A mock Brain object that specifies the params of this environment.
:Mock mock_braininfo: A mock BrainInfo object that will be returned at each step and reset.
"""
brain_name = mock_brain.brain_name
mock_env.return_value.academy_name = "MockAcademy"
mock_env.return_value.brains = {"MockBrain": mock_brain}
mock_env.return_value.external_brain_names = ["MockBrain"]
mock_env.return_value.brain_names = ["MockBrain"]
mock_env.return_value.reset.return_value = {"MockBrain": mock_braininfo}
mock_env.return_value.step.return_value = {"MockBrain": mock_braininfo}
mock_env.return_value.brains = {brain_name: mock_brain}
mock_env.return_value.external_brain_names = [brain_name]
mock_env.return_value.brain_names = [brain_name]
mock_env.return_value.reset.return_value = {brain_name: mock_braininfo}
mock_env.return_value.step.return_value = {brain_name: mock_braininfo}


def simulate_rollout(env, policy, buffer_init_samples):
Expand Down
50 changes: 46 additions & 4 deletions ml-agents/mlagents/trainers/tests/test_bc.py
Expand Up @@ -18,9 +18,9 @@
def dummy_config():
return yaml.safe_load(
"""
hidden_units: 128
hidden_units: 32
learning_rate: 3.0e-4
num_layers: 2
num_layers: 1
use_recurrent: false
sequence_length: 32
memory_size: 32
Expand All @@ -32,8 +32,8 @@ def dummy_config():
)


@mock.patch("mlagents.envs.UnityEnvironment")
def test_bc_trainer(mock_env, dummy_config):
def create_bc_trainer(dummy_config):
mock_env = mock.Mock()
mock_brain = mb.create_mock_3dball_brain()
mock_braininfo = mb.create_mock_braininfo(num_agents=12, num_vector_observations=8)
mb.setup_mock_unityenvironment(mock_env, mock_brain, mock_braininfo)
Expand All @@ -49,12 +49,54 @@ def test_bc_trainer(mock_env, dummy_config):
mock_brain, trainer_parameters, training=True, load=False, seed=0, run_id=0
)
trainer.demonstration_buffer = mb.simulate_rollout(env, trainer.policy, 100)
return trainer, env


def test_bc_trainer_step(dummy_config):
trainer, env = create_bc_trainer(dummy_config)
# Test get_step
assert trainer.get_step == 0
# Test update policy
trainer.update_policy()
assert len(trainer.stats["Losses/Cloning Loss"]) > 0
# Test increment step
trainer.increment_step(1)
assert trainer.step == 1


def test_bc_trainer_add_proc_experiences(dummy_config):
trainer, env = create_bc_trainer(dummy_config)
# Test add_experiences
returned_braininfo = env.step()
trainer.add_experiences(
returned_braininfo, returned_braininfo, {}
) # Take action outputs is not used
for agent_id in returned_braininfo["Ball3DBrain"].agents:
assert trainer.evaluation_buffer[agent_id].last_brain_info is not None
assert trainer.episode_steps[agent_id] > 0
assert trainer.cumulative_rewards[agent_id] > 0
# Test process_experiences by setting done
returned_braininfo["Ball3DBrain"].local_done = 12 * [True]
trainer.process_experiences(returned_braininfo, returned_braininfo)
for agent_id in returned_braininfo["Ball3DBrain"].agents:
assert trainer.episode_steps[agent_id] == 0
assert trainer.cumulative_rewards[agent_id] == 0


def test_bc_trainer_end_episode(dummy_config):
trainer, env = create_bc_trainer(dummy_config)
returned_braininfo = env.step()
trainer.add_experiences(
returned_braininfo, returned_braininfo, {}
) # Take action outputs is not used
trainer.process_experiences(returned_braininfo, returned_braininfo)
# Should set everything to 0
trainer.end_episode()
for agent_id in returned_braininfo["Ball3DBrain"].agents:
assert trainer.episode_steps[agent_id] == 0
assert trainer.cumulative_rewards[agent_id] == 0


@mock.patch("mlagents.envs.UnityEnvironment.executable_launcher")
@mock.patch("mlagents.envs.UnityEnvironment.get_communicator")
def test_bc_policy_evaluate(mock_communicator, mock_launcher, dummy_config):
Expand Down

0 comments on commit cc6c22f

Please sign in to comment.