Skip to content

Commit

Permalink
Merge e0e0771 into e26958b
Browse files Browse the repository at this point in the history
  • Loading branch information
SwamyDev committed Oct 27, 2019
2 parents e26958b + e0e0771 commit b98a26f
Show file tree
Hide file tree
Showing 14 changed files with 458 additions and 65 deletions.
9 changes: 9 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 38 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ pip install gym-quickcheck
```

## Quick Start

### Random Walk
A random agent navigating the random walk environment, rendering a textual representation to the standard output:

[embedmd]:# (examples/random_walk.py python)
Expand All @@ -20,8 +22,8 @@ env = gym.make('gym_quickcheck:random-walk-v0')
done = False
observation = env.reset()
while not done:
env.render()
observation, reward, done, info = env.step(env.action_space.sample())
env.render()
print(f"Observation: {observation}, Reward: {reward}")
```

Expand All @@ -35,9 +37,44 @@ Observation: [0. 0. 0. 0. 0. 1. 0.], Reward: -1
#######
Observation: [0. 0. 0. 0. 0. 0. 1.], Reward: 1
```

### Alternation
A random agent navigating the alteration environment, rendering a textual representation to the standard output:

[embedmd]:# (examples/alternation.py python)
```python
import gym

env = gym.make('gym_quickcheck:alternation-v0')
done = False
observation = env.reset()
while not done:
observation, reward, done, info = env.step(env.action_space.sample())
env.render()
print(f"Observation: {observation}, Reward: {reward}")
```

Running the example should produce an output similar to this:
```
...
(Right)
##
Observation: [0 1], Reward: -0.9959229664071392
(Left)
##
Observation: [1 0], Reward: 0.8693727604523271
```

## Random Walk
This random walk environment is similar to the one described in [Reinforcement Learning An Introduction](http://incompleteideas.net/book/the-book-2nd.html). It differs in having max episode length instead of terminating at both ends, and in penalizing each step except the goal.

![random walk graph](assets/random-walk.png)

The agent receives a reward of 1 when it reaches the goal, which is the rightmost cell and -1 on reaching any other cell. The environment either terminates upon reaching the goal or after a maximum amount of steps. First, this ensures that the environment has an upper bound of episodes it takes to complete, making testing faster. Second, because the maximum negative reward has a lower bound that is reached quickly, reasonable baseline estimates should improve learning significantly. With baselines having such a noticeable effect, it makes this environment well suited for testing algorithms which make use of baseline estimates.

## Alternation
The alteration environment is straightforward, as it just requires the agent to alternate between its two possible states to achieve the maximum reward.

![alteration graph](assets/alteration.png)

The agent receives a normally distributed reward of 1 when switching from one state to the other, and a normally distributes penalty of -1 when staying in its current state. The environment terminates after a fixed amount of steps. This environment's rewards nicely scale linearly with performance. Meaning if the agent alternates one sequence more, it gets precisely one more reward. It makes it easier for agents not to get stuck at local minima. Hence most agents should be able to learn the optimal policy quickly. However, a random agent only achieves, on average, a total reward around zero. It makes this environment well suited for sanity checking algorithms making sure that they learn at all. By providing such a simple setup, it is also easier to comprehend any obvious problems an algorithm might have.
Binary file added assets/alteration.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions examples/alternation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import gym

env = gym.make('gym_quickcheck:alternation-v0')
done = False
observation = env.reset()
while not done:
observation, reward, done, info = env.step(env.action_space.sample())
env.render()
print(f"Observation: {observation}, Reward: {reward}")
2 changes: 1 addition & 1 deletion examples/random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
done = False
observation = env.reset()
while not done:
env.render()
observation, reward, done, info = env.step(env.action_space.sample())
env.render()
print(f"Observation: {observation}, Reward: {reward}")
5 changes: 5 additions & 0 deletions gym_quickcheck/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,8 @@
id='random-walk-v0',
entry_point='gym_quickcheck.envs:RandomWalkEnv',
)

register(
id='alternation-v0',
entry_point='gym_quickcheck.envs:AlternationEnv',
)
2 changes: 1 addition & 1 deletion gym_quickcheck/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from gym_quickcheck.envs.alteration_env import AlternationEnv
from gym_quickcheck.envs.random_walk_env import RandomWalkEnv

83 changes: 83 additions & 0 deletions gym_quickcheck/envs/alteration_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import sys

import gym
import numpy as np
from gym import utils


class NormalDistribution:
def __init__(self, mean, std):
self.mean = mean
self.std = std

def sample(self):
return np.random.normal(self.mean, self.std)


class AlternationEnv(gym.Env):
def __init__(self):
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Box(low=0, high=1, shape=(2,), dtype=np.uint8)
self._len_episode = 100
self._current_step = 0
self._current_state = None
self._last_action = None
self._has_alternated = None
self._reward = NormalDistribution(1, 0.1)
self._penalty = NormalDistribution(-1, 0.1)
self.reward_range = (self.penalty.mean * self.len_episode, self.reward.mean * self.len_episode)

@property
def len_episode(self):
return self._len_episode

@property
def reward(self):
return self._reward

@property
def penalty(self):
return self._penalty

def step(self, action):
self._has_alternated = self._is_alternating(action)
reward = self.reward.sample() if self._has_alternated else self.penalty.sample()
self._update_state(action)
self._last_action = action
return self._current_state, reward, self._is_done(), None

def _is_alternating(self, action):
return self._current_state[0] != action

def _update_state(self, action):
self._current_state = np.zeros(shape=self.observation_space.shape, dtype=self.observation_space.dtype)
self._current_state[1 - action] = 1
self._current_step += 1

def _is_done(self):
return self._current_step == self.len_episode

def reset(self):
self._current_step = 0
self._current_state = self.observation_space.sample()
return self._current_state

def render(self, mode='human'):
sys.stdout.write(f"{self._render_action()}\n{self._render_walk()}\n")

def _render_action(self):
if self._last_action is None:
return ""
return ["(Right)", "(Left)"][self._last_action]

def _render_walk(self):
chars = ['#', '#']
if self._has_alternated is None:
color = 'gray'
elif self._has_alternated:
color = 'green'
else:
color = 'red'
pos = 1 - self._current_state[0]
chars[pos] = utils.colorize(chars[pos], color=color, highlight=True)
return "".join(chars)
11 changes: 11 additions & 0 deletions scripts/alternation.dot
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
digraph {
rankdir = "LR";
overlap = false;
a
b

a -> a [label = "N(-1, 0.1)"];
b -> b [label = "N(-1, 0.1)"];
a -> b [label = "N(1, 0.1)"];
b -> a [label = "N(1, 0.1)"];
}
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[tool:pytest]
norecursedirs = tests/aux
81 changes: 81 additions & 0 deletions tests/aux/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import subprocess

import numpy as np


def assert_that(obj, matcher):
assert matcher(obj)


class ContractMatcher:
def __init__(self, interface, properties):
self._interface = interface
self._properties = properties
self._tested_obj = None
self._missing_functions = []
self._missing_properties = []

def __call__(self, obj):
self._tested_obj = obj
self._collect_missing_functions()
self._collect_missing_properties()
return len(self._missing_functions) == 0 and len(self._missing_properties) == 0

def _collect_missing_properties(self):
for p in self._properties:
if getattr(self._tested_obj, p) is None:
self._missing_properties.append(p)

def _collect_missing_functions(self):
for func, args in self._interface:
try:
getattr(self._tested_obj, func)(*args)
except (AttributeError, NotImplementedError):
self._missing_functions.append(func)

def __repr__(self):
return f"the object {self._tested_obj}, is missing{self._print_missing()}"

def _print_missing(self):
message = ""
if len(self._missing_functions):
message += f" these functions: {', '.join(self._missing_functions)}"
if len(self._missing_properties):
if len(message) != 0:
message += " and"
message += f" these properties: {', '.join(self._missing_properties)}"
return message


def follows_contract(interface=None, properties=None):
return ContractMatcher(interface or [], properties or [])


def assert_obs_eq(actual, expected):
np.testing.assert_array_equal(actual, expected)


def unpack_obs(step_tuple):
return step_tuple[0]


def unpack_reward(step_tuple):
return step_tuple[1]


def unpack_done(step_tuple):
return step_tuple[2]


def until_done(env, direction):
done = False
while not done:
a = direction if isinstance(direction, int) else direction()
o, r, done, _ = env.step(a)
yield o, r, done, _


def run_example(example):
r = subprocess.run(['python', example], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
lines = r.stdout.decode('utf-8').splitlines()
return lines
49 changes: 49 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np
import pytest

from tests.aux import until_done


@pytest.fixture(scope='session')
def gym_interface():
return [('reset', ()), ('step', (0,)), ('render', ()), ('close', ())]


@pytest.fixture(scope='session')
def gym_properties():
return ['action_space', 'observation_space']


@pytest.fixture
def make_observation_of():
def obs_fac(obs_shape, agent_pos):
obs = np.zeros(shape=obs_shape)
obs[agent_pos] = 1
return obs

return obs_fac


@pytest.fixture
def capstdout(capsys):
class _CapStdOut:
def __init__(self, cap):
self._cap = cap

def read(self):
return self._cap.readouterr()[0]

return _CapStdOut(capsys)


@pytest.fixture
def sample_average_reward():
def sample_average_reward_func(env, n):
total = 0
for _ in range(0, n):
env.reset()
total += sum(r for _, r, _, _ in until_done(env, env.action_space.sample))

return total / n

return sample_average_reward_func
Loading

0 comments on commit b98a26f

Please sign in to comment.