Skip to content

Commit

Permalink
Polish code
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Mar 29, 2023
1 parent 621f64f commit 5e1f507
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 15 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ filterwarnings = [
# Tensorboard warnings
"ignore::DeprecationWarning:tensorboard",
# Gymnasium warnings
# "ignore::UserWarning:gym",
"ignore::UserWarning:gymnasium",
# "ignore::DeprecationWarning:.*passive_env_checker.*",
]
markers = [
Expand Down
7 changes: 5 additions & 2 deletions stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
assert isinstance(reset_returns, tuple), "`reset()` must return a tuple (obs, info)"
assert len(reset_returns) == 2, f"`reset()` must return a tuple of size 2 (obs, info), not {len(reset_returns)}"
obs, info = reset_returns
assert isinstance(info, dict), "The second element of the tuple return by `reset()` must be a dictionary"
assert isinstance(info, dict), f"The second element of the tuple return by `reset()` must be a dictionary not {info}"

if _is_goal_env(env):
# Make mypy happy, already checked
Expand All @@ -277,7 +277,10 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
action = action_space.sample()
data = env.step(action)

assert len(data) == 5, "The `step()` method must return four values: obs, reward, terminated, truncated, info"
assert len(data) == 5, (
"The `step()` method must return five values: "
f"obs, reward, terminated, truncated, info. Actual: {len(data)} values returned."
)

# Unpack
obs, reward, terminated, truncated, info = data
Expand Down
6 changes: 3 additions & 3 deletions stable_baselines3/common/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ def get_action_dim(action_space: spaces.Space) -> int:
return int(len(action_space.nvec))
elif isinstance(action_space, spaces.MultiBinary):
# Number of binary actions
assert isinstance(action_space.n, int), (
"Multi-dimensional MultiBinary action space is not supported. " "You can flatten it instead."
)
assert isinstance(
action_space.n, int
), "Multi-dimensional MultiBinary action space is not supported. You can flatten it instead."
return int(action_space.n)
else:
raise NotImplementedError(f"{action_space} action space is not supported")
Expand Down
6 changes: 4 additions & 2 deletions tests/test_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def step(self, action):
self._t += 1
index = self._t % len(self._observations)
obs = self._observations[index]
terminated = truncated = self._t >= self._ep_length
terminated = False
truncated = self._t >= self._ep_length
reward = self._rewards[index]
return obs, reward, terminated, truncated, {}

Expand Down Expand Up @@ -63,7 +64,8 @@ def step(self, action):
self._t += 1
index = self._t % len(self._observations)
obs = {key: self._observations[index] for key in self.observation_space.spaces.keys()}
terminated = truncated = self._t >= self._ep_length
terminated = False
truncated = self._t >= self._ep_length
reward = self._rewards[index]
return obs, reward, terminated, truncated, {}

Expand Down
6 changes: 4 additions & 2 deletions tests/test_vec_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def step(self, action):
reward = float(np.random.rand())
self._choose_next_state()
self.current_step += 1
terminated = truncated = self.current_step >= self.ep_length
terminated = False
truncated = self.current_step >= self.ep_length
return self.state, reward, terminated, truncated, {}

def _choose_next_state(self):
Expand Down Expand Up @@ -178,7 +179,8 @@ def reset(self):
def step(self, action):
prev_step = self.current_step
self.current_step += 1
terminated = truncated = self.current_step >= self.max_steps
terminated = False
truncated = self.current_step >= self.max_steps
return np.array([prev_step], dtype="int"), 0.0, terminated, truncated, {}


Expand Down
11 changes: 6 additions & 5 deletions tests/test_vec_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def step(self, action):
self.t += 1
index = (self.t + self.return_reward_idx) % len(self.returned_rewards)
returned_value = self.returned_rewards[index]
terminated = truncated = self.t == len(self.returned_rewards)
terminated = False
truncated = self.t == len(self.returned_rewards)
return np.array([returned_value]), returned_value, terminated, truncated, {}

def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
Expand Down Expand Up @@ -69,8 +70,8 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
def step(self, action):
obs = self.observation_space.sample()
reward = self.compute_reward(obs["achieved_goal"], obs["desired_goal"], {})
done = np.random.rand() > 0.8
return obs, reward, done, False, {}
terminated = np.random.rand() > 0.8
return obs, reward, terminated, False, {}

def compute_reward(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, _info) -> np.float32:
distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1)
Expand Down Expand Up @@ -100,8 +101,8 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):

def step(self, action):
obs = self.observation_space.sample()
done = np.random.rand() > 0.8
return obs, 0.0, done, False, {}
terminated = np.random.rand() > 0.8
return obs, 0.0, terminated, False, {}


def allclose(obs_1, obs_2):
Expand Down

0 comments on commit 5e1f507

Please sign in to comment.