In [None]:
import numpy as np
from PIL import Image, ImageDraw
from IPython import display
from io import BytesIO

import gym
from gym import spaces

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import logging
logger = logging.getLogger()
fhandler = logging.FileHandler(filename='eye-on-stick.log', mode='w')
formatter = logging.Formatter('%(asctime)s %(levelname)s - %(message)s')
fhandler.setFormatter(formatter)
logger.addHandler(fhandler)
logger.setLevel(logging.DEBUG)

In [None]:
def showarray(img_array):
    buf = BytesIO()
    Image.fromarray(np.uint8(img_array)).save(buf, 'png')
    display.display(display.Image(data=buf.getvalue()))

In [None]:
X_LOW = 2
X_HIGH = 3
Y_LOW = -2
Y_HIGH = 2

SCREEN_SIZE = 500
SCREEN_SCALE = SCREEN_SIZE / 7
CIRCLE_SIZE = 0.05
TARGET_CIRCLE_COLOR = (255, 0, 0)        
EYE_CIRCLE_COLOR = (0, 0, 255)
BASE_CIRCLE_COLOR = (0, 255, 0)
BG_COLOR = (0, 0, 0)

PHI_MIN = -np.pi/2
PHI_MAX = np.pi/2
DPHI = np.pi/360
ALPHA_GOAL = np.pi/180 * 10

MAX_STEPS = 100

In [None]:
class EyeOnStickEnv(gym.Env):    
    metadata = {'render.modes': ['rgb_array']}
    
    ACC_PLUS = 2
    ACC_ZERO = 1
    ACC_MINUS = 0
    
    def __init__(self):
        super(EyeOnStickEnv, self).__init__()
        self.base_x = 0
        self.base_y = 0
        self.stick_len = 1.0

        self.action_space = spaces.Discrete(3)
        self.observation_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
        
        self.draw_center = True
    
    def reset(self):
        # set random target location
        self.target_x = np.random.uniform(low=X_LOW, high=X_HIGH)
        self.target_y = np.random.uniform(low=Y_LOW, high=Y_HIGH)
                
        # the stick is randomly oriented, but stationary
        self.phi = np.random.uniform(low=PHI_MIN, high=PHI_MAX)
        self.dphi = 0
        
        self._recalc()
        
        return self.get_obs()
    
    def _recalc(self):    
        # eye observes target as projection on retina
        self.eye_x = self.stick_len * np.cos(self.phi)
        self.eye_y = self.stick_len * np.sin(self.phi)
        
        dx = self.target_x - self.eye_x
        dy = self.target_y - self.eye_y
        self.alpha = np.arctan2(dy, dx) - self.phi
              
    def get_obs(self):
        return np.array([np.sin(self.alpha), np.cos(self.alpha)]).astype(np.float32)
    
    def step(self, action):
        if action == self.ACC_PLUS:
            self.dphi += 1
        elif action == self.ACC_MINUS:
            self.dphi -= 1
        elif action != self.ACC_ZERO:
            raise ValueError("Received invalid action={} which is not part of the action space".format(action))

        self.phi += self.dphi * DPHI
        if self.phi > PHI_MAX:
            self.phi = PHI_MAX
            self.dphi = 0
        elif self.phi < PHI_MIN:
            self.phi = PHI_MIN
            self.dphi = 0
            
        self._recalc()
        
        done = bool(np.abs(self.alpha) <= ALPHA_GOAL)
        reward = 1 if done else 0
        info = {}
        
        return self.get_obs(), reward, done, info


    def render(self, mode='rgb_array'):
        if mode != 'rgb_array':
            raise NotImplementedError()

        image = Image.new('RGB', (SCREEN_SIZE, SCREEN_SIZE), BG_COLOR)
        draw = ImageDraw.Draw(image)
            
        def draw_circle(x, y, r, fill):
            px = int(SCREEN_SIZE / 2 + x * SCREEN_SCALE)
            py = int(SCREEN_SIZE / 2 + y * SCREEN_SCALE)
            pr = int(r * SCREEN_SCALE)
            draw.ellipse((px - pr, py - pr, px + pr, py + pr), fill=fill)        

        if self.draw_center:
            draw_circle(self.base_x, self.base_y, CIRCLE_SIZE, BASE_CIRCLE_COLOR)
            self.draw_center = False
        else:
            self.draw_center = True
            
        #stick = matplotlib.patches.Polygon([(self.base_x, self.base_y), (self.eye_x, self.eye_y)], fill=False)
        #draw.line((350, 200, 450, 100), fill=(255, 255, 0), width=10)
        draw_circle(self.eye_x, self.eye_y, CIRCLE_SIZE, EYE_CIRCLE_COLOR)
        draw_circle(self.target_x, self.target_y, CIRCLE_SIZE, TARGET_CIRCLE_COLOR)
        
        return np.asarray(image)

    def close(self):
        pass

In [None]:
#eos = EyeOnStickEnv()
#eos.reset()
#eos.get_obs()
#eos.step(1)
#plt.imshow(eos.render())

In [None]:
from stable_baselines import DQN, PPO2, A2C, ACKTR
from stable_baselines.common.cmd_util import make_vec_env
from stable_baselines.common.env_checker import check_env

env = EyeOnStickEnv()
check_env(env, warn=True)

env = make_vec_env(lambda: env, n_envs=1)

In [None]:
model = ACKTR('MlpLstmPolicy', env, verbose=1)

In [None]:
while True:
    obs = env.reset()
    
    n_victories = 0
    n_steps = 0
    while n_steps < MAX_STEPS:
        display.clear_output(wait=True)
        showarray(env.render(mode='rgb_array'))

        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, info = env.step(action)
        #logger.debug("action %s" % (action))
        if done:
            n_victories += 1
            logger.debug("goal reached after %d steps (%d victories)" % (n_steps, n_victories))
            n_steps = 0
            # env reset is done automatically, we stay inside while loop
        else:
            n_steps += 1

    logger.debug("goal not reached (after %d victories), back to the school ..." % (n_victories))
    model.learn(2000)

env.close()