Skip to content

Commit

Permalink
Merge pull request #54 from hill-a/ppo1-mpi-test
Browse files Browse the repository at this point in the history
Add test for PPO1 to avoid regression
  • Loading branch information
hill-a committed Oct 12, 2018
2 parents 7ab842c + 971f69b commit 5f3c4b6
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 3 deletions.
6 changes: 6 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ Changelog

For download links, please look at `Github release page <https://github.com/hill-a/stable-baselines/releases>`_.

Pre Release 2.1.1.a0 (WIP)
--------------------------

- fixed MpiAdam synchronization issue in PPO1 (thanks to @brendenpetersen) issue #50


Release 2.1.0 (2018-10-2)
-------------------------

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from stable_baselines.ppo2 import PPO2
from stable_baselines.trpo_mpi import TRPO

__version__ = "2.1.0"
__version__ = "2.1.1.a0"


# patch Gym spaces to add equality functions, if not implemented
Expand Down
9 changes: 9 additions & 0 deletions stable_baselines/ppo1/experiments/train_cartpole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
Simple test to check that PPO1 is running with no errors (see issue #50)
"""
from stable_baselines import PPO1


if __name__ == '__main__':
model = PPO1('MlpPolicy', 'CartPole-v1', schedule='linear', verbose=0)
model.learn(total_timesteps=10000)
4 changes: 2 additions & 2 deletions tests/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def test_model_manipulation(model_class):
obs, reward, _, _ = env.step(action)
loaded_acc_reward += reward
loaded_acc_reward = sum(loaded_acc_reward) / N_TRIALS
# assert <10% diff
assert abs(acc_reward - loaded_acc_reward) / max(acc_reward, loaded_acc_reward) < 0.1, \
# assert <15% diff
assert abs(acc_reward - loaded_acc_reward) / max(acc_reward, loaded_acc_reward) < 0.15, \
"Error: the prediction seems to have changed between loading and saving"

# learn post loading
Expand Down
8 changes: 8 additions & 0 deletions tests/test_mpi_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,11 @@ def test_mpi_adam():
return_code = subprocess.call(['mpirun', '--allow-run-as-root', '-np', '2',
'python', '-m', 'stable_baselines.common.mpi_adam'])
_assert_eq(return_code, 0)


def test_mpi_adam_ppo1():
"""Running test for ppo1"""
return_code = subprocess.call(['mpirun', '--allow-run-as-root', '-np', '2',
'python', '-m',
'stable_baselines.ppo1.experiments.train_cartpole'])
_assert_eq(return_code, 0)

0 comments on commit 5f3c4b6

Please sign in to comment.