Skip to content

Commit

Permalink
Doc fix and add Stable-Baselines3 Jax (SBX) page (#1566)
Browse files Browse the repository at this point in the history
* Fix custom policy example

* Add RL Zoo doc link

* Add changelog to pypi

* Add SBX doc page

* Fix small mistake in docstring

---------

Co-authored-by: Peter Elmers <peter.elmers@yahoo.com>
  • Loading branch information
araffin and pelmers committed Jun 21, 2023
1 parent f667f08 commit 4fdb65e
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ Goals of this repository:

Github repo: https://github.com/DLR-RM/rl-baselines3-zoo

Documentation: https://stable-baselines3.readthedocs.io/en/master/guide/rl_zoo.html
Documentation: https://rl-baselines3-zoo.readthedocs.io/en/master/

## SB3-Contrib: Experimental RL Features

Expand Down
6 changes: 3 additions & 3 deletions docs/guide/custom_policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,8 @@ If your task requires even more granular control over the policy/value architect
*args,
**kwargs,
):
# Disable orthogonal initialization
kwargs["ortho_init"] = False
super().__init__(
observation_space,
action_space,
Expand All @@ -380,8 +381,7 @@ If your task requires even more granular control over the policy/value architect
*args,
**kwargs,
)
# Disable orthogonal initialization
self.ortho_init = False
def _build_mlp_extractor(self) -> None:
self.mlp_extractor = CustomNetwork(self.features_dim)
Expand Down
2 changes: 2 additions & 0 deletions docs/guide/rl_zoo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Goals of this repository:
3. Provide tuned hyperparameters for each environment and RL algorithm
4. Have fun with the trained agents!

Documentation is available online: https://rl-baselines3-zoo.readthedocs.io/

Installation
------------

Expand Down
66 changes: 66 additions & 0 deletions docs/guide/sbx.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
.. _sbx:

==========================
Stable Baselines Jax (SBX)
==========================

`Stable Baselines Jax (SBX) <https://github.com/araffin/sbx>`_ is a proof of concept version of Stable-Baselines3 in Jax.

It provides a minimal number of features compared to SB3 but can be much faster (up to 20x times!): https://twitter.com/araffin2/status/1590714558628253698

Implemented algorithms:

- Soft Actor-Critic (SAC) and SAC-N
- Truncated Quantile Critics (TQC)
- Dropout Q-Functions for Doubly Efficient Reinforcement Learning (DroQ)
- Proximal Policy Optimization (PPO)
- Deep Q Network (DQN)


As SBX follows SB3 API, it is also compatible with the `RL Zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_.
For that you will need to create two files:

``train_sbx.py``:

.. code-block:: python
import rl_zoo3
import rl_zoo3.train
from rl_zoo3.train import train
from sbx import DQN, PPO, SAC, TQC, DroQ
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["dqn"] = DQN
rl_zoo3.train.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS
if __name__ == "__main__":
train()
Then you can call ``python train_sbx.py --algo sac --env Pendulum-v1`` and use the RL Zoo CLI.


``enjoy_sbx.py``:

.. code-block:: python
import rl_zoo3
import rl_zoo3.enjoy
from rl_zoo3.enjoy import enjoy
from sbx import DQN, PPO, SAC, TQC, DroQ
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["dqn"] = DQN
rl_zoo3.enjoy.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS
if __name__ == "__main__":
enjoy()
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ Main Features
guide/integrations
guide/rl_zoo
guide/sb3_contrib
guide/sbx
guide/imitation
guide/migration
guide/checking_nan
Expand Down
2 changes: 2 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ Documentation:
- Added ``EvalCallback`` example (@sidney-tio)
- Update custom env documentation
- Added `pink-noise-rl` to projects page
- Fix custom policy example, ``ortho_init`` was ignored
- Added SBX page


Release 1.8.0 (2023-04-07)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
project_urls={
"Code": "https://github.com/DLR-RM/stable-baselines3",
"Documentation": "https://stable-baselines3.readthedocs.io/",
"Changelog": "https://stable-baselines3.readthedocs.io/en/master/misc/changelog.html",
"SB3-Contrib": "https://github.com/Stable-Baselines-Team/stable-baselines3-contrib",
"RL-Zoo": "https://github.com/DLR-RM/rl-baselines3-zoo",
},
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def __init__(

def write_row(self, epinfo: Dict[str, float]) -> None:
"""
Close the file handler
Write row of monitor data to csv log file.
:param epinfo: the information on episodic return, length, and time
"""
Expand Down

0 comments on commit 4fdb65e

Please sign in to comment.