-
Notifications
You must be signed in to change notification settings - Fork 60
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Make mpi optional * Remove PR note * Update docs * Update docs and changelog * Fix testing name error * Add mpi4py-not-installed test * Add import stable_baselines mpi test * Remove redundant pass statement * Use FQN `sys.modules` for clarity * Move import into disabled context for simplicity * Linting
- Loading branch information
1 parent
8ceda3b
commit 9a76054
Showing
14 changed files
with
147 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,24 @@ | ||
from stable_baselines.a2c import A2C | ||
from stable_baselines.acer import ACER | ||
from stable_baselines.acktr import ACKTR | ||
from stable_baselines.ddpg import DDPG | ||
from stable_baselines.deepq import DQN | ||
from stable_baselines.her import HER | ||
from stable_baselines.gail import GAIL | ||
from stable_baselines.ppo1 import PPO1 | ||
from stable_baselines.ppo2 import PPO2 | ||
from stable_baselines.td3 import TD3 | ||
from stable_baselines.trpo_mpi import TRPO | ||
from stable_baselines.sac import SAC | ||
|
||
|
||
# Load mpi4py-dependent algorithms only if mpi is installed. | ||
try: | ||
import mpi4py | ||
except ImportError: | ||
mpi4py = None | ||
|
||
if mpi4py is not None: | ||
from stable_baselines.ddpg import DDPG | ||
from stable_baselines.gail import GAIL | ||
from stable_baselines.ppo1 import PPO1 | ||
from stable_baselines.trpo_mpi import TRPO | ||
del mpi4py | ||
|
||
__version__ = "2.7.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,32 @@ | ||
from contextlib import contextmanager | ||
import sys | ||
|
||
|
||
def _assert_eq(left, right): | ||
assert left == right, '{} != {}'.format(left, right) | ||
|
||
|
||
def _assert_neq(left, right): | ||
assert left != right, '{} == {}'.format(left, right) | ||
|
||
|
||
@contextmanager | ||
def _maybe_disable_mpi(mpi_disabled): | ||
"""A context that can temporarily remove the mpi4py import. | ||
Useful for testing whether non-MPI algorithms work as intended when | ||
mpi4py isn't installed. | ||
Args: | ||
disable_mpi (bool): If True, then this context temporarily removes | ||
the mpi4py import from `sys.modules` | ||
""" | ||
if mpi_disabled and "mpi4py" in sys.modules: | ||
temp = sys.modules["mpi4py"] | ||
try: | ||
sys.modules["mpi4py"] = None | ||
yield | ||
finally: | ||
sys.modules["mpi4py"] = temp | ||
else: | ||
yield |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from .test_common import _maybe_disable_mpi | ||
|
||
def test_no_mpi_no_crash(): | ||
with _maybe_disable_mpi(True): | ||
import stable_baselines | ||
del stable_baselines # keep Codacy happy |