In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

from exciting_exciting_systems.utils.density_estimation import build_grid_2d
from exciting_exciting_systems.models.model_utils import simulate_ahead_with_env
from exciting_exciting_systems.utils.signals import aprbs
from exciting_exciting_systems.evaluation.plotting_utils import plot_sequence
from exciting_exciting_systems.excitation.excitation_utils import soft_penalty

---

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=3333) # 21)

data_key, opt_key, key = jax.random.split(key, 3)
data_rng = PRNGSequence(data_key)

In [None]:
batch_size = 1
tau = 2e-2 # 5e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    tau=tau
)

In [None]:
obs, env_state = env.reset()
obs = obs.astype(jnp.float32)
env_state = env_state.astype(jnp.float32)
n_steps = 999

actions = aprbs(n_steps, batch_size, 1, 10, next(data_rng))

In [None]:
observations = jax.vmap(simulate_ahead_with_env, in_axes=(None, 0, 0, 0, 0, 0, 0))(
    env,
    obs,
    env_state,
    actions,
    env.env_state_normalizer,
    env.action_normalizer,
    env.static_params
)

print("actions.shape:", actions.shape)
print("observations.shape:", observations.shape)

print(" \n One of the trajectories:")
fig, axs = plot_sequence(
    observations=observations[0, ...],
    actions=actions[0, ...],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

### Metrics:

- maximum nearest neighbor sequence **[Smits+Nelles2024]**:

\begin{align}
    f_{\mathrm{MNNS}} = &- \frac{1} {L} \sum_{k=N+1}^{N+L} \min_{i \in \{1, ..., N \}} \| \mathbf{x}_i - \mathbf{x}_k \|_2 \\
                        &+ \# u_{v, l_v} d_{max},
\end{align}

where $L$ is the sequence length made up of the next $h$ steps (which can vary in length) and $\mathbf{x}_k$ are the new data_points that are simulated from the model. The term $\# u_{v, l_v} d_{max}$ is meant to weaken the effect of overemphasized boundaries and corners.
**Thereby, $\# u_{v, l_v}$ denotes the counter of the amplitude levels of the $v$-th input dimension?** and $d_{max} = k_{d_{max}} \Delta$, where

\begin{align}
    \Delta = \frac{2}{N(N-1)} \sum_{i=1}^N \sum_{k=i+1}^N \| \mathbf{x}_i - \mathbf{x}_k \|_2.
\end{align}

- audze eglais  **[Smits+Nelles2024]**:

In [None]:
from sklearn.neighbors import NearestNeighbors

from exciting_exciting_systems.utils.metrics import audze_eglais, MC_uniform_sampling_distribution_approximation, MNNS_without_penalty

In [None]:
new_observations = observations[:, 900:, :]
old_observations = observations[:, :900, :]

mnns_score = jax.vmap(MNNS_without_penalty)(
    data_points=old_observations,
    new_data_points=new_observations
)

mnns_score

In [None]:
ae_score = jax.vmap(audze_eglais)(
    data_points=observations,
)

print(ae_score)

In [None]:
support_points = build_grid_2d(-1, 1, 50)
MCUSDA_score = jax.vmap(MC_uniform_sampling_distribution_approximation, in_axes=(0, None),)(
    observations,
    support_points
)
MCUSDA_score

### Offline iGOATs:
Implement and use the iGOATs algorithm as described in **[Smits+Nelles2024]** and **[Smits+Nelles2022]**.

- start with a random aprbs and optimize it by rolling out the next steps using the perfect model
- add the chosen actions to the sequence and start with the next random aprbs steps
- **in any case:** The simulation rollout should be part of the fitness function

- Further **TODO**: Test the metrics given above on random data and the optimized trajectory from your algorithm for sanity check...

In [None]:
import evosax  # genetic algorithm in jax, The state of the library might be interesting to investigate for your own implementations

In [None]:
@jax.jit
def featurize_theta(obs):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is 
    transformed to sin(phi) and cos(phi) for comparison in the loss."""
    feat_obs = jnp.stack([jnp.sin(obs[..., 0] * jnp.pi), jnp.cos(obs[..., 0] * jnp.pi), obs[..., 1]], axis=-1)
    return feat_obs

In [None]:
def MNNS_without_penalty(
        data_points: jnp.ndarray,
        new_data_points: jnp.ndarray
) -> jnp.ndarray:
    """From [Smits+Nelles2024].

    Implementation inspired by https://github.com/google/jax/discussions/9813

    TODO: Not sure about this penalty. Seems difficult to use for continuous action-spaces?
    """
    L = new_data_points.shape[0]
    distance_matrix = jnp.linalg.norm(data_points[:, None, :] - new_data_points[None, ...], axis=-1)
    minimal_distances = jnp.min(distance_matrix, axis=0)
    return - jnp.sum(minimal_distances) / L

In [None]:
def fitness_function(
    env,
    obs,
    env_state,
    prev_observations,
    actions
):

    actions = actions[None, :, None]
    
    observations = jax.vmap(simulate_ahead_with_env, in_axes=(None, 0, 0, 0, 0, 0, 0))(
        env,
        obs,
        env_state,
        actions,
        env.env_state_normalizer,
        env.action_normalizer,
        env.static_params
    )

    feat_observations = jax.vmap(featurize_theta, in_axes=(0))(observations)
    
    score = MNNS_without_penalty(
        data_points=featurize_theta(prev_observations),
        new_data_points=feat_observations[0, 1:, :]
                                       
    )

    rho_obs = 1e10
    rho_act = 1e10
    penalty_terms = rho_obs * soft_penalty(a=observations, a_max=1) + rho_act * soft_penalty(a=actions, a_max=1)
    
    return jnp.squeeze(score) + penalty_terms

In [None]:
def optimize_actions(
    opt_key,
    strategy,
    es_params,
    env,
    obs,
    env_state,
    prev_observations,
    num_generations
):
    """TODO: Put this in a lax loop. Not possible because of the loss...."""

    opt_state = strategy.initialize(opt_key, es_params)
    
    for t in range(num_generations):
        opt_key, opt_gen_key, opt_eval_key = jax.random.split(opt_key, 3)
        actions, opt_state = strategy.ask(opt_gen_key, opt_state, es_params)
    
        fitness = jax.vmap(fitness_function, in_axes=(None, None, None, None, 0))(env, obs, env_state, prev_observations, actions)
        opt_state = strategy.tell(actions, fitness, opt_state, es_params)

    return opt_state.best_member[None, :, None]

In [None]:
observations = []
actions = []

n_prediction_steps = 50

population_size = 20
n_generations = 100

obs, env_state = env.reset()
obs = obs.astype(jnp.float32)
env_state = env_state.astype(jnp.float32)

observations.append(obs[0])

strategy = evosax.CMA_ES(popsize=20, num_dims=n_prediction_steps, elite_ratio=0.5)
es_params = strategy.default_params

for j in tqdm(range(int(5000 / n_prediction_steps))):

    new_actions = optimize_actions(
        opt_key,
        strategy,
        es_params,
        env,
        obs,
        env_state,
        prev_observations=jnp.stack(observations),
        num_generations=n_generations
    )

    for i in range(new_actions.shape[1]):
        action = new_actions[:, i, :]
        obs, _, _, _, env_state = env.step(action, env_state)
    
        observations.append(obs[0])
        actions.append(action[0])


In [None]:
fig, axs = plot_sequence(
    observations=jnp.stack(observations),
    actions=jnp.stack(actions),
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

In [None]:
from exciting_exciting_systems.utils.metrics import MNNS_without_penalty, audze_eglais, MC_uniform_sampling_distribution_approximation
from scipy.stats.qmc import LatinHypercube

In [None]:
lhc_sampler = LatinHypercube(d=2)

In [None]:
ae_score = audze_eglais(jnp.stack(observations))
print(ae_score)

In [None]:
mcudsa_score = MC_uniform_sampling_distribution_approximation(
    data_points=jnp.stack(observations),
    support_points=lhc_sampler.random(n=1600)
)
print(mcudsa_score)