# DQN implementation in JAX

I will implement the DQN algorithm as proposed by the following paper [1]. I will try to make the model as general as possible, so that it can be used for any type of problem. The model will be implemented in JAX, so that fast training and testing will be possible with the model.

**References:**

[1] M Roderick et al. 2017, Implementing the Deep Q-Network

In [77]:
# Import jax
import optax
import jax.numpy as jnp
from jax.nn import relu
from jax import grad, value_and_grad, jit
from jax.random import normal, PRNGKey
from functools import partial

In [78]:
# Neural network hard-coded version for Inverted Pendulum case

# class InvPendulumNN:
#     def __init__(self, in_size=5, out_size=1, hidden_size=10, a_func=relu, seed=42):
#         prng_key = PRNGKey(seed)
#         self.W1 = normal(prng_key,shape=(hidden_size,in_size))
#         self.W2 = normal(prng_key,shape=(out_size,hidden_size))
#         self.params = {"W1": self.W1, "W2": self.W2}
#         self.dW1 = jnp.zeros_like(self.W1)
#         self.dW2 = jnp.zeros_like(self.W2)
#         self.a_funcs = [a_func, a_func]
#     def predict_loss(self, params, x):
#         x_h = self.a_funcs[0](params["W1"] @ x)
#         return self.a_funcs[1](params["W2"] @ x_h)
#     def predict(self, x):
#         x_h = self.a_funcs[0](self.params["W1"] @ x)
#         return self.a_funcs[1](self.params["W2"] @ x_h)
#     def loss(self, params, batch):
#         x,y = batch
#         return jnp.square(y - self.predict_loss(params, x)).mean()
#     def parameters(self):
#         return self.params
#     def set_parameters(self, params):
#         self.params = params

prng_key = PRNGKey(42)

class InvPendulumNNv2:
    def predict(params, x):
        x_h = relu(params["W1"] @ x)
        return relu(params["W2"] @ x_h)
    def loss(params, batch):
        x,y = batch
        return jnp.square(y - InvPendulumNNv2.predict(params, x)).mean()
    def generate_params(in_size=5, hidden_size=10, out_size=1):
        return {
            "W1":normal(prng_key,shape=(hidden_size,in_size)),
            "W2":normal(prng_key,shape=(out_size,hidden_size))
        }

In [79]:
# This cell block contains the general update mechanics for the parameters of a model
def update(parameters, batch, opt_state, optimizer=optax.adam, loss_func=InvPendulumNNv2.loss):
    params = parameters
    loss, grads = value_and_grad(loss_func)(params, batch)
    updates, opt_state = optimizer.update(grads, opt_state)
    return optax.apply_updates(params, updates), opt_state

# Assumes the following data structure: data = (num_batches, num_samples, batch_size)
partial(jit, static_argnums=(3,4))
def train(X_train, y_train, params, lr, n_epochs, optimizer=optax.adam):
    optimizer = optimizer(lr)
    for e in range(n_epochs):
        opt_state = optimizer.init(params)
        for i in range(X_train.shape[0]):
            params, opt_state = update(params, (X_train[i, :, :], y_train[i, :]), opt_state, optimizer=optimizer)
    return params

In [124]:
# Example to test the train method
X_train = jnp.array([[[1.], [2.], [3.], [4.], [-5.]]])
y_train = jnp.array([[[10.]]])
print(X_train.shape)
print(y_train.shape)
params = InvPendulumNNv2.generate_params(hidden_size=21)
params = train(X_train, y_train, params, 2e-3, 30)
print(InvPendulumNNv2.predict(params, X_train))

(1, 5, 1)
(1, 1, 1)
[[[9.671792]]]


### Notes

Okay so far I've figured out how to use the gradient method of jax as well as how to use the parameters and an optimizer to update the model iteratively over episodes. Next step will be to check it's performance one sample data from the OpenAI gym model. We will sample an episode of 500 steps in the environment, and use the retrieved data, to check if we can learn Q-values with this neural network. If we are able to do so, the next step will be to actually implement the DQN algorithm fully, with the neural network as a component of it. 