In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp

jax.config.update('jax_platform_name', 'cpu')

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
from cmaes import CMAwM

import exciting_environments as excenvs

import exciting_exciting_systems
from exciting_exciting_systems.models.model_utils import simulate_ahead_with_env
from exciting_exciting_systems.evaluation.plotting_utils import plot_sequence

from exciting_exciting_systems.related_work.excitation_utils import generate_aprbs

---

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=3333) # 21)

data_key, opt_key, key = jax.random.split(key, 3)
data_rng = PRNGSequence(data_key)

In [None]:
batch_size = 1
tau = 2e-2 # 5e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    tau=tau
)

- on runtime in jit generate a mask
- loss = jnp.sum(jnp.where(a != jnp.inf, a, 0)
- **this does not solve the issue with mixed-integer optimization!?**

In [None]:
h = 100

action_parameters = jnp.concatenate([
    jax.random.uniform(
        key=next(data_rng),
        shape=(h,),
        minval=-1,
        maxval=1
    ),
    jax.random.randint(
        key=next(data_rng),
        shape=(h,),
        minval=2,
        maxval=50
    )
])

actions = generate_aprbs(
    amplitudes=action_parameters[:h],
    durations=np.abs(action_parameters[h:].astype(jnp.int32))
)[None, :, None]

plt.plot(jnp.squeeze(actions))

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)

observations = jax.vmap(simulate_ahead_with_env, in_axes=(None, 0, 0, 0, 0, 0, 0))(
    env,
    obs,
    state,
    actions,
    env.env_state_normalizer,
    env.action_normalizer,
    env.static_params
)

print("actions.shape:", actions.shape)
print("observations.shape:", observations.shape)

print(" \n One of the trajectories:")
fig, axs = plot_sequence(
    observations=observations[0, ...],
    actions=actions[0, ...],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

---

In [None]:
from exciting_exciting_systems.related_work.algorithms import excite_with_iGOATs

In [None]:
def featurize_theta(obs):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is 
    transformed to sin(phi) and cos(phi) for comparison in the loss."""

    feat_obs = np.stack([np.sin(obs[..., 0] * np.pi), np.cos(obs[..., 0] * np.pi), obs[..., 1]], axis=-1)
    return feat_obs

In [None]:
h = 2
a = 2

actions = []
observations = []

observations, actions = excite_with_iGOATs(
    n_timesteps=5000,
    env=env,
    actions=actions,
    observations=observations,
    h=h,
    a=a,
    bounds_amplitude=[-1, 1],
    bounds_duration=[1, 50],
    population_size=20,
    n_generations=100,
    mean=np.hstack([np.zeros(h), np.ones(h) * 25]),
    sigma=2.0,
    featurize=featurize_theta
)

In [None]:
fig, axs = plot_sequence(
    observations=jnp.stack(observations),
    actions=jnp.stack(actions),
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

---

In [None]:
# np.save("obs_iGOATs.npy", np.stack(observations))
# np.save("act_iGOATs.npy", np.stack(actions))

In [None]:
observations = np.load("obs_iGOATs.npy")
actions = np.load("act_iGOATs.npy")

In [None]:
fig, axs = plot_sequence(
    observations=np.stack(observations),
    actions=np.stack(actions),
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

In [None]:
from exciting_exciting_systems.utils.metrics import MNNS_without_penalty, audze_eglais, MC_uniform_sampling_distribution_approximation
from scipy.stats.qmc import LatinHypercube

In [None]:
lhc_sampler = LatinHypercube(d=2)

In [None]:
ae_score = audze_eglais(jnp.stack(observations))
print(ae_score)

In [None]:
mcudsa_score = MC_uniform_sampling_distribution_approximation(
    data_points=jnp.stack(observations),
    support_points=lhc_sampler.random(n=1600)
)
print(mcudsa_score)

In [None]:
numpy.save(