<h2 style="text-align: center;">Value-Iteration Gridworld Example</h2>

- - -

In [None]:
from typing import NamedTuple
from functools import partial
import jax
import jax.numpy as jnp
import jax.random as jrd

from jaxdp.iterations.base import train
from jaxdp.iterations.iteration import q_iteration_update
from jaxdp.mdp.grid_world import grid_world


mdp = grid_world([
    "################################",
    "#   X   X   X   X   X   X   X @#",
    "# X   X   X   X   X   X   X   X#",
    "#   X   X   X   X   X   X   X  #",
    "# X   X   X   X   X   X   X   X#",
    "#   X   X   X   X   X   X   X  #",
    "# X   X   X   X   X   X   X   X#",
    "#   X   X   X   X   X   X   X  #",
    "# X   X   X   X   X   X   X   X#",
    "#   X   X   X   X   X   X   X  #",
    "# X   X   X   X   X   X   X   X#",
    "#P  X   X  X   X   X   X   X   #",
    "################################"
])


class Args(NamedTuple):
    n_iterations: int = 100
    max_episode_length: int = 200
    gamma: float = 0.999
    verbose: bool = True
    seed: int = 42


args = Args()


init_values = jnp.zeros((mdp.action_size, mdp.state_size))

trainer = jax.jit(partial(
    train,
    n_iterations=args.n_iterations,
    gamma=args.gamma,
    update_fn=q_iteration_update,
    verbose=False))



<h2 style="text-align: center;">Solve</h2>

- - -

In [None]:
metrics, values = trainer(
    mdp,
    init_values,
)

<h2 style="text-align: center;">Plot</h2>

- - -

In [None]:
from plotting import line_plot


data = [{
    "iteration": index,
    **{key: value.item() for key, value in zip(metrics._fields, values)}}
    for index, values in enumerate(zip(*(metrics._asdict().values())))
]


line_plot(
    data,
    "iteration",
    "Q-Iteration"
)