In [4]:
import gym
import numpy as np

from keras import layers
from keras.models import Model
from keras import backend as K
from keras import utils as np_utils
from keras import optimizers

In [5]:
env = gym.make('CartPole-v0')

In [6]:
inputs = layers.Input(shape=(4,))
net = inputs
for i in range(2):
    net = layers.Dense(32)(net)
    net = layers.Activation("relu")(net)
net = layers.Dense(2)(net)
net = layers.Activation("softmax")(net)

In [7]:
model = Model(inputs=inputs, outputs=net)

In [8]:
action_prob_ph = model.output
action_onehot_ph = K.placeholder(shape=(None, 2))
rtg_ph = K.placeholder(shape=(None,))

In [9]:
loss = -K.mean(K.log(K.sum(action_prob_ph * action_onehot_ph, axis=1)) * rtg_ph)

In [11]:
adam = optimizers.Adam()
updates = adam.get_updates(params=model.trainable_weights,
                           loss=loss)

In [12]:
train_fn = K.function(inputs=[model.input,
                              action_onehot_ph,
                              rtg_ph],
                      outputs=[],
                      updates=updates)

In [13]:
def generate_trajectory(model):
    done=False
    
    states = []; actions = []; rewards = []
    
    s=env.reset()
    
    while not done:
        a = np.random.choice(np.arange(2), 
                 p=np.squeeze(model.predict(np.array(s)[np.newaxis,:])))
        
        s2, r, done, _ = env.step(a)
        
        states.append(s)
        actions.append(a)
        rewards.append(r)
        
        s = s2
        
    rtg = np.cumsum(rewards[::-1])[::-1]
    return np.array(states), np_utils.to_categorical(actions, num_classes=2), rtg

In [15]:
stored_rtgs = []
model_40_saved = False; model_80_saved = False
for i in range(10000):
    states, actions, rtg = generate_trajectory(model)
    stored_rtgs.append(rtg[0])
    if i % 10 == 0:
        print(np.mean(stored_rtgs))
        
        if np.mean(stored_rtgs) > 40 and not model_40_saved:
            model.save('/tmp/model_40_policy.h5')
            model_40_saved = True
        if np.mean(stored_rtgs) > 80 and not model_80_saved:
            model.save('/tmp/model_80_policy.h5')
            model_80_saved = True
            break
        stored_rtgs = []
    train_fn([states, actions, rtg])

20.0
34.1
38.8
37.8
43.1
62.6
42.0
55.0
42.5
69.4
94.1


In [16]:
action_prob_ph

<tf.Tensor 'activation_3/Softmax:0' shape=(?, 2) dtype=float32>

In [None]:
actions.shape

In [None]:
rtg.shape