# ⚡***Lightning fast parallelized training with*** <img src='https://upload.wikimedia.org/wikipedia/commons/8/86/Google_JAX_logo.svg' alt="Environment" width="60" />

## **1) Q-learning on a GridWorld**

In [1]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px

from jax import random, lax, jit, vmap, pmap
from tqdm.auto import tqdm
from functools import partial
from jax_tqdm import loop_tqdm
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import sys

sys.path.append("../")

from src import GridWorld, Q_learning, EpsilonGreedy
from utils import animated_heatmap, rollout, parallel_rollout

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
SEED = 2
INITIAL_STATE = jnp.array([8, 12])
GOAL_STATE = jnp.array([0, 0])
GRID_SIZE = jnp.array([8, 12])
N_STATES = jnp.prod(GRID_SIZE)
N_ACTIONS = 4
DISCOUNT = 0.9
LEARNING_RATE = 0.1
TIME_STEPS = 100_000

key = random.PRNGKey(SEED)

env = GridWorld(INITIAL_STATE, GOAL_STATE, GRID_SIZE)
policy = EpsilonGreedy(0.1)
agent = Q_learning(
    key,
    N_STATES,
    N_ACTIONS, 
    DISCOUNT,
    LEARNING_RATE,
)

In [12]:
env_states, action_keys, obs, rewards, done, q_values = rollout(key, TIME_STEPS, N_ACTIONS, GRID_SIZE, env, agent, policy)
animated_heatmap(q_values, dims=jnp.asarray(GRID_SIZE), agent_name="Expected Sarsa", sample_freq=1_000, log_scale=False)

Running for 100,000 iterations: 100%|██████████| 100000/100000 [00:00<00:00, 1900704.67it/s]


In [13]:
v_reset = jit(vmap(
                env.reset,
                out_axes=((0, 0), 0),  # ((env_state), obs)
                axis_name="batch_axis",
            ))

v_step = jit(vmap(
                env.step,
                in_axes=((0, 0), 0),  # ((env_state), action)
                out_axes=((0, 0), 0, 0, 0),  # ((env_state), obs, reward, done)
                axis_name="batch_axis",
            ))

v_update = jit(vmap(
                agent.update,
                in_axes=(0, 0, 0, 0, 0, -1),
                # iterate through the last dimension of 
                # agent.update's output (i.e. batch dim)
                out_axes=-1,
                axis_name="batch_axis",
                ))

v_policy = jit(vmap(
                policy.call,
                in_axes=(0, None, 0, -1),  # (keys, n_actions, state, q_values)
                axis_name="batch_axis",
                ),
                static_argnums=(1,),
            )

In [15]:
N_ENV = 30

key = random.PRNGKey(SEED)
keys = random.split(key, N_ENV)

all_obs, all_reward, all_done, all_q_values = parallel_rollout(
    keys, TIME_STEPS, N_ACTIONS, GRID_SIZE, N_ENV, v_policy, v_update, v_step, v_reset
)
animated_heatmap(jnp.mean(all_q_values, axis=-1), 
                 dims=jnp.asarray(GRID_SIZE), 
                 agent_name=f"{N_ENV} E-Sarsa average", 
                 sample_freq=1000, 
                 log_scale=False
                )

In [None]:
# Average number of episodes played
jnp.sum(jnp.array(all_reward), axis=0).mean()

Array(6429.2, dtype=float32)