In [1]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import optax
import haiku as hk

from jax import random, lax, jit, vmap, pmap
from functools import partial
from jax_tqdm import loop_tqdm

import sys

sys.path.append("../../../")

from src import CartPole, DQN, EpsilonGreedy, UniformReplayBuffer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SEED = 2
DISCOUNT = 0.9
LEARNING_RATE = 0.1
N_ACTIONS = 2
NEURONS_PER_LAYER = [128, 256, N_ACTIONS]
BUFFER_SIZE = 512
BATCH_SIZE = 32
TIME_STEPS = 100_000
STATE_SHAPE = 4
LEARNING_RATE = 1e-2
EPSILON = 1e-2

In [3]:
buffer_state = {
    "states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),
    "actions": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "rewards": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "next_states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),
    "dones": jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),
}
print(jax.tree_map(lambda x: x.shape, buffer_state))

{'actions': (512,), 'dones': (512,), 'next_states': (512, 4), 'rewards': (512,), 'states': (512, 4)}


In [4]:
action_key = random.PRNGKey(SEED)

env = CartPole()
policy = EpsilonGreedy(0.1)


@hk.transform
def model(x):
    mlp = hk.nets.MLP(output_sizes=NEURONS_PER_LAYER)
    return mlp(x)


replay_buffer = UniformReplayBuffer(BUFFER_SIZE, BATCH_SIZE)

model_params = model.init(action_key, jnp.zeros((STATE_SHAPE,)))
target_net_params = model.init(action_key, jnp.zeros((STATE_SHAPE,)))
optimizer = optax.adam(learning_rate=LEARNING_RATE)
optimizer_state = optimizer.init(model_params)

agent = DQN(DISCOUNT, LEARNING_RATE, model, EPSILON)

In [5]:
jax.tree_map(lambda x: x.shape, model_params)

{'mlp/~/linear_0': {'b': (128,), 'w': (4, 128)},
 'mlp/~/linear_1': {'b': (256,), 'w': (128, 256)},
 'mlp/~/linear_2': {'b': (2,), 'w': (256, 2)}}

In [6]:
jax.tree_map(lambda x: x.shape, optimizer_state)

(ScaleByAdamState(count=(), mu={'mlp/~/linear_0': {'b': (128,), 'w': (4, 128)}, 'mlp/~/linear_1': {'b': (256,), 'w': (128, 256)}, 'mlp/~/linear_2': {'b': (2,), 'w': (256, 2)}}, nu={'mlp/~/linear_0': {'b': (128,), 'w': (4, 128)}, 'mlp/~/linear_1': {'b': (256,), 'w': (128, 256)}, 'mlp/~/linear_2': {'b': (2,), 'w': (256, 2)}}),
 EmptyState())

# **_Rollout_**

1. Init replay buffer
2. for t steps:
   1. action = agent.act
   2. add experience to replay buffer
   3. sample batch from replay buffer
   4. agent.update
      1. every N steps, update target network


In [7]:
TIMESTEPS = 100
RANDOM_SEED = 0
TARGET_NET_UPDATE_FREQ = 10

In [19]:
def rollout(
    timesteps: int,
    random_seed: int,
    target_net_update_freq: int,
    model: hk.Transformed,
    optimizer: optax.GradientTransformation,
    buffer_state: dict,
):
    @loop_tqdm(timesteps)
    @jit
    def _fori_body(i: int, val: tuple):
        (
            model_params,
            target_net_params,
            optimizer_state,
            buffer_state,
            action_key,
            buffer_key,
            env_state,
            all_obs,
            all_rewards,
            all_done,
        ) = val

        state, _ = env_state
        action, action_key = agent.act(action_key, model_params, state)
        env_state, obs, reward, done = env.step(env_state, action)
        experience = (state, action, reward, obs, done)

        buffer_state = replay_buffer.add(buffer_state, experience, i)
        current_buffer_size = jnp.min(jnp.array([i, BUFFER_SIZE]))
        experiences_batch, buffer_key = replay_buffer.sample(
            buffer_key, buffer_state, current_buffer_size
        )

        model_params, optimizer_state = agent.update(
            model_params,
            target_net_params,
            optimizer,
            optimizer_state,
            experiences_batch,
        )

        target_net_params = lax.cond(
            i % target_net_update_freq,
            lambda _: model_params,
            lambda _: target_net_params,
            operand=None,
        )

        all_obs = all_obs.at[i].set(obs)
        all_rewards = all_rewards.at[i].set(reward)
        all_done = all_done.at[i].set(done)

        val = (
            model_params,
            target_net_params,
            optimizer_state,
            buffer_state,
            action_key,
            buffer_key,
            env_state,
            all_obs,
            all_rewards,
            all_done,
        )

        return val

    init_key, action_key, buffer_key = vmap(random.PRNGKey)(jnp.arange(3) + random_seed)
    env_state, _ = env.reset(init_key)
    all_obs = jnp.zeros([timesteps, STATE_SHAPE])
    all_rewards = jnp.zeros([timesteps], dtype=jnp.int32)
    all_done = jnp.zeros([timesteps], dtype=jnp.bool_)

    model_params = model.init(action_key, jnp.zeros((STATE_SHAPE,)))
    target_net_params = model.init(action_key, jnp.zeros((STATE_SHAPE,)))
    optimizer_state = optimizer.init(model_params)

    val_init = (
        model_params,
        target_net_params,
        optimizer_state,
        buffer_state,
        action_key,
        buffer_key,
        env_state,
        all_obs,
        all_rewards,
        all_done,
    )

    return lax.fori_loop(0, timesteps, _fori_body, val_init)


out = rollout(
    100_000,
    RANDOM_SEED,
    TARGET_NET_UPDATE_FREQ,
    model,
    optimizer,
    buffer_state,
)

Running for 100,000 iterations: 100%|██████████| 100000/100000 [00:38<00:00, 2609.12it/s]


In [20]:
out

({'mlp/~/linear_0': {'b': Array([ 0.        ,  0.        ,  0.        ,  0.        , -0.06005458,
           0.        ,  0.        , -0.04700883,  0.        , -0.06005441,
           0.        ,  0.        ,  0.0963307 , -0.06005458,  0.07454794,
           0.06851127,  0.09773249,  0.        ,  0.        ,  0.01649704,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        , -0.06005459,  0.        , -0.06005457,  0.09156305,
           0.        ,  0.        ,  0.        ,  0.        ,  0.09226978,
           0.10115878,  0.        ,  0.        ,  0.        , -0.06005459,
           0.        , -0.06005457, -0.05608438, -0.08283285,  0.        ,
           0.        ,  0.01948245,  0.        ,  0.        ,  0.        ,
          -0.0505572 ,  0.10953962,  0.        ,  0.        ,  0.        ,
           0.0863216 ,  0.        ,  0.08792508,  0.        ,  0.        ,
           0.        ,  0.        ,  0.09208603,  0.09083539,  0.00946203,
  