In [1]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import optax
import haiku as hk

from jax import random, lax, jit, vmap, pmap
from functools import partial
from jax_tqdm import loop_tqdm

import sys

sys.path.append("../../../")

from src import CartPole, DQN, EpsilonGreedy, UniformReplayBuffer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SEED = 2
DISCOUNT = 0.9
LEARNING_RATE = 0.1
N_ACTIONS = 2
NEURONS_PER_LAYER = [128, 256, N_ACTIONS]
BUFFER_SIZE = 512
BATCH_SIZE = 32
TIME_STEPS = 100_000
STATE_SHAPE = 4
LEARNING_RATE = 1e-2
EPSILON = 1e-2

In [3]:
buffer = {
    "states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),
    "actions": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "rewards": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "next_states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),
    "dones": jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),
}
print(jax.tree_map(lambda x: x.shape, buffer))

{'actions': (512,), 'dones': (512,), 'next_states': (512, 4), 'rewards': (512,), 'states': (512, 4)}


In [5]:
key = random.PRNGKey(SEED)

env = CartPole()
policy = EpsilonGreedy(0.1)

@hk.transform
def model(x):
    mlp = hk.nets.MLP(output_sizes=NEURONS_PER_LAYER)
    return mlp(x)

replay_buffer = UniformReplayBuffer(BUFFER_SIZE, BATCH_SIZE)

model_params = model.init(key, jnp.zeros((4,)))
target_net_params = model.init(key, jnp.zeros((4,)))
optimizer = optax.adam(learning_rate=LEARNING_RATE)
opt_state = optimizer.init(model_params)

agent = DQN(DISCOUNT, LEARNING_RATE, model, EPSILON)

In [6]:
jax.tree_map(lambda x: x.shape, model_params)

{'mlp/~/linear_0': {'b': (128,), 'w': (4, 128)},
 'mlp/~/linear_1': {'b': (256,), 'w': (128, 256)},
 'mlp/~/linear_2': {'b': (2,), 'w': (256, 2)}}

In [7]:
jax.tree_map(lambda x: x.shape, opt_state)

(ScaleByAdamState(count=(), mu={'mlp/~/linear_0': {'b': (128,), 'w': (4, 128)}, 'mlp/~/linear_1': {'b': (256,), 'w': (128, 256)}, 'mlp/~/linear_2': {'b': (2,), 'w': (256, 2)}}, nu={'mlp/~/linear_0': {'b': (128,), 'w': (4, 128)}, 'mlp/~/linear_1': {'b': (256,), 'w': (128, 256)}, 'mlp/~/linear_2': {'b': (2,), 'w': (256, 2)}}),
 EmptyState())

In [8]:
key = random.PRNGKey(0)
exp = (random.normal(key, (4,)), 1, 1, random.normal(key, (4,)), False)

In [9]:
idx, n_experiences = 0, 0
for _ in range(100):
    exp = (random.normal(key, (4,)), 1, 1, random.normal(key, (4,)), False)
    buffer, idx, n_experiences = replay_buffer.add(buffer, exp, idx, n_experiences)
jax.tree_map(lambda x: x.shape, replay_buffer.sample(key, buffer, n_experiences))[:3]

[((), (), (4,), (), (4,)), ((), (), (4,), (), (4,)), ((), (), (4,), (), (4,))]

In [10]:
init_key = random.split(key)[0]
model_params = model.init(init_key, random.normal(init_key, (4,)))

In [11]:
env_state, obs = env.reset(key)
env.step(env_state, jnp.array([1]))

((Array([ 0.04653214, -0.02748411,  0.01330299, -0.02036182], dtype=float32),
  Array([2718843009, 1272950319], dtype=uint32)),
 Array([ 0.04653214, -0.02748411,  0.01330299, -0.02036182], dtype=float32),
 Array(0, dtype=int32),
 Array(False, dtype=bool))

In [12]:
key, subkey = random.split(key)

# loss = agent.batch_loss_fn(
#     model_params,
#     target_net_params,
#     states=random.normal(subkey, (BATCH_SIZE, 4)),
#     actions=random.randint(subkey, (BATCH_SIZE, 1), minval=0, maxval=2),
#     next_states=random.normal(subkey, (BATCH_SIZE, 4)),
#     dones=jnp.bool_(random.randint(subkey, (BATCH_SIZE, 1), minval=0, maxval=2)),
#     rewards=random.randint(subkey, (BATCH_SIZE, 1), minval=0, maxval=2),
# )
# print(loss)

grad_fn = jax.grad(agent.batch_loss_fn)
grads = grad_fn(
    model_params,
    target_net_params,
    states=random.normal(subkey, (BATCH_SIZE, 4)),
    actions=random.randint(subkey, (BATCH_SIZE, 1), minval=0, maxval=2),
    next_states=random.normal(subkey, (BATCH_SIZE, 4)),
    dones=jnp.bool_(random.randint(subkey, (BATCH_SIZE, 1), minval=0, maxval=2)),
    rewards=random.randint(subkey, (BATCH_SIZE, 1), minval=0, maxval=2),
)
# print(loss_val)
optimizer.update(grads, opt_state)

({'mlp/~/linear_0': {'b': Array([-0.00999993,  0.00999993, -0.00999993,  0.00999993,  0.0099999 ,
          -0.00999992, -0.00999993,  0.00999992, -0.00999993,  0.00999992,
           0.00999993,  0.00999992, -0.00999993, -0.00999993, -0.00999993,
           0.00999993, -0.00999993, -0.00999993, -0.00999992,  0.00999991,
          -0.00999993, -0.00999991, -0.00999993,  0.00999993, -0.00999993,
           0.00999992, -0.00999974,  0.00999992, -0.00999993, -0.00999993,
          -0.0099999 , -0.00999992, -0.00999991,  0.00999993,  0.00999992,
          -0.00999991, -0.00999993, -0.00999993, -0.00999993, -0.00999988,
          -0.00999991,  0.00999993,  0.00999993,  0.00999992, -0.00999988,
           0.00999993, -0.00999993,  0.00999993, -0.00999993, -0.00999992,
          -0.00999992, -0.00999993,  0.00999993, -0.00999989, -0.00999993,
          -0.00999991,  0.00999993,  0.00999993, -0.00999993,  0.00999991,
          -0.00999993,  0.00999993, -0.00999993,  0.00999992,  0.0099999 ,
  

In [13]:
replay_buffer.sample(key, buffer, 10)

[(Array(1, dtype=int32),
  Array(False, dtype=bool),
  Array([ 1.8160863 , -0.75488514,  0.33988908, -0.53483534], dtype=float32),
  Array(1, dtype=int32),
  Array([ 1.8160863 , -0.75488514,  0.33988908, -0.53483534], dtype=float32)),
 (Array(1, dtype=int32),
  Array(False, dtype=bool),
  Array([ 1.8160863 , -0.75488514,  0.33988908, -0.53483534], dtype=float32),
  Array(1, dtype=int32),
  Array([ 1.8160863 , -0.75488514,  0.33988908, -0.53483534], dtype=float32)),
 (Array(1, dtype=int32),
  Array(False, dtype=bool),
  Array([ 1.8160863 , -0.75488514,  0.33988908, -0.53483534], dtype=float32),
  Array(1, dtype=int32),
  Array([ 1.8160863 , -0.75488514,  0.33988908, -0.53483534], dtype=float32)),
 (Array(1, dtype=int32),
  Array(False, dtype=bool),
  Array([ 1.8160863 , -0.75488514,  0.33988908, -0.53483534], dtype=float32),
  Array(1, dtype=int32),
  Array([ 1.8160863 , -0.75488514,  0.33988908, -0.53483534], dtype=float32)),
 (Array(1, dtype=int32),
  Array(False, dtype=bool),
  Array