Skip to content

Commit

Permalink
Update README + style fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed May 15, 2020
1 parent d8c5431 commit f068ada
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 34 deletions.
13 changes: 9 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
<img src="docs/\_static/img/logo.png" align="right" width="40%"/>

[![pipeline status](https://gitlab.com/araffin/stable-baselines3/badges/master/pipeline.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master) [![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master)
[![pipeline status](https://gitlab.com/araffin/stable-baselines3/badges/sde/pipeline.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/sde) [![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/sde/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/sde)


**WARNING: Stable Baselines3 is currently in a beta version, breaking changes may occur before 1.0 is released**
# Generalized State-Dependent Exploration (gSDE) for Deep Reinforcement Learning in Robotics

Note: most of the documentation of [Stable Baselines](https://github.com/hill-a/stable-baselines) should be still valid though.
This branch contains the code for reproducing the results in the paper "Generalized State-Dependent Exploration for Deep Reinforcement Learning in Robotics" by Antonin Raffin and Freek Stulp.

# Stable Baselines3
Arxiv: https://arxiv.org/abs/2005.05719

The main difference with the master branch is that TD3 has support in that branch for gSDE.


## Stable Baselines3

Stable Baselines3 is a set of improved implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).

Expand Down
60 changes: 30 additions & 30 deletions stable_baselines3/td3/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.policies import (BasePolicy, register_policy, create_mlp,
create_sde_features_extractor, NatureCNN,
BaseFeaturesExtractor, FlattenExtractor)
create_sde_features_extractor, NatureCNN,
BaseFeaturesExtractor, FlattenExtractor)
from stable_baselines3.common.distributions import StateDependentNoiseDistribution


Expand Down Expand Up @@ -109,19 +109,19 @@ def __init__(self,
def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()

data.update(dict(
net_arch=self.net_arch,
features_dim=self.features_dim,
activation_fn=self.activation_fn,
use_sde=self.use_sde,
log_std_init=self.log_std_init,
clip_noise=self.clip_noise,
lr_sde=self.lr_sde,
full_std=self.full_std,
sde_net_arch=self.sde_net_arch,
use_expln=self.use_expln,
features_extractor=self.features_extractor
))
data.update(dict(net_arch=self.net_arch,
features_dim=self.features_dim,
activation_fn=self.activation_fn,
use_sde=self.use_sde,
log_std_init=self.log_std_init,
clip_noise=self.clip_noise,
lr_sde=self.lr_sde,
full_std=self.full_std,
sde_net_arch=self.sde_net_arch,
use_expln=self.use_expln,
features_extractor=self.features_extractor
)
)
return data

def get_std(self) -> th.Tensor:
Expand Down Expand Up @@ -389,21 +389,21 @@ def _build(self, lr_schedule: Callable) -> None:
def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()

data.update(dict(
net_arch=self.net_args['net_arch'],
activation_fn=self.net_args['activation_fn'],
use_sde=self.actor_kwargs['use_sde'],
log_std_init=self.actor_kwargs['log_std_init'],
clip_noise=self.actor_kwargs['clip_noise'],
lr_sde=self.actor_kwargs['lr_sde'],
sde_net_arch=self.actor_kwargs['sde_net_arch'],
use_expln=self.actor_kwargs['use_expln'],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs
))
data.update(dict(net_arch=self.net_args['net_arch'],
activation_fn=self.net_args['activation_fn'],
use_sde=self.actor_kwargs['use_sde'],
log_std_init=self.actor_kwargs['log_std_init'],
clip_noise=self.actor_kwargs['clip_noise'],
lr_sde=self.actor_kwargs['lr_sde'],
sde_net_arch=self.actor_kwargs['sde_net_arch'],
use_expln=self.actor_kwargs['use_expln'],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs
)
)
return data

def reset_noise(self) -> None:
Expand Down

0 comments on commit f068ada

Please sign in to comment.