In [1]:
import gym
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense

In [2]:
class DQN(tf.keras.Model):
    def __init__(self, action_size):
        super(DQN, self).__init__()
        self.fc1 = Dense(24,activation='relu')
        self.fc2 = Dense(24,activation='relu')
        self.out = Dense(action_size, kernel_initializer=tf.keras.initializers.RandomUniform(-1e-3,1e-3))

    def call(self,x):
        x = self.fc1(x)
        x = self.fc2(x)
        q = self.out(x)
        return q

In [3]:
class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size= action_size
        
        self.model = DQN(action_size)
        self.model.load_weights("./save_model/cartpole_dqn_TF")
        
    def get_action(self, state):
        state = tf.convert_to_tensor([state], dtype=tf.float32)
        q_value = self.model(state)
        return np.argmax(q_value[0])

In [4]:
%matplotlib tk

ENV_NAME = 'CartPole-v1'
EPISODES = 10
# END_SCORE = 400

if __name__ == "__main__":
    env = gym.make(ENV_NAME)
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n

    agent = DQNAgent(state_size, action_size)
    print('Env Name : ',ENV_NAME)
    print('States {}, Actions {}'
            .format(state_size, action_size))

    scores, episodes = [], []
    score_avg = 0

    for e in range(EPISODES):
        # Episode initialization
        done = False
        score = 0

        state = env.reset()
        
        while not done:
            env.render()

            # Interact with env.
            action = agent.get_action(state)
            next_state, reward, done, info = env.step(action)
            state = next_state

            # 
            score += reward
            if done:
                print('epi: {:3d} | score {:3.2f}'.format(e+1, score))
                scores = np.append(scores,score)
    print('Avg. score {:4.2f}'.format(tf.reduce_mean(scores)))
    env.close()

Env Name :  CartPole-v1
States 4, Actions 2
epi:   1 | score 500.00
epi:   2 | score 500.00
epi:   3 | score 500.00
epi:   4 | score 500.00
epi:   5 | score 500.00
epi:   6 | score 500.00
epi:   7 | score 500.00
epi:   8 | score 500.00
epi:   9 | score 500.00
epi:  10 | score 500.00
Avg. score 500.00
