In [1]:
import gym
import sys
import random
import numpy as np
import tensorflow as tf
from collections import deque
import matplotlib.pyplot as plt

In [2]:
class A2C(tf.keras.Model):
    def __init__(self, state_size, action_size):
        super(A2C, self).__init__()
        self.common = tf.keras.layers.Dense(128, activation='tanh')
        self.actor  = tf.keras.layers.Dense(action_size, activation='softmax',
                                        kernel_initializer=tf.keras.initializers.RandomUniform(-1e-3,1e-3))
        self.critic = tf.keras.layers.Dense(1,
                                        kernel_initializer=tf.keras.initializers.RandomUniform(-1e-3,1e-3))
        
    def call(self, x):
        x      = self.common(x)
        policy = self.actor(x)
        value  = self.critic(x)
        return policy, value

In [3]:
class A2CAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size= action_size
        
        # Hyper params for learning
        self.discount_factor = 0.99
        self.learning_rate = 0.001
        
        self.model = A2C(self.state_size,self.action_size)
        self.optimizer = tf.keras.optimizers.Adam(lr=self.learning_rate)
        
    def get_action(self, state):
        policy, _ = self.model(state)
        policy = np.array(policy[0])
        return np.random.choice(self.action_size, 1, p=policy)[0]
        
    def train_model(self, state, action, reward, next_state, done):
        model_params = self.model.trainable_variables
        with tf.GradientTape() as tape:
            policy, value      = self.model(state)
            _,      next_value = self.model(next_state)
            target = reward + (1 - done) * self.discount_factor * next_value[0]
            
            # For policy network
            one_hot_action = tf.one_hot([action], self.action_size)
            action_prob = tf.reduce_sum(one_hot_action * policy, axis=1)
            cross_entropy = - tf.math.log(action_prob + 1e-5)
            advantage = tf.stop_gradient(target - value[0])
            actor_loss = tf.reduce_mean(cross_entropy * advantage)
            
            # For value network
            critic_loss = 0.5 * tf.square(tf.stop_gradient(target) - value[0])
            critic_loss = tf.reduce_mean(critic_loss)
            
            # integrate losses
            loss = 0.2 * actor_loss + critic_loss
            
        grads = tape.gradient(loss, model_params)
        self.optimizer.apply_gradients(zip(grads, model_params))
        return np.array(loss)

In [4]:
%matplotlib tk
if __name__ == "__main__":
    env = gym.make('CartPole-v1')
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n

    agent = A2CAgent(state_size, action_size)

    scores, episodes, losses = [], [], []
    score_avg = 0
    
    end = False
    
    fig = plt.figure(1)
    fig.clf()
    
    num_episode = 2000
    for e in range(num_episode):
        done = False
        score = 0
        loss_list = []
        
        state = env.reset()
        state = np.reshape(state, [1, state_size])
        
        while not done:
            env.render()

            action = agent.get_action(state)

            next_state, reward, done, info = env.step(action)
            next_state = np.reshape(next_state, [1, state_size])

            score += reward
            reward = 0.1 if not done or score == 500 else -1

            loss = agent.train_model(state, action, reward, next_state, done)
            loss_list.append(loss)

            state = next_state
            if done:
                
                score_avg = 0.9 * score_avg + 0.1 * score if score_avg != 0 else score
                print('epi: {:3d} | score avg {:3.2f} | loss: {:.4f}'.format(e, score_avg, np.mean(loss_list)))

                scores.append(score_avg)
                episodes.append(e)
                losses.append(np.mean(loss_list))
                plt.subplot(211)
                plt.plot(episodes, scores, 'b')
                plt.xlabel('episode')
                plt.ylabel('average score')
                plt.title('cartpole A2C')
                
                plt.subplot(212)
                plt.plot(episodes, losses, 'b')
                plt.xlabel('episode')
                plt.ylabel('loss')
                
                plt.savefig('./save_model/cartpole_a2c.png')

                if score_avg > 400:
                    agent.model.save_weights('./save_model/cartpole_a2c', save_format='tf')
                    end = True
                    break;
        if end == True:
            np.save('./save_model/cartpole_a2c_epi',episodes)
            np.save('./save_model/cartpole_a2c_score',scores)
            np.save('./save_model/cartpole_a2c_loss',losses)
            env.close()
            print("End")
            break;

epi:   0 | score avg 24.00 | loss: 0.0500
epi:   1 | score avg 23.70 | loss: 0.0628
epi:   2 | score avg 22.53 | loss: 0.0830
epi:   3 | score avg 21.78 | loss: 0.0350
epi:   4 | score avg 21.10 | loss: 0.0641
epi:   5 | score avg 20.39 | loss: 0.0617
epi:   6 | score avg 20.55 | loss: 0.0367
epi:   7 | score avg 19.90 | loss: 0.0446
epi:   8 | score avg 19.11 | loss: 0.0754
epi:   9 | score avg 18.90 | loss: 0.0479
epi:  10 | score avg 18.61 | loss: 0.0381
epi:  11 | score avg 19.85 | loss: 0.0404
epi:  12 | score avg 20.36 | loss: 0.0441
epi:  13 | score avg 20.22 | loss: 0.0390
epi:  14 | score avg 20.20 | loss: 0.0374
epi:  15 | score avg 19.18 | loss: 0.0421
epi:  16 | score avg 18.36 | loss: 0.0104
epi:  17 | score avg 18.23 | loss: 0.1236
epi:  18 | score avg 17.90 | loss: 0.0345
epi:  19 | score avg 17.61 | loss: 0.0368
epi:  20 | score avg 17.85 | loss: 0.0863
epi:  21 | score avg 17.97 | loss: 0.0256
epi:  22 | score avg 18.07 | loss: 0.0392
epi:  23 | score avg 17.76 | loss:

epi: 196 | score avg 10.56 | loss: 0.0272
epi: 197 | score avg 10.60 | loss: 0.0494
epi: 198 | score avg 10.34 | loss: 0.0279
epi: 199 | score avg 10.31 | loss: 0.1425
epi: 200 | score avg 10.18 | loss: 0.1409
epi: 201 | score avg 10.16 | loss: 0.1212
epi: 202 | score avg 10.04 | loss: 0.1160
epi: 203 | score avg 10.04 | loss: 0.0399
epi: 204 | score avg 10.23 | loss: 0.0369
epi: 205 | score avg 11.51 | loss: 0.0570
epi: 206 | score avg 11.26 | loss: 0.0448
epi: 207 | score avg 11.13 | loss: 0.1056
epi: 208 | score avg 11.32 | loss: 0.0782
epi: 209 | score avg 12.19 | loss: 0.0426
epi: 210 | score avg 13.57 | loss: 0.0517
epi: 211 | score avg 13.21 | loss: 0.0788
epi: 212 | score avg 12.79 | loss: 0.0761
epi: 213 | score avg 13.11 | loss: 0.0620
epi: 214 | score avg 12.80 | loss: 0.0644
epi: 215 | score avg 12.82 | loss: 0.0545
epi: 216 | score avg 12.54 | loss: 0.0907
epi: 217 | score avg 12.98 | loss: 0.0618
epi: 218 | score avg 13.19 | loss: 0.0498
epi: 219 | score avg 12.77 | loss:

epi: 392 | score avg 48.12 | loss: 0.0385
epi: 393 | score avg 46.01 | loss: 0.2237
epi: 394 | score avg 43.41 | loss: 0.2613
epi: 395 | score avg 40.37 | loss: 0.3338
epi: 396 | score avg 38.53 | loss: 0.1752
epi: 397 | score avg 38.58 | loss: 0.1173
epi: 398 | score avg 42.72 | loss: 0.0722
epi: 399 | score avg 56.85 | loss: 0.0487
epi: 400 | score avg 60.86 | loss: 0.0628
epi: 401 | score avg 67.78 | loss: 0.0810
epi: 402 | score avg 65.80 | loss: 0.2240
epi: 403 | score avg 63.32 | loss: 0.3025
epi: 404 | score avg 71.99 | loss: 0.0660
epi: 405 | score avg 71.29 | loss: 0.1341
epi: 406 | score avg 71.26 | loss: 0.1080
epi: 407 | score avg 67.23 | loss: 0.2564
epi: 408 | score avg 63.51 | loss: 0.2393
epi: 409 | score avg 59.46 | loss: 0.2915
epi: 410 | score avg 58.31 | loss: 0.1105
epi: 411 | score avg 63.58 | loss: 0.1377
epi: 412 | score avg 63.52 | loss: 0.0906
epi: 413 | score avg 71.87 | loss: 0.1658
epi: 414 | score avg 71.58 | loss: 0.2327
epi: 415 | score avg 67.63 | loss:

epi: 587 | score avg 63.42 | loss: 0.0244
epi: 588 | score avg 62.48 | loss: 0.0398
epi: 589 | score avg 61.73 | loss: 0.0288
epi: 590 | score avg 61.06 | loss: 0.3494
epi: 591 | score avg 60.85 | loss: 0.2705
epi: 592 | score avg 61.27 | loss: 0.0324
epi: 593 | score avg 60.24 | loss: 0.3080
epi: 594 | score avg 66.52 | loss: 0.1354
epi: 595 | score avg 67.86 | loss: 0.0552
epi: 596 | score avg 64.28 | loss: 0.5022
epi: 597 | score avg 61.55 | loss: 0.3736
epi: 598 | score avg 60.30 | loss: 0.3049
epi: 599 | score avg 57.67 | loss: 0.3604
epi: 600 | score avg 58.90 | loss: 0.0660
epi: 601 | score avg 56.91 | loss: 0.2898
epi: 602 | score avg 55.12 | loss: 0.3131
epi: 603 | score avg 58.51 | loss: 0.0600
epi: 604 | score avg 69.46 | loss: 0.0389
epi: 605 | score avg 70.21 | loss: 0.1618
epi: 606 | score avg 69.09 | loss: 0.2567
epi: 607 | score avg 71.48 | loss: 0.1303
epi: 608 | score avg 73.23 | loss: 0.1139
epi: 609 | score avg 70.31 | loss: 0.1927
epi: 610 | score avg 67.18 | loss:

epi: 781 | score avg 109.92 | loss: 0.0007
epi: 782 | score avg 112.13 | loss: 0.0066
epi: 783 | score avg 114.01 | loss: 0.0105
epi: 784 | score avg 114.01 | loss: 0.0101
epi: 785 | score avg 113.61 | loss: 0.0254
epi: 786 | score avg 115.05 | loss: 0.0220
epi: 787 | score avg 114.75 | loss: 0.0063
epi: 788 | score avg 115.37 | loss: 0.0157
epi: 789 | score avg 115.83 | loss: 0.0092
epi: 790 | score avg 116.35 | loss: 0.0449
epi: 791 | score avg 118.02 | loss: 0.0206
epi: 792 | score avg 117.51 | loss: 0.0022
epi: 793 | score avg 116.76 | loss: 0.0068
epi: 794 | score avg 116.59 | loss: 0.0062
epi: 795 | score avg 116.83 | loss: 0.0050
epi: 796 | score avg 117.24 | loss: 0.0063
epi: 797 | score avg 117.42 | loss: 0.0086
epi: 798 | score avg 119.28 | loss: 0.0166
epi: 799 | score avg 119.25 | loss: 0.0347
epi: 800 | score avg 121.13 | loss: 0.0152
epi: 801 | score avg 122.11 | loss: 0.0283
epi: 802 | score avg 133.40 | loss: 0.1398
epi: 803 | score avg 136.46 | loss: 0.1725
epi: 804 | 

epi: 972 | score avg 173.60 | loss: 0.0177
epi: 973 | score avg 174.44 | loss: 0.0415
epi: 974 | score avg 178.49 | loss: 0.0266
epi: 975 | score avg 176.25 | loss: 0.0024
epi: 976 | score avg 174.72 | loss: 0.0012
epi: 977 | score avg 169.65 | loss: 0.0015
epi: 978 | score avg 167.38 | loss: 0.0008
epi: 979 | score avg 163.45 | loss: 0.0028
epi: 980 | score avg 162.40 | loss: -0.0001
epi: 981 | score avg 161.96 | loss: 0.0068
epi: 982 | score avg 167.46 | loss: 0.0032
epi: 983 | score avg 171.22 | loss: 0.0089
epi: 984 | score avg 182.70 | loss: 0.0491
epi: 985 | score avg 204.33 | loss: 0.0478
epi: 986 | score avg 219.69 | loss: 0.1347
epi: 987 | score avg 238.92 | loss: 0.0849
epi: 988 | score avg 239.23 | loss: 0.0945
epi: 989 | score avg 246.41 | loss: 0.0681
epi: 990 | score avg 246.37 | loss: 0.0561
epi: 991 | score avg 236.13 | loss: 0.0434
epi: 992 | score avg 226.52 | loss: 0.0174
epi: 993 | score avg 216.77 | loss: 0.0065
epi: 994 | score avg 207.09 | loss: 0.0001
epi: 995 |