Skip to content

Commit

Permalink
e2e tested, api tested, finished documentation.
Browse files Browse the repository at this point in the history
  • Loading branch information
clemens4321 committed Apr 22, 2020
1 parent b889f9f commit ddc7252
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 35 deletions.
58 changes: 33 additions & 25 deletions pettingzoo/classic/hanabi/hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,7 @@ def __init__(self, preset_name: str = None, **kwargs):
# List of agent names
self.agents = ["player_{}".format(i) for i in range(self.hanabi_env.players)]

# Rearrange self.agents as pyhanabi starts with player 1 not 0
self.agents = self.agents[1:] + [self.agents[0]]

# Agent order list, on which the agent selector operates on.
self.agent_order = list(self.agents)
self._agent_selector = agent_selector(self.agent_order)

# Set initial agent
self.agent_selection = self._agent_selector.reset()
self.agent_selection: str

# Sets hanabi game to clean state and updates all internal dictionaries
self.reset(observe=False)
Expand Down Expand Up @@ -168,12 +160,11 @@ def observation_vector_dim(self):

@property
def legal_moves(self) -> List[int]:
obs = self.latest_observations['player_observations']
return obs[self._offset(self.agents.index(self.agent_selection))]['legal_moves']
return self.infos[self.agent_selection]['legal_moves']

@property
def all_moves(self) -> List[int]:
return range(0, self.hanabi_env.num_moves())
return list(range(0, self.hanabi_env.num_moves()))

# ToDo: Fix Return value
def reset(self, observe=True) -> Optional[List[int]]:
Expand All @@ -187,7 +178,10 @@ def reset(self, observe=True) -> Optional[List[int]]:
# Reset underlying hanabi reinforcement learning environment
obs = self.hanabi_env.reset()

# Update internal state
# Reset agent and agent_selection
self._reset_agents(player_number=obs['current_player'])

# Reset internal state
self._process_latest_observations(obs=obs)

# If specified, return observation of current agent
Expand All @@ -196,7 +190,24 @@ def reset(self, observe=True) -> Optional[List[int]]:
else:
return None

def step(self, action: int, observe: bool = True, as_vector: bool = True) -> Optional[Union[List[int],
def _reset_agents(self, player_number: int):
""" Rearrange self.agents as pyhanabi starts a different player after each reset(). """

# Shifts self.agents list as long order starting player is not according to player_number
while not self.agents[0] == 'player_' + str(player_number):
self.agents = self.agents[1:] + [self.agents[0]]

# Agent order list, on which the agent selector operates on.
self.agent_order = list(self.agents)
self._agent_selector = agent_selector(self.agent_order)

# Reset agent_selection
self.agent_selection = self._agent_selector.reset()

def _step_agents(self):
self.agent_selection = self._agent_selector.next()

def step(self, action: int, observe: bool = True, as_vector: bool = True) -> Optional[Union[np.ndarray,
List[List[dict]]]]:
""" Advances the environment by one step. Action must be within self.legal_moves, otherwise throws error.
Expand All @@ -213,7 +224,7 @@ def step(self, action: int, observe: bool = True, as_vector: bool = True) -> Opt

else:
# Iterate agent_selection
self.agent_selection = self._agent_selector.next()
self._step_agents()

# Apply action
all_observations, reward, done, _ = self.hanabi_env.step(action=action)
Expand All @@ -225,9 +236,9 @@ def step(self, action: int, observe: bool = True, as_vector: bool = True) -> Opt
if observe:
return self.observe(agent_name=agent_on_turn, as_vector=as_vector)

def observe(self, agent_name: str, as_vector: bool = True) -> List:
def observe(self, agent_name: str, as_vector: bool = True) -> Union[np.ndarray, List]:
if as_vector:
return self.infos[agent_name]['observations_vectorized']
return np.array([[self.infos[agent_name]['observations_vectorized']]], np.int32)
else:
return self.infos[agent_name]['observations']

Expand All @@ -247,14 +258,11 @@ def _process_latest_observations(self, obs: Dict, reward: Optional[float] = 0, d

# Here we have to deal with the player index with offset = 1
self.infos = {player_name: dict(legal_moves=self.latest_observations['player_observations']
[self._offset(player_index)]['legal_moves_as_int'],
[int(player_name[-1])]['legal_moves_as_int'],
legal_moves_as_dict=self.latest_observations['player_observations']
[self._offset(player_index)]['legal_moves'],
[int(player_name[-1])]['legal_moves'],
observations_vectorized=self.latest_observations['player_observations']
[self._offset(player_index)]['vectorized'],
[int(player_name[-1])]['vectorized'],
observations=self.latest_observations['player_observations']
[self._offset(player_index)])
for player_index, player_name in enumerate(self.agents)}

def _offset(self, index: int) -> int:
return (index + 1) % len(self.agents)
[int(player_name[-1])])
for player_name in self.agents}
24 changes: 14 additions & 10 deletions pettingzoo/classic/hanabi/test_hanabi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest import TestCase
from pettingzoo.classic.hanabi.hanabi import env
import pettingzoo.tests.api_test as api_test
import numpy as np


class HanabiTest(TestCase):
Expand Down Expand Up @@ -62,8 +63,8 @@ def test_reset(self):
test_env = env(**self.full_config)

obs = test_env.reset()
self.assertIsInstance(obs, list)
self.assertIsInstance(obs[0], int)
self.assertIsInstance(obs, np.ndarray)
self.assertEqual(obs.size, test_env.hanabi_env.vectorized_observation_shape()[0])

obs = test_env.reset(observe=False)
self.assertIsNone(obs)
Expand All @@ -74,9 +75,9 @@ def test_reset(self):

self.assertNotEqual(old_state, new_state)

# ToDo: Implement and test this, so that internal properties of class do not have to get queried.
def test_get_legal_moves(self):
pass
test_env = env(**self.full_config)
self.assertIs(set(test_env.legal_moves).issubset(set(test_env.all_moves)), True)

def test_observe(self):
# Tested within test_step
Expand All @@ -88,17 +89,14 @@ def test_step(self):
# Get current player
old_player = test_env.agent_selection

# Get range of moves
all_moves = test_env.all_moves

# Pick a legal move
legal_moves = test_env.legal_moves

# Assert return value
new_obs = test_env.step(action=legal_moves[0])
self.assertIsInstance(test_env.infos, dict)
self.assertIsInstance(new_obs, list)
self.assertIsInstance(new_obs[0], int)
self.assertIsInstance(new_obs, np.ndarray)
self.assertEqual(new_obs.size, test_env.hanabi_env.vectorized_observation_shape()[0])

# Get new_player
new_player = test_env.agent_selection
Expand All @@ -120,7 +118,7 @@ def test_step(self):

# Assert raises error if wrong input
new_legal_moves = test_env.legal_moves
illegal_move = list(set(all_moves) - set(new_legal_moves))[0]
illegal_move = list(set(test_env.all_moves) - set(new_legal_moves))[0]
self.assertRaises(ValueError, test_env.step, illegal_move)

def test_legal_moves(self):
Expand All @@ -135,6 +133,12 @@ def test_legal_moves(self):
def test_run_whole_game(self):
test_env = env(**self.full_config)

while not all(test_env.dones.values()):
self.assertIs(all(test_env.dones.values()), False)
test_env.step(test_env.legal_moves[0], observe=False)

test_env.reset(observe=False)

while not all(test_env.dones.values()):
self.assertIs(all(test_env.dones.values()), False)
test_env.step(test_env.legal_moves[0], observe=False)
Expand Down

0 comments on commit ddc7252

Please sign in to comment.