In [1]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import optax
import haiku as hk

from jax import random, lax, jit, vmap, pmap
from functools import partial
from jax_tqdm import loop_tqdm

import sys

sys.path.append("../../../")

from src import CartPole, DQN, EpsilonGreedy, UniformReplayBuffer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
RANDOM_SEED = 0
N_ACTIONS = 2
STATE_SHAPE = 4

DISCOUNT = 0.9
NEURONS_PER_LAYER = [64, 128, N_ACTIONS]
BUFFER_SIZE = 512
BATCH_SIZE = 32
TIME_STEPS = 100_000
LEARNING_RATE = 0.01
EPSILON = 1e-2
TARGET_NET_UPDATE_FREQ = 100

In [3]:
buffer_state = {
    "states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),
    "actions": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "rewards": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "next_states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),
    "dones": jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),
}
print(jax.tree_map(lambda x: x.shape, buffer_state))

{'actions': (512,), 'dones': (512,), 'next_states': (512, 4), 'rewards': (512,), 'states': (512, 4)}


In [18]:
model_key, target_key = vmap(random.PRNGKey)(jnp.arange(2) + RANDOM_SEED)
env = CartPole()
policy = EpsilonGreedy(0.1)

@hk.transform
def model(x):
    mlp = hk.nets.MLP(output_sizes=NEURONS_PER_LAYER)
    return mlp(x)

replay_buffer = UniformReplayBuffer(BUFFER_SIZE, BATCH_SIZE)
model_params = model.init(model_key, jnp.zeros((STATE_SHAPE,)))
target_net_params = model.init(target_key, jnp.zeros((STATE_SHAPE,)))
optimizer = optax.adam(learning_rate=LEARNING_RATE)
optimizer_state = optimizer.init(model_params)
agent = DQN(DISCOUNT, LEARNING_RATE, model, EPSILON)

In [19]:
jax.tree_map(lambda x: x.shape, model_params)

{'mlp/~/linear_0': {'b': (64,), 'w': (4, 64)},
 'mlp/~/linear_1': {'b': (128,), 'w': (64, 128)},
 'mlp/~/linear_2': {'b': (2,), 'w': (128, 2)}}

In [20]:
jax.tree_map(lambda x: x.shape, optimizer_state)

(ScaleByAdamState(count=(), mu={'mlp/~/linear_0': {'b': (64,), 'w': (4, 64)}, 'mlp/~/linear_1': {'b': (128,), 'w': (64, 128)}, 'mlp/~/linear_2': {'b': (2,), 'w': (128, 2)}}, nu={'mlp/~/linear_0': {'b': (64,), 'w': (4, 64)}, 'mlp/~/linear_1': {'b': (128,), 'w': (64, 128)}, 'mlp/~/linear_2': {'b': (2,), 'w': (128, 2)}}),
 EmptyState())

# **_Rollout_**

1. Init replay buffer
2. for t steps:
   1. action = agent.act
   2. add experience to replay buffer
   3. sample batch from replay buffer
   4. agent.update
      1. every N steps, update target network


In [21]:
def rollout(
    timesteps: int,
    random_seed: int,
    target_net_update_freq: int,
    model: hk.Transformed,
    optimizer: optax.GradientTransformation,
    buffer_state: dict,
):
    @loop_tqdm(timesteps)
    @jit
    def _fori_body(i: int, val: tuple):
        (
            model_params,
            target_net_params,
            optimizer_state,
            buffer_state,
            action_key,
            buffer_key,
            env_state,
            all_obs,
            all_rewards,
            all_done,
            losses
        ) = val

        state, _ = env_state
        action, action_key = agent.act(action_key, model_params, state)
        env_state, obs, reward, done = env.step(env_state, action)
        experience = (state, action, reward, obs, done)

        buffer_state = replay_buffer.add(buffer_state, experience, i)
        current_buffer_size = jnp.min(jnp.array([i, BUFFER_SIZE]))
        experiences_batch, buffer_key = replay_buffer.sample(
            buffer_key, buffer_state, current_buffer_size
        )

        model_params, optimizer_state, loss = agent.update(
            model_params,
            target_net_params,
            optimizer,
            optimizer_state,
            experiences_batch,
        )

        target_net_params = lax.cond(
            i % target_net_update_freq,
            lambda _: model_params,
            lambda _: target_net_params,
            operand=None,
        )

        all_obs = all_obs.at[i].set(obs)
        all_rewards = all_rewards.at[i].set(reward)
        all_done = all_done.at[i].set(done)
        losses = losses.at[i].set(loss)

        val = (
            model_params,
            target_net_params,
            optimizer_state,
            buffer_state,
            action_key,
            buffer_key,
            env_state,
            all_obs,
            all_rewards,
            all_done,
            losses
        )

        return val

    init_key, action_key, buffer_key = vmap(random.PRNGKey)(jnp.arange(3) + random_seed)
    env_state, _ = env.reset(init_key)
    all_obs = jnp.zeros([timesteps, STATE_SHAPE])
    all_rewards = jnp.zeros([timesteps], dtype=jnp.int32)
    all_done = jnp.zeros([timesteps], dtype=jnp.bool_)
    losses = jnp.zeros([timesteps], dtype=jnp.float32)

    model_params = model.init(action_key, jnp.zeros((STATE_SHAPE,)))
    target_net_params = model.init(action_key, jnp.zeros((STATE_SHAPE,)))
    optimizer_state = optimizer.init(model_params)

    val_init = (
        model_params,
        target_net_params,
        optimizer_state,
        buffer_state,
        action_key,
        buffer_key,
        env_state,
        all_obs,
        all_rewards,
        all_done,
        losses,
    )

    return lax.fori_loop(0, timesteps, _fori_body, val_init)


out = rollout(
    100_000,
    RANDOM_SEED,
    TARGET_NET_UPDATE_FREQ,
    model,
    optimizer,
    buffer_state,
)

Running for 100,000 iterations: 100%|██████████| 100000/100000 [00:07<00:00, 13229.86it/s]


In [24]:
jax.tree_map(lambda x: x.shape, out)

({'mlp/~/linear_0': {'b': (64,), 'w': (4, 64)},
  'mlp/~/linear_1': {'b': (128,), 'w': (64, 128)},
  'mlp/~/linear_2': {'b': (2,), 'w': (128, 2)}},
 {'mlp/~/linear_0': {'b': (64,), 'w': (4, 64)},
  'mlp/~/linear_1': {'b': (128,), 'w': (64, 128)},
  'mlp/~/linear_2': {'b': (2,), 'w': (128, 2)}},
 (ScaleByAdamState(count=(), mu={'mlp/~/linear_0': {'b': (64,), 'w': (4, 64)}, 'mlp/~/linear_1': {'b': (128,), 'w': (64, 128)}, 'mlp/~/linear_2': {'b': (2,), 'w': (128, 2)}}, nu={'mlp/~/linear_0': {'b': (64,), 'w': (4, 64)}, 'mlp/~/linear_1': {'b': (128,), 'w': (64, 128)}, 'mlp/~/linear_2': {'b': (2,), 'w': (128, 2)}}),
  EmptyState()),
 {'actions': (512,),
  'dones': (512,),
  'next_states': (512, 4),
  'rewards': (512,),
  'states': (512, 4)},
 (2,),
 (2,),
 ((4,), (2,)),
 (100000, 4),
 (100000,),
 (100000,),
 (100000,))

In [None]:
d = {}
outputs = ["model_params","target_net_params","optimizer_state","buffer_state","action_key","buffer_key","env_state","all_obs","all_rewards","all_done","losses"]

In [22]:
px.line(out[-1])

In [23]:
df = pd.DataFrame(
    data={
        "episode":out[-2].cumsum(),
        "reward": out[-3],
    },
)
df["episode"] = df["episode"].shift().fillna(0)
px.bar(df.groupby("episode").agg("sum").tail(500), title="Reward Per Episode")

import math

def _reset(key):
    new_state = random.uniform(
        key,
        shape=(4,),
        minval=-reset_bounds,
        maxval=reset_bounds,
    )
    key, sub_key = random.split(key)

    return new_state, sub_key


def _reset_if_done(env_state, done):
    key = env_state[1]

    def reset_fn(key):
        # Assuming _reset returns a tuple of the same structure as env_state
        return _reset(key)

    def no_reset_fn(key):
        return env_state

    return lax.cond(
        done,
        reset_fn,
        no_reset_fn,
        operand=key,
    )

gravity = 9.8
masscart = 1.0
masspole = 0.1
total_mass = masspole + masscart
length = 0.5  # half the pole's length
polemass_length = masspole * length
force_mag = 10.0
tau = 0.02  # seconds between state updates
reset_bounds = 0.05

# Limits defining episode termination
x_limit = 2.4
theta_limit_rads = 12 * 2 * math.pi / 360

state, key = env_state
x, x_dot, theta, theta_dot = state

force = lax.cond(
    jnp.all(action) == 1,
    lambda _: force_mag,
    lambda _: -force_mag,
    operand=None,
)
cos_theta, sin_theta = jnp.cos(theta), jnp.sin(theta)

temp = (force + polemass_length * jnp.square(theta_dot) * sin_theta) / total_mass
theta_accel = (gravity * sin_theta - cos_theta * temp) / (
    length * (4.0 / 3.0 - masspole * jnp.square(cos_theta) / total_mass)
)
x_accel = temp - polemass_length * theta_accel * cos_theta / total_mass

# euler
x += tau * x_dot
x_dot += tau * x_accel
theta += theta + tau * theta_dot
theta_dot += tau * theta_accel

new_state = jnp.array([x, x_dot, theta, theta_dot])

done = (
    (x < -x_limit)
    | (x > x_limit)
    | (theta > theta_limit_rads)
    | (theta < -theta_limit_rads)
)
reward = jnp.int32(jnp.invert(done))

jax.tree_map(print, (done, reward, new_state))

env_state = new_state, key
jax.tree_map(lambda x: x.shape, env_state)
_reset_if_done(env_state, done)