Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hanabi Integration #63

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "pettingzoo/classic/hanabi/env"]
path = pettingzoo/classic/hanabi/env
url = https://github.com/deepmind/hanabi-learning-environment
1 change: 1 addition & 0 deletions pettingzoo/classic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
from .uno import uno as uno_v0
from .dou_dizhu import dou_dizhu as dou_dizhu_v0
from .gin_rummy import gin_rummy as gin_rummy_v0
from .hanabi import hanabi
14 changes: 14 additions & 0 deletions pettingzoo/classic/hanabi/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
## Hanabi learning environment

### How to use Hanabi within pettingZoo
* If Hanabi is not pulled, execute the following script from within this directory:

```bash
bash pull_prepare_and_setup_hanabi.sh
```
The bash script does the following steps:
* Git install submodule
* Install dependencies and build package by running installation.sh

### Codebase information
* Codebase cloned from official google repo: https://github.com/deepmind/hanabi-learning-environment
1 change: 1 addition & 0 deletions pettingzoo/classic/hanabi/env
Submodule env added at 5df6a7
348 changes: 348 additions & 0 deletions pettingzoo/classic/hanabi/hanabi.py

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions pettingzoo/classic/hanabi/pull_prepare_and_setup_hanabi.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash

echo "Pulling hanabi submodule, if not existent"
git pull --recurse-submodules

echo "Installing gcc. First trying apt-get (linux). If not possible, try homebrew (macOS)"
apt-get install g++ || brew install gcc

cd env
pip install .
cd ..

echo "Verify hanabi wrapper is working"
python hanabi.py --test
149 changes: 149 additions & 0 deletions pettingzoo/classic/hanabi/test_hanabi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from unittest import TestCase
from pettingzoo.classic.hanabi.hanabi import env
import pettingzoo.tests.api_test as api_test
import numpy as np


class HanabiTest(TestCase):

@classmethod
def setUpClass(cls):
cls.preset_name = "Hanabi-Small"
cls.player_count = 4
cls.full_config: dict = {
"colors": 2,
"ranks": 5,
"players": 3,
"hand_size": 2,
"max_information_tokens": 3,
"max_life_tokens": 1,
"observation_type": 0,
'seed': 1,
"random_start_player": 1
}

cls.incomplete_config: dict = {
"colors": 5,
"ranks": 5,
"players": 3,
"max_information_tokens": 8,
}

cls.config_values_out_of_reach: dict = {
"colors": 20,
"ranks": 5,
"players": 3,
"hand_size": 2,
"max_information_tokens": 3,
"max_life_tokens": 1,
"observation_type": 0,
'seed': 1,
"random_start_player": 1
}

def test_preset(self):
test = env(preset_name=self.preset_name)
self.assertEqual(test.hanabi_env.__class__.__name__, 'HanabiEnv')

def test_preset_with_players(self):
test = env(preset_name=self.preset_name, players=self.player_count)
self.assertEqual(test.hanabi_env.__class__.__name__, 'HanabiEnv')

def test_full_dictionary(self):
test = env(**self.full_config)
self.assertEqual(test.hanabi_env.__class__.__name__, 'HanabiEnv')

def test_incomplete_dictionary(self):
self.assertRaises(KeyError, env, **self.incomplete_config)

def test_config_values_out_of_range(self):
self.assertRaises(ValueError, env, **self.config_values_out_of_reach)

def test_reset(self):
test_env = env(**self.full_config)

obs = test_env.reset()
self.assertIsInstance(obs, np.ndarray)
self.assertEqual(obs.size, test_env.hanabi_env.vectorized_observation_shape()[0])

obs = test_env.reset(observe=False)
self.assertIsNone(obs)

old_state = test_env.hanabi_env.state
test_env.reset(observe=False)
new_state = test_env.hanabi_env.state

self.assertNotEqual(old_state, new_state)

def test_get_legal_moves(self):
test_env = env(**self.full_config)
self.assertIs(set(test_env.legal_moves).issubset(set(test_env.all_moves)), True)

def test_observe(self):
# Tested within test_step
pass

def test_step(self):
test_env = env(**self.full_config)

# Get current player
old_player = test_env.agent_selection

# Pick a legal move
legal_moves = test_env.legal_moves

# Assert return value
new_obs = test_env.step(action=legal_moves[0])
self.assertIsInstance(test_env.infos, dict)
self.assertIsInstance(new_obs, np.ndarray)
self.assertEqual(new_obs.size, test_env.hanabi_env.vectorized_observation_shape()[0])

# Get new_player
new_player = test_env.agent_selection
# Assert player shifted
self.assertNotEqual(old_player, new_player)

# Assert legal moves have changed
new_legal_moves = test_env.legal_moves
self.assertNotEqual(legal_moves, new_legal_moves)

# Assert return not as vector:
new_obs = test_env.step(action=new_legal_moves[0], as_vector=False)
self.assertIsInstance(new_obs, dict)

# Assert no return
new_legal_moves = test_env.legal_moves
new_obs = test_env.step(action=new_legal_moves[0], observe=False)
self.assertIsNone(new_obs)

# Assert raises error if wrong input
new_legal_moves = test_env.legal_moves
illegal_move = list(set(test_env.all_moves) - set(new_legal_moves))[0]
self.assertRaises(ValueError, test_env.step, illegal_move)

def test_legal_moves(self):
test_env = env(**self.full_config)
legal_moves = test_env.legal_moves

self.assertIsInstance(legal_moves, list)
self.assertIsInstance(legal_moves[0], int)
self.assertLessEqual(len(legal_moves), len(test_env.all_moves))
test_env.step(legal_moves[0])

def test_run_whole_game(self):
test_env = env(**self.full_config)

while not all(test_env.dones.values()):
self.assertIs(all(test_env.dones.values()), False)
test_env.step(test_env.legal_moves[0], observe=False)

test_env.reset(observe=False)

while not all(test_env.dones.values()):
self.assertIs(all(test_env.dones.values()), False)
test_env.step(test_env.legal_moves[0], observe=False)

self.assertIs(all(test_env.dones.values()), True)

def test_api(self):
api_test.api_test(env(**self.full_config))
10 changes: 7 additions & 3 deletions pettingzoo/utils/env.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import numpy as np
from typing import Optional


class AECEnv(object):
def __init__(self):
pass

def step(self, action, observe=True):
def step(self, action, observe=True) -> Optional[np.ndarray]:
raise NotImplementedError

def reset(self, observe=True):
def reset(self, observe=True) -> Optional[np.ndarray]:
raise NotImplementedError

def observe(self, agent):
def observe(self, agent) -> Optional[np.ndarray]:
raise NotImplementedError

def last(self):
Expand Down