In [1]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px

from jax import random, lax, jit, vmap, pmap
from tqdm.auto import tqdm
from functools import partial
from jax_tqdm import loop_tqdm
import plotly.graph_objects as go
from plotly.subplots import make_subplots


from src import GridWorld, Q_learning, EpsilonGreedy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SEED = 2
INITIAL_STATE = jnp.array([8, 12])
GOAL_STATE = jnp.array([0, 0])
GRID_SIZE = jnp.array([8, 12])
N_STATES = jnp.prod(GRID_SIZE)
N_ACTIONS = 4
DISCOUNT = 0.9
LEARNING_RATE = 0.1

key = random.PRNGKey(SEED)

env = GridWorld(INITIAL_STATE, GOAL_STATE, GRID_SIZE)
policy = EpsilonGreedy(0.1)
agent = Q_learning(
    key,
    N_STATES,
    N_ACTIONS, 
    DISCOUNT,
    LEARNING_RATE,
    policy,
)



No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


def train_n_episodes(n_steps=100):    
    
    key = random.PRNGKey(0)
    q_values = jnp.zeros([GRID_SIZE[0], GRID_SIZE[1], N_ACTIONS], dtype=jnp.float32)
    env_state, obs = env.reset(key) 
    done = False

    steps = []

    for _ in tqdm(range(n_steps)):
        done = False
        n_steps = 0

        while not done:
            state, _ = env_state
            action, key = policy(key, N_ACTIONS, state, q_values)
            env_state, obs, reward, done = env.step(env_state, action)

            q_values = agent.update(state, action, reward, done, obs, q_values)
            n_steps+=1
        steps.append(n_steps)


    print(f"Average Q-values: {q_values.mean():.3f}")
    print(f"Total Number of Steps : {sum(steps)}")
    fig=px.imshow(pd.DataFrame(jnp.max(q_values, axis=2)).round(2), title="Maximal Q-value for each state")
    fig.show()
    fig=px.line(steps)
    fig.show()
    
    
train_n_episodes(1000)

In [4]:
v_reset = jit(
            vmap(
                env.reset,
                out_axes=((0, 0), 0),  # ((env_state), obs)
                axis_name="batch_axis",
            )
        )

v_step = jit(
            vmap(
                env.step,
                in_axes=((0, 0), 0),  # ((env_state), action)
                out_axes=((0, 0), 0, 0, 0),  # ((env_state), obs, reward, done)
                axis_name="batch_axis",
            )
        )
v_update = jit(
            vmap(
                agent.update,
                in_axes=(0, 0, 0, 0, 0, -1),
                # iterate through the last dimension of 
                # agent.update's output (i.e. batch dim
                out_axes=-1,
                axis_name="batch_axis",
                ),
            )

v_policy = jit(
            vmap(
                policy.call,
                in_axes=(0, None, 0, -1),  # (keys, n_actions, state, q_values)
                axis_name="batch_axis",
                ),
                static_argnums=(1,),
            )

In [17]:
def plot_performances(all_q_values: jnp.array, time_steps, n_env):
    q_values_avg = pd.DataFrame(jnp.mean(all_q_values[-1], axis=[2,3]))
    fig = make_subplots(
        rows=1, cols=2,
        subplot_titles=("Q-values as a surface", "Q-values as a heatmap"),
        column_widths=[0.5, 0.5],
        specs=[[{'type': 'surface'}, {'type': 'heatmap'}]]
    )

    fig.add_trace(
        go.Surface(z=q_values_avg),
        row=1, col=1
    ).add_trace(
        go.Heatmap(z=q_values_avg),
        row=1, col=2
    )

    fig.update_layout(
        title_text=f"Q-values averaged over {n_env} agents, {time_steps} individual training steps",
        height=600,
    ).update_yaxes(autorange="reversed", row=1, col=2)

    fig.show()

In [20]:
N_ENV = 30
TIME_STEPS = 100_000
key = random.PRNGKey(SEED)
keys = random.split(key, N_ENV)

@loop_tqdm(TIME_STEPS)
def fori_body(i:int, val:tuple):
    env_states, action_keys, all_obs, all_rewards, all_done, all_q_values = val
    states, _ = env_states
    q_values = all_q_values[i, :, :, :, :]
    actions, action_keys = v_policy(action_keys, N_ACTIONS, states, q_values)

    env_states, obs, rewards, done = v_step(env_states, actions)
    q_values = v_update(states, actions, rewards, done, obs, q_values)
    
    all_obs = all_obs.at[i].set(obs)
    all_rewards = all_rewards.at[i].set(rewards)
    all_done = all_done.at[i].set(done)
    all_q_values = all_q_values.at[i+1, :, :, :, :].set(q_values)

    val = (env_states, action_keys, all_obs, all_rewards, all_done, all_q_values)
    return val


def rollout(keys, timesteps, n_env):
    all_obs = jnp.zeros([timesteps, n_env, 2])
    all_rewards = jnp.zeros([timesteps, n_env], dtype=jnp.int32)
    all_done = jnp.zeros([timesteps, n_env], dtype=jnp.bool_)
    # q_values has first dimension = timesteps +1, as the update targets time step t+1
    all_q_values = jnp.zeros([timesteps+1, GRID_SIZE[0], GRID_SIZE[1], N_ACTIONS, n_env], dtype=jnp.float32)

    action_keys = random.split(random.PRNGKey(0), n_env)
    env_states, _ = v_reset(keys)
    
    val_init = (env_states, action_keys, all_obs, all_rewards, all_done, all_q_values)
    val = lax.fori_loop(0, timesteps, fori_body, val_init)
    env_states, action_keys, all_obs, all_reward, all_done, all_q_values = val
    
    return all_obs, all_reward, all_done, all_q_values

all_obs, all_reward, all_done, all_q_values = rollout(keys, TIME_STEPS, N_ENV)
plot_performances(all_q_values, TIME_STEPS, N_ENV)

Running for 100,000 iterations: 100%|██████████| 100000/100000 [00:02<00:00, 46200.90it/s]


In [60]:
q_t.at[0].set(jnp.ones([GRID_SIZE[0], GRID_SIZE[1], N_ACTIONS, N_ENV]))

Array([[[[[1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.]],

         [[1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.]],

         [[1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.]],

         ...,

         [[1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.]],

         [[1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.]],

         [[1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 

In [58]:
q_t=jnp.zeros([10, GRID_SIZE[0], GRID_SIZE[1], N_ACTIONS, N_ENV], dtype=jnp.float32)


In [48]:
all_obs = jnp.zeros([10, N_ENV])
all_rewards = jnp.zeros([10, N_ENV], dtype=jnp.int32)
all_done = jnp.zeros([10, N_ENV], dtype=jnp.bool_)
all_q_values = jnp.zeros([10, GRID_SIZE[0], GRID_SIZE[1], N_ACTIONS, N_ENV], dtype=jnp.float32)
env_states, keys = v_reset(keys)

key = random.PRNGKey(SEED)
keys = random.split(key, N_ENV)

dummy_val = (
    env_states, keys, all_obs, all_rewards, all_done, all_q_values
)  # These should be dummy/initial values in the actual shapes you are using
i = 0  # or whatever index you want to debug
fori_body(i, dummy_val)

(10, 8, 8, 4)


ValueError: Incompatible shapes for broadcasting: (10, 8, 8, 4) and requested shape (8, 8, 4, 10)