In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

from exciting_exciting_systems.utils.density_estimation import build_grid_2d
from exciting_exciting_systems.models.model_utils import simulate_ahead_with_env
from exciting_exciting_systems.utils.signals import aprbs
from exciting_exciting_systems.evaluation.plotting_utils import plot_sequence

---

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=3333) # 21)

data_key, model_key, loader_key, key = jax.random.split(key, 4)
data_rng = PRNGSequence(data_key)

In [None]:
batch_size = 1
tau = 2e-2 # 5e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    tau=tau
)

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)
n_steps = 999

actions = aprbs(n_steps, batch_size, 1, 10, next(data_rng))

In [None]:
observations = jax.vmap(simulate_ahead_with_env, in_axes=(None, 0, 0, 0, 0, 0, 0))(
    env,
    obs,
    state,
    actions,
    env.env_state_normalizer,
    env.action_normalizer,
    env.static_params
)

print("actions.shape:", actions.shape)
print("observations.shape:", observations.shape)

print(" \n One of the trajectories:")
fig, axs = plot_sequence(
    observations=observations[0, ...],
    actions=actions[0, ...],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

### Metrics:

- maximum nearest neighbor sequence **[Smits+Nelles2024]**:

\begin{align}
    f_{\mathrm{MNNS}} = &- \frac{1} {L} \sum_{k=N+1}^{N+L} \min_{i \in \{1, ..., N \}} \| \mathbf{x}_i - \mathbf{x}_k \|_2 \\
                        &+ \# u_{v, l_v} d_{max},
\end{align}

where $L$ is the sequence length made up of the next $h$ steps (which can vary in length) and $\mathbf{x}_k$ are the new data_points that are simulated from the model. The term $\# u_{v, l_v} d_{max}$ is meant to weaken the effect of overemphasized boundaries and corners.
**Thereby, $\# u_{v, l_v}$ denotes the counter of the amplitude levels of the $v$-th input dimension?** and $d_{max} = k_{d_{max}} \Delta$, where

\begin{align}
    \Delta = \frac{2}{N(N-1)} \sum_{i=1}^N \sum_{k=i+1}^N \| \mathbf{x}_i - \mathbf{x}_k \|_2.
\end{align}

- audze eglais  **[Smits+Nelles2024]**:

In [None]:
from sklearn.neighbors import NearestNeighbors

from exciting_exciting_systems.utils.metrics import audze_eglais, MC_uniform_sampling_distribution_approximation, MNNS_without_penalty

In [None]:
new_observations = observations[:, 900:, :]
old_observations = observations[:, :900, :]

mnns_score = jax.vmap(MNNS_without_penalty)(
    data_points=old_observations,
    new_data_points=new_observations
)

mnns_score

In [None]:
ae_score = jax.vmap(audze_eglais)(
    data_points=observations,
)

print(ae_score)

In [None]:
support_points = build_grid_2d(-1, 1, 50)
MCUSDA_score = jax.vmap(MC_uniform_sampling_distribution_approximation, in_axes=(0, None),)(
    observations,
    support_points
)
MCUSDA_score

### Offline iGOATs:
Implement and use the iGOATs algorithm as described in **[Smits+Nelles2024]** and **[Smits+Nelles2022]**.

- start with a random aprbs and optimize it by rolling out the next steps using the perfect model
- add the chosen actions to the sequence and start with the next random aprbs steps
- **in any case:** The simulation rollout should be part of the fitness function

- Further **TODO**: Test the metrics given above on random data and the optimized trajectory from your algorithm for sanity check...

In [None]:
import evosax  # genetic algorithm in jax, The state of the library might be interesting to investigate

In [None]:
def generate_aprbs():
    """Parameterizable aprbs. This is used to transform the aprbs parameters into a signal."""
    pass