In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

from functools import partial
import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['text.usetex'] = True
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}"
import plotly.express as px


In [None]:
import jax
import jax.numpy as jnp
# jax.config.update("jax_enable_x64", True)
gpus = jax.devices()
print(gpus)

jax.config.update("jax_default_device", gpus[0])

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

import exciting_exciting_systems
from exciting_exciting_systems.models import NeuralEulerODEPendulum, NeuralODEPendulum, NeuralEulerODECartpole
from exciting_exciting_systems.models.model_utils import simulate_ahead_with_env
from exciting_exciting_systems.models.model_training import ModelTrainer
from exciting_exciting_systems.excitation import loss_function, Exciter

from exciting_exciting_systems.utils.density_estimation import (
    select_bandwidth, update_density_estimate_single_observation, update_density_estimate_multiple_observations, DensityEstimate
)
from exciting_exciting_systems.utils.signals import aprbs
from exciting_exciting_systems.evaluation.plotting_utils import (
    plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction, plot_model_performance, plot_2d_kde_as_contourf
)

from exciting_exciting_systems.models.model_utils import ModelEnvWrapperCartPole

---

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=2)

data_key, model_key, loader_key, expl_key, key = jax.random.split(key, 5)
data_rng = PRNGSequence(data_key)

In [None]:
batch_size = 1
tau = 2e-2

# cartpole parameters: "Evaluation of Policy Gradient Methods and Variants on the Cart-Pole Benchmark" Riedmiller2007
# comparable: "Evaluation of Policy Gradient Methods and Variants on the Cart-Pole Benchmark" Nagendra2017

env_params = dict(batch_size=1, tau=2e-2, max_force=10, env_solver=diffrax.Tsit5()) # diffrax.Euler())
env = excenvs.make(
    env_id="CartPole-v0",
    batch_size=env_params["batch_size"],
    action_constraints={"force": env_params["max_force"]},
    static_params={# typical values? 10.1109/TSMC.1983.6313077
        "mu_p": 0.002, # ?0.000002, 2e-6
        "mu_c": 0.5, # ?0.0005, 5e-4
        "l": 0.5,
        "m_p": 0.1,
        "m_c": 1,
        "g": 9.81,
    },
    physical_constraints={
        "deflection": 2.4,
        "velocity": 8,
        "theta": jnp.pi,
        "omega": 8,
    },
    solver=env_params["env_solver"],
    tau=env_params["tau"],
)

### Test simulation:

- starting from the intial state/obs ($\mathbf{x}_0$ / $\mathbf{y}_0$)
- apply $N = 999$ actions $\mathbf{u}_0 \dots \mathbf{u}_N$ (**here**: random APRBS actions)
- which results in the state trajectory $\mathbf{x}_0 ... \mathbf{x}_N+1$ with $N+1 = 1000$ elements

In [None]:
obs, state = env.reset()
obs = obs[0]

n_steps = 999

# actions = aprbs(n_steps, batch_size, 1, 10, next(data_rng))[0]


actions = jnp.ones((1000, 1))
actions = actions.at[10:].set(0)

In [None]:
observations, _ = simulate_ahead_with_env(env, obs, state, actions)

print("actions.shape:", actions.shape)
print("observations.shape:", observations.shape)

print(" \n One of the trajectories:")
fig, axs = plot_sequence(
    observations=observations,
    actions=actions,
    tau=tau,
    obs_labels=[r"$p_x$", r"$v_x$", r"$\theta$", r"$\omega$"],
    action_labels=[r"$F$"],
);
plt.show()

---

In [None]:
from exciting_exciting_systems.algorithms import excite_with_dmpe

In [None]:
def featurize_theta(obs):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is 
    transformed to sin(phi) and cos(phi) for comparison in the loss."""
    feat_obs = jnp.stack([obs[..., 0], obs[..., 1], jnp.sin(obs[..., 2] * jnp.pi), jnp.cos(obs[..., 2] * jnp.pi), obs[..., 3]], axis=-1)
    return feat_obs

In [None]:
seed = 55551212515

points_per_dim = 15

alg_params = dict(
    bandwidth=select_bandwidth(2, 5, points_per_dim, 0.1),
    n_prediction_steps=50,
    points_per_dim=points_per_dim,
    action_lr=1e-2,
    n_opt_steps=10,
    rho_obs=1,
    rho_act=1,
    penalty_order=2,
    clip_action=True,
    n_starts=5,
    reuse_proposed_actions=True,
)

exp_params = dict(
    seed=None,
    n_timesteps=15_000,
    model_class=None,
    env_params=env_params,
    alg_params=alg_params,
    model_trainer_params=None,
    model_params=None,
    model_env_wrapper=ModelEnvWrapperCartPole,
)

In [None]:
key = jax.random.PRNGKey(seed=seed)
data_key, _, _, expl_key, key = jax.random.split(key, 5)
data_rng = PRNGSequence(data_key)

# initial guess
proposed_actions = aprbs(exp_params["alg_params"]["n_prediction_steps"], env.batch_size, 1, 10, next(data_rng))[0]

# run excitation algorithm
observations, actions, model, density_estimate, losses, proposed_actions = excite_with_dmpe(
    env,
    exp_params,
    proposed_actions,
    None,
    expl_key,
    plot_every=500
)

In [None]:
raise

In [None]:
seed = 4444

points_per_dim = 15

alg_params = dict(
    bandwidth=select_bandwidth(2, 5, points_per_dim, 0.1),
    n_prediction_steps=50,
    points_per_dim=points_per_dim,
    action_lr=1e-2,
    n_opt_steps=10,
    rho_obs=1,
    rho_act=1,
    penalty_order=2,
    clip_action=True,
    n_starts=5,
    reuse_proposed_actions=True,
)

model_trainer_params = dict(
    start_learning=alg_params["n_prediction_steps"],
    training_batch_size=128,
    n_train_steps=5,
    sequence_length=alg_params["n_prediction_steps"],
    featurize=featurize_theta,
    model_lr=1e-4,
)
model_params = dict(obs_dim=env.physical_state_dim, action_dim=env.action_dim, width_size=128, depth=3, key=None)

exp_params = dict(
    seed=None,
    n_timesteps=15_000,
    model_class=NeuralEulerODECartpole,
    env_params=env_params,
    alg_params=alg_params,
    model_trainer_params=model_trainer_params,
    model_params=model_params,
)

In [None]:
key = jax.random.PRNGKey(seed=seed)
data_key, model_key, loader_key, expl_key, key = jax.random.split(key, 5)
data_rng = PRNGSequence(data_key)

exp_params["seed"] = int(seed)
exp_params["model_params"]["key"] = model_key

# initial guess
proposed_actions = aprbs(exp_params["alg_params"]["n_prediction_steps"], env.batch_size, 1, 10, next(data_rng))[0]

# run excitation algorithm
observations, actions, model, density_estimate, losses, proposed_actions = excite_with_dmpe(
    env, exp_params, proposed_actions, loader_key, expl_key, plot_every=500
)

In [None]:
bw = select_bandwidth(2, 5, points_per_dim, 0.1)

In [None]:
bw

In [None]:
from exciting_exciting_systems.utils.density_estimation import gaussian_kernel

In [None]:
gaussian_kernel(x=jnp.array([2 / 30, 2 / 30, 2 / 30, 2 / 30, 2 / 30,]), bandwidth=bw) 

In [None]:
bw**5 * jnp.power(2 * jnp.pi, 5 / 2)

In [None]:
points = [
    jnp.array([2 / 30, 2 / 30, 2 / 30, 2 / 30, 2 / 30,]),
    jnp.array([0, 2 / 30, 2 / 30, 2 / 30, 2 / 30,]),
    jnp.array([0, 0, 2 / 30, 2 / 30, 2 / 30,]),
    jnp.array([0, 0, 0, 2 / 30, 2 / 30,]),
    jnp.array([0, 0, 0, 0, 2 / 30,]),
    jnp.array([0, 0, 0, 0, 0,]),
    jnp.array([8/30, 0, 0, 0, 0,]),
]

for point in points:
    print(gaussian_kernel(x=point, bandwidth=bw) * bw**5 * jnp.power(2 * jnp.pi, 5 / 2))

In [None]:
2/30

In [None]:
jnp.array([2 / 30, 2 / 30, 2 / 30, 2 / 30, 2 / 30,]).dim