In [1]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px

from jax import random, lax, jit, vmap, pmap
from tqdm.auto import tqdm
from functools import partial

from policies import BasePolicy
from agents import Q_learning
from envs import GridWorld
from policies import EpsilonGreedy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SEED = 2
GRID_DIM = jnp.array([8, 8])
INITIAL_STATE = jnp.array([8, 8])
GOAL_STATE = jnp.array([0, 0])
GRID_SIZE = jnp.array([8, 8])
N_STATES = jnp.prod(GRID_DIM)
N_ACTIONS = 4
DISCOUNT = 0.9
LEARNING_RATE = 0.1

key = random.PRNGKey(SEED)

env = GridWorld(INITIAL_STATE, GOAL_STATE, GRID_SIZE)
policy = EpsilonGreedy(0.1)
agent = Q_learning(
    key,
    N_STATES,
    N_ACTIONS,
    DISCOUNT,
    LEARNING_RATE,
    policy,
)

q_values = jnp.zeros([GRID_SIZE[0], GRID_SIZE[1], N_ACTIONS], dtype=jnp.float32)
movements = jnp.array([[0, 1], [1, 0], [0, -1], [-1, 0]])

env_state, obs = env.reset(key) 
done = False

steps = []

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


@jit
def train_one_episode(key, q_values, env_state):
    def cond_fun(loop_vars):
        _, _, _, done = loop_vars
        return ~done  # Continue looping while not done
    
    def body_fun(loop_vars):
        key, q_values, env_state, _ = loop_vars
        state, _ = env_state
        action, key = policy(key, N_ACTIONS, q_values[tuple(state)])
        movement = movements[action]
        env_state, obs, reward, done_ = env.step(env_state, movement)
        q_values = agent.update(state, action, reward, done_, obs, q_values)
        return key, q_values, env_state, done_
    
    initial_state = key, q_values, env_state, jnp.array(False)
    final_state = lax.while_loop(cond_fun, body_fun, initial_state)
    return final_state[:3]  # Drop the `done` variable

train_one_episode(key, q_values, env_state)


In [3]:

for _ in tqdm(range(400)):
    done = False
    n_steps = 0

    while not done:
        state, _ = env_state
        action, key = policy(key, N_ACTIONS, q_values[tuple(state)])
        movement = movements[action]
        env_state, obs, reward, done = env.step(env_state, movement)

        q_values = agent.update(state, action, reward, done, obs, q_values)
        n_steps+=1
    steps.append(n_steps)


px.imshow(pd.DataFrame(jnp.max(q_values, axis=2)).round(2))

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]


[0. 0. 0. 0.]


ValueError: Non-hashable static arguments are not supported. An error occurred during a call to 'call' while trying to hash an object of type <class 'jaxlib.xla_extension.ArrayImpl'>, [8 8]. The error was:
TypeError: unhashable type: 'ArrayImpl'


In [None]:
v_step = jit(vmap(env.step, 
                  in_axes=((0, 0), 0), # ((env_state), action)
                  out_axes=((0, 0), 0, 0, 0), # ((env_state), obs, reward, done)
                  axis_name="batch_axis"
                  ))

v_reset = jit(vmap(env.reset,
                   out_axes=((0,0), 0), # ((env_state), obs)
                   axis_name="batch_axis"
                   ))

v_policy = jit(vmap(policy.call,
                    in_axes= (0, None, 3),
                    out_axes= (0, 0),
                    axis_name="batch_axis"
                    ),
                    static_argnums=(1)
            )

In [None]:
def _greedy_action_fn(subkey, state, q_values):
    """
    Selects the greedy action with random tie-break
    """
    q_max = jnp.max(q_values[tuple(state)])
    q_max_mask = jnp.equal(q_values, q_max)
    p = q_max_mask / jnp.sum(q_max_mask)
    jax.debug.print("{x}", x=p.shape)
    choice = random.choice(subkey, jnp.arange(q_values.shape[-1]), p=p)
    return jnp.int32(choice)
batch_greedy_action_fn = vmap(_greedy_action_fn, in_axes=(0, 0, 3))

In [None]:
batch_greedy_action_fn(keys, state, q_values)

ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * most axes (2 of them) had size 10, e.g. axis 0 of argument subkey of type uint32[10,2];
  * one axis had size 2: axis 0 of argument state of type int32[2]

In [None]:
NUM_ENV = 10
key = random.PRNGKey(SEED)
keys = random.split(key, NUM_ENV)
env_state, obs = v_reset(keys)
q_values = jnp.zeros([GRID_SIZE[0], GRID_SIZE[1], N_ACTIONS, NUM_ENV], dtype=jnp.float32)


action = v_policy(keys, N_ACTIONS, q_values)
env_state,obs, reward, done = v_step(env_state, action)

ValueError: p must be None or match the shape of a