In [1]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px

from jax import random, lax, jit, vmap, pmap
from tqdm.auto import tqdm
from functools import partial

from src import GridWorld, Q_learning, EpsilonGreedy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SEED = 2
GRID_DIM = jnp.array([8, 8])
INITIAL_STATE = jnp.array([8, 8])
GOAL_STATE = jnp.array([2, 0])
GRID_SIZE = jnp.array([8, 8])
N_STATES = jnp.prod(GRID_DIM)
N_ACTIONS = 4
DISCOUNT = 0.9
LEARNING_RATE = 0.1

key = random.PRNGKey(SEED)

env = GridWorld(INITIAL_STATE, GOAL_STATE, GRID_SIZE)
policy = EpsilonGreedy(0.1)
agent = Q_learning(
    key,
    N_STATES,
    N_ACTIONS,
    DISCOUNT,
    LEARNING_RATE,
    policy,
)

movements = jnp.array([[0, 1], [1, 0], [0, -1], [-1, 0]])


No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


```python
def train_n_steps(n_steps=1000):    
    
    key = random.PRNGKey(0)
    q_values = jnp.zeros([GRID_SIZE[0], GRID_SIZE[1], N_ACTIONS], dtype=jnp.float32)
    env_state, obs = env.reset(key) 
    done = False

    steps = []

    for _ in tqdm(range(n_steps)):
        done = False
        n_steps = 0

        while not done:
            state, _ = env_state
            action, key = policy(key, N_ACTIONS, state, q_values)
            movement = movements[action]
            env_state, obs, reward, done = env.step(env_state, movement)

            q_values = agent.update(state, action, reward, done, obs, q_values)
            n_steps+=1
        steps.append(n_steps)


    fig=px.imshow(pd.DataFrame(jnp.max(q_values, axis=2)).round(2), title="Maximal Q-value for each state")
    fig.show()
    fig=px.line(steps)
    fig.show()
    
train_n_steps()```

In [3]:
v_reset = jit(
            vmap(
                env.reset,
                out_axes=((0, 0), 0),  # ((env_state), obs)
                axis_name="batch_axis",
            )
        )

v_step = jit(
            vmap(
                env.step,
                in_axes=((0, 0), 0),  # ((env_state), action)
                out_axes=((0, 0), 0, 0, 0),  # ((env_state), obs, reward, done)
                axis_name="batch_axis",
            )
        )
v_update = jit(
            vmap(
                agent.update,
                in_axes=(0, 0, 0, 0, 0, -1),
                axis_name="batch_axis",
                ),
            )

v_policy = jit(
            vmap(
                policy.call,
                in_axes=(0, None, 0, -1),  # (keys, n_actions, state, q_values)
                axis_name="batch_axis",
                ),
                static_argnums=(1,),
            )

In [17]:
N_ENV = 10

key = random.PRNGKey(SEED)
keys = random.split(key, N_ENV)
action_keys = random.split(random.PRNGKey(0), N_ENV) 

env_states, obs = v_reset(keys)
states, keys = env_states
q_values = jnp.zeros([GRID_SIZE[0], GRID_SIZE[1], N_ACTIONS, N_ENV], dtype=jnp.float32)

actions, action_keys = v_policy(action_keys, N_ACTIONS, states, q_values)

env_state, obs, reward, done = v_step(env_states, actions)
states, keys = env_states

q_values = v_update(states, actions, reward, done, obs, q_values)

In [19]:
env_states

(Array([[8, 8],
        [8, 8],
        [8, 8],
        [8, 8],
        [8, 8],
        [8, 8],
        [8, 8],
        [8, 8],
        [8, 8],
        [8, 8]], dtype=int32),
 Array([[3643052944, 4056903599],
        [   3955054, 2850452153],
        [2199943636,  945034629],
        [1737963661, 1193160294],
        [  99651555, 3503963802],
        [ 938951844, 1349476278],
        [ 730919943, 1121179547],
        [1284802347, 3461820139],
        [4102267121, 3690616328],
        [3366893866, 3437845683]], dtype=uint32))

In [27]:
def fori_body(i:int, val:tuple):
    env_states, action_keys, all_obs, all_rewards, all_done, all_q_values = val
    states, _ = env_states
    q_values = all_q_values[i, :, :, :, :]

    actions, action_keys = v_policy(action_keys, N_ACTIONS, states, q_values)
    env_states, obs, rewards, done = v_step(env_states, actions)
    q_values = v_update(states, actions, rewards, done, obs, q_values)
    jax.debug.print("{x}", x=q_values.shape)

    # all_obs = all_obs.at[i].set(obs)
    all_rewards = all_rewards.at[i].set(rewards)
    all_done = all_done.at[i].set(done)
    all_q_values = all_q_values.at[i, :, :, :, :].set(q_values)

    val = (env_states, action_keys, all_obs, all_rewards, all_done, all_q_values)
    return val

def rollout(keys, timesteps):
    all_obs = jnp.zeros([timesteps, N_ENV])
    all_rewards = jnp.zeros([timesteps, N_ENV], dtype=jnp.int32)
    all_done = jnp.zeros([timesteps, N_ENV], dtype=jnp.bool_)
    all_q_values = jnp.zeros([timesteps, GRID_SIZE[0], GRID_SIZE[1], N_ACTIONS, N_ENV], dtype=jnp.float32)

    action_keys = random.split(random.PRNGKey(0), N_ENV)
    env_states, _ = v_reset(keys)
    
    val = (env_states, action_keys, all_obs, all_rewards, all_done, all_q_values)
    val = lax.fori_loop(0, timesteps, fori_body, val)
    env_states, action_keys, all_obs, all_reward, all_done, all_q_values = val
    return all_obs, all_reward, all_done, all_q_values

key = random.PRNGKey(SEED)
keys = random.split(key, N_ENV)
rollout(keys, 10)

ValueError: Incompatible shapes for broadcasting: (10, 8, 8, 4) and requested shape (8, 8, 4, 10)

In [60]:
q_t.at[0].set(jnp.ones([GRID_SIZE[0], GRID_SIZE[1], N_ACTIONS, N_ENV]))

Array([[[[[1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.]],

         [[1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.]],

         [[1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.]],

         ...,

         [[1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.]],

         [[1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.]],

         [[1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 1., 1.],
          [1., 1., 1., ..., 1., 

In [58]:
q_t=jnp.zeros([10, GRID_SIZE[0], GRID_SIZE[1], N_ACTIONS, N_ENV], dtype=jnp.float32)


In [48]:
all_obs = jnp.zeros([10, N_ENV])
all_rewards = jnp.zeros([10, N_ENV], dtype=jnp.int32)
all_done = jnp.zeros([10, N_ENV], dtype=jnp.bool_)
all_q_values = jnp.zeros([10, GRID_SIZE[0], GRID_SIZE[1], N_ACTIONS, N_ENV], dtype=jnp.float32)
env_states, keys = v_reset(keys)

key = random.PRNGKey(SEED)
keys = random.split(key, N_ENV)

dummy_val = (
    env_states, keys, all_obs, all_rewards, all_done, all_q_values
)  # These should be dummy/initial values in the actual shapes you are using
i = 0  # or whatever index you want to debug
fori_body(i, dummy_val)

(10, 8, 8, 4)


ValueError: Incompatible shapes for broadcasting: (10, 8, 8, 4) and requested shape (8, 8, 4, 10)