In [1]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

from jax import random, lax, jit, vmap, pmap
from functools import partial
from jax_tqdm import loop_tqdm

import sys

sys.path.append("../../../")

from src import CartPole, DQN, EpsilonGreedy, MLP, UniformReplayBuffer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SEED = 2
DISCOUNT = 0.9
LEARNING_RATE = 0.1
N_ACTIONS = 2
NEURONS_PER_LAYER = [4, 256, 1]
BUFFER_SIZE = 9
BATCH_SIZE = 5
TIME_STEPS = 100_000
STATE_SHAPE = 4

In [3]:
buffer = {
    "states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),
    "actions": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "rewards": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "next_states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),
    "dones": jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),
}
buffer

{'states': Array([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]], dtype=float32),
 'actions': Array([0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32),
 'rewards': Array([0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32),
 'next_states': Array([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]], dtype=float32),
 'dones': Array([False, False, False, False, False, False, False, False, False],      dtype=bool)}

In [4]:
key = random.PRNGKey(0)
exp = (random.normal(key, (4,)), 1, 1, random.normal(key, (4,)), False)

In [5]:
replay_buffer = UniformReplayBuffer(BUFFER_SIZE, BATCH_SIZE)

In [6]:
random.split(key)

Array([[4146024105,  967050713],
       [2718843009, 1272950319]], dtype=uint32)

In [7]:
idx = 0
for _ in range(20):
    exp = (random.normal(key, (4,)), 1, 1, random.normal(key, (4,)), False)
    buffer, idx = replay_buffer.add(buffer, exp, idx)
replay_buffer.sample(key, buffer)

[(Array(1, dtype=int32),
  Array(False, dtype=bool),
  Array([ 1.8160863 , -0.75488514,  0.33988908, -0.53483534], dtype=float32),
  Array(1, dtype=int32),
  Array([ 1.8160863 , -0.75488514,  0.33988908, -0.53483534], dtype=float32)),
 (Array(1, dtype=int32),
  Array(False, dtype=bool),
  Array([ 1.8160863 , -0.75488514,  0.33988908, -0.53483534], dtype=float32),
  Array(1, dtype=int32),
  Array([ 1.8160863 , -0.75488514,  0.33988908, -0.53483534], dtype=float32)),
 (Array(1, dtype=int32),
  Array(False, dtype=bool),
  Array([ 1.8160863 , -0.75488514,  0.33988908, -0.53483534], dtype=float32),
  Array(1, dtype=int32),
  Array([ 1.8160863 , -0.75488514,  0.33988908, -0.53483534], dtype=float32)),
 (Array(1, dtype=int32),
  Array(False, dtype=bool),
  Array([ 1.8160863 , -0.75488514,  0.33988908, -0.53483534], dtype=float32),
  Array(1, dtype=int32),
  Array([ 1.8160863 , -0.75488514,  0.33988908, -0.53483534], dtype=float32)),
 (Array(1, dtype=int32),
  Array(False, dtype=bool),
  Array

In [8]:
assert

SyntaxError: invalid syntax (2389114725.py, line 1)

In [None]:
@partial(vmap, in_axes=(0, None))
def sample_batch(indexes, buffer):
    return jax.tree_map(lambda buff: buff[indexes], buffer)

samples = sample_batch(jnp.arange(4), buffer)
[exp for exp in zip(*samples.values())]

[(Array(0, dtype=int32),
  Array(False, dtype=bool),
  Array([0., 0., 0., 0.], dtype=float32),
  Array(0., dtype=float32),
  Array([0., 0., 0., 0.], dtype=float32)),
 (Array(0, dtype=int32),
  Array(False, dtype=bool),
  Array([0., 0., 0., 0.], dtype=float32),
  Array(0., dtype=float32),
  Array([0., 0., 0., 0.], dtype=float32)),
 (Array(0, dtype=int32),
  Array(False, dtype=bool),
  Array([0., 0., 0., 0.], dtype=float32),
  Array(0., dtype=float32),
  Array([0., 0., 0., 0.], dtype=float32)),
 (Array(0, dtype=int32),
  Array(False, dtype=bool),
  Array([0., 0., 0., 0.], dtype=float32),
  Array(0., dtype=float32),
  Array([0., 0., 0., 0.], dtype=float32))]

In [None]:
key = random.PRNGKey(0)
state, action, reward, next_state, done = random.normal(key, (4,)), 1.0, 1, random.normal(key, (4,)), False
state, action, reward, next_state, done

(Array([ 1.8160863 , -0.75488514,  0.33988908, -0.53483534], dtype=float32),
 1.0,
 1,
 Array([ 1.8160863 , -0.75488514,  0.33988908, -0.53483534], dtype=float32),
 False)

In [None]:
assert

SyntaxError: invalid syntax (2389114725.py, line 1)

In [None]:
key = random.PRNGKey(SEED)

env = CartPole()
policy = EpsilonGreedy(0.1)
model = MLP(NEURONS_PER_LAYER)
agent = DQN(DISCOUNT, LEARNING_RATE, N_ACTIONS, model)
replay_buffer = UniformReplayBuffer(BUFFER_SIZE, BATCH_SIZE)

In [None]:
init_key = random.split(key)[0]
params = model.init(init_key, random.normal(init_key, (4,)))

In [None]:
env_state, obs = env.reset(key)
env.step(env_state, jnp.array([1]))

((Array([ 0.01658618, -0.03144887,  0.0064795 , -0.04463173], dtype=float32),
  Array([2425776485,  230565590], dtype=uint32)),
 Array([ 0.01658618, -0.03144887,  0.0064795 , -0.04463173], dtype=float32),
 Array(0, dtype=int32),
 Array(False, dtype=bool))