In [1]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

from jax import random, lax, jit, vmap, pmap
from functools import partial
from jax_tqdm import loop_tqdm

import sys

sys.path.append("../../../")

from src import CartPole, DQN, EpsilonGreedy, MLP, UniformReplayBuffer, Experience

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SEED = 2
DISCOUNT = 0.9
LEARNING_RATE = 0.1
N_ACTIONS = 2
NEURONS_PER_LAYER = [4, 256, 1]
BUFFER_SIZE = 9
TIME_STEPS = 100_000

In [3]:
key = random.PRNGKey(SEED)

env = CartPole()
policy = EpsilonGreedy(0.1)
model = MLP(NEURONS_PER_LAYER)
agent = DQN(DISCOUNT, LEARNING_RATE, N_ACTIONS, model)
replay_buffer = UniformReplayBuffer(BUFFER_SIZE)

In [4]:
exp = Experience(random.normal(key, (4,)), 0, 1, random.normal(key, (4,)), False)
for _ in range(10):
    print(replay_buffer.idx, replay_buffer.buffer.keys())
    replay_buffer.add(exp)

0 dict_keys([])
1 dict_keys([0])
2 dict_keys([0, 1])
3 dict_keys([0, 1, 2])
4 dict_keys([0, 1, 2, 3])
5 dict_keys([0, 1, 2, 3, 4])
6 dict_keys([0, 1, 2, 3, 4, 5])
7 dict_keys([0, 1, 2, 3, 4, 5, 6])
8 dict_keys([0, 1, 2, 3, 4, 5, 6, 7])
0 dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8])


In [5]:
replay_buffer.sample(key)

Experience(state=Array([ 0.42851502, -0.8945591 ,  0.16313784, -1.6101485 ], dtype=float32), action=0, reward=1, next_state=Array([ 0.42851502, -0.8945591 ,  0.16313784, -1.6101485 ], dtype=float32), done=False)

In [6]:
init_key = random.split(key)[0]
params = model.init(init_key, random.normal(init_key, (4,)))

In [7]:
env_state, obs = env.reset(key)
env.step(env_state, jnp.array([1]))

((Array([ 0.01658618, -0.03144887,  0.0064795 , -0.04463173], dtype=float32),
  Array([2425776485,  230565590], dtype=uint32)),
 Array([ 0.01658618, -0.03144887,  0.0064795 , -0.04463173], dtype=float32),
 Array(0, dtype=int32),
 Array(False, dtype=bool))