In [1]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import optax
import haiku as hk

from jax import random, lax, jit, vmap, pmap
from functools import partial
from jax_tqdm import loop_tqdm

import sys

sys.path.append("../../../")

from src import CartPole, DQN, EpsilonGreedy, UniformReplayBuffer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SEED = 2
DISCOUNT = 0.9
LEARNING_RATE = 0.1
N_ACTIONS = 2
NEURONS_PER_LAYER = [128, 256, N_ACTIONS]
BUFFER_SIZE = 512
BATCH_SIZE = 32
TIME_STEPS = 100_000
STATE_SHAPE = 4
LEARNING_RATE = 1e-2
EPSILON = 1e-2

In [3]:
buffer = {
    "states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),
    "actions": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "rewards": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "next_states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),
    "dones": jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),
}
print(jax.tree_map(lambda x: x.shape, buffer))

{'actions': (512,), 'dones': (512,), 'next_states': (512, 4), 'rewards': (512,), 'states': (512, 4)}


In [4]:
key = random.PRNGKey(SEED)

env = CartPole()
policy = EpsilonGreedy(0.1)

@hk.transform
def model(x):
    mlp = hk.nets.MLP(output_sizes=NEURONS_PER_LAYER)
    return mlp(x)

replay_buffer = UniformReplayBuffer(BUFFER_SIZE, BATCH_SIZE)

model_params = model.init(key, jnp.zeros((4,)))
target_net_params = model.init(key, jnp.zeros((4,)))
optimizer = optax.adam(learning_rate=LEARNING_RATE)
opt_state = optimizer.init(model_params)

agent = DQN(DISCOUNT, LEARNING_RATE, model, EPSILON)

In [5]:
jax.tree_map(lambda x: x.shape, model_params)

{'mlp/~/linear_0': {'b': (128,), 'w': (4, 128)},
 'mlp/~/linear_1': {'b': (256,), 'w': (128, 256)},
 'mlp/~/linear_2': {'b': (2,), 'w': (256, 2)}}

In [6]:
jax.tree_map(lambda x: x.shape, opt_state)

(ScaleByAdamState(count=(), mu={'mlp/~/linear_0': {'b': (128,), 'w': (4, 128)}, 'mlp/~/linear_1': {'b': (256,), 'w': (128, 256)}, 'mlp/~/linear_2': {'b': (2,), 'w': (256, 2)}}, nu={'mlp/~/linear_0': {'b': (128,), 'w': (4, 128)}, 'mlp/~/linear_1': {'b': (256,), 'w': (128, 256)}, 'mlp/~/linear_2': {'b': (2,), 'w': (256, 2)}}),
 EmptyState())

In [7]:
key = random.PRNGKey(0)
exp = (random.normal(key, (4,)), 1, 1, random.normal(key, (4,)), False)

In [8]:
idx, n_experiences = 0, 0
for _ in range(100):
    exp = (random.normal(key, (4,)), 1, 1, random.normal(key, (4,)), False)
    buffer, idx, n_experiences = replay_buffer.add(buffer, exp, idx, n_experiences)
# jax.tree_map(lambda x: x.shape, replay_buffer.sample(key, buffer, n_experiences))[:3]

In [9]:
init_key = random.split(key)[0]
model_params = model.init(init_key, random.normal(init_key, (4,)))

In [10]:
env_state, obs = env.reset(key)
env.step(env_state, jnp.array([1]))

((Array([ 0.04653214, -0.02748411,  0.01330299, -0.02036182], dtype=float32),
  Array([2718843009, 1272950319], dtype=uint32)),
 Array([ 0.04653214, -0.02748411,  0.01330299, -0.02036182], dtype=float32),
 Array(0, dtype=int32),
 Array(False, dtype=bool))

In [20]:
key, subkey = random.split(key)
experiences = replay_buffer.sample(key, buffer, 10)
# agent.batch_loss_fn(model_params, target_net_params, **experiences)

In [31]:
# @partial(jit, static_argnames=("optimizer"))
# def update(
#     model_params: dict,
#     target_net_params: dict,
#     optimizer: optax.GradientTransformation,
#     optimizer_state: jnp.ndarray,
#     # loss_fn: Callable,
#     experiences: dict[str : jnp.ndarray],
# ):
#     @jit
#     def batch_loss_fn(
#         model_params, target_net_params, states, actions, next_states, dones, rewards
#     ):
#         @partial(vmap, in_axes=(None, None, 0, 0, 0, 0, 0))
#         def _loss_fn(
#             model_params, target_net_params, state, action, next_state, done, reward
#         ):
#             target = lax.cond(
#                 jnp.all(done is True),
#                 lambda _: 0.0,
#                 lambda _: 0.9
#                 * jnp.max(
#                     model.apply(target_net_params, None, next_state),
#                 ),
#                 operand=None,
#             )
#             prediction = model.apply(model_params, None, state)[action]
#             return jnp.square(reward + target - prediction)

#         return jnp.mean(
#             _loss_fn(
#                 model_params,
#                 target_net_params,
#                 states,
#                 actions,
#                 next_states,
#                 dones,
#                 rewards,
#             ),
#             axis=0,
#         )

#     grad_fn = jax.grad(batch_loss_fn)
#     grads = grad_fn(model_params, target_net_params, **experiences)
#     updates, optimizer_state = optimizer.update(grads, optimizer_state)
#     model_params = optax.apply_updates(model_params, updates)

#     return model_params, optimizer_state


agent.update(model_params, target_net_params, optimizer, opt_state, experiences)

({'mlp/~/linear_0': {'b': Array([ 0.        ,  0.        ,  0.        ,  0.        , -0.00999993,
           0.        ,  0.00999993,  0.00999993,  0.00999993,  0.        ,
           0.        ,  0.        ,  0.00999993,  0.        , -0.00999993,
           0.        ,  0.        ,  0.00999993,  0.        , -0.00999993,
          -0.00999993,  0.        ,  0.00999993,  0.00999993,  0.00999991,
           0.        ,  0.00999993, -0.00999993,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          -0.00999993, -0.00999993,  0.00999982,  0.        ,  0.        ,
           0.00999993,  0.        ,  0.        ,  0.        ,  0.        ,
           0.00999993,  0.        ,  0.00999993,  0.        , -0.00999993,
           0.00999993, -0.00999993,  0.        , -0.00999993, -0.00999993,
           0.00999993,  0.        ,  0.00999993,  0.        ,  0.        ,
          -0.00999993,  0.        , -0.00999993,  0.00999993,  0.        ,
  

In [None]:
assert

In [None]:
key, subkey = random.split(key)

experiences = replay_buffer.sample(key, buffer, 10)

grad_fn = jax.grad(agent.batch_loss_fn)
grads = grad_fn(
    model_params,
    target_net_params,
    **experiences
    # states,
    # actions,
    # next_states,
    # dones,
    # rewards,
)
# print(loss_val)
optimizer.update(grads, opt_state)

({'mlp/~/linear_0': {'b': Array([-0.        , -0.        , -0.        , -0.        , -0.00999993,
          -0.        ,  0.00999993,  0.00999993,  0.00999993, -0.        ,
          -0.        , -0.        ,  0.00999993, -0.        , -0.00999993,
          -0.        , -0.        ,  0.00999993, -0.        , -0.00999993,
          -0.00999993, -0.        ,  0.00999993,  0.00999993,  0.00999991,
          -0.        ,  0.00999993, -0.00999993, -0.        , -0.        ,
          -0.        , -0.        , -0.        , -0.        , -0.        ,
          -0.00999993, -0.00999993,  0.00999982, -0.        , -0.        ,
           0.00999993, -0.        , -0.        , -0.        , -0.        ,
           0.00999993, -0.        ,  0.00999993, -0.        , -0.00999993,
           0.00999993, -0.00999993, -0.        , -0.00999993, -0.00999993,
           0.00999993, -0.        ,  0.00999993, -0.        , -0.        ,
          -0.00999993, -0.        , -0.00999993,  0.00999993, -0.        ,
  