In [None]:
import jax
import jax.numpy as jnp
import diffrax

from dmpe.utils.signals import aprbs
import exciting_environments as excenvs
from dmpe.models import NeuralEulerODE
from dmpe.algorithms import excite_with_dmpe
from dmpe.utils.density_estimation import select_bandwidth

In [None]:
# setup env

env_params = dict(
    batch_size=1,
    tau=2e-2,
    max_force=10,
    static_params={
        "mu_p": 0.002,
        "mu_c": 0.5,
        "l": 0.5,
        "m_p": 0.1,
        "m_c": 1,
        "g": 9.81,
    },
    physical_constraints={
        "deflection": 2.4,
        "velocity": 8,
        "theta": jnp.pi,
        "omega": 8,
    },
    env_solver=diffrax.Tsit5(),
)
env = excenvs.make(
    env_id="CartPole-v0",
    batch_size=env_params["batch_size"],
    action_constraints={"force": env_params["max_force"]},
    physical_constraints=env_params["physical_constraints"],
    static_params=env_params["static_params"],
    solver=env_params["env_solver"],
    tau=env_params["tau"],
)

### Apply DMPE to the cart pole system:

In [None]:
# setup a featurization method for the angle information (map theta to sin, cos)

def featurize_theta_cart_pole(obs):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is
    transformed to sin(phi) and cos(phi) for comparison in the loss."""
    feat_obs = jnp.stack(
        [obs[..., 0], obs[..., 1], jnp.sin(obs[..., 2] * jnp.pi), jnp.cos(obs[..., 2] * jnp.pi), obs[..., 3]],
        axis=-1,
    )
    return feat_obs

In [None]:
# build a model for the cart pole system that wraps the angle to the other side when it exceeds the max value

class NeuralEulerODECartpole(NeuralEulerODE):
    """Cartpole specific model that deals with the periodic properties of the angle information."""

    def step(self, obs, action, tau):
        next_obs = super().step(obs, action, tau)
        next_obs = jnp.stack(
            [next_obs[..., 0], next_obs[..., 1], (((next_obs[..., 2] + 1) % 2) - 1), next_obs[..., 3]], axis=-1
        )
        return next_obs

In [None]:
# setup algorithm parameters

# overall algorithm parameters, mostyl concerning the excitation optimization problem
points_per_dim = 20
alg_params = dict(
    bandwidth=select_bandwidth(2, 5, points_per_dim, 0.1),
    n_prediction_steps=50,
    points_per_dim=points_per_dim,
    action_lr=1e-1,
    n_opt_steps=5,
    rho_obs=1,
    rho_act=1,
    penalty_order=2,
    clip_action=True,
    n_starts=5,
    reuse_proposed_actions=True,
)

# parameters for the training of the model
model_trainer_params = dict(
    start_learning=alg_params["n_prediction_steps"],
    training_batch_size=128,
    n_train_steps=10,
    sequence_length=alg_params["n_prediction_steps"],
    featurize=featurize_theta_cart_pole,
    model_lr=1e-4,
)

# parameters of the model itself
model_params = dict(obs_dim=env.physical_state_dim, action_dim=env.action_dim, width_size=128, depth=3, key=None)


# setup the whole experiment parameter dict
exp_params = dict(
    seed=int(42),
    n_time_steps=15_000,
    model_class=NeuralEulerODECartpole,
    env_params=env_params,
    alg_params=alg_params,
    model_trainer_params=model_trainer_params,
    model_params=model_params,
)

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=exp_params["seed"])
data_key, model_key, loader_key, expl_key, key = jax.random.split(key, 5)
exp_params["model_params"]["key"] = model_key

In [None]:
# initial guess for U_k
proposed_actions = aprbs(exp_params["alg_params"]["n_prediction_steps"], env.batch_size, 1, 10, data_key)[0]

In [None]:
# run the algorithm

observations, actions, model, density_estimate, losses, proposed_actions = excite_with_dmpe(
    env, exp_params, proposed_actions, loader_key, expl_key, plot_every=200
)