-
Notifications
You must be signed in to change notification settings - Fork 60
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix action scaling for warmup exploration (SAC/DDPG/TD3) (#584)
* Adding action scaling to and from tanh co-domain as a generic utility. * Formating * Adding action squashing to tanh co-domain for DDPG, TD3 and SAC whenever sampled at random from action_space. * Unifying other instances of action scaling withing SAC, TD3 and DDPG. Adding a test. * Adding info on fix to changelog. * Flipping action scaling/unscaling due to confusion by parameter naming. Adding test checking involved algorithms. * Changing names of local variables for actions, in order to follow naming of used action scaling methods. * Adding check on scaling inferred actions as well * Considering learning_starts parameter of SAC and TD3 when checking action scaling. * Removing misclick addition * Adding to changelog * Adding nick to bugfix. * Removing asserts enforcing symmetric action space (DDPG, TD3, SAC). * Changelog: non-symmetric action spaces info. * Test Action Scaling: remove unnecessary wrapping of environment, make action space asymmetric. * Adding comments * Missing line break * Removing unused import.
- Loading branch information
Showing
10 changed files
with
174 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import pytest | ||
import numpy as np | ||
|
||
from stable_baselines import DDPG, TD3, SAC | ||
from stable_baselines.common.identity_env import IdentityEnvBox | ||
|
||
ROLLOUT_STEPS = 100 | ||
|
||
MODEL_LIST = [ | ||
(DDPG, dict(nb_train_steps=0, nb_rollout_steps=ROLLOUT_STEPS)), | ||
(TD3, dict(train_freq=ROLLOUT_STEPS + 1, learning_starts=0)), | ||
(SAC, dict(train_freq=ROLLOUT_STEPS + 1, learning_starts=0)), | ||
(TD3, dict(train_freq=ROLLOUT_STEPS + 1, learning_starts=ROLLOUT_STEPS)), | ||
(SAC, dict(train_freq=ROLLOUT_STEPS + 1, learning_starts=ROLLOUT_STEPS)) | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("model_class, model_kwargs", MODEL_LIST) | ||
def test_buffer_actions_scaling(model_class, model_kwargs): | ||
""" | ||
Test if actions are scaled to tanh co-domain before being put in a buffer | ||
for algorithms that use tanh-squashing, i.e., DDPG, TD3, SAC | ||
:param model_class: (BaseRLModel) A RL Model | ||
:param model_kwargs: (dict) Dictionary containing named arguments to the given algorithm | ||
""" | ||
|
||
# check random and inferred actions as they possibly have different flows | ||
for random_coeff in [0.0, 1.0]: | ||
|
||
env = IdentityEnvBox(-2000, 1000) | ||
|
||
model = model_class("MlpPolicy", env, seed=1, random_exploration=random_coeff, **model_kwargs) | ||
model.learn(total_timesteps=ROLLOUT_STEPS) | ||
|
||
assert hasattr(model, 'replay_buffer') | ||
|
||
buffer = model.replay_buffer | ||
|
||
assert buffer.can_sample(ROLLOUT_STEPS) | ||
|
||
_, actions, _, _, _ = buffer.sample(ROLLOUT_STEPS) | ||
|
||
assert not np.any(actions > np.ones_like(actions)) | ||
assert not np.any(actions < -np.ones_like(actions)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters