In [3]:
from typing import Optional, Callable, Tuple, Dict, Union, Any, NewType, Sequence
from flax import struct
from typing import Tuple
from functools import partial
import jax.numpy as jnp
import jax.random as jrd
import jax
from nestedtuple import nestedtuple
from gymnax.environments.environment import Environment
from jax.typing import ArrayLike as KeyType
import distrax

import jaxdp
from jaxdp.learning.algorithms import q_learning, StepSample
from jaxdp.learning.runner import train, reducer
from jaxdp.mdp.sampler import SamplerState, RolloutSample, sample_gymnax_rollout
from jaxdp.typehints import F, QType

from mjnax.pendulum import MjModelType, MjStateType, DiscretizedPendulum


@nestedtuple
class Arg:
    seed: int = 42                     # Initial seeds
    n_seed: int = 10                  # Number of seeds to execute the same algorithm
    n_env: int = 4                     # Number of parallel environments for sampling

    class policy_fn:
        epsilon: float = 0.15          # Epsilon-greedy parameter

    class update_fn:
        alpha: float = 0.10            # Step size (a.k.a learning rate)

    class train_loop:
        gamma: float = 0.99            # Discount factor
        n_steps: int = 1000            # Number of steps
        eval_period: int = 50          # Evaluation period (in terms of <n_steps>)

    class sampler_init:
        queue_size: int = 50           # Queue size of the sampler for the metrics

    class sampler_fn:
        max_episode_len: int = 125  # Maximum length of an episode allowed by the sampler
        rollout_len: int = 10          # Length of a rollout

    class value_init:
        minval: float = -1.0            # Minimum value of the uniform distribution
        maxval: float = 1.0            # Maxiumum value of the uniform distribution


@struct.dataclass
class EGreedyPolicyState():
    value: QType
    epsilon: float


class EGreedyPolicy():

    def sample(self, key, state: EGreedyPolicyState, obs: F["S"]):
        pi = jaxdp.e_greedy_policy.q(state.value, state.epsilon)
        policy_p = jnp.einsum("as,s->a", pi, obs)
        act = jaxdp.sample_from(policy_p, key)
        return act, state

    @staticmethod
    def reset(key: KeyType, state: EGreedyPolicyState) -> EGreedyPolicyState:
        return state


@struct.dataclass
class Metric():
    td_error: F["N"]


@struct.dataclass
class RunState():
    key: jax.Array
    sampler: SamplerState
    env_state: MjStateType
    env_model: MjModelType
    pi: EGreedyPolicyState
    metric: Metric


@struct.dataclass
class RunStatic():
    env: Environment
    pi: EGreedyPolicy
    logger: Callable
    sampler: Callable
    updater: Callable


def reset_metric(size: int) -> Metric:
    return Metric(
        *(jnp.full(size, jnp.nan)
          for _ in Metric.__dataclass_fields__)
    )


# @partial(jax.jit, static_argnums=[1, 2])
def sample_batch_rollout(state: RunState, static: RunStatic, arg: Arg
                         ) -> Tuple[RolloutSample, RunState]:
    key, sample_key = jrd.split(state.key)
    # jax.debug.print("{x}", x=(state.sampler.last_obs.shape, state.pi.value.shape))
    (
        _, rollout, sampler_state, env_state, policy_state
    ) = jax.vmap(partial(sample_gymnax_rollout,
                         env=static.env,
                         policy=static.pi,
                         rollout_length=arg.sampler_fn.rollout_len,
                         max_episode_length=arg.sampler_fn.max_episode_len),
                 in_axes=(0, 0, 0, None, None),
                 out_axes=(0, 0, 0, 0, EGreedyPolicyState(value=None, epsilon=None)))(
        jrd.split(sample_key, arg.n_env),
        state.sampler,
        state.env_state,
        state.pi,
        state.env_model
    )
    state = state.replace(
        key=key,
        sampler=sampler_state,
        env_state=env_state,
        pi=policy_state
    )
    return rollout, state


# @partial(jax.jit, static_argnums=[2, 3])
def update_ql(rollout: RolloutSample,
              state: RunState,
              static: RunStatic,
              arg: Arg
              ) -> Tuple[RunState, Dict[str, Any]]:

    def batch_update_fn(rollout: RolloutSample, pi_state: EGreedyPolicyState) -> EGreedyPolicyState:
        batch_step_fn = jax.vmap(jax.vmap(q_learning.asynchronous.step,
                                 (0, None, None)), (0, None, None))
        scalar_target_values = batch_step_fn(rollout, pi_state.value, arg.train_loop.gamma)
        target_value = reducer.every_visit(rollout, scalar_target_values)
        updated_value = q_learning.update(
            pi_state.value, target_value, alpha=arg.update_fn.alpha
        )
        return pi_state.replace(value=updated_value), jnp.abs(scalar_target_values).mean()


    updated_pi, avg_tde = batch_update_fn(
        StepSample(
            state=rollout.obs,
            next_state=rollout.next_obs,
            action=rollout.action,
            reward=rollout.reward,
            terminal=rollout.terminal,
            timeout=rollout.timeout,
        ),
        state.pi)
    return state.replace(pi=updated_pi), {"td_error": avg_tde}


@partial(jax.jit, static_argnums=[2, 3])
def train_step(step: int, state: RunState, static: RunStatic, arg: Arg):
    rollout, state = static.sampler(state, static, arg)

    state, losses = static.updater(rollout, state, static, arg)

    metric = state.metric
    metric = metric.replace(**{
        loss_name: getattr(metric, loss_name).at[step % arg.train_loop.eval_period].set(loss_val)
        for loss_name, loss_val in losses.items()
    })
    # jax.debug.print("td {x}", x=metric)
    # jax.debug.print("index {x}", x=step % arg.train_loop.eval_period)

    is_log_step = (step % arg.train_loop.eval_period) == (arg.train_loop.eval_period - 1)

    def _log(_is_log_step, state, *args):
        jax.debug.callback(
            lambda _is_log_step, *_args: static.logger(*_args) if _is_log_step else None,
            _is_log_step, state, *args)
        return {"metric": reset_metric(arg.train_loop.eval_period),
                "sampler": state.sampler.refresh_queues()}

    return state.replace(**jax.lax.cond(
        is_log_step,
        _log,
        lambda _, state, __, metric, : {"metric": metric, "sampler": state.sampler},
        is_log_step, state, step, metric
    ))

In [None]:
@partial(jax.jit, static_argnums=[])
def eval_policy():
    """ TODO: Implement one step eval episode """
    pass


def logger(state: RunState, step: int, metric: Metric) -> None:
    rewards = state.sampler.episode_reward_queue
    lengths = state.sampler.episode_length_queue
    values = {
        "mean_behavior_reward": jnp.nanmean(rewards),
        "mean_behavior_length": jnp.nanmean(lengths),
        "std_behavior_reward": jnp.nanstd(rewards),
        "std_behavior_length": jnp.nanstd(lengths),
        "mean_td_error": jnp.nanmean(metric.td_error),
        # "debug": state.behavior_pi.param["weight"].std()
    }
    title = "Training Metrics - Step"
    print("=" * 43)
    print(f"{title:^40} {step + 1}")
    print("-" * 43)
    for name, val in values.items():
        formatted_name = name.replace("_", " ").title()
        print(f"{formatted_name:<25} | {val:>15.4f}")


arg = Arg(
    n_seed=1,
    n_env=16,
    policy_fn=Arg.policy_fn(epsilon=0.25),
    sampler_fn=Arg.sampler_fn(rollout_len=32),
    train_loop=Arg.train_loop(eval_period=50, n_steps=1000, gamma=0.99),
    update_fn=
)

env = DiscretizedPendulum()
env_model = env.default_params
key, env_reset_key, pi_reset_key = jrd.split(jrd.PRNGKey(42), 3)
obs, env_state = jax.vmap(env.reset)(jrd.split(env_reset_key, arg.n_env))

sampler_state = jax.vmap(SamplerState.initialize_rollout_state, in_axes=(0, None)
                         )(obs, arg.sampler_init.queue_size)
pi_state = EGreedyPolicyState(
    value=jrd.uniform(pi_reset_key, (env.num_actions, env.num_states,),
                      dtype="float32", **arg.value_init._asdict()),
    epsilon=arg.policy_fn.epsilon)
policy = EGreedyPolicy()

run_state = RunState(
    key,
    sampler_state,
    env_state,
    env_model,
    pi_state,
    reset_metric(arg.train_loop.eval_period)
)

run_static = RunStatic(
    env,
    policy,
    logger,
    sample_batch_rollout,
    update_ql
)

final_state = jax.lax.fori_loop(
    0,
    arg.train_loop.n_steps,
    partial(train_step, static=run_static, arg=arg),
    run_state)

  return lax_numpy.astype(self, dtype, copy=copy, device=device)


        Training Metrics - Step          50
-------------------------------------------
Mean Behavior Reward      |          0.0000
Mean Behavior Length      |        125.0000
Std Behavior Reward       |          0.0000
Std Behavior Length       |          0.0000
Mean Td Error             |          0.0884
        Training Metrics - Step          100
-------------------------------------------
Mean Behavior Reward      |          0.0000
Mean Behavior Length      |        125.0000
Std Behavior Reward       |          0.0000
Std Behavior Length       |          0.0000
Mean Td Error             |          0.0135
        Training Metrics - Step          150
-------------------------------------------
Mean Behavior Reward      |          0.0087
Mean Behavior Length      |        125.0000
Std Behavior Reward       |          0.1248
Std Behavior Length       |          0.0000
Mean Td Error             |          0.0238
        Training Metrics - Step          200
-----------------------------

In [5]:
import jax.numpy as jnp
import jax


def jax_static_method(method):

    def static_method(self, *args, **kwargs):
        return method(*args, **kwargs)
    
    return static_method

class X():

    # @jax_static_method
    def fn(x, y):
        return x * y
    
X.fn(jnp.ones(3), 4)

Array([4., 4., 4.], dtype=float32)

In [4]:
from typing import NamedTuple


class X(NamedTuple):
    x: int = 4
    y: float = 6.


X()._asdict()

{'x': 4, 'y': 6.0}