# DQN implementation in JAX

I will implement the DQN algorithm as proposed by the following paper [1]. I will try to make the model as general as possible, so that it can be used for any type of problem. The model will be implemented in JAX, so that fast training and testing will be possible with the model.

**References:**

[1] M Roderick et al. 2017, Implementing the Deep Q-Network

In [43]:
# Import jax
import optax
import jax.numpy as jnp
from jax.nn import relu
from jax import grad, value_and_grad, jit
from jax.random import normal, uniform, PRNGKey, split
from functools import partial

In [13]:
# Neural network hard-coded version for Inverted Pendulum case

# class InvPendulumNN:
#     def __init__(self, in_size=5, out_size=1, hidden_size=10, a_func=relu, seed=42):
#         prng_key = PRNGKey(seed)
#         self.W1 = normal(prng_key,shape=(hidden_size,in_size))
#         self.W2 = normal(prng_key,shape=(out_size,hidden_size))
#         self.params = {"W1": self.W1, "W2": self.W2}
#         self.dW1 = jnp.zeros_like(self.W1)
#         self.dW2 = jnp.zeros_like(self.W2)
#         self.a_funcs = [a_func, a_func]
#     def predict_loss(self, params, x):
#         x_h = self.a_funcs[0](params["W1"] @ x)
#         return self.a_funcs[1](params["W2"] @ x_h)
#     def predict(self, x):
#         x_h = self.a_funcs[0](self.params["W1"] @ x)
#         return self.a_funcs[1](self.params["W2"] @ x_h)
#     def loss(self, params, batch):
#         x,y = batch
#         return jnp.square(y - self.predict_loss(params, x)).mean()
#     def parameters(self):
#         return self.params
#     def set_parameters(self, params):
#         self.params = params

prng_key = PRNGKey(42)

class InvPendulumNNv2:
    def predict(params, x):
        x_h = relu(params["W1"] @ x)
        return relu(params["W2"] @ x_h)
    def loss(params, batch):
        x,y = batch
        return jnp.square(y - InvPendulumNNv2.predict(params, x)).mean()
    def generate_params(in_size=5, hidden_size=10, out_size=1):
        return {
            "W1":normal(prng_key,shape=(hidden_size,in_size)),
            "W2":normal(prng_key,shape=(out_size,hidden_size))
        }

In [14]:
# This cell block contains the general update mechanics for the parameters of a model
def update(parameters, batch, opt_state, optimizer=optax.adam, loss_func=InvPendulumNNv2.loss):
    params = parameters
    loss, grads = value_and_grad(loss_func)(params, batch)
    updates, opt_state = optimizer.update(grads, opt_state)
    return optax.apply_updates(params, updates), opt_state

# Assumes the following data structure: data = (num_samples, batch_size)
# Should only be given a single mini-batch of data to train on at the time
partial(jit, static_argnums=(3))
def train(X_train, y_train, params, lr, optimizer=optax.adam):
    optimizer = optimizer(lr)
    opt_state = optimizer.init(params)
    for i in range(X_train.shape[1]):
        params, opt_state = update(params, (X_train[:,i], y_train[:,i]), opt_state, optimizer=optimizer)
    return params

In [26]:
# Example to test the train method
X_train = jnp.array([[1.,-1.2,-1.], [2.,3.2,-2.], [3.,6.1,-3.], [4.,1.,-4.], [-5.,-3., -5.]])
y_train = jnp.array([[10., 6., 3.]])
print(X_train.shape)
print(y_train.shape)
params = InvPendulumNNv2.generate_params(hidden_size=21)
print(InvPendulumNNv2.predict(params, X_train))
params = train(X_train, y_train, params, 1e-3)
print(InvPendulumNNv2.predict(params, X_train))

(5, 3)
(1, 3)
[[ 0.46015662  0.         30.585041  ]]
[[ 0.7819718  0.        30.496143 ]]


### Notes

Okay so far I've figured out how to use the gradient method of jax as well as how to use the parameters and an optimizer to update the model iteratively over episodes. Next step will be to check it's performance one sample data from the OpenAI gym model. We will sample an episode of 500 steps in the environment, and use the retrieved data, to check if we can learn Q-values with this neural network. If we are able to do so, the next step will be to actually implement the DQN algorithm fully, with the neural network as a component of it. 

## Retrieve Inverted Pendulum Environment

In [30]:
import gymnasium as gym
env = gym.make('InvertedPendulum-v4')
env.reset()
env.step(jnp.array([-2]))

(array([-0.01079811,  0.03859825, -0.67176235,  1.52850572]),
 1.0,
 False,
 False,
 {})

## The full DQN Algorithm for the Inverted Pendulum Environment

In [133]:
def get_action(state, Q_params, key, eta=0.1, val_range=[-3,3], action_res=1000):
    key, *subkeys = split(key, num=3) # Make a different key w.r.t the provided key (otherwise no random numbers are provided)
    choice = uniform(subkeys[0], minval=0, maxval=1)
    if choice <= eta:
        return jnp.array([uniform(subkeys[1], minval=val_range[0], maxval=val_range[1])])
    else: 
        state = jnp.expand_dims(state, axis=0) # Needed to ensure correct concatenation
        states = jnp.repeat(state, action_res, axis=0)
        actions = jnp.linspace(val_range[0], val_range[1], action_res)
        actions = jnp.expand_dims(actions, axis=1) # Needed to ensure correct concatenation
        nn_input = jnp.concatenate((states, actions), axis=1).transpose() # With transpose we ensure (num_samples, batch_size) shape required for the NN
        print(nn_input.shape)
        ys = InvPendulumNNv2.predict(Q_params, nn_input)
        return actions[jnp.argmax(ys)] # Gets the best action out of our discretized space run through the network

def add_to_stack(item, stack, stack_size=100):
    if(len(stack) >= stack_size):
        stack.pop(0)
    stack.append(item)

def DQN(env, num_episodes, lr=1e-3, memory_size=100):
    prng_key = PRNGKey(42)
    D = list()
    Q_params = InvPendulumNNv2.generate_params()
    Q_params_target = InvPendulumNNv2.generate_params() # Same rng key as Q_params, so weights are the same :)
    for e in range(num_episodes):
        e_done = False
        state,info = env.reset()
        while not e_done:
            # Do episode stuf
            action = get_action(state, Q_params, prng_key)
            next_state, r, terminated, truncated = env.step(action)
            add_to_stack(D, (state, action, r, next_state), stack_size=memory_size)
            state = next_state
            
    return Q_params # Make sure to return the weights of our model
        
     

In [136]:
# Test action function
env.reset()
state = env.step(jnp.array([-2]))[0] # Only need the state values of the step
Q_params = InvPendulumNNv2.generate_params()
prng_key = PRNGKey(42)
print(get_action(state, Q_params, prng_key, eta=1))

[-2.335619]
