Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to use float64 actions for off policy algorithms #1572

Merged
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Bug Fixes:
for ``Inf`` and ``NaN`` (@lutogniew)
- Fixed HER ``truncate_last_trajectory()`` (@lbergmann1)
- Fixed HER desired and achieved goal order in reward computation (@JonathanKuelz)
- Fixed off-policy algorithms with continuous float64 actions (see #1145) (@tobirohrer)

Deprecations:
^^^^^^^^^^^^^
Expand Down
12 changes: 10 additions & 2 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ def __init__(
else:
self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)

self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype)
self.actions = np.zeros(
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._get_action_data_type(action_space)
)

self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
Expand Down Expand Up @@ -311,6 +313,10 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non
)
return ReplayBufferSamples(*tuple(map(self.to_torch, data)))

@staticmethod
def _get_action_data_type(action_space):
araffin marked this conversation as resolved.
Show resolved Hide resolved
return np.float32 if action_space.dtype == np.float64 else action_space.dtype


class RolloutBuffer(BaseBuffer):
"""
Expand Down Expand Up @@ -543,7 +549,9 @@ def __init__(
for key, _obs_shape in self.obs_shape.items()
}

self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype)
self.actions = np.zeros(
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._get_action_data_type(action_space)
)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)

Expand Down
74 changes: 74 additions & 0 deletions tests/test_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pytest
from gymnasium import spaces
from gymnasium.spaces.space import Space

from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.env_checker import check_env
Expand Down Expand Up @@ -56,6 +57,21 @@ def step(self, action):
return self.observation_space.sample(), 0.0, False, False, {}


class DummyEnv(gym.Env):
araffin marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, action_space: Space, observation_space: Space):
super().__init__()
self.action_space = action_space
self.observation_space = observation_space

def step(self, action):
return self.observation_space.sample(), 0.0, False, False, {}

def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
if seed is not None:
super().reset(seed=seed)
return self.observation_space.sample(), {}


@pytest.mark.parametrize(
"env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2)), DummyMultidimensionalAction()]
)
Expand Down Expand Up @@ -127,3 +143,61 @@ def test_discrete_obs_space(model_class, env):
else:
kwargs = dict(n_steps=256)
model_class("MlpPolicy", env, **kwargs).learn(256)


@pytest.fixture()
def dummy_env_float64_action_float64_observation():
return DummyEnv(
action_space=spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64),
observation_space=spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float64),
)


@pytest.fixture()
def dummy_env_float64_action_float32_observation():
return DummyEnv(
action_space=spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64),
observation_space=spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32),
)


@pytest.fixture()
def dummy_env_float64_action_float32_dict_observation():
space = spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64)
return DummyEnv(
action_space=spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32),
observation_space=spaces.Dict({"a": space, "b": space}),
)


@pytest.fixture()
def dummy_env_float64_action_float64_dict_observation():
space = spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64)
return DummyEnv(
action_space=spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64),
observation_space=spaces.Dict({"a": space, "b": space}),
)


@pytest.mark.parametrize(
"env_fixture",
[
"dummy_env_float64_action_float32_observation",
"dummy_env_float64_action_float64_observation",
"dummy_env_float64_action_float32_dict_observation",
"dummy_env_float64_action_float64_dict_observation",
],
)
@pytest.mark.parametrize("model_class", [SAC, TD3, PPO, DDPG, A2C])
def test_float64_action_space_support(env_fixture, model_class, request):
env = request.getfixturevalue(env_fixture)
env = gym.wrappers.TimeLimit(env, max_episode_steps=200)
if isinstance(env.observation_space, spaces.Dict):
policy = "MultiInputPolicy"
else:
policy = "MlpPolicy"
model = model_class(policy, env)
model.learn(20)
initial_obs, _ = env.reset()
action, _ = model.predict(initial_obs)
assert action.dtype == env.action_space.dtype