# ⚡***Lightning fast parallelized training with*** <img src='https://upload.wikimedia.org/wikipedia/commons/8/86/Google_JAX_logo.svg' alt="Environment" width="60" />

## **2) Expected SARSA with Softmax policy on a GridWorld**

In [1]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px

from jax import random, lax, jit, vmap, pmap
from tqdm.auto import tqdm
from functools import partial
from jax_tqdm import loop_tqdm
import plotly.graph_objects as go
from plotly.subplots import make_subplots


from src import GridWorld, Softmax_policy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SEED = 2
INITIAL_STATE = jnp.array([8, 12])
GOAL_STATE = jnp.array([0, 0])
GRID_SIZE = jnp.array([8, 12])
N_STATES = jnp.prod(GRID_SIZE)
N_ACTIONS = 4
DISCOUNT = 0.9
LEARNING_RATE = 0.1

key = random.PRNGKey(SEED)

env = GridWorld(INITIAL_STATE, GOAL_STATE, GRID_SIZE)
policy = Softmax_policy(0.1)
# agent = Q_learning(
#     key,
#     N_STATES,
#     N_ACTIONS, 
#     DISCOUNT,
#     LEARNING_RATE,
# )

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [34]:
TIME_STEPS = 3_000_000
key = random.PRNGKey(SEED)

def single_agent_rollout(key: random.PRNGKey, timesteps:int):

    @loop_tqdm(TIME_STEPS)
    @jit
    def fori_body(i:int, val:tuple):
        env_states, action_key, all_obs, all_rewards, all_done, all_q_values = val
        states, _ = env_states
        q_values = all_q_values[i, :, :, :]
        
        # action selection, step and q-update
        actions, action_key = policy(action_key, N_ACTIONS, states, q_values)
        env_states, obs, rewards, done = env.step(env_states, actions)
        q_values = agent.update(states, actions, rewards, done, obs, q_values)     
           
        # update observations, rewards, done flag and q_values
        all_obs = all_obs.at[i].set(obs)
        all_rewards = all_rewards.at[i].set(rewards)
        all_done = all_done.at[i].set(done)
        all_q_values = all_q_values.at[i+1, :, :, :].set(q_values)

        val = (env_states, action_key, all_obs, all_rewards, all_done, all_q_values)
        return val

    # initialize obs, rewards, done and q_values with an added time index
    all_obs = jnp.zeros([timesteps, 2])
    all_rewards = jnp.zeros([timesteps], dtype=jnp.int32)
    all_done = jnp.zeros([timesteps], dtype=jnp.bool_)
    # q_values has first dimension = timesteps +1, as the update targets time step t+1
    all_q_values = jnp.zeros([timesteps+1, GRID_SIZE[0], GRID_SIZE[1], N_ACTIONS], dtype=jnp.float32)

    # random keys used for policy / action selection
    action_key = random.PRNGKey(0)
    env_states, _ = env.reset(key)
    
    val_init = (env_states, action_key, all_obs, all_rewards, all_done, all_q_values)
    val = lax.fori_loop(0, timesteps, fori_body, val_init)
    
    return val

env_states, action_keys, obs, rewards, done, q_values = single_agent_rollout(key, TIME_STEPS)

Running for 3,000,000 iterations: 100%|██████████| 3000000/3000000 [00:03<00:00, 924042.76it/s] 


In [4]:
fig = px.imshow(jnp.max(q_values[-1], axis=-1), title="Maximal Q-value for each state")
fig.show()

In [5]:
v_reset = jit(vmap(
                env.reset,
                out_axes=((0, 0), 0),  # ((env_state), obs)
                axis_name="batch_axis",
            ))

v_step = jit(vmap(
                env.step,
                in_axes=((0, 0), 0),  # ((env_state), action)
                out_axes=((0, 0), 0, 0, 0),  # ((env_state), obs, reward, done)
                axis_name="batch_axis",
            ))

v_update = jit(vmap(
                agent.update,
                in_axes=(0, 0, 0, 0, 0, -1),
                # iterate through the last dimension of 
                # agent.update's output (i.e. batch dim)
                out_axes=-1,
                axis_name="batch_axis",
                ))

v_policy = jit(vmap(
                policy.call,
                in_axes=(0, None, 0, -1),  # (keys, n_actions, state, q_values)
                axis_name="batch_axis",
                ),
                static_argnums=(1,),
            )

In [6]:
def plot_performances(all_q_values: jnp.array, time_steps, n_env):
    q_values_avg = pd.DataFrame(jnp.mean(all_q_values[-1], axis=[2,3]))
    fig = make_subplots(
        rows=1, cols=2,
        subplot_titles=("Q-values as a surface", "Q-values as a heatmap"),
        column_widths=[0.5, 0.5],
        specs=[[{'type': 'surface'}, {'type': 'heatmap'}]]
    )

    fig.add_trace(
        go.Surface(z=q_values_avg),
        row=1, col=1
    ).add_trace(
        go.Heatmap(z=q_values_avg),
        row=1, col=2
    )

    fig.update_layout(
        title_text=f"Q-values averaged over {n_env} agents, {time_steps} individual training steps",
        height=600,
    ).update_yaxes(autorange="reversed", row=1, col=2)

    fig.show()

In [7]:
N_ENV = 30
TIME_STEPS = 100_000
key = random.PRNGKey(SEED)
keys = random.split(key, N_ENV)

@loop_tqdm(TIME_STEPS)
@jit
def fori_body(i:int, val:tuple):
    env_states, action_keys, all_obs, all_rewards, all_done, all_q_values = val
    states, _ = env_states
    q_values = all_q_values[i, :, :, :, :]
    actions, action_keys = v_policy(action_keys, N_ACTIONS, states, q_values)

    env_states, obs, rewards, done = v_step(env_states, actions)
    q_values = v_update(states, actions, rewards, done, obs, q_values)
    
    all_obs = all_obs.at[i].set(obs)
    all_rewards = all_rewards.at[i].set(rewards)
    all_done = all_done.at[i].set(done)
    all_q_values = all_q_values.at[i+1, :, :, :, :].set(q_values)

    val = (env_states, action_keys, all_obs, all_rewards, all_done, all_q_values)
    return val

def single_agent_rollout(keys, timesteps, n_env):
    all_obs = jnp.zeros([timesteps, n_env, 2])
    all_rewards = jnp.zeros([timesteps, n_env], dtype=jnp.int32)
    all_done = jnp.zeros([timesteps, n_env], dtype=jnp.bool_)
    # q_values has first dimension = timesteps +1, as the update targets time step t+1
    all_q_values = jnp.zeros([timesteps+1, GRID_SIZE[0], GRID_SIZE[1], N_ACTIONS, n_env], dtype=jnp.float32)

    action_keys = random.split(random.PRNGKey(0), n_env)
    env_states, _ = v_reset(keys)
    
    val_init = (env_states, action_keys, all_obs, all_rewards, all_done, all_q_values)
    val = lax.fori_loop(0, timesteps, fori_body, val_init)
    env_states, action_keys, all_obs, all_reward, all_done, all_q_values = val
    
    return all_obs, all_reward, all_done, all_q_values

all_obs, all_reward, all_done, all_q_values = single_agent_rollout(keys, TIME_STEPS, N_ENV)
plot_performances(all_q_values, TIME_STEPS, N_ENV)

Running for 100,000 iterations: 100%|██████████| 100000/100000 [00:03<00:00, 33085.88it/s]


In [9]:
# Average number of episodes played
jnp.sum(jnp.array(all_reward), axis=0).mean()

Array(6429.2, dtype=float32)