-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
458 additions
and
65 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import gym | ||
|
||
env = gym.make('gym_quickcheck:alternation-v0') | ||
done = False | ||
observation = env.reset() | ||
while not done: | ||
observation, reward, done, info = env.step(env.action_space.sample()) | ||
env.render() | ||
print(f"Observation: {observation}, Reward: {reward}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from gym_quickcheck.envs.alteration_env import AlternationEnv | ||
from gym_quickcheck.envs.random_walk_env import RandomWalkEnv | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import sys | ||
|
||
import gym | ||
import numpy as np | ||
from gym import utils | ||
|
||
|
||
class NormalDistribution: | ||
def __init__(self, mean, std): | ||
self.mean = mean | ||
self.std = std | ||
|
||
def sample(self): | ||
return np.random.normal(self.mean, self.std) | ||
|
||
|
||
class AlternationEnv(gym.Env): | ||
def __init__(self): | ||
self.action_space = gym.spaces.Discrete(2) | ||
self.observation_space = gym.spaces.Box(low=0, high=1, shape=(2,), dtype=np.uint8) | ||
self._len_episode = 100 | ||
self._current_step = 0 | ||
self._current_state = None | ||
self._last_action = None | ||
self._has_alternated = None | ||
self._reward = NormalDistribution(1, 0.1) | ||
self._penalty = NormalDistribution(-1, 0.1) | ||
self.reward_range = (self.penalty.mean * self.len_episode, self.reward.mean * self.len_episode) | ||
|
||
@property | ||
def len_episode(self): | ||
return self._len_episode | ||
|
||
@property | ||
def reward(self): | ||
return self._reward | ||
|
||
@property | ||
def penalty(self): | ||
return self._penalty | ||
|
||
def step(self, action): | ||
self._has_alternated = self._is_alternating(action) | ||
reward = self.reward.sample() if self._has_alternated else self.penalty.sample() | ||
self._update_state(action) | ||
self._last_action = action | ||
return self._current_state, reward, self._is_done(), None | ||
|
||
def _is_alternating(self, action): | ||
return self._current_state[0] != action | ||
|
||
def _update_state(self, action): | ||
self._current_state = np.zeros(shape=self.observation_space.shape, dtype=self.observation_space.dtype) | ||
self._current_state[1 - action] = 1 | ||
self._current_step += 1 | ||
|
||
def _is_done(self): | ||
return self._current_step == self.len_episode | ||
|
||
def reset(self): | ||
self._current_step = 0 | ||
self._current_state = self.observation_space.sample() | ||
return self._current_state | ||
|
||
def render(self, mode='human'): | ||
sys.stdout.write(f"{self._render_action()}\n{self._render_walk()}\n") | ||
|
||
def _render_action(self): | ||
if self._last_action is None: | ||
return "" | ||
return ["(Right)", "(Left)"][self._last_action] | ||
|
||
def _render_walk(self): | ||
chars = ['#', '#'] | ||
if self._has_alternated is None: | ||
color = 'gray' | ||
elif self._has_alternated: | ||
color = 'green' | ||
else: | ||
color = 'red' | ||
pos = 1 - self._current_state[0] | ||
chars[pos] = utils.colorize(chars[pos], color=color, highlight=True) | ||
return "".join(chars) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
digraph { | ||
rankdir = "LR"; | ||
overlap = false; | ||
a | ||
b | ||
|
||
a -> a [label = "N(-1, 0.1)"]; | ||
b -> b [label = "N(-1, 0.1)"]; | ||
a -> b [label = "N(1, 0.1)"]; | ||
b -> a [label = "N(1, 0.1)"]; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[tool:pytest] | ||
norecursedirs = tests/aux |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import subprocess | ||
|
||
import numpy as np | ||
|
||
|
||
def assert_that(obj, matcher): | ||
assert matcher(obj) | ||
|
||
|
||
class ContractMatcher: | ||
def __init__(self, interface, properties): | ||
self._interface = interface | ||
self._properties = properties | ||
self._tested_obj = None | ||
self._missing_functions = [] | ||
self._missing_properties = [] | ||
|
||
def __call__(self, obj): | ||
self._tested_obj = obj | ||
self._collect_missing_functions() | ||
self._collect_missing_properties() | ||
return len(self._missing_functions) == 0 and len(self._missing_properties) == 0 | ||
|
||
def _collect_missing_properties(self): | ||
for p in self._properties: | ||
if getattr(self._tested_obj, p) is None: | ||
self._missing_properties.append(p) | ||
|
||
def _collect_missing_functions(self): | ||
for func, args in self._interface: | ||
try: | ||
getattr(self._tested_obj, func)(*args) | ||
except (AttributeError, NotImplementedError): | ||
self._missing_functions.append(func) | ||
|
||
def __repr__(self): | ||
return f"the object {self._tested_obj}, is missing{self._print_missing()}" | ||
|
||
def _print_missing(self): | ||
message = "" | ||
if len(self._missing_functions): | ||
message += f" these functions: {', '.join(self._missing_functions)}" | ||
if len(self._missing_properties): | ||
if len(message) != 0: | ||
message += " and" | ||
message += f" these properties: {', '.join(self._missing_properties)}" | ||
return message | ||
|
||
|
||
def follows_contract(interface=None, properties=None): | ||
return ContractMatcher(interface or [], properties or []) | ||
|
||
|
||
def assert_obs_eq(actual, expected): | ||
np.testing.assert_array_equal(actual, expected) | ||
|
||
|
||
def unpack_obs(step_tuple): | ||
return step_tuple[0] | ||
|
||
|
||
def unpack_reward(step_tuple): | ||
return step_tuple[1] | ||
|
||
|
||
def unpack_done(step_tuple): | ||
return step_tuple[2] | ||
|
||
|
||
def until_done(env, direction): | ||
done = False | ||
while not done: | ||
a = direction if isinstance(direction, int) else direction() | ||
o, r, done, _ = env.step(a) | ||
yield o, r, done, _ | ||
|
||
|
||
def run_example(example): | ||
r = subprocess.run(['python', example], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) | ||
lines = r.stdout.decode('utf-8').splitlines() | ||
return lines |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from tests.aux import until_done | ||
|
||
|
||
@pytest.fixture(scope='session') | ||
def gym_interface(): | ||
return [('reset', ()), ('step', (0,)), ('render', ()), ('close', ())] | ||
|
||
|
||
@pytest.fixture(scope='session') | ||
def gym_properties(): | ||
return ['action_space', 'observation_space'] | ||
|
||
|
||
@pytest.fixture | ||
def make_observation_of(): | ||
def obs_fac(obs_shape, agent_pos): | ||
obs = np.zeros(shape=obs_shape) | ||
obs[agent_pos] = 1 | ||
return obs | ||
|
||
return obs_fac | ||
|
||
|
||
@pytest.fixture | ||
def capstdout(capsys): | ||
class _CapStdOut: | ||
def __init__(self, cap): | ||
self._cap = cap | ||
|
||
def read(self): | ||
return self._cap.readouterr()[0] | ||
|
||
return _CapStdOut(capsys) | ||
|
||
|
||
@pytest.fixture | ||
def sample_average_reward(): | ||
def sample_average_reward_func(env, n): | ||
total = 0 | ||
for _ in range(0, n): | ||
env.reset() | ||
total += sum(r for _, r, _, _ in until_done(env, env.action_space.sample)) | ||
|
||
return total / n | ||
|
||
return sample_average_reward_func |
Oops, something went wrong.