In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

from exciting_exciting_systems.utils.density_estimation import build_grid_2d
from exciting_exciting_systems.models.model_utils import simulate_ahead_with_env
from exciting_exciting_systems.utils.signals import aprbs
from exciting_exciting_systems.evaluation.plotting_utils import plot_sequence

---

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=3333) # 21)

data_key, model_key, loader_key, key = jax.random.split(key, 4)
data_rng = PRNGSequence(data_key)

In [None]:
batch_size = 1
tau = 2e-2 # 5e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    tau=tau
)

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)
n_steps = 999

actions = aprbs(n_steps, batch_size, 1, 10, next(data_rng))

In [None]:
observations = jax.vmap(simulate_ahead_with_env, in_axes=(None, 0, 0, 0, 0, 0, 0))(
    env,
    obs,
    state,
    actions,
    env.env_state_normalizer,
    env.action_normalizer,
    env.static_params
)

print("actions.shape:", actions.shape)
print("observations.shape:", observations.shape)

print(" \n One of the trajectories:")
fig, axs = plot_sequence(
    observations=observations[0, ...],
    actions=actions[0, ...],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

### Metrics:

- maximum nearest neighbor sequence **[Smits+Nelles2024]**:

\begin{align}
    f_{\mathrm{MNNS}} = &- \frac{1} {L} \sum_{k=N+1}^{N+L} \min_{i \in \{1, ..., N \}} \| \mathbf{x}_i - \mathbf{x}_k \|_2 \\
                        &+ \# u_{v, l_v} d_{max},
\end{align}

where $L$ is the sequence length made up of the next $h$ steps (which can vary in length) and $\mathbf{x}_k$ are the new data_points that are simulated from the model. The term $\# u_{v, l_v} d_{max}$ is meant to weaken the effect of overemphasized boundaries and corners.
**Thereby, $\# u_{v, l_v}$ denotes the counter of the amplitude levels of the $v$-th input dimension?** and $d_{max} = k_{d_{max}} \Delta$, where

\begin{align}
    \Delta = \frac{2}{N(N-1)} \sum_{i=1}^N \sum_{k=i+1}^N \| \mathbf{x}_i - \mathbf{x}_k \|_2.
\end{align}

- audze eglais  **[Smits+Nelles2024]**:

In [None]:
from sklearn.neighbors import NearestNeighbors

In [None]:
@jax.jit
def max_nearest_neighbour_seq(
        data_points: jnp.ndarray,
        new_data_points: jnp.ndarray
) -> jnp.ndarray:
    """From [Smits+Nelles2024].

    Implementation inspired by https://github.com/google/jax/discussions/9813

    TODO: add penalty? I do not really understand what this penalty is supposed to be..
    """
    L = new_data_points.shape[0]
    distance_matrix = jnp.linalg.norm(data_points[:, None, :] - new_data_points[None, ...], axis=-1)
    minimal_distances = jnp.min(distance_matrix, axis=0)
    return - jnp.sum(minimal_distances) / L


@jax.jit
def audze_eglais(data_points: jnp.ndarray) -> jnp.ndarray:
    """From [Smits+Nelles2024]. The maximin-desing penalizes points that 
    are too close in the point distribution.

    TODO: There has to be a more efficient way to do this.    
    """
    N = data_points.shape[0]
    distance_matrix = jnp.linalg.norm(data_points[:, None, :] - data_points[None, ...], axis=-1)
    distances = distance_matrix[jax.numpy.triu_indices(N, k=1)]
    
    return 2 / (N * (N-1)) * jnp.sum(1 / distances**2)


# jax does not like the dynamic slicing....
# @jax.jit
# def audze_eglais_loop(data_points: jnp.ndarray) -> jnp.ndarray:
#     """From [Smits+Nelles2024].

#     Single loop based implementation.
#     """
#     N = data_points.shape[0]

#     def body_fun(i, carry):
#         sum_value = carry

#         helper_points = data_points[i+1:, :]
#         distance_matrix = jnp.linalg.norm(data_points[:, None, :] - data_points[None, ...], axis=-1)
#         sum_value += jnp.sum(distance_matrix)
#         return sum_value        

#     result = jax.lax.fori_loop(
#         lower=0, 
#         upper=N,
#         body_fun=body_fun,
#         init_val=jnp.array([0.])
#     )                  

#     return result * 2 / (N * (N-1))


@jax.jit
def audze_eglais_loops(data_points: jnp.ndarray) -> jnp.ndarray:
    """From [Smits+Nelles2024].

    Multiple loops based implementation. Pretty slow. This is not it
    """
    N = data_points.shape[0]

    def body_fun(k, carry):
        i, sum_value = carry
        sum_value += 1 / jnp.linalg.norm(
            data_points[i, :] - data_points[k, :], axis=-1
        )**2
        return i, sum_value       

    result = jax.lax.fori_loop(
        lower=0, 
        upper=N,
        body_fun=lambda i, carry: jax.lax.fori_loop(i+1, N, body_fun, (i, carry))[1],
        init_val=jnp.array([0.])
    )                  

    return result * 2 / (N * (N-1))


@jax.jit
def MC_uniform_sampling_distribution_approximation(
        data_points: jnp.ndarray,
        support_points: jnp.ndarray
) -> jnp.ndarray:
    """From [Smits+Nelles2024]. The minimax-design tries to minimize
    the distances of the data points to the support points.

    What stops the data points to just flock to a single support point?
    This is just looking at the shortest distance.
    """
    M = support_points.shape[0]
    distance_matrix = jnp.linalg.norm(data_points[:, None, :] - support_points[None, ...], axis=-1)
    minimal_distances = jnp.min(distance_matrix, axis=0)
    return jnp.sum(minimal_distances) / M

In [None]:
new_observations = observations[:, 900:, :]
old_observations = observations[:, :900, :]

mnns_score = jax.vmap(max_nearest_neighbour_seq)(
    data_points=old_observations,
    new_data_points=new_observations
)

mnns_score

In [None]:
ae_score = jax.vmap(audze_eglais)(
    data_points=observations,
)

print(ae_score)

In [None]:
support_points = build_grid_2d(-1, 1, 50)
MCUSDA_score = jax.vmap(MC_uniform_sampling_distribution_approximation, in_axes=(0, None),)(
    observations,
    support_points
)
MCUSDA_score

### Offline iGOATs:

In [None]:
def generate_aprbs