In [1]:
import numpy as np
import tensorflow as tf
import random
from collections import deque
import gym

env = gym.make('CartPole-v1')

input_size = env.observation_space.shape[0]
output_size = env.action_space.n

dis = 0.9
REPLAY_MEMORY = 50000

class DQN:
    def __init__(self, input_size, output_size):
        self.input_size = input_size
        self.output_size = output_size
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
        self.model = self.build_model()

    def build_model(self):
        model = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(self.input_size,)),
            tf.keras.layers.Dense(24, activation='relu'),
            tf.keras.layers.Dense(24, activation='relu'),
            tf.keras.layers.Dense(self.output_size)
        ])
        model.compile(optimizer=self.optimizer, loss='mse')
        return model

    def predict(self, state):
        state = np.reshape(state, [-1, self.input_size])
        return self.model(state)

    @tf.function
    def update(self, x_stack, y_stack):
        with tf.GradientTape() as tape:
            q_values = self.model(x_stack)
            loss = tf.reduce_mean(tf.square(y_stack - q_values))
        grads = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
        return loss

def simple_replay_train(DQN, train_batch):
    x_stack = np.empty(0).reshape(0, DQN.input_size)
    y_stack = np.empty(0).reshape(0, DQN.output_size)

    for state, action, reward, next_state, done in train_batch:
        Q = DQN.predict(state).numpy()

        if done:
            Q[0, action] = reward
        else:
            Q[0, action] = reward + dis * np.max(DQN.predict(next_state))

        y_stack = np.vstack([y_stack, Q])
        x_stack = np.vstack([x_stack, state])

    y_stack = y_stack.astype(np.float32)

    return DQN.update(x_stack, y_stack)

def bot_play(mainDQN):
    s = env.reset()[0]
    reward_sum = 0

    while True:
        env.render()
        a = np.argmax(mainDQN.predict(s))
        s, reward, done, _, _ = env.step(a)
        reward_sum += reward

        if done:
            print("Total score: {}".format(reward_sum))
            break

In [2]:
def main():
    max_episodes = 5000
    replay_buffer = deque(maxlen=REPLAY_MEMORY)

    mainDQN = DQN(input_size, output_size)

    for episode in range(max_episodes):
        e = 1. / ((episode / 10) + 1)
        done = False
        step_count = 0

        state = env.reset()[0]

        while not done:
            if np.random.rand(1) < e:
                action = env.action_space.sample()
            else:
                action = np.argmax(mainDQN.predict(state))

            next_state, reward, done, _, _ = env.step(action)

            if done:
                reward = -100

            replay_buffer.append((state, action, reward, next_state, done))

            state = next_state
            step_count += 1
            if step_count > 10000:
                break

        print(f"Episode: {episode} steps: {step_count}")

        if episode % 10 == 1 and len(replay_buffer) >= 10:
            for _ in range(50):
                minibatch = random.sample(replay_buffer, 10)
                loss = simple_replay_train(mainDQN, minibatch)
            print("Loss: ", loss.numpy())

        if episode % 50 == 0:
            bot_play(mainDQN)

    env.close()

In [3]:
main()

  if not isinstance(terminated, (bool, np.bool8)):
  gym.logger.warn(


Episode: 0 steps: 30
Total score: 9.0
Episode: 1 steps: 10
Loss:  491.50177
Episode: 2 steps: 24
Episode: 3 steps: 27
Episode: 4 steps: 18
Episode: 5 steps: 11
Episode: 6 steps: 14
Episode: 7 steps: 19
Episode: 8 steps: 10
Episode: 9 steps: 13
Episode: 10 steps: 28
Episode: 11 steps: 10
Loss:  0.5737263
Episode: 12 steps: 13
Episode: 13 steps: 14
Episode: 14 steps: 13
Episode: 15 steps: 12
Episode: 16 steps: 9
Episode: 17 steps: 11
Episode: 18 steps: 15
Episode: 19 steps: 13
Episode: 20 steps: 12
Episode: 21 steps: 10
Loss:  942.6343
Episode: 22 steps: 10
Episode: 23 steps: 12
Episode: 24 steps: 14
Episode: 25 steps: 9
Episode: 26 steps: 12
Episode: 27 steps: 12
Episode: 28 steps: 11
Episode: 29 steps: 10
Episode: 30 steps: 14
Episode: 31 steps: 11
Loss:  1.4555897
Episode: 32 steps: 13
Episode: 33 steps: 12
Episode: 34 steps: 10
Episode: 35 steps: 10
Episode: 36 steps: 14
Episode: 37 steps: 10
Episode: 38 steps: 10
Episode: 39 steps: 9
Episode: 40 steps: 10
Episode: 41 steps: 10
Loss:

KeyboardInterrupt: 