In [1]:
import random
import gym
import numpy as np
from collections import deque
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
import tensorflow as tf

In [2]:
ENV_NAME = "CartPole-v0"

GAMMA = 0.95
LEARNING_RATE = 0.001

MEMORY_SIZE = 1000000
BATCH_SIZE = 20

n_epsilon=0.1

EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.01
EXPLORATION_DECAY = 0.995

In [3]:
class DQNSolver:

    def __init__(self, observation_space, action_space, num_envs):
        self.exploration_rate = EXPLORATION_MAX

        self.action_space = action_space
        self.num_envs = num_envs
        self.memory = deque(maxlen=MEMORY_SIZE)

        self.model = Sequential()
        self.model.add(Dense(24, input_shape=(observation_space,), activation="relu", use_bias='false', kernel_initializer='he_uniform'))
        self.model.add(Dense(24, activation="relu", use_bias='false', kernel_initializer='he_uniform'))
        self.model.add(Dense(self.action_space, activation="linear", use_bias='false', kernel_initializer='zeros'))
        self.optimizer = tf.keras.optimizers.SGD(LEARNING_RATE)

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        if np.random.rand() < self.exploration_rate:
            return np.random.randint(self.action_space, size=self.num_envs)
        q_values = self.model(state)
        return np.argmax(q_values, axis=1)
    
    def decay_epsilon(self, n):
#         self.exploration_rate *= EXPLORATION_DECAY
#         self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)
        self.exploration_rate = max(EXPLORATION_MIN, EXPLORATION_MAX - (n/n_epsilon)*(EXPLORATION_MAX - EXPLORATION_MIN))

    def experience_replay(self):
        if len(self.memory) < BATCH_SIZE:
            return
        batch = random.sample(self.memory, BATCH_SIZE)
        for state, action, reward, state_next, terminal in batch:
            with tf.GradientTape() as tape:
                Q_next = tf.stop_gradient(tf.reduce_max(self.model(state_next), axis=1))
                Q_pred = tf.reduce_sum(self.model(state)*tf.one_hot(action, self.action_space, dtype=tf.float32), axis=1)
                loss = tf.reduce_mean(0.5*(reward + (1-terminal)*GAMMA*Q_next - Q_pred)**2)
            grads = tape.gradient(loss, self.model.trainable_variables)
            self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))

In [4]:
class VectorizedEnvWrapper(gym.Wrapper):
    def __init__(self, make_env, num_envs=1):
        super().__init__(make_env())
        self.num_envs = num_envs
        self.envs = [make_env() for env_index in range(num_envs)]
    
    def reset(self):
        return np.asarray([env.reset() for env in self.envs])
    
    def reset_at(self, env_index):
        return self.envs[env_index].reset()
    
    def step(self, actions):
        next_states, rewards, dones, infos = [], [], [], []
        for env, action in zip(self.envs, actions):
            next_state, reward, done, info = env.step(action)
            next_states.append(next_state)
            rewards.append(reward)
            dones.append(done)
            infos.append(info)
        return np.asarray(next_states), np.asarray(rewards), \
            np.asarray(dones), np.asarray(infos)

In [5]:
def train(T=20000, num_envs=32):
    env = VectorizedEnvWrapper(lambda: gym.make(ENV_NAME), num_envs)
    observation_space = env.observation_space.shape[0]
    action_space = env.action_space.n
    dqn_solver = DQNSolver(observation_space, action_space, num_envs)
    rewards = []
    episode_rewards = 0
    state = env.reset()
    for t in range(T):
        action = dqn_solver.act(state)
        state_next, reward, terminal, info = env.step(action)
        dqn_solver.remember(state, action, reward, state_next, terminal)
        state = state_next
        dqn_solver.experience_replay()
        dqn_solver.decay_epsilon(t/T)
        episode_rewards += reward

        for i in range(env.num_envs):
            if terminal[i]:
                print("exploration: " + str(dqn_solver.exploration_rate) + ", score: " + str(episode_rewards[i]))
                rewards.append(episode_rewards[i])
                episode_rewards[i] = 0
                state[i] = env.reset_at(i)
    return episode_rewards

In [7]:
rewards = train(T=10000, num_envs=32)

exploration: 0.99208, score: 9.0
exploration: 0.9901, score: 11.0
exploration: 0.9901, score: 11.0
exploration: 0.9901, score: 11.0
exploration: 0.9901, score: 11.0
exploration: 0.98911, score: 12.0
exploration: 0.98911, score: 12.0
exploration: 0.98911, score: 12.0
exploration: 0.98812, score: 13.0
exploration: 0.98713, score: 14.0
exploration: 0.98614, score: 15.0
exploration: 0.98515, score: 16.0
exploration: 0.98515, score: 16.0
exploration: 0.98416, score: 17.0
exploration: 0.98317, score: 18.0
exploration: 0.98317, score: 18.0
exploration: 0.98218, score: 19.0
exploration: 0.98218, score: 19.0
exploration: 0.98218, score: 19.0
exploration: 0.98218, score: 19.0
exploration: 0.98119, score: 20.0
exploration: 0.9802, score: 21.0
exploration: 0.97822, score: 23.0
exploration: 0.97822, score: 23.0
exploration: 0.97723, score: 12.0
exploration: 0.97723, score: 11.0
exploration: 0.97624, score: 25.0
exploration: 0.97624, score: 25.0
exploration: 0.97525, score: 12.0
exploration: 0.97525

exploration: 0.83269, score: 9.0
exploration: 0.83269, score: 23.0
exploration: 0.8317, score: 23.0
exploration: 0.82972, score: 22.0
exploration: 0.82972, score: 14.0
exploration: 0.82972, score: 61.0
exploration: 0.82873, score: 46.0
exploration: 0.82873, score: 41.0
exploration: 0.82774, score: 53.0
exploration: 0.82774, score: 12.0
exploration: 0.82576, score: 12.0
exploration: 0.82477, score: 19.0
exploration: 0.82477, score: 16.0
exploration: 0.82378, score: 11.0
exploration: 0.82378, score: 34.0
exploration: 0.82279, score: 14.0
exploration: 0.82279, score: 19.0
exploration: 0.82279, score: 9.0
exploration: 0.82279, score: 11.0
exploration: 0.8218000000000001, score: 14.0
exploration: 0.8218000000000001, score: 13.0
exploration: 0.82081, score: 12.0
exploration: 0.82081, score: 12.0
exploration: 0.82081, score: 25.0
exploration: 0.81982, score: 23.0
exploration: 0.81982, score: 9.0
exploration: 0.81883, score: 11.0
exploration: 0.81883, score: 11.0
exploration: 0.81784, score: 1

exploration: 0.59806, score: 69.0
exploration: 0.59707, score: 36.0
exploration: 0.5960799999999999, score: 22.0
exploration: 0.59212, score: 57.0
exploration: 0.58816, score: 53.0
exploration: 0.58717, score: 18.0
exploration: 0.5861800000000001, score: 46.0
exploration: 0.5812300000000001, score: 42.0
exploration: 0.5812300000000001, score: 43.0
exploration: 0.5802400000000001, score: 72.0
exploration: 0.57826, score: 39.0
exploration: 0.57826, score: 68.0
exploration: 0.57826, score: 35.0
exploration: 0.57826, score: 51.0
exploration: 0.5762800000000001, score: 28.0
exploration: 0.5752900000000001, score: 39.0
exploration: 0.56935, score: 145.0
exploration: 0.5614300000000001, score: 174.0
exploration: 0.56044, score: 62.0
exploration: 0.56044, score: 46.0
exploration: 0.5574700000000001, score: 66.0
exploration: 0.55549, score: 23.0
exploration: 0.55351, score: 70.0
exploration: 0.5515300000000001, score: 81.0
exploration: 0.5515300000000001, score: 128.0
exploration: 0.54559, scor

KeyboardInterrupt: 