- - -

<h2 style="text-align: center;">Planning Algorithms in Gridworld</h2>

- - -

In [12]:
import jax.numpy as jnp
import jax.random as jrd
import jaxdp
from jaxdp.planning.algorithms import value_iteration
from jaxdp.planning.runner import train, no_update_state


# Define the arguments
args = dict(
    seed=42,
    train_loop=dict(
        gamma=0.99,
        n_iterations=100,
        verbose=False,
    ),
    value_init=dict(
        minval=0.0,
        maxval=1.0
    ),
    mdp_init=dict(
        p_slip=0.1,
        board=["#####",
               "#  @#",
               "#  X#",
               "#P  #",
               "#####"]
    )
)

# Initiate the MDP and the Q values
key = jrd.PRNGKey(args["seed"])
mdp = jaxdp.mdp.grid_world(**args["mdp_init"])
init_value = jrd.uniform(key, (mdp.action_size, mdp.state_size,),
                         dtype="float32", **args["value_init"])

# Initiate the state of the update_fn
update_state = jnp.zeros_like(init_value)
update_fn = no_update_state(value_iteration.update.q)

# Train a policy
metrics, value, update_state = train(
    mdp=mdp,
    init_value=init_value,
    update_state=update_state,
    update_fn=update_fn,
    value_star=jnp.zeros_like(init_value),
    **args["train_loop"]
)