Skip to content

Commit

Permalink
Make MPI dependency optional (#433)
Browse files Browse the repository at this point in the history
* Make mpi optional

* Remove PR note

* Update docs

* Update docs and changelog

* Fix testing name error

* Add mpi4py-not-installed test

* Add import stable_baselines mpi test

* Remove redundant pass statement

* Use FQN `sys.modules` for clarity

* Move import into disabled context for simplicity

* Linting
  • Loading branch information
shwang authored and AdamGleave committed Aug 5, 2019
1 parent 8ceda3b commit 9a76054
Show file tree
Hide file tree
Showing 14 changed files with 147 additions and 27 deletions.
16 changes: 16 additions & 0 deletions docs/guide/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,22 @@ Stable Release
pip install stable-baselines
.. _openmpi:

Stable Release with OpenMPI
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
GAIL, DDPG, TRPO, and PPO1 parallelize training using OpenMPI. OpenMPI has had weird
interactions with Tensorflow in the past (see
`Issue #430 <https://github.com/hill-a/stable-baselines/issues/430>`) and so is disabled by
default.

.. code-block:: bash
pip install stable-baselines[mpi]
To disable OpenMPI, uninstall ``mpi4py`` with ``pip uninstall mpi4py``.


Bleeding-edge version
---------------------

Expand Down
31 changes: 31 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,37 @@ Changelog

For download links, please look at `Github release page <https://github.com/hill-a/stable-baselines/releases>`_.

Pre-Release 2.7.1a0 (WIP)
-------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- OpenMPI-dependent algorithms (PPO1, TRPO, GAIL, DDPG) are disabled in the
default installation of stable_baselines. `mpi4py` is now installed as an
extra. When `mpi4py` is not available, stable-baselines skips imports of
OpenMPI-dependent algorithms.
See :ref:`installation notes <openmpi>` and
`Issue #430 <https://github.com/hill-a/stable-baselines/issues/430>`.

New Features:
^^^^^^^^^^^^^

Bug Fixes:
^^^^^^^^^^
- Skip automatic imports of OpenMPI-dependent algorithms to avoid an issue
where OpenMPI would cause stable-baselines to hang on Ubuntu installs.
See :ref:`installation notes <openmpi>` and
`Issue #430 <https://github.com/hill-a/stable-baselines/issues/430>`.

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^

Documentation:
^^^^^^^^^^^^^^


Release 2.7.0 (2019-07-31)
--------------------------
Expand Down
4 changes: 4 additions & 0 deletions docs/modules/ddpg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ DDPG
====
`Deep Deterministic Policy Gradient (DDPG) <https://arxiv.org/abs/1509.02971>`_

.. note::

DDPG requires :ref:`OpenMPI <openmpi>`. If OpenMPI isn't enabled, then DDPG isn't
imported into the `stable_baselines` module.

.. warning::

Expand Down
5 changes: 5 additions & 0 deletions docs/modules/gail.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ Learning a cost function from expert demonstrations is called Inverse Reinforcem
The connection between GAIL and Generative Adversarial Networks (GANs) is that it uses a discriminator that tries
to seperate expert trajectory from trajectories of the learned policy, which has the role of the generator here.

.. note::

GAIL requires :ref:`OpenMPI <openmpi>`. If OpenMPI isn't enabled, then GAIL isn't
imported into the `stable_baselines` module.


Notes
-----
Expand Down
11 changes: 8 additions & 3 deletions docs/modules/ppo1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,18 @@ PPO1
The `Proximal Policy Optimization <https://arxiv.org/abs/1707.06347>`_ algorithm combines ideas from A2C (having multiple workers)
and TRPO (it uses a trust region to improve the actor).

The main idea is that after an update, the new policy should be not too far form the `old` policy.
The main idea is that after an update, the new policy should be not too far from the `old` policy.
For that, ppo uses clipping to avoid too large update.

.. note::

PPO2 is the implementation of OpenAI made for GPU. For multiprocessing, it uses vectorized environments
compared to PPO1 which uses MPI.
PPO1 requires :ref:`OpenMPI <openmpi>`. If OpenMPI isn't enabled, then PPO1 isn't
imported into the `stable_baselines` module.

.. note::

PPO1 uses MPI for multiprocessing unlike PPO2, which uses vectorized environments.
PPO2 is the implementation OpenAI made for GPU.

Notes
-----
Expand Down
5 changes: 5 additions & 0 deletions docs/modules/trpo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ TRPO
`Trust Region Policy Optimization (TRPO) <https://arxiv.org/abs/1502.05477>`_
is an iterative approach for optimizing policies with guaranteed monotonic improvement.

.. note::

TRPO requires :ref:`OpenMPI <openmpi>`. If OpenMPI isn't enabled, then TRPO isn't
imported into the `stable_baselines` module.

Notes
-----

Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,16 @@
'gym[atari,classic_control]>=0.10.9',
'scipy',
'joblib',
'mpi4py',
'cloudpickle>=0.5.5',
'opencv-python',
'numpy',
'pandas',
'matplotlib'
] + tf_dependency,
extras_require={
'mpi': [
'mpi4py',
],
'tests': [
'pytest',
'pytest-cov',
Expand Down
18 changes: 14 additions & 4 deletions stable_baselines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
from stable_baselines.a2c import A2C
from stable_baselines.acer import ACER
from stable_baselines.acktr import ACKTR
from stable_baselines.ddpg import DDPG
from stable_baselines.deepq import DQN
from stable_baselines.her import HER
from stable_baselines.gail import GAIL
from stable_baselines.ppo1 import PPO1
from stable_baselines.ppo2 import PPO2
from stable_baselines.td3 import TD3
from stable_baselines.trpo_mpi import TRPO
from stable_baselines.sac import SAC


# Load mpi4py-dependent algorithms only if mpi is installed.
try:
import mpi4py
except ImportError:
mpi4py = None

if mpi4py is not None:
from stable_baselines.ddpg import DDPG
from stable_baselines.gail import GAIL
from stable_baselines.ppo1 import PPO1
from stable_baselines.trpo_mpi import TRPO
del mpi4py

__version__ = "2.7.0"
5 changes: 2 additions & 3 deletions stable_baselines/common/cmd_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

import os

from mpi4py import MPI
import gym
from gym.wrappers import FlattenDictWrapper

from stable_baselines import logger
from stable_baselines.bench import Monitor
from stable_baselines.common import set_global_seeds
from stable_baselines.common.atari_wrappers import make_atari, wrap_deepmind
from stable_baselines.common.misc_util import mpi_rank_or_zero
from stable_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv


Expand Down Expand Up @@ -60,8 +60,7 @@ def make_mujoco_env(env_id, seed, allow_early_resets=True):
:param allow_early_resets: (bool) allows early reset of the environment
:return: (Gym Environment) The mujoco environment
"""
rank = MPI.COMM_WORLD.Get_rank()
set_global_seeds(seed + 10000 * rank)
set_global_seeds(seed + 10000 * mpi_rank_or_zero())
env = gym.make(env_id)
env = Monitor(env, os.path.join(logger.get_dir(), str(rank)), allow_early_resets=allow_early_resets)
env.seed(seed)
Expand Down
10 changes: 10 additions & 0 deletions stable_baselines/common/misc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,13 @@ def pickle_load(path, compression=False):
else:
with open(path, "rb") as file_handler:
return pickle.load(file_handler)

def mpi_rank_or_zero():
"""
Return the MPI rank if mpi is installed. Otherwise, return 0.
"""
try:
from mpi4py import MPI
return MPI.COMM_WORLD.Get_rank()
except ImportError:
return 0
10 changes: 4 additions & 6 deletions stable_baselines/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from tensorflow.core.util import event_pb2
from tensorflow.python.util import compat

from stable_baselines.common.misc_util import mpi_rank_or_zero

DEBUG = 10
INFO = 20
WARN = 30
Expand Down Expand Up @@ -579,15 +581,11 @@ def configure(folder=None, format_strs=None):
os.makedirs(folder, exist_ok=True)

log_suffix = ''
from mpi4py import MPI
rank = MPI.COMM_WORLD.Get_rank()
if rank > 0:
log_suffix = "-rank%03i" % rank

if format_strs is None:
if rank == 0:
if mpi_rank_or_zero() == 0:
format_strs = os.getenv('OPENAI_LOG_FORMAT', 'stdout,log,csv').split(',')
else:
log_suffix = "-rank%03i" % rank
format_strs = os.getenv('OPENAI_LOG_FORMAT_MPI', 'log').split(',')
format_strs = filter(None, format_strs)
output_formats = [make_output_format(f, folder, log_suffix) for f in format_strs]
Expand Down
26 changes: 26 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,32 @@
from contextlib import contextmanager
import sys


def _assert_eq(left, right):
assert left == right, '{} != {}'.format(left, right)


def _assert_neq(left, right):
assert left != right, '{} == {}'.format(left, right)


@contextmanager
def _maybe_disable_mpi(mpi_disabled):
"""A context that can temporarily remove the mpi4py import.
Useful for testing whether non-MPI algorithms work as intended when
mpi4py isn't installed.
Args:
disable_mpi (bool): If True, then this context temporarily removes
the mpi4py import from `sys.modules`
"""
if mpi_disabled and "mpi4py" in sys.modules:
temp = sys.modules["mpi4py"]
try:
sys.modules["mpi4py"] = None
yield
finally:
sys.modules["mpi4py"] = temp
else:
yield
23 changes: 13 additions & 10 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np

from stable_baselines.logger import make_output_format, read_tb, read_csv, read_json, _demo
from .test_common import _maybe_disable_mpi


KEY_VALUES = {
Expand All @@ -24,21 +25,23 @@ def test_main():


@pytest.mark.parametrize('_format', ['tensorboard', 'stdout', 'log', 'json', 'csv'])
def test_make_output(_format):
@pytest.mark.parametrize('mpi_disabled', [False, True])
def test_make_output(_format, mpi_disabled):
"""
test make output
:param _format: (str) output format
"""
writer = make_output_format(_format, LOG_DIR)
writer.writekvs(KEY_VALUES)
if _format == 'tensorboard':
read_tb(LOG_DIR)
elif _format == "csv":
read_csv(LOG_DIR + 'progress.csv')
elif _format == 'json':
read_json(LOG_DIR + 'progress.json')
writer.close()
with _maybe_disable_mpi(mpi_disabled):
writer = make_output_format(_format, LOG_DIR)
writer.writekvs(KEY_VALUES)
if _format == 'tensorboard':
read_tb(LOG_DIR)
elif _format == "csv":
read_csv(LOG_DIR + 'progress.csv')
elif _format == 'json':
read_json(LOG_DIR + 'progress.json')
writer.close()


def test_make_output_fail():
Expand Down
6 changes: 6 additions & 0 deletions tests/test_no_mpi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .test_common import _maybe_disable_mpi

def test_no_mpi_no_crash():
with _maybe_disable_mpi(True):
import stable_baselines
del stable_baselines # keep Codacy happy

0 comments on commit 9a76054

Please sign in to comment.