In [1]:
import random
import gym
import numpy as np
from collections import deque
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
import tensorflow as tf

In [2]:
ENV_NAME = "CartPole-v0"

MEMORY_SIZE = 1000000
BATCH_SIZE = 20

In [3]:
class Agent:

    def __init__(self, observation_space, num_actions, num_envs, alpha=0.001, gamma=0.95, epsilon_i=1.0, epsilon_f=0.01, n_epsilon=0.1):
        self.epsilon_i = epsilon_i
        self.epsilon_f = epsilon_f
        self.n_epsilon = n_epsilon
        self.epsilon = epsilon_i
        self.gamma = gamma

        self.num_actions = num_actions
        self.num_envs = num_envs
        self.memory = deque(maxlen=MEMORY_SIZE)

        self.Q = Sequential()
        self.Q.add(Dense(24, input_shape=(observation_space,), activation="relu", use_bias='false', kernel_initializer='he_uniform'))
        self.Q.add(Dense(24, activation="relu", use_bias='false', kernel_initializer='he_uniform'))
        self.Q.add(Dense(self.num_actions, activation="linear", use_bias='false', kernel_initializer='zeros'))
        self.optimizer = tf.keras.optimizers.SGD(alpha)

    def remember(self, s_t, a_t, r_t, s_t_next, done):
        self.memory.append((s_t, a_t, r_t, s_t_next, done))

    def act(self, s_t):
        if np.random.rand() < self.epsilon:
            return np.random.randint(self.num_actions, size=self.num_envs)
        q_values = self.Q(s_t)
        return np.argmax(q_values, axis=1)
    
    def decay_epsilon(self, n):
        self.epsilon = max(
            self.epsilon_f, 
            self.epsilon_i - (n/self.n_epsilon)*(self.epsilon_i - self.epsilon_f))

    def experience_replay(self):
        if len(self.memory) < BATCH_SIZE:
            return
        batch = random.sample(self.memory, BATCH_SIZE)
        for s_t, a_t, r_t, s_t_next, d_t in batch:
            with tf.GradientTape() as tape:
                Q_next = tf.stop_gradient(tf.reduce_max(self.Q(s_t_next), axis=1))
                Q_pred = tf.reduce_sum(self.Q(s_t)*tf.one_hot(a_t, self.num_actions, dtype=tf.float32), axis=1)
                loss = tf.reduce_mean(0.5*(r_t + (1-d_t)*self.gamma*Q_next - Q_pred)**2)
            grads = tape.gradient(loss, self.Q.trainable_variables)
            self.optimizer.apply_gradients(zip(grads, self.Q.trainable_variables))

In [4]:
class VectorizedEnvWrapper(gym.Wrapper):
    def __init__(self, make_env, num_envs=1):
        super().__init__(make_env())
        self.num_envs = num_envs
        self.envs = [make_env() for env_index in range(num_envs)]
    
    def reset(self):
        return np.asarray([env.reset() for env in self.envs])
    
    def reset_at(self, env_index):
        return self.envs[env_index].reset()
    
    def step(self, actions):
        next_states, rewards, dones, infos = [], [], [], []
        for env, action in zip(self.envs, actions):
            next_state, reward, done, info = env.step(action)
            next_states.append(next_state)
            rewards.append(reward)
            dones.append(done)
            infos.append(info)
        return np.asarray(next_states), np.asarray(rewards), \
            np.asarray(dones), np.asarray(infos)

In [5]:
def train(T=20000, num_envs=32):
    env = VectorizedEnvWrapper(lambda: gym.make(ENV_NAME), num_envs)
    observation_space = env.observation_space.shape[0]
    num_actions = env.action_space.n
    agent = Agent(observation_space, num_actions, num_envs)
    rewards = []
    episode_rewards = 0
    s_t = env.reset()
    for t in range(T):
        a_t = agent.act(s_t)
        s_t_next, r_t, d_t, info = env.step(a_t)
        agent.remember(s_t, a_t, r_t, s_t_next, d_t)
        s_t = s_t_next
        agent.experience_replay()
        agent.decay_epsilon(t/T)
        episode_rewards += r_t

        for i in range(env.num_envs):
            if d_t[i]:
                print("exploration: " + str(agent.epsilon) + ", score: " + str(episode_rewards[i]))
                rewards.append(episode_rewards[i])
                episode_rewards[i] = 0
                s_t[i] = env.reset_at(i)
    return episode_rewards

In [6]:
r_t = train(T=10000, num_envs=32)

W0610 17:01:20.609395 4556862912 deprecation.py:323] From /anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py:1205: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


exploration: 0.99109, score: 10.0
exploration: 0.99109, score: 10.0
exploration: 0.9901, score: 11.0
exploration: 0.98911, score: 12.0
exploration: 0.98812, score: 13.0
exploration: 0.98812, score: 13.0
exploration: 0.98812, score: 13.0
exploration: 0.98614, score: 15.0
exploration: 0.98614, score: 15.0
exploration: 0.98614, score: 15.0
exploration: 0.98416, score: 17.0
exploration: 0.98218, score: 19.0
exploration: 0.98218, score: 19.0
exploration: 0.98119, score: 20.0
exploration: 0.98119, score: 20.0
exploration: 0.98119, score: 20.0
exploration: 0.97921, score: 22.0
exploration: 0.97723, score: 24.0
exploration: 0.97723, score: 24.0
exploration: 0.97723, score: 14.0
exploration: 0.97723, score: 24.0
exploration: 0.97723, score: 24.0
exploration: 0.97723, score: 24.0
exploration: 0.97624, score: 12.0
exploration: 0.97525, score: 15.0
exploration: 0.97327, score: 16.0
exploration: 0.97327, score: 28.0
exploration: 0.97327, score: 28.0
exploration: 0.97228, score: 9.0
exploration: 0.9

exploration: 0.8406100000000001, score: 17.0
exploration: 0.8376399999999999, score: 40.0
exploration: 0.8356600000000001, score: 22.0
exploration: 0.8356600000000001, score: 25.0
exploration: 0.8356600000000001, score: 17.0
exploration: 0.83368, score: 17.0
exploration: 0.83368, score: 39.0
exploration: 0.83368, score: 30.0
exploration: 0.83368, score: 52.0
exploration: 0.83368, score: 31.0
exploration: 0.83269, score: 32.0
exploration: 0.8317, score: 12.0
exploration: 0.8307100000000001, score: 39.0
exploration: 0.82774, score: 22.0
exploration: 0.82774, score: 31.0
exploration: 0.82576, score: 41.0
exploration: 0.82477, score: 34.0
exploration: 0.82477, score: 26.0
exploration: 0.82279, score: 46.0
exploration: 0.82081, score: 25.0
exploration: 0.81883, score: 15.0
exploration: 0.81883, score: 30.0
exploration: 0.81883, score: 48.0
exploration: 0.81784, score: 18.0
exploration: 0.81784, score: 16.0
exploration: 0.8168500000000001, score: 26.0
exploration: 0.8168500000000001, score: 

exploration: 0.4525300000000001, score: 23.0
exploration: 0.4475800000000001, score: 100.0
exploration: 0.4406500000000001, score: 81.0
exploration: 0.43867, score: 77.0
exploration: 0.42976000000000003, score: 110.0
exploration: 0.4287700000000001, score: 190.0
exploration: 0.4238200000000001, score: 90.0
exploration: 0.42184, score: 84.0
exploration: 0.41095000000000004, score: 200.0
exploration: 0.41095000000000004, score: 148.0
exploration: 0.40501, score: 104.0
exploration: 0.39907000000000004, score: 107.0
exploration: 0.39214000000000004, score: 109.0
exploration: 0.38521000000000005, score: 200.0
exploration: 0.38322999999999996, score: 28.0
exploration: 0.37432, score: 92.0
exploration: 0.36937, score: 157.0
exploration: 0.3624400000000001, score: 129.0
exploration: 0.36046, score: 107.0
exploration: 0.35056, score: 162.0
exploration: 0.35056, score: 154.0
exploration: 0.34957000000000005, score: 80.0
exploration: 0.3485800000000001, score: 200.0
exploration: 0.333730000000000

exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 168.0
exploration: 0.01, score: 194.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 179.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 182.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
exploration: 0.01, score: 200.0
explorat

KeyboardInterrupt: 