Skip to content

Commit

Permalink
Merge pull request #19 from hill-a/fixes
Browse files Browse the repository at this point in the history
Several Fixes
  • Loading branch information
hill-a committed Sep 15, 2018
2 parents bd852c8 + 6244bf0 commit 0a3948f
Show file tree
Hide file tree
Showing 11 changed files with 32 additions and 14 deletions.
16 changes: 13 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
<img src="docs/\_static/img/logo.png" align="right" width="40%"/>

[![Build Status](https://travis-ci.com/hill-a/stable-baselines.svg?branch=master)](https://travis-ci.com/hill-a/stable-baselines) [![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines.readthedocs.io/en/master/?badge=master) [![Codacy Badge](https://api.codacy.com/project/badge/Grade/3bcb4cd6d76a4270acb16b5fe6dd9efa)](https://www.codacy.com/app/baselines_janitors/stable-baselines?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=hill-a/stable-baselines&amp;utm_campaign=Badge_Grade) [![Codacy Badge](https://api.codacy.com/project/badge/Coverage/3bcb4cd6d76a4270acb16b5fe6dd9efa)](https://www.codacy.com/app/baselines_janitors/stable-baselines?utm_source=github.com&utm_medium=referral&utm_content=hill-a/stable-baselines&utm_campaign=Badge_Coverage)

# Stable Baselines
Expand Down Expand Up @@ -27,16 +29,18 @@ This toolset is a fork of OpenAI Baselines, with a major structural refactoring,
| Tensorboard support | :heavy_check_mark: | :heavy_minus_sign: <sup>(4)</sup> |
| Ipython / Notebook friendly | :heavy_check_mark: | :x: |
| PEP8 code style | :heavy_check_mark: | :heavy_minus_sign: <sup>(5)</sup> |
| Custom callback | :heavy_check_mark: | :heavy_minus_sign: <sup>(6)</sup> |

<sup><sup>(1): Forked from previous version of OpenAI baselines, however missing refactoring for HER.</sup></sup><br>
<sup><sup>(2): Currently not available for DDPG, and only from the run script. </sup></sup><br>
<sup><sup>(3): Only via the run script.</sup></sup><br>
<sup><sup>(4): Rudimentary logging of training information (no loss nor graph). </sup></sup><br>
<sup><sup>(5): WIP on OpenAI's side (you can do it OpenAI! :cat:)</sup></sup><br>
<sup><sup>(6): Passing a callback function is only available for DQN</sup></sup><br>

## Documentation

Documentation is available online: [http://stable-baselines.readthedocs.io/](http://stable-baselines.readthedocs.io/)
Documentation is available online: [https://stable-baselines.readthedocs.io/](https://stable-baselines.readthedocs.io/)

## Installation

Expand All @@ -63,7 +67,7 @@ Using pip from pypi:
pip install stable-baselines
```

Please read the [documentation](http://stable-baselines.readthedocs.io/) for more details and alternatives (from source, using docker).
Please read the [documentation](https://stable-baselines.readthedocs.io/) for more details and alternatives (from source, using docker).


## Example
Expand Down Expand Up @@ -99,7 +103,7 @@ from stable_baselines import PPO2
model = PPO2('MlpPolicy', 'CartPole-v1').learn(10000)
```

Please read the [documentation](http://stable-baselines.readthedocs.io/) for more examples.
Please read the [documentation](https://stable-baselines.readthedocs.io/) for more examples.


## Try it online with Colab Notebooks !
Expand Down Expand Up @@ -174,3 +178,9 @@ If you want to contribute, please open an issue first and then propose your pull
Nice to have (for the future):
- [ ] Continuous actions support for ACER
- [ ] Continuous actions support for ACKTR

## Acknowledgments

Stable Baselines was created in the [robotics lab U2IS](http://u2is.ensta-paristech.fr/index.php?lang=en) ([INRIA Flowers](https://flowers.inria.fr/) team) at [ENSTA ParisTech](http://www.ensta-paristech.fr/en).

Logo credits: L.M. Tenkes
Binary file added docs/_static/img/logo.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def __getattr__(cls, name):
else:
html_theme = 'sphinx_rtd_theme'

html_logo = '_static/img/logo.png'


def setup(app):
app.add_stylesheet("css/baselines_theme.css")

Expand Down
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Changelog

For download links, please look at `Github release page <https://github.com/hill-a/stable-baselines/releases>`_.

Master version 1.0.8.rc0 (TO BE RELEASED SOON)
Master version 1.0.8.rc1 (TO BE RELEASED SOON)
-----------------------------------------------

**Tensorboard and bug fixes**
Expand All @@ -30,6 +30,7 @@ Master version 1.0.8.rc0 (TO BE RELEASED SOON)
- fixed DummyVecEnv not copying the observation array when stepping and resetting
- added pre-built docker images + installation instructions
- added ``deterministic`` argument in the predict function
- added assert in PPO2 for recurrent policies


Release 1.0.7 (2018-08-29)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
from stable_baselines.ppo2 import PPO2
from stable_baselines.trpo_mpi import TRPO

__version__ = "1.0.8.rc0"
__version__ = "1.0.8.rc2"
3 changes: 1 addition & 2 deletions stable_baselines/common/identity_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(self, low=-1, high=1, eps=0.05, ep_length=100):
super(IdentityEnvBox, self).__init__(1, ep_length)
self.action_space = Box(low=low, high=high, shape=(1,), dtype=np.float32)
self.observation_space = self.action_space
# TODO: test with epsilon instead of just pos/neg actions
self.eps = eps
self.reset()

Expand All @@ -75,7 +74,7 @@ def _choose_next_state(self):
self.state = self.observation_space.sample()

def _get_reward(self, action):
return 1 if action * self.state > 0 else 0
return 1 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0


class IdentityEnvMultiDiscrete(IdentityEnv):
Expand Down
7 changes: 5 additions & 2 deletions stable_baselines/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class SubprocVecEnv(VecEnv):
:param env_fns: ([Gym Environment]) Environments to run in subprocesses
"""

def __init__(self, env_fns):
self.waiting = False
self.closed = False
Expand Down Expand Up @@ -87,8 +87,11 @@ def close(self):

def render(self, mode='human', **kwargs):
for pipe in self.remotes:
pipe.send(('render', kwargs))
# gather images from subprocesses
# `mode` will be taken into account later
pipe.send(('render', {'mode':'rgb_array', **kwargs}))
imgs = [pipe.recv() for pipe in self.remotes]
# Create a big image by tiling images from subprocesses
bigimg = tile_images(imgs)
if mode == 'human':
import cv2
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_
step = (int(t_train * (self.nb_rollout_steps / self.nb_train_steps)) +
total_steps - self.nb_rollout_steps)

critic_loss, actor_loss = self._train_step(t_train, writer, log=t_train == 0)
critic_loss, actor_loss = self._train_step(step, writer, log=t_train == 0)
epoch_critic_losses.append(critic_loss)
epoch_actor_losses.append(actor_loss)
self._update_target_net()
Expand Down Expand Up @@ -929,7 +929,7 @@ def as_scalar(scalar):
with open(os.path.join(logdir, 'eval_env_state.pkl'), 'wb') as file_handler:
pickle.dump(self.eval_env.get_state(), file_handler)

def predict(self, observation, state=None, mask=None, deterministic=False):
def predict(self, observation, state=None, mask=None, deterministic=True):
observation = np.array(observation).reshape(self.observation_space.shape)

action, _ = self._policy(observation, apply_noise=not deterministic, compute_q=False)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/ddpg/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256
super(DDPGPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=n_lstm, reuse=reuse,
scale=scale, add_action_ph=True)
assert isinstance(ac_space, Box), "Error: the action space must be of type gym.spaces.Box"
assert np.abs(ac_space.low) == ac_space.high, "Error: the action space low and high must be symetric"
assert (np.abs(ac_space.low) == ac_space.high).all(), "Error: the action space low and high must be symmetric"
self.value_fn = None
self.policy = None

Expand Down
4 changes: 3 additions & 1 deletion stable_baselines/ppo2/ppo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class PPO2(BaseRLModel):
:param max_grad_norm: (float) The maximum value for the gradient clipping
:param lam: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator
:param nminibatches: (int) Number of training minibatches per update. For recurrent policies,
should be smaller or equal than number of environments run in parallel.
the number of environments run in parallel should be a multiple of nminibatches.
:param noptepochs: (int) Number of epoch when optimizing the surrogate
:param cliprange: (float or callable) Clipping parameter, it can be a function
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
Expand Down Expand Up @@ -114,6 +114,8 @@ def setup_model(self):
n_batch_step = None
n_batch_train = None
if issubclass(self.policy, LstmPolicy):
assert self.n_envs % self.nminibatches == 0, "For recurrent policies, "\
"the number of environments run in parallel should be a multiple of nminibatches."
n_batch_step = self.n_envs
n_batch_train = self.n_batch // self.nminibatches

Expand Down
2 changes: 1 addition & 1 deletion tests/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


N_TRIALS = 1000
NUM_TIMESTEPS = 10000
NUM_TIMESTEPS = 15000

MODEL_LIST = [
A2C,
Expand Down

0 comments on commit 0a3948f

Please sign in to comment.