- - -

<h2 style="text-align: center;">Planning Algorithms in Garnet MDPs</h2>

- - -

In [5]:
from ast import Call
from typing import Any, Callable, NamedTuple, Optional, Tuple
from functools import partial
import jax
import jax.numpy as jnp
import jax.random as jrd
from jaxdp.mdp import grid_world
from jaxtyping import ArrayLike as KeyType

from jaxdp.mdp.mdp import MDP
from jaxdp.planning.runner import PlanningMetrics, train, no_update_state
from jaxdp.planning.algorithms import (anderson_vi,
                                       nesterov_vi,
                                       value_iteration,
                                       policy_iteration)
from jaxdp.mdp.garnet import garnet_mdp
from jaxdp.typehints import QType


board = [
    "#############",
    "#           #",
    "#           #",
    "#           #",
    "#PXXXXXXXXX@#",
    "#############"
]


class Args(NamedTuple):
    n_iterations: int = 1000
    max_episode_length: int = 200
    gamma: float = 0.99
    verbose: bool = True
    seed: int = 42
    n_env: int = 25


class Algorithm(NamedTuple):
    init_fn: Callable
    update_fn: Callable
    args: Args


def garnet_fn(key):
    return garnet_mdp(state_size=50,
                      action_size=2,
                      branch_size=10,
                      validate=False,
                      key=key)


def gridworld_fn(key):
    return grid_world(board, p_slip=0.25)


def trainer_for_one_mdp(key: KeyType,
                        mdp: MDP,
                        value_star: QType,
                        args: Args,
                        algorithm: Algorithm,
                        ) -> Tuple[PlanningMetrics, QType]:
    mdp_key, init_state_key = jrd.split(key, 2)
    # mdp = mdp_fn(mdp_key)
    init_value = jnp.zeros((mdp.action_size, mdp.state_size))
    return train(mdp,
                 init_value,
                 update_state=algorithm.init_fn(mdp, init_value, args.gamma, init_state_key),
                 value_star=value_star,
                 n_iterations=args.n_iterations,
                 gamma=args.gamma,
                 update_fn=algorithm.update_fn,
                 verbose=False)


def batch_train(algo: Algorithm, value_star: Optional[QType], mdp_fn: Callable):
    args = algo.args
    keys = jrd.split(jrd.PRNGKey(args.seed), args.n_env)
    mdps = jax.vmap(mdp_fn)(keys)
    if value_star is None:
        mdp = mdp_fn(keys[0])
        value_star = jnp.zeros((mdp.action_size, mdp.state_size))
    metrics, value, _ = jax.jit(
        jax.vmap(
            partial(
                trainer_for_one_mdp,
                value_star=value_star,
                args=args,
                algorithm=algo),
            (0, 0)
        )
    )(keys, mdps)
    return metrics, value


algorithms = {
    "PI": Algorithm(
        update_fn=no_update_state(policy_iteration.update.q),
        init_fn=lambda *_: None,
        args=Args(n_iterations=10)
    ),
    "VI": Algorithm(
        update_fn=no_update_state(value_iteration.update.q),
        init_fn=lambda *_: None,
        args=Args(n_iterations=1250)

    ),
    "Nesterov-VI": Algorithm(
        update_fn=nesterov_vi.update.q,
        init_fn=lambda mdp, init_val, *_: init_val,
        args=Args(n_iterations=250)
    ),
    "Anderson-VI": Algorithm(
        update_fn=anderson_vi.update.q,
        init_fn=lambda mdp, init_val, *_: init_val,
        args=Args(n_iterations=250)
    ),
}


results = {}
mdp_fn = garnet_fn
_, value_star = batch_train(algorithms["PI"], None, mdp_fn)
for algo_name, algo in algorithms.items():
    metrics, _ = batch_train(algo, value_star, mdp_fn)
    results[algo_name] = metrics

- - -

<h2 style="text-align: center;">
Make Figure
</h2>

- - -

In [6]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from plot_util import make_data_table, add_figure_data


data = make_data_table(results, percentile=25)
error_info = {
    "bellman_error": {
        "column": 1,
        "y_axes_layout": {
            "type": "log",
            "exponentformat": "power",
            "dtick": 1,
        },
    }, "policy_evaluation": {
        "column": 2,
        "y_axes_layout": {
            "type": "linear",
        },
    }
}

fig = make_subplots(rows=1,
                    cols=len(error_info),
                    horizontal_spacing=0.15,
                    vertical_spacing=0.12,
                    shared_yaxes=False)

algorithm_colors = {
    "VI": "#2c7bb6",
    "PI": "#ffa600",
    "Nesterov-VI": "#58508d",
    "Anderson-VI": "#bc5090",
    "R1-VI": "#ff6361",
    "AA-I-S-m": "#9CA986",
}

algorithm_legend_use = {name: False for name in algorithm_colors.keys()}

for item in data:
    if item["metric"] in error_info.keys():
        add_figure_data(item,
                        algorithm_colors[item["algo"]],
                        fig,
                        row=1,
                        col=error_info[item["metric"]]["column"])
        fig.data[-1]["showlegend"] = not algorithm_legend_use[item["algo"]]
        algorithm_legend_use[item["algo"]] = True
        fig.data[-1]["line"]["width"] = 4
        fig.update_yaxes(title=item["metric"].replace("_", " ").capitalize(),
                         row=1,
                         col=error_info[item["metric"]]["column"],
                         **error_info[item["metric"]]["y_axes_layout"])

fig.update_layout(template="plotly")
fig.update_yaxes(
    showline=True,
    linecolor="gray",
    linewidth=2,
    mirror=True,
    gridcolor="white",
    gridwidth=3,
)
fig.update_xaxes(type="log",
                 exponentformat="power",
                 showline=True,
                 linecolor="gray",
                 linewidth=2,
                 mirror=True,
                 title="Iteration",
                 gridcolor="white",
                 gridwidth=3)
fig.update_layout(font=dict(size=15,))
fig.update_layout(
    legend={
        "x": 0.2,
        "y": 1.15,
        "font": {"size": 20},
        "orientation": "h",
        "visible": True
    },
    font=dict(size=16, color="black"),
    plot_bgcolor="#EDEDF3",
    width=550 * len(error_info),
    height=500
)