In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['text.usetex'] = True
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}"
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp
# jax.config.update("jax_enable_x64", True)
import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

import exciting_exciting_systems
from exciting_exciting_systems.models import NeuralEulerODEPendulum, NeuralODEPendulum
from exciting_exciting_systems.models.model_utils import simulate_ahead_with_env
from exciting_exciting_systems.models.model_training import ModelTrainer
from exciting_exciting_systems.excitation import loss_function, Exciter

from exciting_exciting_systems.utils.density_estimation import (
    update_density_estimate_single_observation, update_density_estimate_multiple_observations, DensityEstimate
)
from exciting_exciting_systems.utils.signals import aprbs
from exciting_exciting_systems.evaluation.plotting_utils import (
    plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction, plot_model_performance
)

---

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=2)#8)

data_key, model_key, loader_key, key = jax.random.split(key, 4)
data_rng = PRNGSequence(data_key)

In [None]:
batch_size = 1
tau = 2e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    action_constraints={"torque": 8},
    static_params={"g": 9.81, "l": 1, "m": 1},
    solver=diffrax.Tsit5(),
    tau=tau,
)

### Test simulation:

- starting from the intial state/obs ($\mathbf{x}_0$ / $\mathbf{y}_0$)
- apply $N = 999$ actions $\mathbf{u}_0 \dots \mathbf{u}_N$ (**here**: random APRBS actions)
- which results in the state trajectory $\mathbf{x}_0 ... \mathbf{x}_N+1$ with $N+1 = 1000$ elements

In [None]:
obs, state = env.reset()
obs = obs[0]

n_steps = 999

actions = aprbs(n_steps, batch_size, 1, 10, next(data_rng))[0]

In [None]:
observations = simulate_ahead_with_env(env, obs, state, actions)

print("actions.shape:", actions.shape)
print("observations.shape:", observations.shape)

print(" \n One of the trajectories:")
fig, axs = plot_sequence(
    observations=observations,
    actions=actions,
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

## Build an algorithm that simultaneously learns the model and optimizes its trajectory:

In [None]:
def featurize_theta(obs):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is 
    transformed to sin(phi) and cos(phi) for comparison in the loss."""
    feat_obs = jnp.stack([jnp.sin(obs[..., 0] * jnp.pi), jnp.cos(obs[..., 0] * jnp.pi), obs[..., 1]], axis=-1)
    return feat_obs

In [None]:
bandwidth = 0.05
n_prediction_steps = 50

dim_obs_space = 2
dim_action_space = 1

dim = dim_obs_space + dim_action_space
points_per_dim = 50
n_grid_points=points_per_dim**dim

n_timesteps = 15_000

In [None]:
obs, state = env.reset()
obs = obs[0]

observations = jnp.zeros((n_timesteps, dim_obs_space))
observations = observations.at[0].set(obs)
actions = jnp.zeros((n_timesteps-1, dim_action_space))

proposed_actions = aprbs(n_prediction_steps, batch_size, 1, 10, next(data_rng))[0]

In [None]:
exciter = Exciter(
    grad_loss_function=jax.grad(loss_function, argnums=(2)),
    excitation_optimizer=optax.adabelief(1e-1),
    tau=tau,
    target_distribution=jnp.ones(shape=(n_grid_points, 1)) * 1 / (1 - (-1))**dim
)

model_trainer = ModelTrainer(
    start_learning=n_prediction_steps,
    training_batch_size=128,
    n_train_steps=1,
    sequence_length=n_prediction_steps,
    featurize=featurize_theta,
    model_optimizer=optax.adabelief(1e-4),
    tau=tau
)

# density_estimate = DensityEstimate(
#     p=jnp.zeros([batch_size, n_grid_points, 1]),
#     x_g=eesys.utils.density_estimation.build_grid_2d(
#         low=env.env_observation_space.low,
#         high=env.env_observation_space.high,
#         points_per_dim=points_per_dim
#     ),
#     bandwidth=jnp.array([bandwidth]),
#     n_observations=jnp.array([0])
# )

density_estimate = DensityEstimate(
    p=jnp.zeros([n_grid_points, 1]),
    x_g=exciting_exciting_systems.utils.density_estimation.build_grid_3d(
        low=-1,
        high=1,
        points_per_dim=points_per_dim
    ),
    bandwidth=jnp.array([bandwidth]),
    n_observations=jnp.array([0])
)

model = NeuralODEPendulum(
    solver=diffrax.Euler(),
    obs_dim=dim_obs_space,
    action_dim=dim_action_space,
    width_size=128,
    depth=3,
    key=model_key
)

# model = NeuralEulerODEPendulum(
#     obs_dim=dim_obs_space,
#     action_dim=dim_action_space,
#     width_size=128,
#     depth=3,
#     key=model_key
# )

opt_state_model = model_trainer.model_optimizer.init(eqx.filter(model, eqx.is_inexact_array))

In [None]:
from exciting_exciting_systems.algorithms import excite_and_fit

In [None]:
observations, actions, model, density_estimate = excite_and_fit(
    n_timesteps=n_timesteps,
    env=env,
    model=model,
    obs=obs,
    state=state,
    proposed_actions=proposed_actions,
    exciter=exciter,
    model_trainer=model_trainer,
    density_estimate=density_estimate,
    observations=observations,
    actions=actions,
    opt_state_model=opt_state_model,
    loader_key=loader_key,
    plot_every=500,
)

In [None]:
%debug

In [None]:
from exciting_exciting_systems.utils.metrics import JSDLoss, KLDLoss

In [None]:
JSDLoss(
    p=density_estimate.p / jnp.sum(density_estimate.p),
    q=exciter.target_distribution / jnp.sum(exciter.target_distribution),
)

In [None]:
fig, axs = plot_sequence(
    observations,
    actions,
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.plot()

### load data from experiments:

In [None]:
import json
import pathlib
import glob

from exciting_exciting_systems.models.model_utils import load_model

In [None]:
def get_experiment_ids(results_path: pathlib.Path):
    json_file_paths = glob.glob(str(results_path / pathlib.Path("*.json")))
    identifiers = set([pathlib.Path(path).stem.split('_', maxsplit=1)[-1] for path in json_file_paths])
    return sorted(list(identifiers))

In [None]:
def load_experiment_results(
    exp_id: str,
    model_class,
    results_path: pathlib.Path
):
    with open(results_path / pathlib.Path(f"params_{exp_id}.json"), "rb") as fp:
        params = json.load(fp)

    with open(results_path / pathlib.Path(f"data_{exp_id}.json"), "rb") as fp:
        data = json.load(fp)
        observations = jnp.array(data["observations"])
        actions = jnp.array(data["actions"])
    
    model = load_model(results_path / pathlib.Path(f"model_{exp_id}.json"), model_class)
    
    return params, observations, actions, model

In [None]:
results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/dmpe/")

In [None]:
get_experiment_ids(results_path)

In [None]:
def quick_eval(identifier):
    params, observations, actions, model = load_experiment_results(
        exp_id=identifier,
        model_class=NeuralEulerODEPendulum,
        results_path=results_path
    )

    print(identifier)
    print(params["alg_params"])

    
    fig, axs = plot_sequence(
        observations=observations,
        actions=actions,
        tau=env.tau,
        obs_labels=[r"$\theta$", r"$\omega$"],
        action_labels=[r"$u$"],
    );
    plt.show()
    
    fig, axs = plot_model_performance(
        model=model,
        true_observations=observations[:1000],
        actions=actions[:999],
        tau=tau,
        obs_labels=[r"$\theta$", r"$\omega$"],
        action_labels=[r"$u$"],
    );
    plt.show()

In [None]:
# identifier = '2024-06-26_14:13:52'  -> 0.9 u_max
# identifier = '2024-06-26_14:21:15' -> clipping

print("standard")
quick_eval(identifier='2024-06-26_11:13:15') # standard

print("0.9 u_max")
quick_eval(identifier='2024-06-26_14:13:52') # 0.9 u_max

print("clipping")
quick_eval(identifier = '2024-06-26_14:21:15') # clipping

In [None]:
for identifier in get_experiment_ids(results_path):
    print(identifier)
    quick_eval(identifier)

### Interpolation tests:

in exc_envs:
- action is given with $(batch\_size, n\_steps, action\_dim)$
- also $\tau_{act}$ is given
- The length of the action sequence is therefore $n\_steps \cdot \tau_{act}$ when each input is held constant (zero-order-hold) for $\tau_{act}$ seconds
- Note however, that the last element starts at the time $(n\_steps - 1) \cdot \tau_{act}$
- This occurs because the first element starts at time $0$

In [None]:
tau = 2e-2
actions = aprbs(20, 1, 1, 10, next(data_rng))[0]

t_true = jnp.linspace(0, (actions.shape[0]-1) * tau, actions.shape[0])
plt.plot(t_true, actions, 'r.')

In [None]:
t_int = jnp.linspace(0, (actions.shape[0]) * tau, 1000)
interpolated_actions = actions[jnp.array(t_int / tau, int), 0]

In [None]:
plt.plot(t_int, interpolated_actions, label=r"$u(t)$")
plt.plot(t_true, actions, label=r"$u_k$")
plt.legend()
plt.grid()

### KDEs & Metrics:

In [None]:
if dim == 2:
    fig, axs, cax = eesys.evaluation.plotting_utils.plot_2d_kde_as_contourf(
        density_estimate.p, density_estimate.x_g, [r"$\theta$", r"$\omega$"]
    )
    # fig.savefig("excited_pendulum_kde_contourf.png")

In [None]:
if dim == 2:
    fig, axs = eesys.evaluation.plotting_utils.plot_2d_kde_as_surface(
        density_estimate.p, density_estimate.x_g, [r"$\theta$", r"$\omega$"]
    )
    fig.suptitle("Vanilla KDE")
    # fig.savefig("excited_pendulum_kde_surface.png")
    plt.show()
    
    fig, axs = eesys.evaluation.plotting_utils.plot_2d_kde_as_surface(
        jnp.abs(density_estimate.p - exciter.target_distribution), density_estimate.x_g, [r"$\theta$", r"$\omega$"]
    )
    fig.suptitle("Difference")
    
    plt.show()
    
    fig, axs, cax = eesys.evaluation.plotting_utils.plot_2d_kde_as_contourf(
        jnp.abs(density_estimate.p - exciter.target_distribution), density_estimate.x_g, [r"$\theta$", r"$\omega$"]
    )
    plt.colorbar(cax)
    fig.suptitle("Abs Difference")
    
    plt.show()

In [None]:
raise

In [None]:
from exciting_exciting_systems.utils.metrics import MNNS_without_penalty, audze_eglais, MC_uniform_sampling_distribution_approximation
from scipy.stats.qmc import LatinHypercube

In [None]:
lhc_sampler = LatinHypercube(d=2)

In [None]:
ae_score = audze_eglais(observations)
print(ae_score)

In [None]:
def MC_uniform_sampling_distribution_approximation(
        data_points: jnp.ndarray,
        support_points: jnp.ndarray
) -> jnp.ndarray:
    """From [Smits+Nelles2024]. The minimax-design tries to minimize
    the distances of the data points to the support points.

    What stops the data points to just flock to a single support point?
    This is just looking at the shortest distance.
    """
    M = support_points.shape[0]
    distance_matrix = jnp.linalg.norm(data_points[:, None, :] - support_points[None, ...], axis=-1)
    minimal_distances = jnp.min(distance_matrix, axis=0)

    return jnp.sum(minimal_distances) / M

In [None]:
mcudsa_score = MC_uniform_sampling_distribution_approximation(
    data_points=observations,
    support_points=lhc_sampler.random(n=1600) * 2 - 1
)
print(mcudsa_score)

---
### Look at the actions:

In [None]:
def build_grid_3d(low, high, points_per_dim):
    x1, x2, x3 = [
        jnp.linspace(low, high, points_per_dim),
        jnp.linspace(low, high, points_per_dim),
        jnp.linspace(low, high, points_per_dim)
    ]

    x_g = jnp.meshgrid(*[x1, x2, x3])
    x_g = jnp.stack([_x for _x in x_g], axis=-1)
    x_g = x_g.reshape(-1, 3)

    assert x_g.shape[0] == points_per_dim**3
    return x_g

In [None]:
jnp.concatenate([observations[0:-1, :], actions], axis=-1).shape

In [None]:
points_per_dim = 40
n_grid_points=points_per_dim**3
density_estimate = DensityEstimate(
    p=jnp.zeros([batch_size, n_grid_points, 1]),
    x_g=build_grid_3d(-1, 1, points_per_dim),
    bandwidth=jnp.array([bandwidth]),
    n_observations=jnp.array([0])
)

density_estimate = jax.vmap(
    update_density_estimate_multiple_observations,
    in_axes=(DensityEstimate(0, None, None, None), 0),
    out_axes=(DensityEstimate(0, None, None, None))
)(
    density_estimate,
    jnp.concatenate([observations[0:-1, :], actions], axis=-1)[None],
)

In [None]:
density_estimate.p.shape

In [None]:
x_plot = density_estimate.x_g.reshape((points_per_dim, points_per_dim, points_per_dim, 3))

fig, axs = plt.subplots(
    figsize=(6, 6)
)

ims = []
for i in range(points_per_dim):
    # fig, axs = plt.subplots(
    #     figsize=(6, 6)
    # )
    cax = axs.contourf(
        # x_plot[:, :, 0, :-1][..., 0],
        #x_plot[:, :, 0, :-1][..., 1],
        density_estimate.p[0].reshape((points_per_dim, points_per_dim, points_per_dim))[:, :, i],
        #jnp.sum(density_estimate.p[0].reshape((points_per_dim, points_per_dim, points_per_dim)), axis=-1),
        antialiased=False,
        levels=100,
        alpha=0.9,
        cmap=plt.cm.coolwarm
    )
    ims.append([cax])
    # plt.title(jnp.linspace(-1, 1, points_per_dim)[i])
    # plt.show()

In [None]:
import matplotlib.animation as animation

In [None]:
ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True,
                                repeat_delay=1000)

In [None]:
writer = animation.PillowWriter(fps=5,
                                metadata=dict(artist='Me'),
                                bitrate=1800)
ani.save('opt_wrt_obs_and_act.gif', writer=writer)

- maybe look at the vector fields here as well?
- I thinks it is possible that the system does not have enough strength to go through the upper equilibrium at max velocity?

AFAICT ~$6.2 \ \mathrm{ Nm}$ are necessary for max angular velocity at the top position and only $5 \ \mathrm{ Nm}$ were available

$8 \ \mathrm{ Nm}$ seems to be a good point, were stability is also still maintainable