Skip to content

Commit

Permalink
Fix and rename custom policy names (#444)
Browse files Browse the repository at this point in the history
* fix and rename custom policy names

* Update changelog
  • Loading branch information
eavelardev authored and araffin committed Aug 29, 2019
1 parent 51285cc commit 647f5bc
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 8 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Documentation:
^^^^^^^^^^^^^^
- Add WaveRL project (@jaberkow)
- Add Fenics-DRL project (@DonsetPG)
- Fix and rename custom policy names (@eavelardev)


Release 2.7.0 (2019-07-31)
Expand Down Expand Up @@ -453,4 +454,4 @@ In random order...
Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk @JohannesAck
@EliasHasle @mrakgr @Bleyddyn @antoine-galataud @junhyeokahn @AdamGleave @keshaviyengar @tperol
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
@Miffyli @dwiel @miguelrass @qxcv @jaberkow
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev
6 changes: 3 additions & 3 deletions docs/modules/ddpg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ You can easily define a custom architecture for the policy network:
from stable_baselines import DDPG
# Custom MLP policy of two layers of size 16 each
class CustomPolicy(FeedForwardPolicy):
class CustomDDPGPolicy(FeedForwardPolicy):
def __init__(self, *args, **kwargs):
super(CustomPolicy, self).__init__(*args, **kwargs,
super(CustomDDPGPolicy, self).__init__(*args, **kwargs,
layers=[16, 16],
layer_norm=False,
feature_extraction="mlp")
Expand All @@ -163,6 +163,6 @@ You can easily define a custom architecture for the policy network:
env = gym.make('Pendulum-v0')
env = DummyVecEnv([lambda: env])
model = DDPG(CustomPolicy, env, verbose=1)
model = DDPG(CustomDDPGPolicy, env, verbose=1)
# Train the agent
model.learn(total_timesteps=100000)
6 changes: 3 additions & 3 deletions docs/modules/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ You can easily define a custom architecture for the policy network:
from stable_baselines import DQN
# Custom MLP policy of two layers of size 32 each
class CustomPolicy(FeedForwardPolicy):
class CustomDQNPolicy(FeedForwardPolicy):
def __init__(self, *args, **kwargs):
super(CustomPolicy, self).__init__(*args, **kwargs,
super(CustomDQNPolicy, self).__init__(*args, **kwargs,
layers=[32, 32],
layer_norm=False,
feature_extraction="mlp")
Expand All @@ -159,6 +159,6 @@ You can easily define a custom architecture for the policy network:
env = gym.make('LunarLander-v2')
env = DummyVecEnv([lambda: env])
model = DQN(CustomPolicy, env, verbose=1)
model = DQN(CustomDQNPolicy, env, verbose=1)
# Train the agent
model.learn(total_timesteps=100000)
2 changes: 1 addition & 1 deletion docs/modules/sac.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ You can easily define a custom architecture for the policy network:
# Custom MLP policy of three layers of size 128 each
class CustomSACPolicy(FeedForwardPolicy):
def __init__(self, *args, **kwargs):
super(CustomPolicy, self).__init__(*args, **kwargs,
super(CustomSACPolicy, self).__init__(*args, **kwargs,
layers=[128, 128, 128],
layer_norm=False,
feature_extraction="mlp")
Expand Down

0 comments on commit 647f5bc

Please sign in to comment.