In [8]:
%matplotlib inline

import gym
import itertools
import numpy as np
import random
import tensorflow as tf
import os

from gym.wrappers import Monitor
from collections import deque, namedtuple

In [9]:
ENV = gym.make('CartPole-v0')

In [10]:
class StateProcessor:
    TF_SCOPE_NAME = 'state_processor'
    
    def __init__(self):
        with tf.variable_scope(StateProcessor.TF_SCOPE_NAME):
            self.input_state = tf.placeholder(shape=[210, 160, 3], dtype=tf.uint8)
            self.output = tf.image.rgb_to_grayscale(self.input_state)
            self.output = tf.image.crop_to_bounding_box(self.output, 34, 0, 160, 160)
            self.output = tf.image.resize_images(self.output, [84, 84], tf.image.ResizeMethod.NEAREST_NEIGHBOR)
            self.output = tf.squeeze(self.output)
    
    def process(self, session, state):
        return session.run(self.output, {self.input_state: state})

In [11]:
class Estimator:
    def __init__(self, valid_actions, scope='estimator', summaries_dir=None):
        self.valid_actions = valid_actions
        self.scope = scope
        self.summary_writer = None
        with tf.variable_scope(scope):
            self._build_model()
            if summaries_dir:
                summary_dir = os.path.join(summaries_dir, 'summaries_{}'.format(scope))
                if not os.path.exists(summary_dir):
                    os.makedirs(summary_dir)
                self.summary_writer = tf.summary.FileWriter(summary_dir)

    def _build_model(self):
        self.X_pl = tf.placeholder(shape=[None, 84, 84, 4], dtype=tf.uint8, name='X')
        self.Y_pl = tf.placeholder(shape=[None], dtype=tf.float32, name='Y')
        self.actions_pl = tf.placeholder(shape=[None], dtype=tf.int32, name='actions')

        X = tf.to_float(self.X_pl) / 255.
        batch_size = tf.shape(self.X_pl)[0]

        conv1 = tf.contrib.layers.conv2d(X, 32, 8, 4, activation_fn=tf.nn.relu)
        conv2 = tf.contrib.layers.conv2d(conv1, 64, 4, 2, activation_fn=tf.nn.relu)
        conv3 = tf.contrib.layers.conv2d(conv2, 64, 3, 1, activation_fn=tf.nn.relu)
        flattened = tf.contrib.layers.flatten(conv3)
        fc1 = tf.contrib.layers.fully_connected(flattened, 512)

        self.predictions = tf.contrib.layers.fully_connected(fc1, len(self.valid_actions))
        
        gather_indices = tf.range(batch_size) * len(self.valid_actions) + self.actions_pl
        self.action_predictions = tf.gather(tf.reshape(self.predictions, [-1]), gather_indices)
        
        self.losses = tf.squared_difference(self.Y_pl, self.action_predictions)
        self.loss = tf.reduce_mean(self.losses)
        
        self.optimizer = tf.train.RMSPropOptimizer(0.00025, 0.99, 0.0, 1e-6)
        self.train_optimization = self.optimizer.minimize(self.loss, tf.train.get_global_step())
        
        self.summaries = tf.summary.merge([
            tf.summary.scalar('loss', self.loss),
            tf.summary.histogram('loss_hist', self.losses),
            tf.summary.histogram('q_values_hist', self.predictions),
            tf.summary.scalar('max_q_value', tf.reduce_max(self.predictions))
        ])

    def predict(self, session, s):
        return session.run(self.predictions, {self.X_pl: s})

    def update(self, session, s, a, targets):
        feed_dict = {self.X_pl: s, self.Y_pl: targets, self.actions_pl: a}
        summaries, global_step, _, loss = session.run(
            [self.summaries, tf.train.get_global_step(), self.train_optimization, self.loss],
            feed_dict
        )
        
        if self.summary_writer:
            self.summary_writer.add_summary(summaries, global_step)


In [None]:
tf.reset_default_graph()
global_step = tf.Variable(0, name='global_step', trainable=False)

e = Estimator(ENV.action_space.n, scope='test')
sp = StateProcessor()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    s = ENV.reset()
    s_p = sp.process(sess, s)
    
    