
## Planning Algorithms in Gridworld</h2>

- ### Value Iteration

In [None]:
import jax.numpy as jnp
import jax.random as jrd
import jaxdp
from jaxdp.planning.algorithms import value_iteration
from jaxdp.planning.runner import train, no_update_state


# Define the arguments
args = dict(
    seed=42,
    train_loop=dict(
        gamma=0.99,
        n_iterations=100,
        verbose=False,
    ),
    value_init=dict(
        minval=0.0,
        maxval=1.0
    ),
    mdp_init=dict(
        p_slip=0.1,
        board=["#####",
               "#  @#",
               "#  X#",
               "#P  #",
               "#####"]
    )
)

# Initiate the MDP and the Q values
key = jrd.PRNGKey(args["seed"])
mdp = jaxdp.mdp.grid_world(**args["mdp_init"])
init_value = jrd.uniform(key, (mdp.action_size, mdp.state_size,),
                         dtype="float32", **args["value_init"])

# Define value update function and its initial state
update_state = None
update_fn = no_update_state(value_iteration.update.q)

# Train a policy
metrics, value, update_state = train(
    mdp=mdp,
    init_value=init_value,
    update_state=update_state,
    update_fn=update_fn,
    value_star=jnp.zeros_like(init_value),
    **args["train_loop"]
)

- ### Visualize

In [None]:
from itertools import chain
import pandas as pd
from plot_util import make_figure


# Make dataframe from the metrics
percentile = 25

index = pd.MultiIndex.from_product(
    [["gridworld"], ["VI"], list(range(100))],
    names=["ENV", "ALG", "STEP"])
columns = pd.MultiIndex.from_product(
    [metrics._fields, ["low", "med", "high"]],
    names=["METRIC", "PERCENTILE"])

data = []
for name in metrics._fields:
    values = getattr(metrics, name)
    if values.ndim == 1:
        values = values.reshape(1, -1)
    percentiles = jnp.nanpercentile(
        values, q=jnp.array([percentile, 50, 100 - percentile]), axis=0)
    data.append(percentiles)


df = pd.DataFrame(data=jnp.stack(list(chain(*data)), axis=1), columns=columns, index=index)

# Generate the figure
make_figure(df.loc["gridworld"])