In [0]:
import gym
gym.logger.set_level(40) # suppress warnings (please remove if gives error)
import numpy as onp
from collections import deque
import matplotlib.pyplot as plt
%matplotlib inline
from jax import jacobian, lax
import jax
import jax.numpy as np
from jax.ops import index, index_add, index_update
from jax import jit, grad, vmap, random, jacrev, jacobian, jacfwd, value_and_grad
from functools import partial
from jax.experimental import stax # neural network library
from jax.experimental.stax import GeneralConv, Conv, ConvTranspose, Dense, MaxPool, Relu, Flatten, LogSoftmax, LeakyRelu, Dropout, Tanh, Sigmoid, BatchNorm, Softmax # neural network layers
from jax.nn import softmax, sigmoid
from jax.nn.initializers import zeros
from jax.experimental import optimizers
from jax.tree_util import tree_multimap  # Element-wise manipulation of collections of numpy arrays

In [0]:
from collections import deque
class Memory():
    def __init__(self, max_size = 1000):
        self.buffer = deque(maxlen=max_size)
    
    def add(self, experience):
        self.buffer.append(experience)
            
    def sample(self, batch_size):
        idx = onp.random.choice(onp.arange(len(self.buffer)), 
                               size=batch_size, 
                               replace=False)
        return [self.buffer[ii] for ii in idx]

In [0]:
train_episodes = 1000          # max number of episodes to learn from
max_steps = 200                # max steps in an episode
gamma = 0.99                   # future reward discount

# Exploration parameters
explore_start = 1.0            # exploration probability at start
explore_stop = 0.01            # minimum exploration probability 
decay_rate = 0.0001            # exponential decay rate for exploration prob

# Network parameters
hidden_size = 64               # number of units in each Q-network hidden layer
learning_rate = 0.0001         # Q-network learning rate

# Memory parameters
memory_size = 10000            # memory capacity
batch_size = 20                # experience mini-batch size
pretrain_length = batch_size   # number experiences to pretrain the memory

In [0]:
#Define Q-network
class QNetwork:
    def __init__(self, rng, env, learning_rate=0.01, state_size=4, 
                 action_size=2, hidden_size=10, 
                 name='QNetwork'):
        self.key = rng
        self.init_fun, self.apply_fun = stax.serial(
            Dense(hidden_size), Relu,
            Dense(hidden_size), Relu,
            Dense(action_size)
        )
        self.in_shape = (-1, state_size)
        _, self.net_params = self.init_fun(self.key, self.in_shape)
        self.opt_init, self.opt_update, self.get_params = optimizers.adam(step_size=learning_rate)
        self.opt_state = self.opt_init(self.net_params)
        self.env = env
        self.loss = np.inf

    def loss_fun(self, params, inputs, actions, targets):
        output = self.apply_fun(params, inputs)
        selectedq = np.sum(actions*output, axis=-1)
        # print(inputs.shape, actions.shape, output.shape, selectedq.shape, targets.shape)
        return np.mean(np.square(selectedq - targets))

    def output(self, inputs):
        return self.apply_fun(self.net_params, inputs)

    def update_key(self):
        self.key, _ = jax.random.split(self.key)
        return self.key

    def act(self, state, explore_p):
        uf = jax.random.uniform(self.update_key(), (1,), minval=0.0, maxval=1.0)[0]
        # print(self.key, uf)
        if explore_p > uf:
            # Make a random action
            action = self.env.action_space.sample()
        else:
            # Get action from Q-network
            action = np.argmax(self.output(state)).item()
        return action

    def step(self, i, inputs, actions, targets):
        params = self.get_params(self.opt_state)
        self.loss, g = value_and_grad(self.loss_fun)(params, inputs, actions, targets)
        self.opt_state = self.opt_update(i, g, self.opt_state)
        self.net_params = self.get_params(self.opt_state)

In [0]:
def init_memory(env):
    # Initialize the simulation
    env.reset()
    # Take one random step to get the pole and cart moving
    state, reward, done, _ = env.step(env.action_space.sample())

    memory = Memory(max_size=memory_size)

    # Make a bunch of random actions and store the experiences
    for ii in range(pretrain_length):
        # Uncomment the line below to watch the simulation
        # env.render()

        # Make a random action
        action = env.action_space.sample()
        next_state, reward, done, _ = env.step(action)

        if done:
            # The simulation fails so no next state
            next_state = np.zeros(state.shape)
            # Add experience to memory
            memory.add((state, action, reward, next_state))
            
            # Start new episode
            env.reset()
            # Take one random step to get the pole and cart moving
            state, reward, done, _ = env.step(env.action_space.sample())
        else:
            # Add experience to memory
            memory.add((state, action, reward, next_state))
            state = next_state
    return memory, state

In [0]:
# Now train with experiences
def one_hot(x, k, dtype=np.float32):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), dtype)


def train(env, mainQN):
    rewards_list = []    
    step = 0
    memory, state = init_memory(env)
    for ep in range(1, train_episodes):
        total_reward = 0
        t = 0
        while t < max_steps:
            step += 1
            # Uncomment this next line to watch the training
            # env.render() 
            
            # Explore or Exploit
            explore_p = explore_stop + (explore_start - explore_stop)*np.exp(-decay_rate*step) 
            action = mainQN.act(state, explore_p)
            
            # Take action, get new state and reward
            next_state, reward, done, _ = env.step(action)

            total_reward += reward
            
            if done:
                # the episode ends so no next state
                next_state = np.zeros(state.shape)
                t = max_steps
                
                print('Episode: {}'.format(ep),
                    'Total reward: {}'.format(total_reward),
                    'Training loss: {:.4f}'.format(loss),
                    'Explore P: {:.4f}'.format(explore_p))
                rewards_list.append((ep, total_reward))
                
                # Add experience to memory
                memory.add((state, action, reward, next_state))
                
                # Start new episode
                env.reset()
                # Take one random step to get the pole and cart moving
                state, reward, done, _ = env.step(env.action_space.sample())

            else:
                # Add experience to memory
                memory.add((state, action, reward, next_state))
                state = next_state
                t += 1
            
            # Sample mini-batch from memory
            batch = memory.sample(batch_size)
            states = np.array([each[0] for each in batch])
            actions = one_hot(np.array([each[1] for each in batch]), 2)
            rewards = np.array([each[2] for each in batch])
            next_states = np.array([each[3] for each in batch])
            
            # Train network
            target_Qs = mainQN.output(next_states)
            
            # Set target_Qs to 0 for states where episode ends
            episode_ends = (next_states == np.zeros(states[0].shape)).all(axis=1)
            new_target_Qs = index_update(target_Qs, index[episode_ends], (0, 0))
            target_Qs = new_target_Qs
            
            targets = rewards + gamma * np.max(target_Qs, axis=1)
            # print(states.shape, targets.shape, targets)
            mainQN.step((ep-1)*train_episodes+step-1, states, actions, targets)

            loss = mainQN.loss

    return rewards_list

def plot_scores(rewards_list):
    def running_mean(x, N):
        cumsum = np.cumsum(np.insert(x, 0, 0)) 
        return (cumsum[N:] - cumsum[:-N]) / N
    eps, rews = np.array(rewards_list).T
    smoothed_rews = running_mean(rews, 10)
    plt.plot(eps[-len(smoothed_rews):], smoothed_rews)
    plt.plot(eps, rews, color='grey', alpha=0.3)
    plt.xlabel('Episode')
    plt.ylabel('Total Reward')
    plt.show()

In [0]:
def main():
    seed = 0
    env = gym.make('CartPole-v0')
    env.seed(seed)
    print('observation space:', env.observation_space)
    print('action space:', env.action_space.n)
    rng = jax.random.PRNGKey(seed)
    onp.random.seed(seed)
    mainQN = QNetwork(rng, env, name='main', hidden_size=hidden_size, learning_rate=learning_rate)
    rewards_list = train(env, mainQN)
    plot_scores(rewards_list)

if __name__ == '__main__':
    main()

observation space: Box(4,)
action space: 2
Episode: 1 Total reward: 29.0 Training loss: 1.0450 Explore P: 0.9971
Episode: 2 Total reward: 30.0 Training loss: 1.0738 Explore P: 0.9942
Episode: 3 Total reward: 7.0 Training loss: 0.9976 Explore P: 0.9935
Episode: 4 Total reward: 18.0 Training loss: 1.1242 Explore P: 0.9917
Episode: 5 Total reward: 30.0 Training loss: 1.2784 Explore P: 0.9888
Episode: 6 Total reward: 23.0 Training loss: 1.1642 Explore P: 0.9865
Episode: 7 Total reward: 11.0 Training loss: 1.2198 Explore P: 0.9855
Episode: 8 Total reward: 11.0 Training loss: 1.7343 Explore P: 0.9844
Episode: 9 Total reward: 15.0 Training loss: 3.4449 Explore P: 0.9829
Episode: 10 Total reward: 43.0 Training loss: 1.9305 Explore P: 0.9787
Episode: 11 Total reward: 38.0 Training loss: 2.1615 Explore P: 0.9751
Episode: 12 Total reward: 11.0 Training loss: 1.9250 Explore P: 0.9740
Episode: 13 Total reward: 52.0 Training loss: 2.7796 Explore P: 0.9690
Episode: 14 Total reward: 13.0 Training loss