In [1]:
import gym
import sys
import random
import numpy as np
import tensorflow as tf
from collections import deque
import matplotlib.pyplot as plt

In [2]:
class A2C(tf.keras.Model):
    def __init__(self, state_size, action_size):
        super(A2C, self).__init__()
        self.common = tf.keras.layers.Dense(128, activation='tanh')
        self.actor  = tf.keras.layers.Dense(action_size, activation='softmax',
                                        kernel_initializer=tf.keras.initializers.RandomUniform(-1e-3,1e-3))
        self.critic = tf.keras.layers.Dense(1,
                                        kernel_initializer=tf.keras.initializers.RandomUniform(-1e-3,1e-3))
        
    def call(self, x):
        x      = self.common(x)
        policy = self.actor(x)
        value  = self.critic(x)
        return policy, value

In [3]:
class A2CAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size= action_size
        
        # Hyper params for learning
        self.discount_factor = 0.99
        self.learning_rate = 0.001
        
        self.model = A2C(self.state_size,self.action_size)
        self.optimizer = tf.keras.optimizers.Adam(lr=self.learning_rate)
        
    def get_action(self, state):
        policy, _ = self.model(state)
        policy = np.array(policy[0])
        return np.random.choice(self.action_size, 1, p=policy)[0]
        
    def train_model(self, state, action, reward, next_state, done):
        model_params = self.model.trainable_variables
        with tf.GradientTape() as tape:
            policy, value      = self.model(state)
            _,      next_value = self.model(next_state)
            target = reward + (1 - done) * self.discount_factor * next_value[0]
            
            # For policy network
            one_hot_action = tf.one_hot([action], self.action_size)
            action_prob = tf.reduce_sum(one_hot_action * policy, axis=1)
            cross_entropy = - tf.math.log(action_prob + 1e-5)
            advantage = tf.stop_gradient(target - value[0])
            actor_loss = tf.reduce_mean(cross_entropy * advantage)
            
            # For value network
            critic_loss = 0.5 * tf.square(tf.stop_gradient(target) - value[0])
            critic_loss = tf.reduce_mean(critic_loss)
            
            # integrate losses
            loss = 0.2 * actor_loss + critic_loss
            
        grads = tape.gradient(loss, model_params)
        self.optimizer.apply_gradients(zip(grads, model_params))
        return np.array(loss)

In [4]:
%matplotlib tk
if __name__ == "__main__":
    env = gym.make('CartPole-v1')
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n

    agent = A2CAgent(state_size, action_size)

    scores, episodes, losses = [], [], []
    score_avg = 0
    
    end = False
    
    fig = plt.figure(1)
    fig.clf()
    
    num_episode = 2000
    for e in range(num_episode):
        done = False
        score = 0
        loss_list = []
        
        state = env.reset()
        state = np.reshape(state, [1, state_size])
        
        while not done:
            env.render()

            action = agent.get_action(state)

            next_state, reward, done, info = env.step(action)
            next_state = np.reshape(next_state, [1, state_size])

            score += reward
            reward = 0.1 if not done or score == 500 else -1

            loss = agent.train_model(state, action, reward, next_state, done)
            loss_list.append(loss)

            state = next_state
            if done:
                
                score_avg = 0.9 * score_avg + 0.1 * score if score_avg != 0 else score
                print('epi: {:3d} | score avg {:3.2f} | loss: {:.4f}'.format(e, score_avg, np.mean(loss_list)))

                scores.append(score_avg)
                episodes.append(e)
                losses.append(np.mean(loss_list))
                plt.subplot(211)
                plt.plot(episodes, scores, 'b')
                plt.xlabel('episode')
                plt.ylabel('average score')
                plt.title('cartpole A2C')
                
                plt.subplot(212)
                plt.plot(episodes, losses, 'b')
                plt.xlabel('episode')
                plt.ylabel('loss')
                
                plt.savefig('./save_model/cartpole_a2c.png')

                if score_avg > 400:
                    agent.model.save_weights('./save_model/cartpole_a2c', save_format='tf')
                    end = True
                    break;
        if end == True:
            np.save('./save_model/cartpole_a2c_epi',episodes)
            np.save('./save_model/cartpole_a2c_score',scores)
            np.save('./save_model/cartpole_a2c_loss',losses)
            env.close()
            print("End")
            break;

epi:   0 | score avg 47.00 | loss: 0.0317
epi:   1 | score avg 43.40 | loss: 0.0757
epi:   2 | score avg 40.26 | loss: 0.0472
epi:   3 | score avg 37.73 | loss: 0.0576
epi:   4 | score avg 36.06 | loss: 0.0371
epi:   5 | score avg 36.85 | loss: 0.0327
epi:   6 | score avg 35.37 | loss: 0.0532
epi:   7 | score avg 34.43 | loss: 0.0380
epi:   8 | score avg 32.59 | loss: 0.0387
epi:   9 | score avg 31.63 | loss: 0.0322
epi:  10 | score avg 30.47 | loss: 0.0383
epi:  11 | score avg 30.32 | loss: 0.0792
epi:  12 | score avg 28.69 | loss: 0.0428
epi:  13 | score avg 27.22 | loss: 0.1112
epi:  14 | score avg 26.10 | loss: 0.0757
epi:  15 | score avg 25.79 | loss: 0.0578
epi:  16 | score avg 26.51 | loss: 0.0506
epi:  17 | score avg 25.46 | loss: 0.0793
epi:  18 | score avg 25.11 | loss: 0.0475
epi:  19 | score avg 25.10 | loss: 0.0410
epi:  20 | score avg 23.89 | loss: 0.0672
epi:  21 | score avg 23.00 | loss: 0.0565
epi:  22 | score avg 22.00 | loss: 0.0327
epi:  23 | score avg 22.00 | loss:

epi: 195 | score avg 23.94 | loss: 0.0430
epi: 196 | score avg 23.25 | loss: 0.0780
epi: 197 | score avg 22.52 | loss: 0.0936
epi: 198 | score avg 21.87 | loss: 0.0705
epi: 199 | score avg 21.18 | loss: 0.0546
epi: 200 | score avg 21.16 | loss: 0.0467
epi: 201 | score avg 20.25 | loss: 0.0574
epi: 202 | score avg 19.82 | loss: 0.2048
epi: 203 | score avg 21.04 | loss: 0.1030
epi: 204 | score avg 21.34 | loss: 0.1039
epi: 205 | score avg 20.80 | loss: 0.0789
epi: 206 | score avg 20.12 | loss: 0.1568
epi: 207 | score avg 19.41 | loss: 0.0789
epi: 208 | score avg 18.57 | loss: 0.1500
epi: 209 | score avg 18.71 | loss: 0.0984
epi: 210 | score avg 18.84 | loss: 0.0737
epi: 211 | score avg 18.66 | loss: 0.0984
epi: 212 | score avg 18.89 | loss: 0.0747
epi: 213 | score avg 18.40 | loss: 0.0872
epi: 214 | score avg 17.86 | loss: 0.0777
epi: 215 | score avg 17.38 | loss: 0.0911
epi: 216 | score avg 16.74 | loss: 0.0888
epi: 217 | score avg 15.96 | loss: 0.1053
epi: 218 | score avg 16.07 | loss:

epi: 391 | score avg 40.12 | loss: 0.1736
epi: 392 | score avg 40.31 | loss: 0.0855
epi: 393 | score avg 38.38 | loss: 0.1563
epi: 394 | score avg 38.44 | loss: 0.0797
epi: 395 | score avg 37.20 | loss: 0.1153
epi: 396 | score avg 43.78 | loss: 0.0834
epi: 397 | score avg 44.50 | loss: 0.1420
epi: 398 | score avg 44.95 | loss: 0.1433
epi: 399 | score avg 42.55 | loss: 0.2739
epi: 400 | score avg 40.70 | loss: 0.2107
epi: 401 | score avg 39.03 | loss: 0.1932
epi: 402 | score avg 39.13 | loss: 0.1134
epi: 403 | score avg 39.41 | loss: 0.1275
epi: 404 | score avg 39.47 | loss: 0.1129
epi: 405 | score avg 39.12 | loss: 0.1292
epi: 406 | score avg 38.11 | loss: 0.1592
epi: 407 | score avg 36.60 | loss: 0.1607
epi: 408 | score avg 36.24 | loss: 0.1031
epi: 409 | score avg 35.72 | loss: 0.1172
epi: 410 | score avg 37.05 | loss: 0.0727
epi: 411 | score avg 36.34 | loss: 0.0866
epi: 412 | score avg 36.91 | loss: 0.0713
epi: 413 | score avg 40.22 | loss: 0.1060
epi: 414 | score avg 44.09 | loss:

epi: 586 | score avg 54.45 | loss: 0.0088
epi: 587 | score avg 54.01 | loss: 0.0147
epi: 588 | score avg 53.01 | loss: 0.0016
epi: 589 | score avg 51.21 | loss: -0.0011
epi: 590 | score avg 49.99 | loss: -0.0002
epi: 591 | score avg 49.59 | loss: 0.0043
epi: 592 | score avg 49.33 | loss: 0.0014
epi: 593 | score avg 49.10 | loss: 0.0079
epi: 594 | score avg 48.39 | loss: 0.0413
epi: 595 | score avg 47.15 | loss: -0.0009
epi: 596 | score avg 45.53 | loss: 0.0036
epi: 597 | score avg 44.58 | loss: 0.0016
epi: 598 | score avg 43.22 | loss: 0.0033
epi: 599 | score avg 42.40 | loss: 0.0111
epi: 600 | score avg 41.26 | loss: -0.0017
epi: 601 | score avg 41.03 | loss: 0.0214
epi: 602 | score avg 40.73 | loss: 0.0290
epi: 603 | score avg 39.46 | loss: 0.0113
epi: 604 | score avg 38.51 | loss: 0.0048
epi: 605 | score avg 37.36 | loss: 0.0143
epi: 606 | score avg 37.22 | loss: 0.0192
epi: 607 | score avg 36.40 | loss: 0.0003
epi: 608 | score avg 36.36 | loss: 0.0051
epi: 609 | score avg 36.23 | l

epi: 778 | score avg 158.49 | loss: 0.0048
epi: 779 | score avg 159.74 | loss: 0.0059
epi: 780 | score avg 154.27 | loss: 0.0002
epi: 781 | score avg 151.44 | loss: 0.0033
epi: 782 | score avg 144.40 | loss: 0.0004
epi: 783 | score avg 139.26 | loss: -0.0011
epi: 784 | score avg 133.63 | loss: 0.0316
epi: 785 | score avg 129.07 | loss: 0.0199
epi: 786 | score avg 121.36 | loss: 0.0451
epi: 787 | score avg 117.93 | loss: 0.0030
epi: 788 | score avg 112.83 | loss: -0.0000
epi: 789 | score avg 108.25 | loss: 0.0475
epi: 790 | score avg 103.83 | loss: 0.0044
epi: 791 | score avg 100.04 | loss: 0.0015
epi: 792 | score avg 94.14 | loss: -0.0123
epi: 793 | score avg 94.52 | loss: 0.0555
epi: 794 | score avg 94.87 | loss: 0.0326
epi: 795 | score avg 91.79 | loss: 0.0259
epi: 796 | score avg 87.91 | loss: 0.0135
epi: 797 | score avg 85.42 | loss: 0.0069
epi: 798 | score avg 81.17 | loss: 0.0015
epi: 799 | score avg 80.26 | loss: 0.0168
epi: 800 | score avg 79.93 | loss: 0.0021
epi: 801 | score 

epi: 971 | score avg 140.12 | loss: 0.0027
epi: 972 | score avg 137.21 | loss: 0.0015
epi: 973 | score avg 134.09 | loss: 0.0084
epi: 974 | score avg 131.08 | loss: 0.0238
epi: 975 | score avg 131.87 | loss: 0.0334
epi: 976 | score avg 130.48 | loss: 0.0147
epi: 977 | score avg 128.84 | loss: 0.0216
epi: 978 | score avg 128.85 | loss: 0.0474
epi: 979 | score avg 130.27 | loss: 0.0422
epi: 980 | score avg 132.14 | loss: 0.0108
epi: 981 | score avg 133.13 | loss: 0.0170
epi: 982 | score avg 131.91 | loss: 0.0148
epi: 983 | score avg 130.72 | loss: 0.0221
epi: 984 | score avg 133.55 | loss: 0.0771
epi: 985 | score avg 139.60 | loss: 0.0719
epi: 986 | score avg 144.44 | loss: 0.0553
epi: 987 | score avg 148.69 | loss: 0.0497
epi: 988 | score avg 148.32 | loss: 0.0201
epi: 989 | score avg 148.09 | loss: 0.0158
epi: 990 | score avg 152.58 | loss: 0.0301
epi: 991 | score avg 161.42 | loss: 0.0317
epi: 992 | score avg 183.48 | loss: 0.0689
epi: 993 | score avg 215.13 | loss: 0.0838
epi: 994 | 

In [None]:
    num_episode = 1000
    for e in range(num_episode,2*num_episode):
        done = False
        score = 0
        loss_list = []
        
        state = env.reset()
        state = np.reshape(state, [1, state_size])
        
        while not done:
            env.render()

            action = agent.get_action(state)

            next_state, reward, done, info = env.step(action)
            next_state = np.reshape(next_state, [1, state_size])

            score += reward
            reward = 0.1 if not done or score == 500 else -1

            loss = agent.train_model(state, action, reward, next_state, done)
            loss_list.append(loss)

            state = next_state
            if done:
                
                score_avg = 0.9 * score_avg + 0.1 * score if score_avg != 0 else score
                print('epi: {:3d} | score avg {:3.2f} | loss: {:.4f}'.format(e, score_avg, np.mean(loss_list)))

                scores.append(score_avg)
                episodes.append(e)
                losses.append(np.mean(loss_list))
                plt.subplot(211)
                plt.plot(episodes, scores, 'b')
                plt.xlabel('episode')
                plt.ylabel('average score')
                plt.title('cartpole A2C')
                
                plt.subplot(212)
                plt.plot(episodes, losses, 'b')
                plt.xlabel('episode')
                plt.ylabel('loss')
                
                plt.savefig('./save_model/cartpole_a2c.png')

                if score_avg > 400:
                    agent.model.save_weights('./save_model/cartpole_a2c', save_format='tf')
                    end = True
                    break;
        if end == True:
            np.save('./save_model/cartpole_a2c_epi',episodes)
            np.save('./save_model/cartpole_a2c_score',scores)
            np.save('./save_model/cartpole_a2c_loss',losses)
            env.close()
            print("End")
            break;

epi: 1001 | score avg 308.70 | loss: 0.0611
epi: 1002 | score avg 299.43 | loss: 0.2241
epi: 1003 | score avg 280.78 | loss: 0.2205
epi: 1004 | score avg 302.71 | loss: 0.0805
epi: 1005 | score avg 322.44 | loss: 0.0863
epi: 1006 | score avg 340.19 | loss: 0.0855
