In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
from cmaes import CMAwM

import exciting_environments as excenvs

from exciting_exciting_systems.models.model_utils import simulate_ahead_with_env
from exciting_exciting_systems.evaluation.plotting_utils import plot_sequence

---

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=3333) # 21)

data_key, opt_key, key = jax.random.split(key, 3)
data_rng = PRNGSequence(data_key)

In [None]:
batch_size = 1
tau = 2e-2 # 5e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    tau=tau
)

In [None]:
def generate_aprbs(amplitudes, durations):
    """Parameterizable aprbs. This is used to transform the aprbs parameters into a signal.

    Not jittable because duration cannot be traced. Seems to be an unsolvable problem for this algorithm.
    -> To me it just looks like a bad idea to have the duration be dependent on randomness.
    -> An alternative would be fixed full duration and have indexing dependent on the duration? Actually, this is not traceable as well?
    """
    signal = np.concatenate([np.ones(duration) * amplitude for (amplitude, duration) in zip(amplitudes, durations)])
    return signal

In [None]:
h = 100

action_parameters = jnp.concatenate([
    jax.random.uniform(
        key=next(data_rng),
        shape=(h,),
        minval=-1,
        maxval=1
    ),
    jax.random.randint(
        key=next(data_rng),
        shape=(h,),
        minval=2,
        maxval=50
    )
])

actions = generate_aprbs(
    amplitudes=action_parameters[:h],
    durations=np.abs(action_parameters[h:].astype(jnp.int32))
)[None, :, None]

plt.plot(jnp.squeeze(actions))

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)

observations = jax.vmap(simulate_ahead_with_env, in_axes=(None, 0, 0, 0, 0, 0, 0))(
    env,
    obs,
    state,
    actions,
    env.env_state_normalizer,
    env.action_normalizer,
    env.static_params
)

print("actions.shape:", actions.shape)
print("observations.shape:", observations.shape)

print(" \n One of the trajectories:")
fig, axs = plot_sequence(
    observations=observations[0, ...],
    actions=actions[0, ...],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

---

In [None]:
def soft_penalty(a, a_max=1):
    """Computes penalty for the given input. Assumes symmetry in all dimensions."""
    penalty = np.sum(jax.nn.relu(jnp.abs(a) - a_max), axis=(-2, -1))
    return np.squeeze(penalty)


def featurize_theta(obs):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is 
    transformed to sin(phi) and cos(phi) for comparison in the loss."""

    feat_obs = np.stack([np.sin(obs[..., 0] * np.pi), np.cos(obs[..., 0] * np.pi), obs[..., 1]], axis=-1)
    return feat_obs


def audze_eglais(data_points: np.ndarray) -> np.ndarray:
    """From [Smits+Nelles2024]. The maximin-desing penalizes points that
    are too close in the point distribution.

    TODO: There has to be a more efficient way to do this.
    """
    N = data_points.shape[0]
    distance_matrix = np.linalg.norm(data_points[:, None, :] - data_points[None, ...], axis=-1)
    # distances = distance_matrix[jax.numpy.triu_indices(N, k=1)]
    distances = distance_matrix[np.triu_indices(N, k=1)]
    
    return 2 / (N * (N-1)) * np.sum(1 / distances**2)


def MNNS_without_penalty(
        data_points: np.ndarray,
        new_data_points: np.ndarray
) -> np.ndarray:
    """From [Smits+Nelles2024].

    Implementation inspired by https://github.com/google/jax/discussions/9813

    TODO: Not sure about this penalty. Seems difficult to use for continuous action-spaces?
    """
    L = new_data_points.shape[0]
    distance_matrix = np.linalg.norm(data_points[:, None, :] - new_data_points[None, ...], axis=-1)
    minimal_distances = np.min(distance_matrix, axis=0)
    return - np.sum(minimal_distances) / L

In [None]:
def fitness_function(
    env,
    obs,
    state,
    prev_observations,
    action_parameters,
    h,
    max_duration
):
    actions = generate_aprbs(
        amplitudes=action_parameters[:h],
        durations=action_parameters[h:].astype(np.int64)
    )[:, None]


    max_signal_length = h * max_duration
    diff_to_max = max_signal_length - actions.shape[1]

    padded_actions = jnp.concatenate([actions, jnp.zeros((diff_to_max, 1))], axis=0)
    
    padded_observations = jax.vmap(simulate_ahead_with_env, in_axes=(None, 0, 0, 0, 0, 0, 0))(
        env,
        obs,
        state,
        padded_actions[None, ...],
        env.env_state_normalizer,
        env.action_normalizer,
        env.static_params
    )
    padded_observations = np.array(padded_observations[0])
    actions = np.array(actions)
    
    padded_feat_observations = featurize_theta(padded_observations)

    observations = padded_observations[:-diff_to_max, :]
    feat_observations = padded_feat_observations[:-diff_to_max, :]
    actions = padded_actions[:-diff_to_max, :]

    # score = audze_eglais(
    #     data_points=jnp.concatenate([featurize_theta(prev_observations), feat_observations[1:, :]]),
    # )

    score = MNNS_without_penalty(
        data_points=featurize_theta(prev_observations),
        new_data_points=feat_observations[1:, :]
                                       
    )
    
    rho_obs = 1e10
    rho_act = 1e10
    penalty_terms = rho_obs * soft_penalty(a=observations, a_max=1) + rho_act * soft_penalty(a=actions, a_max=1)
    
    return jnp.squeeze(score).item() + penalty_terms.item()

In [None]:
def optimize_aprbs(
    optimizer,
    obs,
    env_state,
    prev_observations,
    n_generations,
    env,
    h,
    max_duration
):
    for generation in range(n_generations):
        solutions = []
        x_for_eval_list = []

        for i in range(optimizer.population_size):        
            x_for_eval, x_for_tell = optimizer.ask()
            value = fitness_function(
                env,
                obs,
                env_state,
                prev_observations,
                x_for_eval,
                h,
                max_duration=max_duration
            )
            
            solutions.append((x_for_tell, value))
            x_for_eval_list.append(x_for_eval)
            # if (generation % 10) == 0:
            #     print(f"#{generation} {value} (x1={x[0]}, x2 = {x[1]})")
        optimizer.tell(solutions)

    values = []
    for x, value in solutions:
        values.append(value)
    
    xs = np.stack(x_for_eval_list)
    values = np.stack(values)
    min_idx = np.argmin(values)

    return xs[min_idx], values[min_idx], optimizer

In [None]:
observations = []
actions = []

h = 2
a = 2

continuous_dim = h
discrete_dim = h

population_size = 21
n_generations = 100

bounds = np.concatenate(
    [
        np.tile([-1, 1], (continuous_dim, 1)),
        np.tile([1, 50], (discrete_dim, 1)),
    ]
)

steps = np.concatenate([np.zeros(continuous_dim), np.ones(discrete_dim)])

obs, env_state = env.reset()
obs = obs.astype(jnp.float32)
env_state = env_state.astype(jnp.float32)

observations.append(obs[0])

while len(observations) < 5000:

    optimizer = CMAwM(
        mean=np.hstack([np.zeros(h), np.ones(h) * 25]),
        sigma=2.0,
        population_size=population_size,
        bounds=bounds,
        steps=steps
    )

    proposed_aprbs_params, values, optimizer = optimize_aprbs(
        optimizer,
        obs,
        env_state,
        jnp.stack(observations),
        n_generations=n_generations,
        env=env,
        h=h,
        max_duration=bounds[-1, -1]
    )

    amplitudes = proposed_aprbs_params[:h]
    durations = proposed_aprbs_params[h:].astype(np.int64)
    
    new_actions = generate_aprbs(
        amplitudes=amplitudes,
        durations=durations
    )[None, :, None]

    for i in range(new_actions.shape[1]):
        action = new_actions[:, i, :]
        obs, _, _, _, env_state = env.step(action, env_state)
    
        observations.append(obs[0])
        actions.append(action[0])

    print("n_observations:", len(observations))
    # fig, axs = plot_sequence(
    #     observations=jnp.stack(observations),
    #     actions=jnp.stack(actions),
    #     tau=tau,
    #     obs_labels=[r"$\theta$", r"$\omega$"],
    #     action_labels=[r"$u$"],
    # );
    # plt.show()

In [None]:
# np.save("obs_iGOATs.npy", np.stack(observations))
# np.save("act_iGOATs.npy", np.stack(actions))

In [None]:
observations = np.load("obs_iGOATs.npy")
actions = np.load("act_iGOATs.npy")

In [None]:
fig, axs = plot_sequence(
    observations=jnp.stack(observations),
    actions=jnp.stack(actions),
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

In [None]:
from exciting_exciting_systems.utils.metrics import MNNS_without_penalty, audze_eglais, MC_uniform_sampling_distribution_approximation
from scipy.stats.qmc import LatinHypercube

In [None]:
lhc_sampler = LatinHypercube(d=2)

In [None]:
ae_score = audze_eglais(jnp.stack(observations))
print(ae_score)

In [None]:
mcudsa_score = MC_uniform_sampling_distribution_approximation(
    data_points=jnp.stack(observations),
    support_points=lhc_sampler.random(n=1600)
)
print(mcudsa_score)

In [None]:
numpy.save(