In [5]:
import numpy as np
import tensorflow as tf
import gym
from collections import deque

In [15]:
env = gym.make('CartPole-v0')

# 하이퍼 파라미터
LEARNING_RATE = 0.005
INPUT = env.observation_space.shape[0] # 4
OUTPUT = env.action_space.n # 2
DISCOUNT = 0.99


def discount_rewards(r):
    '''Discounted reward를 구하기 위한 함수
    
    Args:
         r(np.array): reward 값이 저장된 array
    
    Returns:
        discounted_r(np.array): Discounted 된 reward가 저장된 array
    '''
    discounted_r = np.zeros_like(r, dtype=np.float32)
    running_add = 0
    for t in reversed(range(len(r))):
        running_add = running_add * DISCOUNT + r[t]
        discounted_r[t] = running_add

    return discounted_r


def train_episodic(PGagent, x, y, adv):
    '''에피소드당 학습을 하기위한 함수
    
    Args:
        PGagent(PolicyGradient): 학습될 네트워크
        x(np.array): State가 저장되어있는 array
        y(np.array): Action(one_hot)이 저장되어있는 array
        adv(np.array) : Discounted reward가 저장되어있는 array
        
    Returns:
        l(float): 네트워크에 의한 loss
    '''
    l,_ = PGagent.sess.run([PGagent.loss, PGagent.train], feed_dict={PGagent.X: x, PGagent.Y: y, PGagent.adv : adv})
    return l

def play_cartpole(PGagent):
    '''학습된 네트워크로 Play하기 위한 함수
    
    Args:
         PGagent(PolicyGradient): 학습된 네트워크
    '''
    print("Play Cartpole!")
    episode = 0
    while True:
        #s = env.reset()
        s = np.array([env.reset()])
        done = False
        rall = 0
        episode += 1
        while not done:
            env.render()
            action_p = PGagent.sess.run(PGagent.a_pre, feed_dict={PGagent.X : s})
            s1, reward, done, _ = env.step(np.argmax(action_p))
            s = np.array([s1])
            rall += reward
        #print("[Episode {0:6f}] Reward: {1:4f} ".format(episode, rall))

class PolicyGradient:
    def __init__(self, sess, input_size, output_size):
        self.sess = sess
        self.input_size = input_size
        self.output_size = output_size

        self.build_network()

    def build_network(self):
        self.X = tf.placeholder('float',[None, self.input_size])
        self.Y = tf.placeholder('float', [None, self.output_size])
        self.adv = tf.placeholder('float')

        w1 = tf.Variable(tf.truncated_normal([self.input_size, 128], stddev=0.1))
        w2 = tf.Variable(tf.truncated_normal([128, self.output_size], stddev=0.1))
       
        l1 = tf.nn.relu(tf.matmul(self.X, w1))
        self.a_pre = tf.nn.softmax(tf.matmul(l1,w2))

        self.log_p = self.Y * tf.log(self.a_pre)
        self.log_lik = self.log_p * self.adv
        self.loss = tf.reduce_mean(tf.reduce_sum(-self.log_lik, axis=1))
        self.train = tf.train.AdamOptimizer(LEARNING_RATE).minimize(self.loss)

    def get_action(self, state):
        state_t = np.reshape(state, [1, self.input_size])
        action_p = self.sess.run(self.a_pre, feed_dict={self.X : state_t})

        # 각 액션의 확률로 액션을 결정
        action = np.random.choice(np.arange(self.output_size), p=action_p[0])

        return action

def main():
    with tf.Session() as sess:
        PGagent = PolicyGradient(sess, INPUT, OUTPUT)

        sess.run(tf.global_variables_initializer())
        episode = 0
        recent_rlist = deque(maxlen=100)
        recent_rlist.append(0)

        # 최근 100개의 점수가 195점 넘을 때까지 학습
        while np.mean(recent_rlist) <= 105:
            episode += 1
            episode_memory = deque()
            rall = 0 # 한 에피소드 동안 받은 reward들의 합
            s = env.reset()
            done = False

            while not done:
                env.render()
                # 액션 선택
                action = PGagent.get_action(s)

                # action을 one_hot으로 표현
                y = np.zeros(OUTPUT)
                y[action] = 1

                s1, reward, done, _ = env.step(action)
                rall += reward

                # 에피소드 메모리에 저장
                episode_memory.append([s, y, reward])
                s = s1

                # 에피소드가 끝났을때 학습
                if done:
                    episode_memory = np.array(episode_memory)

                    discounted_rewards = discount_rewards(np.vstack(episode_memory[:,2]))

                    discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() +1e-7)
                    
                    print("np.vstack(episode_memory[:,0]) : ",np.vstack(episode_memory[:,0]).shape)
                    print(np.vstack(episode_memory[:,0]),"\n")
                    print("np.vstack(episode_memory[:,1]) : ",np.vstack(episode_memory[:,1]).shape)
                    print(np.vstack(episode_memory[:,1]),"\n")
                    print("discounted_rewards : ",discounted_rewards.shape)
                    print(discounted_rewards,"\n")
                          
                    l = train_episodic(PGagent, np.vstack(episode_memory[:,0]), np.vstack(episode_memory[:,1]),
                                      discounted_rewards)

                    recent_rlist.append(rall) 

            #print("[Episode {0:6f}] Reward: {1:4f} Loss: {2:5.5f} Recent Reward: {3:4f}".format(episode, rall, l,np.mean(recent_rlist)))

        print("\n")
        play_cartpole(PGagent)

In [16]:
main()

np.vstack(episode_memory[:,0]) :  (18, 4)
[[-0.03771463  0.01276108 -0.01571571 -0.04669918]
 [-0.03745941  0.20810482 -0.0166497  -0.34429883]
 [-0.03329731  0.01322363 -0.02353567 -0.05691235]
 [-0.03303284  0.208675   -0.02467392 -0.35692705]
 [-0.02885934  0.01391237 -0.03181246 -0.07212543]
 [-0.02858109 -0.1807394  -0.03325497  0.21035321]
 [-0.03219588 -0.37537047 -0.02904791  0.49236324]
 [-0.03970329 -0.57007092 -0.01920064  0.77575182]
 [-0.05110471 -0.76492359 -0.0036856   1.06233231]
 [-0.06640318 -0.56975304  0.01756104  0.7684949 ]
 [-0.07779824 -0.76511226  0.03293094  1.06665123]
 [-0.09310048 -0.57044114  0.05426396  0.78448282]
 [-0.10450931 -0.76626508  0.06995362  1.09373214]
 [-0.11983461 -0.57213092  0.09182826  0.8237927 ]
 [-0.13127723 -0.76838097  0.10830412  1.14388709]
 [-0.14664485 -0.96473818  0.13118186  1.46847638]
 [-0.16593961 -1.16119881  0.16055139  1.79909228]
 [-0.18916359 -0.96819613  0.19653323  1.56031107]] 

np.vstack(episode_memory[:,1]) :  (18

np.vstack(episode_memory[:,0]) :  (21, 4)
[[  9.66592699e-03  -2.89724920e-03  -3.51109920e-02   2.07546300e-02]
 [  9.60798201e-03  -1.97498539e-01  -3.46958994e-02   3.02156147e-01]
 [  5.65801122e-03  -3.92109255e-01  -2.86527765e-02   5.83697985e-01]
 [ -2.18417387e-03  -1.96597870e-01  -1.69788168e-02   2.82128519e-01]
 [ -6.11613127e-03  -1.23790309e-03  -1.13362464e-02  -1.58607256e-02]
 [ -6.14088933e-03   1.94044774e-01  -1.16534609e-02  -3.12098704e-01]
 [ -2.25999386e-03  -9.09235359e-04  -1.78954350e-02  -2.31135868e-02]
 [ -2.27817857e-03  -1.95770038e-01  -1.83577067e-02   2.63869839e-01]
 [ -6.19357933e-03  -3.90625215e-01  -1.30803099e-02   5.50706492e-01]
 [ -1.40060836e-02  -5.85561027e-01  -2.06618010e-03   8.39239694e-01]
 [ -2.57173042e-02  -3.90410924e-01   1.47186138e-02   5.45907709e-01]
 [ -3.35255226e-02  -1.95498839e-01   2.56367680e-02   2.57898366e-01]
 [ -3.74354994e-02  -3.90977236e-01   3.07947353e-02   5.58555993e-01]
 [ -4.52550441e-02  -5.86517614e-01

np.vstack(episode_memory[:,0]) :  (31, 4)
[[  2.64652691e-02   2.47318800e-02   4.28956203e-02  -1.91072426e-03]
 [  2.69599067e-02   2.19213232e-01   4.28574058e-02  -2.80756958e-01]
 [  3.13441714e-02   2.35069758e-02   3.72422666e-02   2.51290885e-02]
 [  3.18143109e-02  -1.72128722e-01   3.77448484e-02   3.29325846e-01]
 [  2.83717364e-02   2.24361469e-02   4.43313653e-02   4.87807623e-02]
 [  2.88204594e-02  -1.73292518e-01   4.53069805e-02   3.55114465e-01]
 [  2.53546090e-02  -3.69028397e-01   5.24092698e-02   6.61732688e-01]
 [  1.79740411e-02  -1.74673342e-01   6.56439236e-02   3.86001667e-01]
 [  1.44805742e-02  -3.70662779e-01   7.33639569e-02   6.98638098e-01]
 [  7.06731863e-03  -1.76630577e-01   8.73367189e-02   4.29922354e-01]
 [  3.53470710e-03   1.71530460e-02   9.59351660e-02   1.65999173e-01]
 [  3.87776802e-03   2.10780153e-01   9.92551494e-02  -9.49442742e-02]
 [  8.09337108e-03   4.04349765e-01   9.73562640e-02  -3.54736517e-01]
 [  1.61803664e-02   2.07988075e-01

np.vstack(episode_memory[:,0]) :  (24, 4)
[[ -3.25128971e-02  -1.39840401e-02  -9.94862405e-04   2.19886799e-03]
 [ -3.27925779e-02   1.81152165e-01  -9.50885045e-04  -2.90797784e-01]
 [ -2.91695346e-02  -1.39562149e-02  -6.76684072e-03   1.58509311e-03]
 [ -2.94486589e-02   1.81262126e-01  -6.73513886e-03  -2.93225149e-01]
 [ -2.58234164e-02   3.76479453e-01  -1.25996419e-02  -5.88024614e-01]
 [ -1.82938273e-02   5.71775564e-01  -2.43601341e-02  -8.84649725e-01]
 [ -6.85831607e-03   3.76992692e-01  -4.20531286e-02  -5.99723276e-01]
 [  6.81537776e-04   1.82483537e-01  -5.40475942e-02  -3.20577467e-01]
 [  4.33120851e-03  -1.18287257e-02  -6.04591435e-02  -4.54169379e-02]
 [  4.09463399e-03  -2.06033958e-01  -6.13674822e-02   2.27594502e-01]
 [ -2.60451695e-05  -4.00227673e-01  -5.68155922e-02   5.00306032e-01]
 [ -8.03059862e-03  -5.94504586e-01  -4.68094716e-02   7.74556384e-01]
 [ -1.99206903e-02  -3.98771045e-01  -3.13183439e-02   4.67520712e-01]
 [ -2.78961112e-02  -5.93436850e-01

np.vstack(episode_memory[:,0]) :  (33, 4)
[[ 0.0332779   0.01319528  0.00829618 -0.00717683]
 [ 0.0335418  -0.18204466  0.00815264  0.28811205]
 [ 0.02990091  0.01296008  0.01391488 -0.0019885 ]
 [ 0.03016011  0.20787974  0.01387511 -0.29024884]
 [ 0.03431771  0.01256272  0.00807013  0.00677765]
 [ 0.03456896  0.20756801  0.00820569 -0.28334817]
 [ 0.03872032  0.40257196  0.00253872 -0.57343181]
 [ 0.04677176  0.20741451 -0.00892991 -0.27995019]
 [ 0.05092005  0.4026627  -0.01452892 -0.57543616]
 [ 0.0589733   0.59798528 -0.02603764 -0.87266046]
 [ 0.07093301  0.40322689 -0.04349085 -0.5882761 ]
 [ 0.07899755  0.2087401  -0.05525637 -0.30960404]
 [ 0.08317235  0.01444719 -0.06144845 -0.0348466 ]
 [ 0.08346129 -0.17974221 -0.06214538  0.23783326]
 [ 0.07986645 -0.37392382 -0.05738872  0.51028442]
 [ 0.07238797 -0.56819234 -0.04718303  0.78434446]
 [ 0.06102412 -0.76263528 -0.03149614  1.0618176 ]
 [ 0.04577142 -0.56711076 -0.01025979  0.75941794]
 [ 0.0344292  -0.37184895  0.00492857  0

np.vstack(episode_memory[:,0]) :  (104, 4)
[[ -1.76832088e-02   2.31775503e-02  -1.05721008e-02   4.81543086e-02]
 [ -1.72196578e-02  -1.71791222e-01  -9.60901460e-03   3.37482984e-01]
 [ -2.06554822e-02   2.34661409e-02  -2.85935491e-03   4.17854541e-02]
 [ -2.01861594e-02   2.18628978e-01  -2.02364583e-03  -2.51798253e-01]
 [ -1.58135798e-02   4.13779767e-01  -7.05961090e-03  -5.45118790e-01]
 [ -7.53798451e-03   2.18757722e-01  -1.79619867e-02  -2.54668521e-01]
 [ -3.16283006e-03   2.38967800e-02  -2.30553571e-02   3.22952026e-02]
 [ -2.68489446e-03   2.19341638e-01  -2.24094531e-02  -2.67571846e-01]
 [  1.70193830e-03   4.14776121e-01  -2.77608900e-02  -5.67237793e-01]
 [  9.99746071e-03   2.20054361e-01  -3.91056458e-02  -2.83428349e-01]
 [  1.43985479e-02   2.55113640e-02  -4.47742128e-02  -3.33108330e-03]
 [  1.49087752e-02   2.21245892e-01  -4.48408345e-02  -3.09797849e-01]
 [  1.93336931e-02   2.67905550e-02  -5.10367915e-02  -3.15868259e-02]
 [  1.98695042e-02  -1.67563764e-0

np.vstack(episode_memory[:,0]) :  (62, 4)
[[  4.27323548e-02   2.34486613e-02   3.47702142e-02   1.11122099e-03]
 [  4.32013281e-02  -1.72154248e-01   3.47924386e-02   3.04558627e-01]
 [  3.97582431e-02   2.24550492e-02   4.08836111e-02   2.30482605e-02]
 [  4.02073441e-02  -1.73228639e-01   4.13445763e-02   3.28344951e-01]
 [  3.67427713e-02  -3.68914048e-01   4.79114754e-02   6.33774068e-01]
 [  2.93644904e-02  -1.74492029e-01   6.05869567e-02   3.56556284e-01]
 [  2.58746498e-02   1.97185516e-02   6.77180824e-02   8.35765975e-02]
 [  2.62690208e-02   2.13807728e-01   6.93896144e-02  -1.86995988e-01]
 [  3.05451754e-02   1.77651423e-02   6.56496946e-02   1.26744408e-01]
 [  3.09004782e-02   2.11888182e-01   6.81845828e-02  -1.44525742e-01]
 [  3.51382419e-02   4.05970762e-01   6.52940679e-02  -4.14942398e-01]
 [  4.32576571e-02   2.09987068e-01   5.69952200e-02  -1.02410471e-01]
 [  4.74573984e-02   1.40965846e-02   5.49470105e-02   2.07695651e-01]
 [  4.77393301e-02   2.08391532e-01

KeyboardInterrupt: 

In [4]:
env = gym.make('CartPole-v0')

In [8]:
env.reset()

array([ 0.02181441,  0.03331934, -0.00995298,  0.04239949])

In [9]:
action = np.array([1,0])

In [10]:
action

array([1, 0])

In [17]:
env.step(0)

(array([ 0.02741247, -0.41747462, -0.0054751 ,  0.5219073 ]), 1.0, False, {})