In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
# plt.rcParams['text.usetex'] = True
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

import exciting_exciting_systems as eesys
from exciting_exciting_systems.models import NeuralEulerODEPendulum
from exciting_exciting_systems.models.model_utils import simulate_ahead_with_env
from exciting_exciting_systems.models.model_training import ModelTrainer
from exciting_exciting_systems.excitation import loss_function, Exciter

from exciting_exciting_systems.utils.density_estimation import (
    update_density_estimate_single_observation, update_density_estimate_multiple_observations, DensityEstimate
)
from exciting_exciting_systems.utils.signals import aprbs
from exciting_exciting_systems.evaluation.plotting_utils import (
    plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction, plot_model_performance
)

---

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=33) # 21)

data_key, model_key, loader_key, key = jax.random.split(key, 4)
data_rng = PRNGSequence(data_key)

In [None]:
batch_size = 1
tau = 2e-2 # 5e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    tau=tau,
    max_torque=5
)

### Test simulation:

- starting from the intial state/obs ($\mathbf{x}_0$ / $\mathbf{y}_0$)
- apply $N = 999$ actions $\mathbf{u}_0 \dots \mathbf{u}_N$ (**here**: random APRBS actions)
- which results in the state trajectory $\mathbf{x}_0 ... \mathbf{x}_N+1$ with $N+1 = 1000$ elements

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)
n_steps = 999

actions = aprbs(n_steps, batch_size, 1, 10, next(data_rng))

In [None]:
observations = jax.vmap(simulate_ahead_with_env, in_axes=(None, 0, 0, 0, 0, 0, 0))(
    env,
    obs,
    state,
    actions,
    env.env_state_normalizer,
    env.action_normalizer,
    env.static_params
)

print("actions.shape:", actions.shape)
print("observations.shape:", observations.shape)

print(" \n One of the trajectories:")
fig, axs = plot_sequence(
    observations=observations[0, ...],
    actions=actions[0, ...],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

## Build an algorithm that simultaneously learns the model and optimizes its trajectory:

In [None]:
def featurize_theta(obs):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is 
    transformed to sin(phi) and cos(phi) for comparison in the loss."""
    feat_obs = jnp.stack([jnp.sin(obs[..., 0] * jnp.pi), jnp.cos(obs[..., 0] * jnp.pi), obs[..., 1]], axis=-1)
    return feat_obs

In [None]:
bandwidth = 0.1
n_prediction_steps = 50

dim = 3
points_per_dim = 50
n_grid_points=points_per_dim**dim

n_timesteps = 15_000

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)

observations = jnp.zeros((n_timesteps, env.env_observation_space.shape[-1]))
observations = observations.at[0].set(obs[0])
actions = jnp.zeros((n_timesteps-1, env.action_space.shape[-1]))

proposed_actions = aprbs(n_prediction_steps, batch_size, 1, 10, next(data_rng))

In [None]:
exciter = Exciter(
    grad_loss_function=jax.grad(loss_function, argnums=(3)),
    excitation_optimizer=optax.adabelief(1e-1),
    tau=tau,
    target_distribution=jnp.ones(shape=(n_grid_points, 1)) * 1 / (env.env_observation_space.high - env.env_observation_space.low)**dim
)

model_trainer = ModelTrainer(
    start_learning=n_prediction_steps,
    training_batch_size=128,
    n_train_steps=1,
    sequence_length=n_prediction_steps,
    featurize=featurize_theta,
    model_optimizer=optax.adabelief(1e-4),
    tau=tau
)

density_estimate = DensityEstimate(
    p=jnp.zeros([batch_size, n_grid_points, 1]),
    x_g=eesys.utils.density_estimation.build_grid_3d(
        low=env.env_observation_space.low,
        high=env.env_observation_space.high,
        points_per_dim=points_per_dim
    ),
    bandwidth=jnp.array([bandwidth]),
    n_observations=jnp.array([0])
)

# density_estimate = DensityEstimate(
#     p=jnp.zeros([batch_size, n_grid_points, 1]),
#     x_g=eesys.utils.density_estimation.build_grid_2d(
#         low=env.env_observation_space.low,
#         high=env.env_observation_space.high,
#         points_per_dim=points_per_dim
#     ),
#     bandwidth=jnp.array([bandwidth]),
#     n_observations=jnp.array([0])
# )

model = NeuralEulerODEPendulum(
    obs_dim=env.env_observation_space.shape[-1],
    action_dim=env.action_space.shape[-1],
    width_size=128,
    depth=3,
    key=model_key
)

opt_state_model = model_trainer.model_optimizer.init(eqx.filter(model, eqx.is_inexact_array))

In [None]:
from exciting_exciting_systems.algorithms import excite_and_fit

In [None]:
observations, actions, model, density_estimate = excite_and_fit(
    n_timesteps=n_timesteps,
    env=env,
    model=model,
    obs=obs,
    state=state,
    proposed_actions=proposed_actions,
    exciter=exciter,
    model_trainer=model_trainer,
    density_estimate=density_estimate,
    observations=observations,
    actions=actions,
    opt_state_model=opt_state_model,
    loader_key=loader_key
)

In [None]:
fig, axs = plot_sequence(
    observations,
    actions,
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.plot()

In [None]:
fig, axs = plot_model_performance(
    model=model,
    true_observations=observations[1000:2000],
    actions=actions[1000:1999],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.plot()

In [None]:
if dim == 2:
    fig, axs, cax = eesys.evaluation.plotting_utils.plot_2d_kde_as_contourf(
        density_estimate.p, density_estimate.x_g, [r"$\theta$", r"$\omega$"]
    )
    # fig.savefig("excited_pendulum_kde_contourf.png")

In [None]:
if dim == 2:
    fig, axs = eesys.evaluation.plotting_utils.plot_2d_kde_as_surface(
        density_estimate.p, density_estimate.x_g, [r"$\theta$", r"$\omega$"]
    )
    fig.suptitle("Vanilla KDE")
    # fig.savefig("excited_pendulum_kde_surface.png")
    plt.show()
    
    fig, axs = eesys.evaluation.plotting_utils.plot_2d_kde_as_surface(
        jnp.abs(density_estimate.p - exciter.target_distribution), density_estimate.x_g, [r"$\theta$", r"$\omega$"]
    )
    fig.suptitle("Difference")
    
    plt.show()
    
    fig, axs, cax = eesys.evaluation.plotting_utils.plot_2d_kde_as_contourf(
        jnp.abs(density_estimate.p - exciter.target_distribution), density_estimate.x_g, [r"$\theta$", r"$\omega$"]
    )
    plt.colorbar(cax)
    fig.suptitle("Abs Difference")
    
    plt.show()

In [None]:
from exciting_exciting_systems.utils.metrics import MNNS_without_penalty, audze_eglais, MC_uniform_sampling_distribution_approximation
from scipy.stats.qmc import LatinHypercube

In [None]:
lhc_sampler = LatinHypercube(d=2)

In [None]:
ae_score = audze_eglais(observations)
print(ae_score)

In [None]:
def MC_uniform_sampling_distribution_approximation(
        data_points: jnp.ndarray,
        support_points: jnp.ndarray
) -> jnp.ndarray:
    """From [Smits+Nelles2024]. The minimax-design tries to minimize
    the distances of the data points to the support points.

    What stops the data points to just flock to a single support point?
    This is just looking at the shortest distance.
    """
    M = support_points.shape[0]
    distance_matrix = jnp.linalg.norm(data_points[:, None, :] - support_points[None, ...], axis=-1)
    minimal_distances = jnp.min(distance_matrix, axis=0)

    return jnp.sum(minimal_distances) / M

In [None]:
mcudsa_score = MC_uniform_sampling_distribution_approximation(
    data_points=observations,
    support_points=lhc_sampler.random(n=1600) * 2 - 1
)
print(mcudsa_score)

---
### Look at the actions:

In [None]:
def build_grid_3d(low, high, points_per_dim):
    x1, x2, x3 = [
        jnp.linspace(low, high, points_per_dim),
        jnp.linspace(low, high, points_per_dim),
        jnp.linspace(low, high, points_per_dim)
    ]

    x_g = jnp.meshgrid(*[x1, x2, x3])
    x_g = jnp.stack([_x for _x in x_g], axis=-1)
    x_g = x_g.reshape(-1, 3)

    assert x_g.shape[0] == points_per_dim**3
    return x_g

In [None]:
jnp.concatenate([observations[0:-1, :], actions], axis=-1).shape

In [None]:
points_per_dim = 40
n_grid_points=points_per_dim**3
density_estimate = DensityEstimate(
    p=jnp.zeros([batch_size, n_grid_points, 1]),
    x_g=build_grid_3d(-1, 1, points_per_dim),
    bandwidth=jnp.array([bandwidth]),
    n_observations=jnp.array([0])
)

density_estimate = jax.vmap(
    update_density_estimate_multiple_observations,
    in_axes=(DensityEstimate(0, None, None, None), 0),
    out_axes=(DensityEstimate(0, None, None, None))
)(
    density_estimate,
    jnp.concatenate([observations[0:-1, :], actions], axis=-1)[None],
)

In [None]:
density_estimate.p.shape

In [None]:
x_plot = density_estimate.x_g.reshape((points_per_dim, points_per_dim, points_per_dim, 3))

fig, axs = plt.subplots(
    figsize=(6, 6)
)

ims = []
for i in range(points_per_dim):
    # fig, axs = plt.subplots(
    #     figsize=(6, 6)
    # )
    cax = axs.contourf(
        # x_plot[:, :, 0, :-1][..., 0],
        #x_plot[:, :, 0, :-1][..., 1],
        density_estimate.p[0].reshape((points_per_dim, points_per_dim, points_per_dim))[:, :, i],
        #jnp.sum(density_estimate.p[0].reshape((points_per_dim, points_per_dim, points_per_dim)), axis=-1),
        antialiased=False,
        levels=100,
        alpha=0.9,
        cmap=plt.cm.coolwarm
    )
    ims.append([cax])
    # plt.title(jnp.linspace(-1, 1, points_per_dim)[i])
    # plt.show()

In [None]:
import matplotlib.animation as animation

In [None]:
ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True,
                                repeat_delay=1000)

In [None]:
writer = animation.PillowWriter(fps=5,
                                metadata=dict(artist='Me'),
                                bitrate=1800)
ani.save('opt_wrt_obs_and_act.gif', writer=writer)

- maybe look at the vector fields here as well?
- I thinks it is possible that the system does not have enough strength to go through the upper equilibrium at max velocity?