
## Learning Algorithms in Gridworld</h2>

- ### Q-learning with synchronous sampling

In [1]:
from functools import partial
import jax.numpy as jnp
import jax.random as jrd
import jaxdp

from jaxdp.learning.algorithms import q_learning
from jaxdp.learning.runner import train, no_learner_state, no_step_index


# Define the arguments
args = dict(
    seed=42,
    update_fn=dict(
        alpha=0.1
    ),
    train_loop=dict(
        gamma=0.99,
        n_steps=100,
        eval_period=10,
    ),
    value_init=dict(
        minval=0.0,
        maxval=1.0
    ),
    mdp_init=dict(
        p_slip=0.1,
        board=["#####",
               "#  @#",
               "#  X#",
               "#P  #",
               "#####"]
    )
)


# Initiate the MDP and the Q values
train_key, value_key = jrd.split(jrd.PRNGKey(args["seed"]), 2)
mdp = jaxdp.mdp.grid_world(**args["mdp_init"])
init_value = jrd.uniform(value_key, (mdp.action_size, mdp.state_size,),
                         dtype="float32", **args["value_init"])

# Define learner function and its initial state
learner_state = None
update_fn = partial(no_learner_state(no_step_index(q_learning.update.sync)),
                    **args["update_fn"])

# Train a policy
metrics, value, learner_state = train.sync(
    init_value,
    mdp,
    value_star=jnp.zeros_like(init_value),
    key=train_key,
    learner_state=learner_state,
    policy_fn=lambda q, i: jaxdp.greedy_policy.q(q),
    update_fn=update_fn,
    **args["train_loop"]
)