
## Learning Algorithms in Gridworld</h2>

- ### Q-learning with synchronous sampling

    > Synchronous sampling samples a state transition for each (s, a) pair. 

In [None]:
from functools import partial
import jax.numpy as jnp
import jax.random as jrd
import jax
import jaxdp

from jaxdp.learning.algorithms import q_learning
from jaxdp.learning.runner import train, no_learner_state, no_step_index


# Define the arguments
args = dict(
    seed=42,
    n_seeds=10,
    update_fn=dict(
        alpha=0.1
    ),
    train_loop=dict(
        gamma=0.99,
        n_steps=100,
        eval_period=10,
    ),
    value_init=dict(
        minval=0.0,
        maxval=1.0
    ),
    mdp_init=dict(
        p_slip=0.15,
        board=["#####",
               "#  @#",
               "#  X#",
               "#P  #",
               "#####"]
    )
)


# Initiate the MDP and the Q values
_train_key, value_key = jrd.split(jrd.PRNGKey(args["seed"]), 2)
train_keys = jrd.split(_train_key, args["n_seeds"])
mdp = jaxdp.mdp.grid_world(**args["mdp_init"])
init_value = jrd.uniform(value_key, (mdp.action_size, mdp.state_size,),
                         dtype="float32", **args["value_init"])

# Define learner function
update_fn = partial(no_learner_state(no_step_index(q_learning.update.sync)),
                    **args["update_fn"])

# Train a policy for 10 different seeds (After JIT compiling the "batch" train function)
jitted_batch_train = jax.jit(
    jax.vmap(
        partial(
            train.sync,
            learner_state=None,
            value_star=jnp.full_like(init_value, jnp.nan),
            target_policy_fn=lambda q, i: jaxdp.greedy_policy.q(q),
            update_fn=update_fn,
            **args["train_loop"]
        ), in_axes=(None, None, 0))
)
metrics, value, learner_state = jitted_batch_train(
    init_value,
    mdp,
    train_keys
)

### Visualize

In [None]:
from itertools import chain
import pandas as pd
from plot_util import make_figure


# Make dataframe from the metrics
percentile = 25

index = pd.MultiIndex.from_product(
    [["gridworld"], ["q-learning"], list(range(args["train_loop"]["n_steps"]))],
    names=["ENV", "ALG", "STEP"])
columns = pd.MultiIndex.from_product(
    [metrics._fields, ["low", "med", "high"]],
    names=["METRIC", "PERCENTILE"])

data = []
for name in metrics._fields:
    values = getattr(metrics, name)
    if values.ndim == 1:
        values = values.reshape(1, -1)
    percentiles = jnp.nanpercentile(
        values, q=jnp.array([percentile, 50, 100 - percentile]), axis=0)
    data.append(percentiles)


df = pd.DataFrame(data=jnp.stack(list(chain(*data)), axis=1), columns=columns, index=index)

# Generate the figure
make_figure(df.loc["gridworld"])

- ### Q-learning with asynchronous sampling

    > Asynchronous sampling samples a state transition by following a behavior policy. 

In [3]:
from functools import partial, reduce
import jax.numpy as jnp
import jax.random as jrd
import jax
import jaxdp

from jaxdp.learning.algorithms import q_learning
from jaxdp.learning.runner import train, no_learner_state, no_step_index, reducer
from jaxdp.learning.sampler import SamplerState, rollout_sample

# Define the arguments
args = dict(
    seed=42,                   # Initial seeds
    n_seeds=10,                # Number of seeds to execute the same algorithm
    n_env=4,                   # Number of parallel environments for sampling
    policy_fn=dict(
        epsilon=0.15           # Epsilon-greedy parameter
    ),
    update_fn=dict(
        alpha=0.10             # Step size (a.k.a learning rate)
    ),
    train_loop=dict(
        gamma=0.99,            # Discount factor
        n_steps=1000,          # Number of steps
        eval_period=50,        # Evaluation period (in terms of <n_steps>)
    ),
    sampler_init=dict(
        queue_size=50,         # Queue size of the sampler for the metrics
    ),
    sampler_fn=dict(
        max_episode_length=15,  # Maximum length of an episode allowed by the sampler
        rollout_len=10,        # Length of a rollout
    ),
    value_init=dict(
        minval=0.0,            # Minimum value of the uniform distribution
        maxval=1.0             # Maxiumum value of the uniform distribution
    ),
    mdp_init=dict(
        p_slip=0.15,           # Probability of slipping
        board=["#####",        # The board of the gridworld
               "#  @#",
               "#  X#",
               "#P  #",
               "#####"]
    )
)


# Initiate the MDP and the Q values
train_key, sampler_key, value_key = jrd.split(jrd.PRNGKey(args["seed"]), 3)
train_keys = jrd.split(train_key, args["n_seeds"])
mdp = jaxdp.mdp.grid_world(**args["mdp_init"])
init_value = jrd.uniform(value_key, (mdp.action_size, mdp.state_size,),
                         dtype="float32", **args["value_init"])

# Define learner function
# For multiple sampling environments
batch_step_fn = jax.vmap(jax.vmap(q_learning.step, (0, None, None)), (0, None, None))
def batch_update_fn(index, rollouts, value, learner_state, gamma):
    scalar_target_values = batch_step_fn(rollouts, value, gamma)
    target_value = reducer.every_visit(rollouts, scalar_target_values)
    return q_learning.update.asynchronous(
        value, target_value, alpha=args["update_fn"]["alpha"]
    ), None

# Initiate sampler
# For single sampling environment
sampler_state = SamplerState.initialize_rollout_state(
    mdp,
    **args["sampler_init"],
    init_state_key=sampler_key)
# For multiple sampling environments
n_env_sampler_state = jax.vmap(SamplerState.initialize_rollout_state, (None, None, 0)
                               )(mdp,
                                 args["sampler_init"]["queue_size"],
                                 jrd.split(sampler_key, args["n_env"])
                                 )
n_env_sampler = jax.vmap(partial(rollout_sample, **args["sampler_fn"]), (0, None, 0, None))


# Train a policy for 10 different seeds (After JIT compiling the "batch" train function)
# In each step, collect a rollout from <n_env> many environments.
# Prepare the batch (vectorized) train function for <n_seeds> many runs
jitted_batch_train = jax.jit(
    jax.vmap(
        partial(
            train.asynchronous,
            learner_state=None,
            value_star=jnp.full_like(init_value, jnp.nan),
            behavior_policy_fn=lambda q, i: jaxdp.e_greedy_policy.q(q, **args["policy_fn"]),
            target_policy_fn=lambda q, i: jaxdp.greedy_policy.q(q),
            update_fn=batch_update_fn,
            sample_fn=lambda key, *_args: n_env_sampler(jrd.split(key, args["n_env"]), *_args),
            verbose=False,
            **args["train_loop"]
        ), in_axes=(None, None, None, 0))
)
# Run the jitted and vectorized train function
metrics, value, learner_state, sampler_state = jitted_batch_train(
    n_env_sampler_state,
    init_value,
    mdp,
    train_keys
)

In [4]:
value[:, 0, 0]

Array([0.9662067 , 0.96218336, 0.9653484 , 0.9641941 , 0.96476895,
       0.9622605 , 0.9632278 , 0.9635721 , 0.9668641 , 0.96149933],      dtype=float32)

### Visualize

In [5]:
from itertools import chain
import pandas as pd
from plot_util import make_figure


# Make dataframe from the metrics
percentile = 25

index = pd.MultiIndex.from_product(
    [["gridworld"], ["q-learning"],
     list(range(0, args["train_loop"]["n_steps"], args["train_loop"]["eval_period"]))],
    names=["ENV", "ALG", "STEP"])
columns = pd.MultiIndex.from_product(
    [metrics._fields, ["low", "med", "high"]],
    names=["METRIC", "PERCENTILE"])

data = []
for name in metrics._fields:
    values = getattr(metrics, name)
    if values.ndim == 1:
        values = values.reshape(1, -1)
    percentiles = jnp.nanpercentile(
        values, q=jnp.array([percentile, 50, 100 - percentile]), axis=0)
    data.append(percentiles)


df = pd.DataFrame(data=jnp.stack(list(chain(*data)), axis=1), columns=columns, index=index)

# Generate the figure
make_figure(df.loc["gridworld"])