In [1]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import optax
import haiku as hk

from jax import random, lax, jit, vmap, pmap
from functools import partial
from jax_tqdm import loop_tqdm

import sys

sys.path.append("../../../")

from src import CartPole, DQN, EpsilonGreedy, UniformReplayBuffer, DeepRlRollout

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Env parameters
RANDOM_SEED = 0
N_ACTIONS = 2
STATE_SHAPE = 4

# Hyperparameters
DISCOUNT = 0.9
NEURONS_PER_LAYER = [64, 64, 64, N_ACTIONS]
TIMESTEPS = 10_000
TARGET_NET_UPDATE_FREQ = 30
BUFFER_SIZE = 512
BATCH_SIZE = 32
LEARNING_RATE = 0.001
EPSILON_START = 0.3
EPSILON_END = 0
DECAY_RATE = 1e-2

In [3]:
buffer_state = {
    "states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),
    "actions": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "rewards": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "next_states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),
    "dones": jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),
}
print(jax.tree_map(lambda x: x.shape, buffer_state))

{'actions': (512,), 'dones': (512,), 'next_states': (512, 4), 'rewards': (512,), 'states': (512, 4)}


In [4]:
model_key, target_key = vmap(random.PRNGKey)(jnp.arange(2) + RANDOM_SEED)
env = CartPole()


@hk.transform
def model(x):
    mlp = hk.nets.MLP(output_sizes=NEURONS_PER_LAYER)
    return mlp(x)


def inverse_scaling_decay(epsilon_start, epsilon_end, current_step, decay_rate):
    return epsilon_end + (epsilon_start - epsilon_end) / (1 + decay_rate * current_step)


replay_buffer = UniformReplayBuffer(BUFFER_SIZE, BATCH_SIZE)
model_params = model.init(model_key, jnp.zeros((STATE_SHAPE,)))
target_net_params = model.init(target_key, jnp.zeros((STATE_SHAPE,)))
optimizer = optax.adam(learning_rate=LEARNING_RATE)
optimizer_state = optimizer.init(model_params)
agent = DQN(model, DISCOUNT, LEARNING_RATE, N_ACTIONS)

In [5]:
px.line([inverse_scaling_decay(EPSILON_START, EPSILON_END, i, DECAY_RATE) for i in range(TIMESTEPS)], title="Epsilon Decay")

In [6]:
jax.tree_map(lambda x: x.shape, model_params)

{'mlp/~/linear_0': {'b': (64,), 'w': (4, 64)},
 'mlp/~/linear_1': {'b': (64,), 'w': (64, 64)},
 'mlp/~/linear_2': {'b': (64,), 'w': (64, 64)},
 'mlp/~/linear_3': {'b': (2,), 'w': (64, 2)}}

In [7]:
jax.tree_map(lambda x: x.shape, optimizer_state)

(ScaleByAdamState(count=(), mu={'mlp/~/linear_0': {'b': (64,), 'w': (4, 64)}, 'mlp/~/linear_1': {'b': (64,), 'w': (64, 64)}, 'mlp/~/linear_2': {'b': (64,), 'w': (64, 64)}, 'mlp/~/linear_3': {'b': (2,), 'w': (64, 2)}}, nu={'mlp/~/linear_0': {'b': (64,), 'w': (4, 64)}, 'mlp/~/linear_1': {'b': (64,), 'w': (64, 64)}, 'mlp/~/linear_2': {'b': (64,), 'w': (64, 64)}, 'mlp/~/linear_3': {'b': (2,), 'w': (64, 2)}}),
 EmptyState())

# **_Rollout_**

1. Init all variables and obtain:
   ```python
   val_init = (
        model_params,
        target_net_params,
        optimizer_state,
        buffer_state,
        action_key,
        buffer_key,
        env_state,
        all_obs,
        all_rewards,
        all_done,
        losses,
    )
   ```
2. for ``timesteps`` steps:
   1. Compute decayed epsilon
   2. ``action`` = agent.act
   3. ``new_state``, ``reward``, ``done`` env.step 
   4. add experience to replay buffer
   5. sample batch from replay buffer
   6. gradient descent on batch = agent.update
      * Every N steps, update target network
   7. Pack variables and continue


In [8]:
# initialize the replay buffer with random samples ?
init_key, action_key, buffer_key = vmap(random.PRNGKey)(jnp.arange(3) + RANDOM_SEED)
env_state, _ = env.reset(init_key)

for i in range(BUFFER_SIZE):
    state, _ = env_state
    epsilon = inverse_scaling_decay(EPSILON_START, EPSILON_END, i, DECAY_RATE)
    action, action_key = agent.act(action_key, model_params, state, epsilon)
    env_state, new_state, reward, done = env.step(env_state, action)
    experience = (state, action, reward, new_state, done)

    buffer_state = replay_buffer.add(buffer_state, experience, i)


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



In [9]:
rollout_params = {
    "timesteps": TIMESTEPS,
    "random_seed": RANDOM_SEED,
    "target_net_update_freq": TARGET_NET_UPDATE_FREQ,
    "model": model,
    "optimizer": optimizer,
    "buffer_state": buffer_state,
    "agent": agent,
    "env": env,
    "replay_buffer": replay_buffer,
    "state_shape": STATE_SHAPE,
    "buffer_size": BUFFER_SIZE,
    "batch_size": BATCH_SIZE,
    "epsilon_decay_fn": inverse_scaling_decay,
    "epsilon_start": EPSILON_START,
    "epsilon_end": EPSILON_END,
    "decay_rate": DECAY_RATE,
}

out = DeepRlRollout(**rollout_params)

Running for 10,000 iterations: 100%|██████████| 10000/10000 [00:00<00:00, 14511.74it/s]


In [10]:
px.line(out["losses"], title="Loss during training")

In [11]:
df = pd.DataFrame(
    data={
        "episode":out["all_done"].cumsum(),
        "reward": out["all_rewards"],
    },
)
df["episode"] = df["episode"].shift().fillna(0)
px.bar(df.groupby("episode").agg("sum").tail(200), title="Reward Per Episode")

In [12]:
df.groupby("episode").agg("sum").max()

reward    68.0
dtype: float32

In [13]:
model.apply(out["target_net_params"], None, out["all_obs"][-1])

Array([9.275506, 8.210668], dtype=float32)

In [14]:
assert

SyntaxError: invalid syntax (2389114725.py, line 1)

In [None]:
for i in range(20):
    print("-"*10 + f"{i}" + "-"*10)
    state_init = out["all_obs"][i]
    key = random.PRNGKey(10)

    # print(model.apply(out["model_params"], None, state_init))
    state = state_init
    done=False
    while not done:
        action, key = agent.act(key, out["model_params"], state, 0)
        env_state = (state, key)

        env_state, state, reward, done = env.step(env_state, action)
        print(action, done)

----------0----------
0 False
0 True
----------1----------
0 True
----------2----------
1 False
1 False
1 False
1 False
1 False
1 True
----------3----------
1 False
1 False
1 False
0 True
----------4----------
0 False
0 False
0 True
----------5----------
0 False
0 True
----------6----------
0 True
----------7----------
1 False
1 False
1 False
1 False
1 False
1 False
1 False
1 True
----------8----------
1 False
1 False
1 False
0 True
----------9----------
1 False
0 False
0 True
----------10----------
0 False
0 True
----------11----------
0 True
----------12----------
1 False
1 False
1 False
1 False
1 False
0 True
----------13----------
1 False
1 False
0 False
0 True
----------14----------
1 False
0 False
0 True
----------15----------
0 True
----------16----------
1 False
0 False
0 False
1 False
0 False
1 False
1 False
1 True
----------17----------
1 False
1 False
0 False
1 False
1 False
1 False
0 True
----------18----------
1 False
1 False
1 False
0 True
----------19----------
1 False
0