In [1]:
import tensorflow as tf
import gym
import numpy as np

In [2]:
class PolicyAgent:
    def __init__(self, num_acts, num_features, num_units=10, learning_rate=0.01, decay=0.95):
        self.n_actions = num_acts
        self.n_features = num_features
        self.learning_rate = learning_rate
        self.decay = decay
        
        self.ep_obs, self.ep_acts, self.ep_rew = [], [], []
        
        self._build(num_units)
        self.sess = tf.Session()
        
        
        self.sess.run(tf.global_variables_initializer())
    
    def _build(self, num_units):
        # Input Info
        self.obs = tf.placeholder(tf.float32, (None, self.n_features))
        self.acts = tf.placeholder(tf.int32, (None,))
        self.rew = tf.placeholder(tf.float32, (None,))
        
        # Model Layers
        h1 = tf.layers.dense(self.obs, num_units, activation=tf.nn.tanh, 
                             kernel_initializer=tf.random_normal_initializer())
        h2 = tf.layers.dense(h1, num_units, activation=tf.nn.tanh, 
                             kernel_initializer=tf.random_normal_initializer())
        out = tf.layers.dense(h2, self.n_actions, activation=None, 
                             kernel_initializer=tf.random_normal_initializer())
        
        self.probabilities = tf.nn.softmax(out)
      #  self.action_chooser = tf.multinomial(probabilities,1)
        
        neg_log_prob = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.probabilities, labels=self.acts)
        loss = tf.reduce_mean(neg_log_prob * self.rew)
        
        self.train_op = tf.train.AdamOptimizer(self.learning_rate).minimize(loss)
    
    def choose_action(self,obs):
        prob_weights = self.sess.run(self.probabilities, feed_dict={self.obs:np.array(obs).reshape(-1,self.n_features)})
        action = np.random.choice(range(prob_weights.shape[1]), p=prob_weights.ravel())
        return action
    
    def store(self, obs, action, reward):
        self.ep_obs.append(obs)
        self.ep_acts.append(action)
        self.ep_rew.append(reward)
        
    def process_rewards(self):
        discounted_rewards = np.zeros_like(self.ep_rew)
        run_sum = 0
        for i in reversed(range(len(self.ep_rew))):
            run_sum *= self.decay
            run_sum += self.ep_rew[i]
            discounted_rewards[i] = run_sum
        
        discounted_rewards -= np.mean(discounted_rewards)
        if np.std(discounted_rewards) != 0:
            discounted_rewards /= np.std(discounted_rewards)
        return discounted_rewards
    
    def train(self):
        discounted_rewards = self.process_rewards()
        
        self.sess.run(self.train_op,feed_dict={self.obs: np.array(self.ep_obs).reshape(-1,self.n_features),
                                          self.acts: np.array(self.ep_acts),
                                          self.rew: np.array(discounted_rewards)})
        
        self.ep_obs, self.ep_acts, self.ep_rew = [], [], []
        
        return discounted_rewards    

In [3]:
env = gym.make('Breakout-ram-v0')
#env = gym.make('CartPole-v1')
agent = PolicyAgent(num_acts=env.action_space.n,num_features=env.observation_space.shape[0],
                    num_units=30,learning_rate=0.001, decay=0.99)

In [4]:
max_episodes = 1000000000000
#max_steps = 100000000

In [5]:
saver = tf.train.Saver()

In [6]:
saver.restore(agent.sess,'./breakout_models/model.ckpt')

INFO:tensorflow:Restoring parameters from ./breakout_models/model.ckpt


In [9]:
save_path = saver.save(agent.sess, './breakout_models/model.ckpt')

In [None]:
avg = []
for episode_num in range(max_episodes):
    obs_prev = env.reset()
    obs, reward, done, _ = env.step(env.action_space.sample())
    obs_array = [obs_prev,obs]
    done = False
    while not done:
        observation = obs_array[0] - obs_array[1]
        
#         env.render()
        action = agent.choose_action(observation)
        obs, reward, done, _ = env.step(action)
        agent.store(observation, action, reward)
        obs_array[0],obs_array[1] = obs_array[1], obs
        
        if done:
            
            reward_list = agent.train()
            avg.append(np.sum(reward_list))
#             print('{0:7f}'.format(np.sum(reward_list)), end=', ')
            if len(avg) == 100:
                mean = np.mean(avg)
                print(mean, end=', ')
                save_path = saver.save(agent.sess, './breakout_models/model.ckpt')
                avg = []
            break
            
#         obs_prev = obs

-1.9850787680297798e-15, 2.091660178393795e-15, 9.201528428093297e-15, 2.1316282072803005e-15, 3.3661962106634747e-15, -4.884981308350689e-16, 7.922551503725117e-15, 9.312550730555813e-15, -4.263256414560601e-15, 5.737632591262809e-15, -2.3270274596143282e-15, 2.5934809855243656e-15, 4.387601393318619e-15, 1.4210854715202005e-15, -1.6875389974302379e-15, -2.3092638912203257e-16, -3.3750779948604757e-15, 1.070254995738651e-14, 7.105427357601002e-16, 3.783640067922533e-15, 1.609823385706477e-15, -3.002043058586423e-15, 6.4170890823334045e-15, 2.3092638912203257e-15, 1.5987211554602254e-16, 3.375077994860476e-16, 4.0323300254385686e-15, 5.897504706808831e-15, -6.217248937900877e-17, 5.17363929475323e-15, 1.5987211554602254e-16, 4.0323300254385686e-15, 4.913847106990943e-15, 8.117950756059145e-15, 1.838529328779259e-15, 8.748557434046233e-15, 4.796163466380677e-16, 4.334310688136611e-15, -4.520828156273637e-15, -7.01660951563099e-16, -1.865174681370263e-15, 1.1013412404281553e-15, 1.083577

-3.0642155479654322e-15, 1.7008616737257399e-15, 5.935252289646087e-15, -3.9790393202565614e-15, 6.719069745031448e-15, 5.402345237826011e-15, 3.241851231905457e-15, 1.7319479184152442e-15, 8.517631044924201e-15, -9.725553695716371e-16, -1.1546319456101628e-16, -1.9895196601282807e-15, -4.587441537751147e-15, 6.23945339839338e-15, 3.783640067922533e-15, -7.349676423018536e-16, -2.5468516184901092e-15, -2.5224267119483555e-15, 6.0129679013698474e-15, -1.6298074001497298e-15, -1.9895196601282807e-15, -5.782041512247815e-15, -9.769962616701378e-17, 1.6431300764452318e-15, 6.039613253960852e-16, -2.993161274389422e-15, -1.9539925233402755e-16, 7.123190925995004e-15, 3.836930773104541e-15, 3.907985046680551e-16, 6.092903959142859e-15, 1.7497114868092467e-15, -3.75699471533153e-15, 1.0524914273446484e-14, 4.991562718714704e-15, -8.881784197001252e-16, 3.552713678800501e-16, 1.6875389974302379e-15, 5.062616992290714e-15, 4.7717385598389225e-15, 3.788080960021034e-15, -4.4009240696141206e-15, 

3.1086244689504383e-15, 5.062616992290714e-15, 4.369837824924616e-15, -2.284838984678572e-15, -3.175237850427948e-15, -9.769962616701378e-16, -4.169997680492088e-15, -8.16235967704415e-15, 2.0250467969162854e-15, -8.659739592076221e-16, -1.616484723854228e-15, -6.217248937900877e-17, -1.3500311979441904e-15, -1.389999226830696e-15, -2.469136006766348e-15, 1.723066134218243e-15, 1.0551559626037488e-14, 4.616307336391401e-15, 