In [1]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import optax
import haiku as hk

from jax import random, lax, jit, vmap, pmap
from functools import partial
from jax_tqdm import loop_tqdm

import sys

sys.path.append("../../../")

from src import CartPole, DQN, EpsilonGreedy, UniformReplayBuffer, DeepRlRollout

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Env parameters
RANDOM_SEED = 0
N_ACTIONS = 2
STATE_SHAPE = 4

# Hyperparameters
DISCOUNT = 0.9
NEURONS_PER_LAYER = [64, 128, N_ACTIONS]
TIMESTEPS = 100_000
TARGET_NET_UPDATE_FREQ = 10
BUFFER_SIZE = 1024
BATCH_SIZE = 64
LEARNING_RATE = 0.001
EPSILON_START = 0.3
EPSILON_END = 0
DECAY_RATE = 1e-3

In [3]:
buffer_state = {
    "states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),
    "actions": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "rewards": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "next_states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),
    "dones": jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),
}
print(jax.tree_map(lambda x: x.shape, buffer_state))

{'actions': (1024,), 'dones': (1024,), 'next_states': (1024, 4), 'rewards': (1024,), 'states': (1024, 4)}


In [4]:
model_key, target_key = vmap(random.PRNGKey)(jnp.arange(2) + RANDOM_SEED)
env = CartPole()
policy = EpsilonGreedy(0.1)


@hk.transform
def model(x):
    mlp = hk.nets.MLP(output_sizes=NEURONS_PER_LAYER)
    return mlp(x)


def inverse_scaling_decay(epsilon_start, epsilon_end, current_step, decay_rate):
    return epsilon_end + (epsilon_start - epsilon_end) / (1 + decay_rate * current_step)



replay_buffer = UniformReplayBuffer(BUFFER_SIZE, BATCH_SIZE)
model_params = model.init(model_key, jnp.zeros((STATE_SHAPE,)))
target_net_params = model.init(target_key, jnp.zeros((STATE_SHAPE,)))
optimizer = optax.adam(learning_rate=LEARNING_RATE)
optimizer_state = optimizer.init(model_params)
agent = DQN(DISCOUNT, LEARNING_RATE, model)

In [5]:
px.line([inverse_scaling_decay(EPSILON_START, EPSILON_END, i, DECAY_RATE) for i in range(TIMESTEPS)], title="Epsilon Decay")

In [6]:
jax.tree_map(lambda x: x.shape, model_params)

{'mlp/~/linear_0': {'b': (64,), 'w': (4, 64)},
 'mlp/~/linear_1': {'b': (128,), 'w': (64, 128)},
 'mlp/~/linear_2': {'b': (2,), 'w': (128, 2)}}

In [7]:
jax.tree_map(lambda x: x.shape, optimizer_state)

(ScaleByAdamState(count=(), mu={'mlp/~/linear_0': {'b': (64,), 'w': (4, 64)}, 'mlp/~/linear_1': {'b': (128,), 'w': (64, 128)}, 'mlp/~/linear_2': {'b': (2,), 'w': (128, 2)}}, nu={'mlp/~/linear_0': {'b': (64,), 'w': (4, 64)}, 'mlp/~/linear_1': {'b': (128,), 'w': (64, 128)}, 'mlp/~/linear_2': {'b': (2,), 'w': (128, 2)}}),
 EmptyState())

# **_Rollout_**

1. Init all variables and obtain:
   ```python
   val_init = (
        model_params,
        target_net_params,
        optimizer_state,
        buffer_state,
        action_key,
        buffer_key,
        env_state,
        all_obs,
        all_rewards,
        all_done,
        losses,
    )
   ```
2. for ``timesteps`` steps:
   1. Compute decayed epsilon
   2. ``action`` = agent.act
   3. ``new_state``, ``reward``, ``done`` env.step 
   4. add experience to replay buffer
   5. sample batch from replay buffer
   6. gradient descent on batch = agent.update
      * Every N steps, update target network
   7. Pack variables and continue


In [8]:
rollout_params = {
    "timesteps": TIMESTEPS,
    "random_seed": RANDOM_SEED,
    "target_net_update_freq": TARGET_NET_UPDATE_FREQ,
    "model": model,
    "optimizer": optimizer,
    "buffer_state": buffer_state,
    "agent": agent,
    "env": env,
    "replay_buffer": replay_buffer,
    "state_shape": STATE_SHAPE,
    "buffer_size": BUFFER_SIZE,
    "epsilon_decay_fn": inverse_scaling_decay,
    "epsilon_start": EPSILON_START,
    "epsilon_end": EPSILON_END,
    "decay_rate": DECAY_RATE,
}

out = DeepRlRollout(**rollout_params)

Running for 100,000 iterations: 100%|██████████| 100000/100000 [00:10<00:00, 9185.72it/s]


In [9]:
px.line(out["losses"], title="Loss during training")

In [12]:
df = pd.DataFrame(
    data={
        "episode":out["all_done"].cumsum(),
        "reward": out["all_rewards"],
    },
)
df["episode"] = df["episode"].shift().fillna(0)
px.bar(df.groupby("episode").agg("sum").tail(1000), title="Reward Per Episode")

In [11]:
df.groupby("episode").agg("sum").max()

reward    136
dtype: int32