In [1]:
import tensorflow as tf
import gym
import numpy as np

In [2]:
class PolicyAgent:
    def __init__(self, num_acts, num_features, num_units=10, learning_rate=0.01, decay=0.95):
        self.n_actions = num_acts
        self.n_features = num_features
        self.learning_rate = learning_rate
        self.decay = decay
        
        self.ep_obs, self.ep_acts, self.ep_rew = [], [], []
        
        self._build(num_units)
        self.sess = tf.Session()
        
        
        self.sess.run(tf.global_variables_initializer())
    
    def _build(self, num_units):
        # Input Info
        self.obs = tf.placeholder(tf.float32, (None, self.n_features))
        self.acts = tf.placeholder(tf.int32, (None,))
        self.rew = tf.placeholder(tf.float32, (None,))
        
        # Model Layers
        h1 = tf.layers.dense(self.obs, num_units, activation=tf.nn.relu, 
                             kernel_initializer=tf.random_normal_initializer())
        h2 = tf.layers.dense(h1, num_units, activation=tf.nn.relu, 
                             kernel_initializer=tf.random_normal_initializer())
        h3 = tf.layers.dense(h2, num_units, activation=tf.nn.relu, 
                             kernel_initializer=tf.random_normal_initializer())
        out = tf.layers.dense(h3, self.n_actions, activation=None, 
                             kernel_initializer=tf.random_normal_initializer())
        
        self.probabilities = tf.nn.softmax(out)
      #  self.action_chooser = tf.multinomial(probabilities,1)
        
        neg_log_prob = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.probabilities, labels=self.acts)
        loss = tf.reduce_mean(neg_log_prob * self.rew)
        
        self.train_op = tf.train.AdamOptimizer(self.learning_rate).minimize(loss)
    
    def choose_action(self,obs):
        prob_weights = self.sess.run(self.probabilities, feed_dict={self.obs:np.array(obs).reshape(-1,self.n_features)})
        action = np.random.choice(range(prob_weights.shape[1]), p=prob_weights.ravel())
        return action
    
    def store(self, obs, action, reward):
        self.ep_obs.append(obs)
        self.ep_acts.append(action)
        self.ep_rew.append(reward)
        
    def process_rewards(self):
        discounted_rewards = np.zeros_like(self.ep_rew)
        run_sum = 0
        for i in reversed(range(len(self.ep_rew))):
            run_sum *= self.decay
            run_sum += self.ep_rew[i]
            discounted_rewards[i] = run_sum
        
        discounted_rewards -= np.mean(discounted_rewards)
        if np.std(discounted_rewards) != 0:
            discounted_rewards /= np.std(discounted_rewards)
        return discounted_rewards
    
    def train(self):
        discounted_rewards = self.process_rewards()
        
        self.sess.run(self.train_op,feed_dict={self.obs: np.array(self.ep_obs).reshape(-1,self.n_features),
                                          self.acts: np.array(self.ep_acts),
                                          self.rew: np.array(discounted_rewards)})
        
        self.ep_obs, self.ep_acts, self.ep_rew = [], [], []
        
        return discounted_rewards    

In [4]:
env = gym.make('Breakout-ram-v0')
#env = gym.make('CartPole-v1')
agent = PolicyAgent(num_acts=env.action_space.n,num_features=env.observation_space.shape[0],
                    num_units=30,learning_rate=0.005, decay=0.99)

In [4]:
max_episodes = 1000000000000
max_steps = 1000

In [5]:
saver = tf.train.Saver()

In [6]:
saver.restore(agent.sess,'./breakout_models_crazy/model.ckpt')

INFO:tensorflow:Restoring parameters from ./breakout_models_crazy/model.ckpt


In [9]:
save_path = saver.save(agent.sess, './breakout_models_crazy/model.ckpt')

In [None]:
avg = []
for episode_num in range(max_episodes):
    obs_prev = env.reset()
    obs, reward, done, _ = env.step(env.action_space.sample())
    obs_array = [obs_prev,obs]
    done = False
    for i in range(max_steps):
        
        observation = obs_array[0] - obs_array[1]
#         env.render()
        action = agent.choose_action(observation)
        obs, reward, done, _ = env.step(action)
        agent.store(observation, action, reward)
        obs_array[0],obs_array[1] = obs_array[1], obs
        
        if done or i==max_steps-1:
            
            reward_list = agent.train()
            avg.append(np.sum(reward_list))
#             print('{0:7f}'.format(np.sum(reward_list)), end=', ')
            if len(avg) == 100:
                mean = np.mean(avg)
                print(mean, end=', ')
                save_path = saver.save(agent.sess, './breakout_models_crazy/model.ckpt')
                avg = []
            break
            
#         obs_prev = obs

-2.9842794901924207e-15, 4.085620730620576e-15, 3.907985046680551e-16, 1.4566126083082053e-15, 2.0961010704922954e-15, 8.526512829121202e-16, 4.263256414560601e-16, 3.730349362740526e-16, -4.618527782440651e-16, -2.3447910280083307e-15, 4.085620730620576e-15, 1.1368683772161603e-15, 9.769962616701378e-16, -1.865174681370263e-15, -7.638334409421077e-16, -3.1796787425264484e-15, -1.2612133559741778e-15, -1.4921397450962103e-15, 3.943512183468556e-15, -6.394884621840901e-16, 8.171241461241152e-16, -1.9895196601282807e-15, 3.348432642269472e-15, -3.517186542012496e-15, 6.394884621840901e-16, 3.659295089164516e-15, -4.618527782440651e-16, 3.481659405224491e-15, 1.0835776720341528e-15, 2.806643806252396e-15, -3.907985046680551e-16, 8.526512829121202e-16, 1.6342482922482305e-15, -1.1368683772161603e-15, 2.7267077484793843e-15, 1.3500311979441904e-15, -8.171241461241152e-16, -1.0658141036401502e-16, 1.5987211554602255e-15, -4.440892098500626e-16, -8.526512829121202e-16, 0.0, -8.348877145181177

4.618527782440651e-16, -1.6342482922482305e-15, -2.0961010704922954e-15, 9.947598300641404e-16, -9.592326932761353e-16, 1.3322676295501878e-15, -1.8118839761882555e-15, -2.7533531010703883e-15, -1.5276668818842154e-15, 1.5276668818842154e-15, 1.5276668818842154e-15, 2.877698079828406e-15, -1.900701818158268e-15, -1.1013412404281553e-15, 3.161915174132446e-15, -1.7763568394002506e-16, 8.526512829121202e-16, -8.881784197001252e-16, -8.881784197001253e-17, -2.486899575160351e-16, -1.3677947663381929e-15, -1.2079226507921704e-15, -1.2789769243681803e-15, -2.3270274596143282e-15, -9.414691248821328e-16, 3.0730973321624333e-15, 5.329070518200751e-16, -1.4566126083082053e-15, 2.5224267119483555e-15, -4.973799150320702e-16, -9.414691248821328e-16, -6.217248937900876e-16, 2.646771690706373e-15, 3.3750779948604757e-15, -1.4210854715202004e-16, 4.6895820560166614e-15, -2.140509991477302e-15, 6.394884621840901e-16, 2.895461648222408e-15, 3.3750779948604757e-15, 2.220446049250313e-15, -5.0626169922

-1.2612133559741778e-15, 4.085620730620576e-16, -9.237055564881303e-16, 1.2612133559741778e-15, -9.592326932761353e-16, -6.927791673660977e-16, -1.4566126083082053e-15, 1.9184653865522706e-15, -1.2079226507921704e-15, -4.423128530106624e-15, -2.486899575160351e-16, 1.6342482922482305e-15, -9.769962616701378e-16, 1.5454304502782179e-15, 1.6253665080512292e-15, -7.815970093361102e-16, 1.0302869668521452e-15, 9.237055564881303e-16, 6.750155989720952e-16, -3.197442310920451e-15, 3.552713678800501e-16, 6.394884621840901e-16, 6.750155989720952e-16, -2.1316282072803005e-16, 1.4743761767022078e-15, 3.552713678800501e-16, 2.4868995751603505e-15, -6.394884621840901e-16, 9.237055564881303e-16, 1.2434497875801752e-15, 2.895461648222408e-15, -2.1316282072803005e-16, -3.268496584496461e-15, -1.0658141036401502e-16, -6.217248937900876e-16, 1.936228954946273e-15, 8.171241461241152e-16, -7.105427357601002e-17, -1.0658141036401502e-16, 2.220446049250313e-15, -7.815970093361102e-16, 3.730349362740526e-16