In [1]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px

from jax import random, lax, jit, vmap, pmap
from tqdm.auto import tqdm
from functools import partial

from policies import BasePolicy
from agents import Q_learning
from envs import GridWorld
from policies import EpsilonGreedy

In [None]:
SEED = 2
GRID_DIM = jnp.array([8, 8])
INITIAL_STATE = jnp.array([8, 8])
GOAL_STATE = jnp.array([0, 0])
GRID_SIZE = jnp.array([8, 8])
N_STATES = jnp.prod(GRID_DIM)
N_ACTIONS = 4
DISCOUNT = 0.9
LEARNING_RATE = 0.1

key = random.PRNGKey(SEED)

env = GridWorld(INITIAL_STATE, GOAL_STATE, GRID_SIZE)
policy = EpsilonGreedy(0.1)
agent = Q_learning(
    key,
    N_STATES,
    N_ACTIONS,
    DISCOUNT,
    LEARNING_RATE,
    policy,
)

q_values = jnp.zeros([GRID_SIZE[0], GRID_SIZE[1], N_ACTIONS], dtype=jnp.float32)
movements = jnp.array([[0, 1], [1, 0], [0, -1], [-1, 0]])

env_state, obs = env.reset(key) 
done = False

steps = []

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)



for _ in tqdm(range(400)):
    done = False
    n_steps = 0

    while not done:
        state, _ = env_state
        action, key = policy(key, N_ACTIONS, state, q_values)
        movement = movements[action]
        env_state, obs, reward, done = env.step(env_state, movement)

        q_values = agent.update(state, action, reward, done, obs, q_values)
        n_steps+=1
    steps.append(n_steps)


px.imshow(pd.DataFrame(jnp.max(q_values, axis=2)).round(2), title="Maximal Q-value for each state")

In [None]:
v_step = jit(vmap(env.step, 
                  in_axes=((0, 0), 0), # ((env_state), action)
                  out_axes=((0, 0), 0, 0, 0), # ((env_state), obs, reward, done)
                  axis_name="batch_axis"
                  ))

v_reset = jit(vmap(env.reset,
                   out_axes=((0,0), 0), # ((env_state), obs)
                   axis_name="batch_axis"
                   ))

v_policy = jit(vmap(policy.call, 
                in_axes=(0, None, 1, -1), # (keys, n_actions, state, q_values) 
                axis_name="batch_axis"
                ),
                static_argnums=(1,)
                )

In [None]:
# @partial(jit, static_argnums=(1))
# def policy(key, n_actions, state, q_values):
#     def _random_action_fn(subkey):
#         return random.choice(subkey, jnp.arange(n_actions))

#     def _greedy_action_fn(subkey):
#         """
#         Selects the greedy action with random tie-break
#         """
#         q = q_values[state[0], state[1]]
#         q_max = jnp.max(q, axis=-1)
#         q_max_mask = jnp.equal(q, q_max)
#         p = jnp.divide(q_max_mask, q_max_mask.sum())
#         choice = random.choice(subkey, jnp.arange(n_actions), p=p)
#         return jnp.int32(choice)

#     explore = random.uniform(key) < 0.1
#     key, subkey = random.split(key)
#     action = lax.cond(
#         explore,
#         _random_action_fn,
#         _greedy_action_fn,
#         operand=subkey,
#     )

#     return action, subkey

# v_policy = vmap(policy, 
#                 in_axes=(0, None, 0, -1), # (keys, n_actions, state, q_values) 
#                 axis_name="batch_axis"
#                 )
                
# N_ENV = 10
# key = random.PRNGKey(SEED)
# keys = random.split(key, N_ENV)
# env_state, obs = v_reset(keys)
# q_values = jnp.zeros([GRID_SIZE[0], GRID_SIZE[1], N_ACTIONS, N_ENV], dtype=jnp.float32)

# action, keys = v_policy(keys, N_ACTIONS, states, q_values)

(Array(2, dtype=int32), Array([2425776485,  230565590], dtype=uint32))

In [None]:
N_ENV = 10
key = random.PRNGKey(SEED)
keys = random.split(key, N_ENV)
env_state, obs = v_reset(keys)
states, keys = env_state
q_values = jnp.zeros([GRID_SIZE[0], GRID_SIZE[1], N_ACTIONS, N_ENV], dtype=jnp.float32)

action, keys = v_policy(keys, N_ACTIONS, states, q_values)

state, key = env_state
# env_state, obs, reward, done = v_step(env_state, action)

ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * most axes (2 of them) had size 10, e.g. axis 0 of argument key of type uint32[10,2];
  * one axis had size 2: axis 1 of argument state of type int32[10,2]