# Use Close-Form Policy to Play PongNoFrameskip-v4

In [1]:
import sys
import logging
import imp
import itertools

import numpy as np
np.random.seed(0)
import gym

imp.reload(logging)
logging.basicConfig(level=logging.DEBUG,
        format='%(asctime)s [%(levelname)s] %(message)s',
        stream=sys.stdout, datefmt='%H:%M:%S')

In [2]:
env = gym.make('PongNoFrameskip-v4')
env.seed(0)
for key in vars(env):
    logging.info('%s: %s', key, vars(env)[key])

16:58:43 [INFO] env: <AtariEnv<PongNoFrameskip-v4>>
16:58:43 [INFO] action_space: Discrete(6)
16:58:43 [INFO] observation_space: Box(0, 255, (210, 160, 3), uint8)
16:58:43 [INFO] reward_range: (-inf, inf)
16:58:43 [INFO] metadata: {'render.modes': ['human', 'rgb_array']}
16:58:43 [INFO] _max_episode_steps: 400000
16:58:43 [INFO] _elapsed_steps: None


In [3]:
class CloseFormAgent:
    def __init__(self, _):
        pass
    
    def reset(self, mode=None):
        pass

    def obs2loc(self, observation):
        colors = {'racket': 92, 'ball': 236}
        heights = {'racket': 16, 'ball': 4}
        ymin, ymax = 34, 193
        locations = {}
        for obj in colors:
            match = observation[ymin:ymax, :, 0] == colors[obj]
            xx = np.where(match.any(axis=0))[0]
            yy = np.where(match.any(axis=1))[0]
            if yy.size and yy.min() == 0:
                yy = np.arange(yy.max()-heights[obj]+1, yy.max()+1, 1)
            if yy.size and yy.max() == ymax - ymin:
                yy = np.arange(yy.min(), yy.min()+heights[obj], 1)
            locations[obj + 'x'] = xx.mean() if xx.size else np.nan
            locations[obj + 'y'] = yy.mean() if yy.size else np.nan
        return locations

    def step(self, observation, _reward, _done):
        locations = self.obs2loc(observation)
        if locations['bally'] < locations['rackety']:
            action = 2 # move up
        elif locations['bally'] > locations['rackety']:
            action = 3 # move down
        else:
            action = 0
        return action

    def close(self):
        pass


agent = CloseFormAgent(env)

In [4]:
def play_episode(env, agent, max_episode_steps=None, mode=None, render=False):
    observation, reward, done = env.reset(), 0., False
    agent.reset(mode=mode)
    episode_reward, elapsed_steps = 0., 0
    while True:
        action = agent.step(observation, reward, done)
        if render:
            env.render()
        if done:
            break
        observation, reward, done, _ = env.step(action)
        episode_reward += reward
        elapsed_steps += 1
        if max_episode_steps and elapsed_steps >= max_episode_steps:
            break
    agent.close()
    return episode_reward, elapsed_steps


logging.info('==== test ====')
episode_rewards = []
for episode in range(100):
    episode_reward, elapsed_steps = play_episode(env, agent)
    episode_rewards.append(episode_reward)
    logging.debug('test episode %d: reward = %.2f, steps = %d',
            episode, episode_reward, elapsed_steps)
logging.info('average episode reward = %.2f ± %.2f',
        np.mean(episode_rewards), np.std(episode_rewards))

16:58:43 [INFO] ==== test ====
16:59:18 [DEBUG] test episode 0: reward = 21.00, steps = 27445
16:59:52 [DEBUG] test episode 1: reward = 21.00, steps = 27445
17:00:27 [DEBUG] test episode 2: reward = 21.00, steps = 27445
17:01:01 [DEBUG] test episode 3: reward = 21.00, steps = 27445
17:01:36 [DEBUG] test episode 4: reward = 21.00, steps = 27445
17:02:10 [DEBUG] test episode 5: reward = 21.00, steps = 27445
17:02:44 [DEBUG] test episode 6: reward = 21.00, steps = 27445
17:03:18 [DEBUG] test episode 7: reward = 21.00, steps = 27445
17:03:52 [DEBUG] test episode 8: reward = 21.00, steps = 27445
17:04:27 [DEBUG] test episode 9: reward = 21.00, steps = 27445
17:05:02 [DEBUG] test episode 10: reward = 21.00, steps = 27445
17:05:35 [DEBUG] test episode 11: reward = 21.00, steps = 27445
17:06:11 [DEBUG] test episode 12: reward = 21.00, steps = 27445
17:06:45 [DEBUG] test episode 13: reward = 21.00, steps = 27445
17:07:20 [DEBUG] test episode 14: reward = 21.00, steps = 27445
17:07:54 [DEBUG] te

In [None]:
env.close()