
## Learning Algorithms in Gridworld</h2>

- ### Q-learning with synchronous sampling

    > Synchronous sampling samples a state transition for each (s, a) pair. 

In [12]:
from functools import partial
import jax.numpy as jnp
import jax.random as jrd
import jax
import jaxdp

from jaxdp.learning.algorithms import q_learning
from jaxdp.learning.runner import train, no_learner_state, no_step_index


# Define the arguments
args = dict(
    seed=42,
    n_env=10,
    update_fn=dict(
        alpha=0.1
    ),
    train_loop=dict(
        gamma=0.99,
        n_steps=100,
        eval_period=10,
    ),
    value_init=dict(
        minval=0.0,
        maxval=1.0
    ),
    mdp_init=dict(
        p_slip=0.15,
        board=["#####",
               "#  @#",
               "#  X#",
               "#P  #",
               "#####"]
    )
)


# Initiate the MDP and the Q values
_train_key, value_key = jrd.split(jrd.PRNGKey(args["seed"]), 2)
train_keys = jrd.split(_train_key, args["n_env"])
mdp = jaxdp.mdp.grid_world(**args["mdp_init"])
init_value = jrd.uniform(value_key, (mdp.action_size, mdp.state_size,),
                         dtype="float32", **args["value_init"])

# Define learner function
update_fn = partial(no_learner_state(no_step_index(q_learning.update.sync)),
                    **args["update_fn"])

# Train a policy for 10 different seeds (After JIT compiling the "batch" train function)
jitted_batch_train = jax.jit(
    jax.vmap(
        partial(
            train.sync,
            value_star=jnp.full_like(init_value, jnp.nan),
            learner_state=None,
            policy_fn=lambda q, i: jaxdp.greedy_policy.q(q),
            update_fn=update_fn,
            **args["train_loop"]
        ), in_axes=(None, None, 0))
)
metrics, value, learner_state = jitted_batch_train(
    init_value,
    mdp,
    train_keys
)

- ### Visualize

In [10]:
from itertools import chain
import pandas as pd
from plot_util import make_figure


# Make dataframe from the metrics
percentile = 25

index = pd.MultiIndex.from_product(
    [["gridworld"], ["q-learning"], list(range(100))],
    names=["ENV", "ALG", "STEP"])
columns = pd.MultiIndex.from_product(
    [metrics._fields, ["low", "med", "high"]],
    names=["METRIC", "PERCENTILE"])

data = []
for name in metrics._fields:
    values = getattr(metrics, name)
    if values.ndim == 1:
        values = values.reshape(1, -1)
    percentiles = jnp.nanpercentile(
        values, q=jnp.array([percentile, 50, 100 - percentile]), axis=0)
    data.append(percentiles)


df = pd.DataFrame(data=jnp.stack(list(chain(*data)), axis=1), columns=columns, index=index)

# Generate the figure
make_figure(df.loc["gridworld"])