<h2 style="text-align: center;">Q-Learning Gridworld Example</h2>

- - -

In [1]:
from typing import NamedTuple
from functools import partial
import jax
import jax.numpy as jnp
import jax.random as jrd

import jaxdp
from jaxdp.mdp import MDP
from jaxdp.mdp.grid_world import grid_world
from jaxdp.mdp.delayed_reward import delayed_reward_mdp
from jaxdp.mdp.sequential import sequential_mdp
from jaxdp.learning.learning import train
from jaxdp.learning.q_learning import q_learning_update
from jaxdp.learning.sampler import SamplerState, rollout_sample


mdp = grid_world([
    "################################",
    "#                           X@ #",
    "#                            X #",
    "#                              #",
    "#                              #",
    "#                              #",
    "#                              #",
    "#                              #",
    "#                              #",
    "#                              #",
    "#                              #",
    "#                              #",
    "#P                             #",
    "################################"
])


class Args(NamedTuple):
    batch_size: int = 100
    queue_size: int = 10
    rollout_len: int = 20
    n_iterations: int = 1000
    max_episode_length: int = 200
    epsilon: float = 0.25
    gamma: float = 0.99
    alpha: float = 1.0
    eval_steps: int = 10
    verbose: bool = True
    seed: int = 42


args = Args()

key = jrd.PRNGKey(args.seed)
key, train_key, init_value_key, init_state_key = jrd.split(key, 4)

init_value = jrd.normal(init_value_key, (mdp.action_size, mdp.state_size))
sampler_state = SamplerState.initialize_rollout_state(
    mdp,
    batch_size=args.batch_size,
    queue_size=args.queue_size,
    init_state_key=init_state_key)

trainer = jax.jit(partial(
    train,
    update_fn=partial(q_learning_update, gamma=args.gamma, alpha=args.alpha),
    sample_fn=jax.vmap(partial(rollout_sample,
                               max_episode_length=args.max_episode_length,
                               rollout_len=args.rollout_len),
                       in_axes=(None, 0, None, 0)),
    policy_fn=lambda v, _: jaxdp.e_greedy_policy(v, args.epsilon),
    n_steps=args.n_iterations,
    eval_steps=args.eval_steps,
    verbose=args.verbose)
)

<h2 style="text-align: center;">Train</h2>

- - -

In [2]:
metrics, value = trainer(
    sampler_state,
    init_value,
    mdp,
    train_key,
)

<h2 style="text-align: center;">Plot</h2>

- - -

In [None]:
from plotting import line_plot


data = [{
    "step": index * args.eval_steps,
    "avg episode reward": avg_eps_rew.item(),
    "std episode reward": std_eps_rew.item(),
    "avg episode length": avg_eps_len.item(),
    "std episode length": std_eps_len.item(),
    "value norm": val_norm.item()
} for index, (avg_eps_rew, std_eps_rew, avg_eps_len, std_eps_len, val_norm)
    in enumerate(zip(metrics.avg_episode_rewards,
                     metrics.std_episode_rewards,
                     metrics.avg_episode_lengths,
                     metrics.std_episode_lengths,
                     metrics.max_value_diff))]


line_plot(
    data,
    "step",
    "Q-Learning"
)