Skip to content

Commit

Permalink
Fix bug in SB3 tutorial ActionMask (#1203)
Browse files Browse the repository at this point in the history
  • Loading branch information
dm-ackerman committed May 3, 2024
1 parent 6f9df27 commit 38e2520
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
18 changes: 16 additions & 2 deletions tutorials/SB3/connect_four/sb3_connect_four_action_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,23 @@ def reset(self, seed=None, options=None):
return self.observe(self.agent_selection), {}

def step(self, action):
"""Gymnasium-like step function, returning observation, reward, termination, truncation, info."""
"""Gymnasium-like step function, returning observation, reward, termination, truncation, info.
The observation is for the next agent (used to determine the next action), while the remaining
items are for the agent that just acted (used to understand what just happened).
"""
current_agent = self.agent_selection

super().step(action)
return super().last()

next_agent = self.agent_selection
return (
self.observe(next_agent),
self._cumulative_rewards[current_agent],
self.terminations[current_agent],
self.truncations[current_agent],
self.infos[current_agent],
)

def observe(self, agent):
"""Return only raw observation, removing action mask."""
Expand Down
9 changes: 4 additions & 5 deletions tutorials/SB3/test/test_sb3_action_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
EASY_ENVS = [
gin_rummy_v4,
texas_holdem_no_limit_v6, # texas holdem human rendered game ends instantly, but with random actions it works fine
texas_holdem_v4,
tictactoe_v3,
leduc_holdem_v4,
]

# More difficult environments which will likely take more training time
MEDIUM_ENVS = [
leduc_holdem_v4, # with 10x as many steps it gets higher total rewards (9 vs -9), 0.52 winrate, and 0.92 vs 0.83 total scores
hanabi_v5, # even with 10x as many steps, total score seems to always be tied between the two agents
tictactoe_v3, # even with 10x as many steps, agent still loses every time (most likely an error somewhere)
texas_holdem_v4, # this performs poorly with updates to SB3 wrapper
chess_v6, # difficult to train because games take so long, performance varies heavily
]

Expand All @@ -50,8 +50,7 @@ def test_action_mask_easy(env_fn):

env_kwargs = {}

# Leduc Hold`em takes slightly longer to outperform random
steps = 8192 if env_fn != leduc_holdem_v4 else 8192 * 4
steps = 8192 * 4

# Train a model against itself (takes ~2 minutes on GPU)
train_action_mask(env_fn, steps=steps, seed=0, **env_kwargs)
Expand Down

0 comments on commit 38e2520

Please sign in to comment.