<a href="https://colab.research.google.com/github/anirbanl/jax-code/blob/master/rlflax/jax_double_dqn_cartpole.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Based on sources:
1. https://arxiv.org/pdf/1509.06461.pdf
2. https://github.com/higgsfield/RL-Adventure/blob/master/2.double%20dqn.ipynb
3. https://medium.com/@parsa_h_m/deep-reinforcement-learning-dqn-double-dqn-dueling-dqn-noisy-dqn-and-dqn-with-prioritized-551f621a9823

In [1]:
!pip install jax jaxlib flax

Collecting flax
[?25l  Downloading https://files.pythonhosted.org/packages/f6/21/21ca1f4831ac24646578d2545c4db9a8369b9da4a4b7dcf067feee312b45/flax-0.3.4-py3-none-any.whl (183kB)
[K     |████████████████████████████████| 184kB 7.3MB/s 
Collecting optax
[?25l  Downloading https://files.pythonhosted.org/packages/ca/04/464fa1d12562d191196f2f7f8112d65e22eaaa9a7e2b599f298aeba2ce27/optax-0.0.8-py3-none-any.whl (113kB)
[K     |████████████████████████████████| 122kB 26.4MB/s 
Collecting chex>=0.0.4
[?25l  Downloading https://files.pythonhosted.org/packages/f5/b9/445eb59ec23249acffc5322c79b07e20b12dbff45b9c1da6cdae9e947685/chex-0.0.7-py3-none-any.whl (52kB)
[K     |████████████████████████████████| 61kB 9.0MB/s 
Installing collected packages: chex, optax, flax
Successfully installed chex-0.0.7 flax-0.3.4 optax-0.0.8


In [2]:
import gym
gym.logger.set_level(40) # suppress warnings (please remove if gives error)
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
%matplotlib inline
import jax
import jax.numpy as jp
from jax.ops import index, index_add, index_update
from jax import jit, grad, vmap, random, jacrev, jacobian, jacfwd, value_and_grad
from functools import partial
from jax.tree_util import tree_multimap  # Element-wise manipulation of collections of numpy arrays
from flax import linen as nn           # The Linen API
from flax.training import train_state  # Useful dataclass to keep train state
import optax                           # Optimizers
from typing import Sequence
import copy

In [3]:
from collections import deque
class Memory():
    def __init__(self, max_size = 1000):
        self.buffer = deque(maxlen=max_size)
    
    def add(self, experience):
        self.buffer.append(experience)
            
    def sample(self, key, batch_size):
        key, _ = jax.random.split(key)
        idx = jax.random.choice(key,
                               jp.arange(len(self.buffer)), 
                               shape=(batch_size, ))
        return [self.buffer[ii] for ii in idx]

In [4]:
train_episodes = 1000          # max number of episodes to learn from
max_steps = 200                # max steps in an episode
gamma = 0.99                   # future reward discount
update_target_every = 20       # Update target Q model every this episodes

# Exploration parameters
explore_start = 1.0            # exploration probability at start
explore_stop = 0.01            # minimum exploration probability 
decay_rate = 0.0001            # exponential decay rate for exploration prob

# Network parameters
hidden_size = 64               # number of units in each Q-network hidden layer
learning_rate = 1e-3         # Q-network learning rate

# Memory parameters
memory_size = 10000            # memory capacity
batch_size = 20                # experience mini-batch size
pretrain_length = batch_size   # number experiences to pretrain the memory

In [5]:
#Define Q-network
class QNetwork:
    def __init__(self, rng, env, learning_rate=0.01, state_size=4, 
                 action_size=2, hidden_size=10, 
                 name='QNetwork'):
        self.key = rng
        self.env = env

        class Model(nn.Module):
            features: Sequence[int]

            @nn.compact
            def __call__(self, x):
                x = nn.relu(nn.Dense(self.features[0])(x))
                x = nn.relu(nn.Dense(self.features[1])(x))
                x = nn.Dense(self.features[2])(x)
                return x

        def create_train_state(rng, learning_rate, s_size, h_size, a_size):
            """Creates initial `TrainState`."""
            model = Model(features=[hidden_size, hidden_size, a_size])
            params = model.init(rng, jp.ones((s_size, )))#['params']
            tx = optax.adam(learning_rate)
            return train_state.TrainState.create(
                apply_fn=model.apply, params=params, tx=tx)

        self.ts = create_train_state(rng, learning_rate, state_size, hidden_size, action_size)

        @jit
        def train_step(ts, inputs, actions, targets):

            def loss_fun(params, inputs, actions, targets):
                output = ts.apply_fn(params, inputs)
                selectedq = jp.sum(actions*output, axis=-1)
                diff = selectedq - jax.lax.stop_gradient(targets)
                return jp.mean(diff**2)

            loss, g = value_and_grad(loss_fun)(ts.params, inputs, actions, targets)
            return ts.apply_gradients(grads=g), loss

        self.train_fn = train_step


    def act(self, state, explore_p):
        self.key, _ = jax.random.split(self.key)
        uf = jax.random.uniform(self.key, (1,), minval=0.0, maxval=1.0)[0]
        if explore_p > uf:
            # Make a random action
            action = self.env.action_space.sample()
        else:
            # Get action from Q-network
            qvalues = self.ts.apply_fn(self.ts.params, state)
            action = jp.argmax(qvalues).item()
        return action



In [6]:
def init_memory(env):
    # Initialize the simulation
    env.reset()
    # Take one random step to get the pole and cart moving
    state, reward, done, _ = env.step(env.action_space.sample())

    memory = Memory(max_size=memory_size)

    # Make a bunch of random actions and store the experiences
    for ii in range(pretrain_length):
        # Uncomment the line below to watch the simulation
        # env.render()

        # Make a random action
        action = env.action_space.sample()
        next_state, reward, done, _ = env.step(action)

        if done:
            # The simulation fails so no next state
            next_state = jp.zeros(state.shape)
            # Add experience to memory
            memory.add((state, action, reward, next_state))
            
            # Start new episode
            env.reset()
            # Take one random step to get the pole and cart moving
            state, reward, done, _ = env.step(env.action_space.sample())
        else:
            # Add experience to memory
            memory.add((state, action, reward, next_state))
            state = next_state
    return memory, state

In [7]:
# Now train with experiences
def one_hot(x, k, dtype=jp.float32):
  """Create a one-hot encoding of x of size k."""
  return jp.array(x[:, None] == jp.arange(k), dtype)


def train(rng, env, mainQN):
    rewards_list = []    
    step = 0
    memory, state = init_memory(env)
    current_params = mainQN.ts.params
    target_params = copy.deepcopy(mainQN.ts.params)
    for ep in range(1, train_episodes):
        total_reward = 0
        t = 0
        while t < max_steps:
            step += 1
            # Uncomment this next line to watch the training
            # env.render() 
            
            # Explore or Exploit
            explore_p = explore_stop + (explore_start - explore_stop)*jp.exp(-decay_rate*step) 
            action = mainQN.act(state, explore_p)
            
            # Take action, get new state and reward
            next_state, reward, done, _ = env.step(action)

            total_reward += reward
            
            if done:
                # the episode ends so no next state
                next_state = jp.zeros(state.shape)
                t = max_steps
                
                print('Episode: {}'.format(ep),
                    'Total reward: {}'.format(total_reward),
                    'Training loss: {:.4f}'.format(loss),
                    'Explore P: {:.4f}'.format(explore_p))
                rewards_list.append((ep, total_reward))
                
                # Add experience to memory
                memory.add((state, action, reward, next_state))
                
                # Start new episode
                env.reset()
                # Take one random step to get the pole and cart moving
                state, reward, done, _ = env.step(env.action_space.sample())

            else:
                # Add experience to memory
                memory.add((state, action, reward, next_state))
                state = next_state
                t += 1
            
            # Sample mini-batch from memory
            batch = memory.sample(rng, batch_size)
            states = jp.array([each[0] for each in batch])
            actions = one_hot(jp.array([each[1] for each in batch]), 2)
            rewards = jp.array([each[2] for each in batch])
            next_states = jp.array([each[3] for each in batch])
            
            # Train network
            current_Qs = mainQN.ts.apply_fn(current_params, next_states)
            target_Qs = mainQN.ts.apply_fn(target_params, next_states)
            
            # Set target_Qs to 0 for states where episode ends
            episode_ends = (next_states == jp.zeros(states[0].shape)).all(axis=1)
            new_target_Qs = index_update(target_Qs, index[episode_ends], (0, 0))
            target_Qs = new_target_Qs
            
            max_current_Qs_indices = jp.argmax(current_Qs, axis=-1)
            targets = rewards + gamma * jp.take_along_axis(target_Qs, max_current_Qs_indices[..., None], axis=-1)
            # print(states.shape, targets.shape, targets)
            mainQN.ts, loss = mainQN.train_fn(mainQN.ts, states, actions, targets)
            current_params = mainQN.ts.params

        if ep % update_target_every == 0:
            target_params = copy.deepcopy(mainQN.ts.params)
            print('***** Updated target QNetwork *****')

    return rewards_list

def plot_scores(rewards_list):
    def running_mean(x, N):
        cumsum = np.cumsum(np.insert(x, 0, 0)) 
        return (cumsum[N:] - cumsum[:-N]) / N
    eps, rews = np.array(rewards_list).T
    smoothed_rews = running_mean(rews, 10)
    plt.plot(eps[-len(smoothed_rews):], smoothed_rews)
    plt.plot(eps, rews, color='grey', alpha=0.3)
    plt.xlabel('Episode')
    plt.ylabel('Total Reward')
    plt.show()

In [8]:
def main():
    seed = 0
    env = gym.make('CartPole-v0')
    env.seed(seed)
    env.action_space.seed(seed)
    print('observation space:', env.observation_space)
    print('action space:', env.action_space.n)
    rng = jax.random.PRNGKey(seed)
    mainQN = QNetwork(rng, env, name='main', hidden_size=hidden_size, learning_rate=learning_rate)
    rewards_list = train(rng, env, mainQN)
    plot_scores(rewards_list)

if __name__ == '__main__':
    main()

observation space: Box(-3.4028234663852886e+38, 3.4028234663852886e+38, (4,), float32)
action space: 2
Episode: 1 Total reward: 31.0 Training loss: 0.8628 Explore P: 0.9969
Episode: 2 Total reward: 54.0 Training loss: 0.2319 Explore P: 0.9916
Episode: 3 Total reward: 28.0 Training loss: 0.3584 Explore P: 0.9889
Episode: 4 Total reward: 21.0 Training loss: 0.1620 Explore P: 0.9868
Episode: 5 Total reward: 8.0 Training loss: 0.2605 Explore P: 0.9860
Episode: 6 Total reward: 13.0 Training loss: 0.4306 Explore P: 0.9848
Episode: 7 Total reward: 20.0 Training loss: 0.2264 Explore P: 0.9828
Episode: 8 Total reward: 35.0 Training loss: 0.1583 Explore P: 0.9794
Episode: 9 Total reward: 27.0 Training loss: 0.1401 Explore P: 0.9768
Episode: 10 Total reward: 22.0 Training loss: 0.1885 Explore P: 0.9747
Episode: 11 Total reward: 15.0 Training loss: 0.2600 Explore P: 0.9732
Episode: 12 Total reward: 28.0 Training loss: 0.1787 Explore P: 0.9705
Episode: 13 Total reward: 10.0 Training loss: 0.2290 Ex

KeyboardInterrupt: ignored