Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow AssertOutOfBoundsWrapper to be applied to any environment #1046

Merged
merged 2 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions pettingzoo/utils/wrappers/assert_out_of_bounds.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
from __future__ import annotations

from gymnasium.spaces import Discrete

from pettingzoo.utils.env import ActionType, AECEnv
from pettingzoo.utils.wrappers.base import BaseWrapper


class AssertOutOfBoundsWrapper(BaseWrapper):
"""Asserts if the action given to step is outside of the action space. Applied in PettingZoo environments with discrete action spaces."""
"""Asserts if the action given to step is outside of the action space."""

def __init__(self, env: AECEnv):
super().__init__(env)
assert all(
isinstance(self.action_space(agent), Discrete)
for agent in getattr(self, "possible_agents", [])
), "should only use AssertOutOfBoundsWrapper for Discrete spaces"

def step(self, action: ActionType) -> None:
assert (
Expand Down
12 changes: 1 addition & 11 deletions test/unwrapped_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,6 @@ def box_observation(env, agents):
return boxable


def discrete_observation(env, agents):
is_discrete = True
for agent in agents:
is_discrete = is_discrete and (
isinstance(env.observation_space(agent), spaces.Discrete)
)
return is_discrete


@pytest.mark.parametrize(("name", "env_module"), list(all_environments.items()))
def test_unwrapped(name, env_module):
env = env_module.env(render_mode="human")
Expand All @@ -37,8 +28,7 @@ def test_unwrapped(name, env_module):
env.reset()
agents = env.agents

if discrete_observation(env, agents):
env = wrappers.AssertOutOfBoundsWrapper(env)
env = wrappers.AssertOutOfBoundsWrapper(env)
env = wrappers.BaseWrapper(env)
env = wrappers.CaptureStdoutWrapper(env)
if box_observation(env, agents) and box_action(env, agents):
Expand Down