In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

from exciting_exciting_systems.utils.density_estimation import build_grid_2d
from exciting_exciting_systems.models.model_utils import simulate_ahead_with_env
from exciting_exciting_systems.utils.signals import aprbs
from exciting_exciting_systems.evaluation.plotting_utils import plot_sequence

---

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=3333) # 21)

data_key, opt_key, key = jax.random.split(key, 3)
data_rng = PRNGSequence(data_key)

In [None]:
batch_size = 1
tau = 2e-2 # 5e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    tau=tau
)

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)
n_steps = 999

actions = aprbs(n_steps, batch_size, 1, 10, next(data_rng))

In [None]:
observations = jax.vmap(simulate_ahead_with_env, in_axes=(None, 0, 0, 0, 0, 0, 0))(
    env,
    obs,
    state,
    actions,
    env.env_state_normalizer,
    env.action_normalizer,
    env.static_params
)

print("actions.shape:", actions.shape)
print("observations.shape:", observations.shape)

print(" \n One of the trajectories:")
fig, axs = plot_sequence(
    observations=observations[0, ...],
    actions=actions[0, ...],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

### Metrics:

- maximum nearest neighbor sequence **[Smits+Nelles2024]**:

\begin{align}
    f_{\mathrm{MNNS}} = &- \frac{1} {L} \sum_{k=N+1}^{N+L} \min_{i \in \{1, ..., N \}} \| \mathbf{x}_i - \mathbf{x}_k \|_2 \\
                        &+ \# u_{v, l_v} d_{max},
\end{align}

where $L$ is the sequence length made up of the next $h$ steps (which can vary in length) and $\mathbf{x}_k$ are the new data_points that are simulated from the model. The term $\# u_{v, l_v} d_{max}$ is meant to weaken the effect of overemphasized boundaries and corners.
**Thereby, $\# u_{v, l_v}$ denotes the counter of the amplitude levels of the $v$-th input dimension?** and $d_{max} = k_{d_{max}} \Delta$, where

\begin{align}
    \Delta = \frac{2}{N(N-1)} \sum_{i=1}^N \sum_{k=i+1}^N \| \mathbf{x}_i - \mathbf{x}_k \|_2.
\end{align}

- audze eglais  **[Smits+Nelles2024]**:

In [None]:
from sklearn.neighbors import NearestNeighbors

from exciting_exciting_systems.utils.metrics import audze_eglais, MC_uniform_sampling_distribution_approximation, MNNS_without_penalty

In [None]:
new_observations = observations[:, 900:, :]
old_observations = observations[:, :900, :]

mnns_score = jax.vmap(MNNS_without_penalty)(
    data_points=old_observations,
    new_data_points=new_observations
)

mnns_score

In [None]:
ae_score = jax.vmap(audze_eglais)(
    data_points=observations,
)

print(ae_score)

In [None]:
support_points = build_grid_2d(-1, 1, 50)
MCUSDA_score = jax.vmap(MC_uniform_sampling_distribution_approximation, in_axes=(0, None),)(
    observations,
    support_points
)
MCUSDA_score

### Offline iGOATs:
Implement and use the iGOATs algorithm as described in **[Smits+Nelles2024]** and **[Smits+Nelles2022]**.

- start with a random aprbs and optimize it by rolling out the next steps using the perfect model
- add the chosen actions to the sequence and start with the next random aprbs steps
- **in any case:** The simulation rollout should be part of the fitness function

- Further **TODO**: Test the metrics given above on random data and the optimized trajectory from your algorithm for sanity check...

In [None]:
import evosax  # genetic algorithm in jax, The state of the library might be interesting to investigate for your own implementations

In [None]:
def generate_aprbs(amplitudes, durations):
    """Parameterizable aprbs. This is used to transform the aprbs parameters into a signal.

    Not jittable because duration cannot be traced. Seems to be an unsolvable problem for this algorithm.
    -> To me it just looks like a bad idea to have the duration be dependent on randomness.
    -> An alternative would be fixed full duration and have indexing dependent on the duration? Actually, this is not traceable as well?
    """
    signal = jnp.concatenate([jnp.ones(duration) * amplitude for (amplitude, duration) in zip(amplitudes, durations)])
    return signal

In [None]:
h = 20

action_parameters = jnp.concatenate([
    jax.random.uniform(
        key=next(data_rng),
        shape=(h,),
        minval=-1,
        maxval=1
    ),
    jax.random.randint(
        key=next(data_rng),
        shape=(h,),
        minval=2,
        maxval=50
    )
])

actions = generate_aprbs(
    amplitudes=action_parameters[:h],
    durations=action_parameters[h:].astype(jnp.int32)
)[None, :, None]

plt.plot(jnp.squeeze(actions))

In [None]:
@jax.jit
def featurize_theta(obs):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is 
    transformed to sin(phi) and cos(phi) for comparison in the loss."""
    feat_obs = jnp.stack([jnp.sin(obs[..., 0] * jnp.pi), jnp.cos(obs[..., 0] * jnp.pi), obs[..., 1]], axis=-1)
    return feat_obs

@jax.jit
def soft_penalty(a, a_max=1):
    """Computes penalty for the given input. Assumes symmetry in all dimensions."""
    penalty = jnp.sum(jax.nn.relu(jnp.abs(a) - a_max), axis=(-2, -1))
    return jnp.squeeze(penalty)

In [None]:
def fitness_function(
    env,
    obs,
    state,
    action_parameters,
    h
):
    actions = generate_aprbs(
        amplitudes=action_parameters[:h],
        durations=jnp.abs(action_parameters[h:]).astype(jnp.int32)
    )[None, :, None]
    observations = jax.vmap(simulate_ahead_with_env, in_axes=(None, 0, 0, 0, 0, 0, 0))(
        env,
        obs,
        state,
        actions,
        env.env_state_normalizer,
        env.action_normalizer,
        env.static_params
    )

    feat_observations = jax.vmap(featurize_theta, in_axes=(0))(observations)

    ae_score = jax.vmap(audze_eglais)(
        data_points=feat_observations,
    )

    rho_obs = 1e4
    rho_act = 1e4
    penalty_terms = rho_obs * soft_penalty(a=observations, a_max=1) + rho_act * soft_penalty(a=actions, a_max=1)
    
    return jnp.squeeze(ae_score) + penalty_terms

In [None]:
@jax.jit
def dummy_fitness_func(x):
    a = 2
    b = 100
    return (a - x[0])**2 + b * (x[1] - x[0]**2)**2

In [None]:
opt_key  # PRNGKey for the GA

In [None]:
num_generations = 100

strategy = evosax.CMA_ES(popsize=20, num_dims=2*h, elite_ratio=0.5)
es_params = strategy.default_params
opt_state = strategy.initialize(opt_key, es_params)

In [None]:
es_params

In [None]:
for t in tqdm(range(num_generations)):
    opt_key, opt_gen_key, opt_eval_key = jax.random.split(opt_key, 3)
    x, opt_state = strategy.ask(opt_gen_key, opt_state, es_params)

    obs, env_state = env.reset()
    obs = obs.astype(jnp.float32)
    env_state = env_state.astype(jnp.float32)
    # fitness = jax.vmap(fitness_function, in_axes=(None, None, None, 0, None))(env, obs, env_state, x, h)
    
    fitness = jnp.stack([fitness_function(env, obs, env_state, _x, h) for _x in x])
    opt_state = strategy.tell(x, fitness, state, es_params)

# Get best overall population member & its fitness
opt_state.best_member, opt_state.best_fitness

In [None]:
opt_state.best_member

In [None]:
actions = generate_aprbs(
    amplitudes=state.best_member[:h],
    durations=jnp.abs(state.best_member[h:]).astype(jnp.int32)
)[None, :, None]

obs, env_state = env.reset()
obs = obs.astype(jnp.float32)
env_state = env_state.astype(jnp.float32)
observations = jax.vmap(simulate_ahead_with_env, in_axes=(None, 0, 0, 0, 0, 0, 0))(
    env,
    obs,
    env_state,
    actions,
    env.env_state_normalizer,
    env.action_normalizer,
    env.static_params
)

fig, axs = plot_sequence(
    observations=observations[0, ...],
    actions=actions[0, ...],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

In [None]:
generate_aprbs(
    amplitudes=state.best_member[:h],
    durations=jnp.abs(state.best_member[h:]).astype(jnp.int32)
)[None, :, None]

In [None]:
state.best_member[:h]

In [None]:
jnp.abs(state.best_member[h:]).astype(jnp.int32)