In [0]:
import random
import gym
import numpy as np
from collections import deque
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam

In [0]:
ENV_NAME = "CartPole-v1"

GAMMA = 0.95
LEARNING_RATE = 0.001

MEMORY_SIZE = 1000000
BATCH_SIZE = 20

EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.01
EXPLORATION_DECAY = 0.995

In [0]:
class Solver:

    def __init__(self, observation_space, action_space):
        self.exploration_rate = EXPLORATION_MAX

        self.action_space = action_space
        self.memory = deque(maxlen=MEMORY_SIZE)

        self.model = Sequential()
        self.model.add(Dense(24, input_shape=(observation_space,), activation="relu"))
        self.model.add(Dense(24, activation="relu"))
        self.model.add(Dense(self.action_space, activation="linear"))
        self.model.compile(loss="mse", optimizer=Adam(lr=LEARNING_RATE))

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        if np.random.rand() < self.exploration_rate:
            return random.randrange(self.action_space)
        q_values = self.model.predict(state)
        return np.argmax(q_values[0])

    def experience_replay(self):
        if len(self.memory) < BATCH_SIZE:
            return
        batch = random.sample(self.memory, BATCH_SIZE)
        for state, action, reward, state_next, terminal in batch:
            q_update = reward
            if not terminal:
                q_update = (reward + GAMMA * np.amax(self.model.predict(state_next)[0]))
            q_values = self.model.predict(state)
            q_values[0][action] = q_update
            self.model.fit(state, q_values, verbose=0)
        self.exploration_rate *= EXPLORATION_DECAY
        self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)

In [0]:
def cartpole():
    env = gym.make(ENV_NAME)
    observation_space = env.observation_space.shape[0]
    action_space = env.action_space.n
    solver = Solver(observation_space, action_space)
    run = 0
    steps = []
    STOP = False
    while True:
        if STOP:
            break
        run += 1
        state = env.reset()
        state = np.reshape(state, [1, observation_space])
        step = 0
        while True:
            step += 1
            action = solver.act(state)
            state_next, reward, terminal, info = env.step(action)
            reward = reward if not terminal else -reward
            state_next = np.reshape(state_next, [1, observation_space])
            solver.remember(state, action, reward, state_next, terminal)
            state = state_next
            if terminal:
                print("Run: " + str(run) + ", score: " + str(step), end="")
                steps.append(step)
                if run >= 100:
                    last100mean = np.mean(steps[-100:])
                    if last100mean >= 195:
                        STOP = True
                    else: pass
                    print(", last 100 runs mean: {:.2f}".format(last100mean))
                else:
                    print()
                break
            solver.experience_replay()

In [36]:
cartpole()

Run: 1, score: 28
Run: 2, score: 43
Run: 3, score: 10
Run: 4, score: 60
Run: 5, score: 15
Run: 6, score: 10
Run: 7, score: 11
Run: 8, score: 19
Run: 9, score: 10
Run: 10, score: 10
Run: 11, score: 12
Run: 12, score: 11
Run: 13, score: 13
Run: 14, score: 11
Run: 15, score: 12
Run: 16, score: 10
Run: 17, score: 10
Run: 18, score: 10
Run: 19, score: 10
Run: 20, score: 9
Run: 21, score: 9
Run: 22, score: 10
Run: 23, score: 10
Run: 24, score: 14
Run: 25, score: 8
Run: 26, score: 12
Run: 27, score: 11
Run: 28, score: 10
Run: 29, score: 9
Run: 30, score: 9
Run: 31, score: 9
Run: 32, score: 9
Run: 33, score: 9
Run: 34, score: 10
Run: 35, score: 9
Run: 36, score: 12
Run: 37, score: 10
Run: 38, score: 14
Run: 39, score: 10
Run: 40, score: 9
Run: 41, score: 11
Run: 42, score: 10
Run: 43, score: 10
Run: 44, score: 10
Run: 45, score: 10
Run: 46, score: 10
Run: 47, score: 10
Run: 48, score: 9
Run: 49, score: 10
Run: 50, score: 12
Run: 51, score: 9
Run: 52, score: 14
Run: 53, score: 19
Run: 54, score