# ⚡***Lightning fast parallelized training with*** <img src='https://upload.wikimedia.org/wikipedia/commons/8/86/Google_JAX_logo.svg' alt="Environment" width="60" />

## **2) Expected SARSA with Softmax Policy on a GridWorld**

In [1]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px

from jax import random, lax, jit, vmap, pmap
from functools import partial
from jax_tqdm import loop_tqdm
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import sys

sys.path.append("../../../")

from src import GridWorld, Expected_Sarsa, EpsilonGreedy, animated_heatmap, tabular_rollout, tabular_parallel_rollout, plot_path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SEED = 2
INITIAL_STATE = jnp.array([7, 11])
GOAL_STATE = jnp.array([0, 0])
GRID_SIZE = jnp.array([8, 12])
N_STATES = jnp.prod(GRID_SIZE)
N_ACTIONS = 4
DISCOUNT = 0.9
LEARNING_RATE = 0.1
TIME_STEPS = 100_000
STOCHASTIC_RESET = False

key = random.PRNGKey(SEED)

env = GridWorld(INITIAL_STATE, GOAL_STATE, GRID_SIZE, STOCHASTIC_RESET)
policy = EpsilonGreedy(0.1)
agent = Expected_Sarsa(
    DISCOUNT,
    LEARNING_RATE,
)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
env_states, action_keys, obs, rewards, done, q_values = tabular_rollout(key, TIME_STEPS, N_ACTIONS, GRID_SIZE, env, agent, policy)
animated_heatmap(q_values, dims=jnp.asarray(GRID_SIZE), agent_name="Expected Sarsa", sample_freq=1_000, log_scale=False)
plot_path(obs)

Running for 100,000 iterations: 100%|██████████| 100000/100000 [00:00<00:00, 1664188.42it/s]


In [4]:
rewards.sum()

Array(4665, dtype=int32)

In [6]:
N_ENV = 30

key = random.PRNGKey(SEED)
keys = random.split(key, N_ENV)

all_obs, all_rewards, all_done, all_q_values = tabular_parallel_rollout(
    keys, TIME_STEPS, N_ACTIONS, GRID_SIZE, N_ENV, env, agent, policy
)
animated_heatmap(jnp.mean(all_q_values, axis=-1), 
                 dims=jnp.asarray(GRID_SIZE), 
                 agent_name=f"{N_ENV} E-Sarsa average", 
                 sample_freq=1000, 
                 log_scale=False
                )

Running for 100,000 iterations: 100%|██████████| 100000/100000 [00:02<00:00, 34026.92it/s]
