In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import time
import matplotlib.pyplot as plt
import numpy as np

import matplotlib as mpl
mpl.rcParams['text.usetex'] = True
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}"


In [None]:
import jax
import jax.numpy as jnp
jax.config.update('jax_platform_name', 'gpu')
gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])


import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_exciting_systems

from exciting_exciting_systems.utils.signals import aprbs
from exciting_exciting_systems.evaluation.plotting_utils import plot_sequence

from exciting_exciting_systems.models.model_utils import simulate_ahead_with_env

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=2)

data_key, model_key, loader_key, key = jax.random.split(key, 4)
data_rng = PRNGSequence(data_key)

# Usage Presentation

## Exciting Environments:

- [Available on github](https://github.com/ExcitingSystems/exciting-environments)
- Uses [jax](https://github.com/google/jax) to provide simulators for physical systems that are ...
  - ... differentiable
  - ... vectorizable
  - ... just-in-time compilable

In [None]:
import exciting_environments as excenvs

In [None]:
batch_size = 10
tau = 2e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    action_constraints={"torque": 8},
    static_params={"g": 9.81, "l": 1, "m": 1},
    solver=diffrax.Tsit5(),
    tau=tau,
)

In [None]:
obs, state = env.reset()
n_steps = 1000

In [None]:
k = jnp.linspace(0, n_steps-1, n_steps)[..., None]

n_freq = 20

A_l = jax.random.uniform(key=next(data_rng), shape=(1, n_freq,), dtype=jnp.float32, minval=-1, maxval=1)
phi_l = jax.random.uniform(key=next(data_rng), shape=(1, n_freq,), dtype=jnp.float32, minval=-jnp.pi, maxval=jnp.pi)
w_l = jnp.linspace(1, n_freq, n_freq)

sin_actions = jnp.sum(A_l * jnp.sin(2 * jnp.pi * w_l/ n_steps * k + phi_l), axis=-1)
plt.plot(sin_actions)
sin_actions /= jnp.max(jnp.abs(sin_actions))
plt.plot(sin_actions)

In [None]:
n_steps = 1000
aprbs_actions = aprbs(n_steps, batch_size, 20, 100, next(data_rng))

fig, axs = plt.subplots(1, 2, figsize=(10, 4), sharey=True)

axs[0].title.set_text("APRBS")
axs[0].plot(jnp.linspace(0, (n_steps-1) *tau, n_steps), aprbs_actions[0])

axs[1].title.set_text("Sinusoidal")
axs[1].plot(jnp.linspace(0, (n_steps-1) *tau, n_steps), sin_actions)

axs[0].set_xlabel(r"$t$ in seconds")
axs[1].set_xlabel(r"$t$ in seconds")
axs[0].set_ylabel(r"$\bm{u}_k$")
axs[0].grid()
axs[1].grid()

plt.tight_layout()

# plt.savefig("results/plots/input_examples.pdf")
plt.show()

In [None]:
n_steps = 400
actions = aprbs(n_steps, batch_size, 10, 50, next(data_rng))

In [None]:
observations, _, _, _, _ = env.vmap_sim_ahead(
    state, actions, env.tau, env.tau
)

print("actions.shape:", actions.shape)
print("observations.shape:", observations.shape)
mpl.rcParams.update({'font.size': 20})
fig,axs = plt.subplots(nrows=1, ncols=3, figsize=(18, 6))

print(" \n One of the trajectories:")
fig, axs = plot_sequence(
    fig=fig,
    axs=axs,
    observations=observations[0, ...],
    actions=actions[0, ...],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$M$"],
);

# plt.savefig("results/plots/excenvs_pendulum_simulation_example.pdf")
plt.show()

#### Speed:

In [None]:
batch_size = 1
tau = 2e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    action_constraints={"torque": 8},
    static_params={"g": 9.81, "l": 1, "m": 1},
    solver=diffrax.Euler(),
    tau=tau,
)

obs, state = env.reset()
obs = obs[0]

n_steps = 999

In [None]:
actions = aprbs(n_steps, 10, 5, 10, next(data_rng))

In [None]:
jax_observations = []

for acts in actions:
    start = time.time()
    observations = simulate_ahead_with_env(env, obs, state, acts).block_until_ready()
    end = time.time()
    print(end - start)

    jax_observations.append(observations)

In [None]:
from exciting_exciting_systems.related_work.np_reimpl.pendulum import Pendulum
from exciting_exciting_systems.related_work.np_reimpl.env_utils import simulate_ahead_with_env as np_simulate_ahead_with_env

In [None]:
batch_size = 1
tau = 2e-2

np_env = Pendulum(
    batch_size=batch_size,
    tau=tau,
    max_torque=8
)

obs, env_state = np_env.reset()
obs = obs.astype(np.float32)[0]
env_state = env_state.astype(np.float32)[0]

In [None]:
np_observations = []
for acts in actions:
    acts = np.array(acts)   
    start = time.time()
    observations, _ = np_simulate_ahead_with_env(
        np_env,
        obs,
        env_state,
        acts,
    )
    end = time.time()
    print(end - start)

    np_observations.append(observations)

In [None]:
for (jax_obs, np_obs) in zip(jax_observations[:3], np_observations[:3]):
    plt.plot(jax_obs[..., 0])
    plt.plot(np_obs[..., 0])
    plt.show()

    plt.plot(jax_obs[..., 1])
    plt.plot(np_obs[..., 1])
    plt.show()

#### Automatic Differentation:

In [None]:
batch_size = 20
tau = 2e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    action_constraints={"torque": 8},
    static_params={"g": 9.81, "l": 1, "m": 1},
    solver=diffrax.Euler(),
    tau=tau,
)

obs, state = env.reset()
n_steps = 999

actions = aprbs(n_steps, batch_size, 5, 10, next(data_rng))

In [None]:
def loss_function(
    target_observations, actions, env, obs, state
):
    observations = jax.vmap(simulate_ahead_with_env, in_axes=(None, 0, 0, 0))(
        env, obs, state, actions
    )
    return jnp.mean((observations - target_observations)**2)

In [None]:
grad_function = jax.grad(loss_function, argnums=[1])

In [None]:
grads = grad_function(
    jnp.ones((20, 1000, 2)) * 0.5,
    actions,
    env=env,
    obs=obs,
    state=state
)[0]
print("actions.shape", actions.shape)

print("grads.shape", grads.shape)

## Differentiable Model Predictive Excitation (DMPE):

- excite systems by simultaneous identification and input optimization

In [None]:
batch_size = 1
tau = 2e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    action_constraints={"torque": 8},
    static_params={"g": 9.81, "l": 1, "m": 1},
    solver=diffrax.Tsit5(),
    tau=tau,
)

In [None]:
def featurize_theta(obs):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is 
    transformed to sin(phi) and cos(phi) for comparison in the loss."""
    feat_obs = jnp.stack([jnp.sin(obs[..., 0] * jnp.pi), jnp.cos(obs[..., 0] * jnp.pi), obs[..., 1]], axis=-1)
    return feat_obs

In [None]:
bandwidth = 0.05
n_prediction_steps = 50

dim_obs_space = 2
dim_action_space = 1

dim = dim_obs_space + dim_action_space
points_per_dim = 50
n_grid_points=points_per_dim**dim

n_timesteps = 15_000

In [None]:
obs, state = env.reset()
obs = obs[0]

observations = jnp.zeros((n_timesteps, dim_obs_space))
observations = observations.at[0].set(obs)
actions = jnp.zeros((n_timesteps-1, dim_action_space))

proposed_actions = aprbs(n_prediction_steps, batch_size, 1, 10, next(data_rng))[0]

In [None]:
from exciting_exciting_systems.models.model_training import ModelTrainer
from exciting_exciting_systems.excitation import loss_function, Exciter
from exciting_exciting_systems.models import NeuralEulerODEPendulum
from exciting_exciting_systems.utils.density_estimation import (
    DensityEstimate, build_grid_2d
)

In [None]:
exciter = Exciter(
    grad_loss_function=jax.grad(loss_function, argnums=(2)),
    excitation_optimizer=optax.adabelief(1e-2),
    n_opt_steps=50,
    tau=tau,
    target_distribution=jnp.ones(shape=(n_grid_points, 1)) * 1 / (1 - (-1))**dim,
    rho_obs=1e10,
    rho_act=1e10
)

model_trainer = ModelTrainer(
    start_learning=n_prediction_steps,
    training_batch_size=128,
    n_train_steps=1,
    sequence_length=n_prediction_steps,
    featurize=featurize_theta,
    model_optimizer=optax.adabelief(1e-4),
    tau=tau
)

density_estimate = DensityEstimate(
    p=jnp.zeros([n_grid_points, 1]),
    x_g=exciting_exciting_systems.utils.density_estimation.build_grid_3d(
        low=-1,
        high=1,
        points_per_dim=points_per_dim
    ),
    bandwidth=jnp.array([bandwidth]),
    n_observations=jnp.array([0])
)


model = NeuralEulerODEPendulum(
    obs_dim=dim_obs_space,
    action_dim=dim_action_space,
    width_size=128,
    depth=3,
    key=model_key
)
opt_state_model = model_trainer.model_optimizer.init(eqx.filter(model, eqx.is_inexact_array))

In [None]:
from exciting_exciting_systems.algorithms import excite_and_fit

In [None]:
# dmpe_observations, dmpe_actions, model, density_estimate, losses = excite_and_fit(
#     n_timesteps=n_timesteps,
#     env=env,
#     model=model,
#     obs=obs,
#     state=state,
#     proposed_actions=proposed_actions,
#     exciter=exciter,
#     model_trainer=model_trainer,
#     density_estimate=density_estimate,
#     observations=observations,
#     actions=actions,
#     opt_state_model=opt_state_model,
#     loader_key=loader_key,
#     plot_every=2500,
# )

# np.save("results/obs_dmpe.npy", np.stack(dmpe_observations))
# np.save("results/act_dmpe.npy", np.stack(dmpe_actions))

dmpe_observations = np.load("results/obs_dmpe.npy")
dmpe_actions = np.load("results/act_dmpe.npy")

In [None]:
fig, axs = plot_sequence(
    dmpe_observations,
    dmpe_actions,
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.savefig("results/plots/dmpe_example_result.pdf")

plt.plot()

## sGOATs reimplementation:

In [None]:
from exciting_exciting_systems.related_work.np_reimpl.pendulum import Pendulum
from exciting_exciting_systems.related_work.algorithms import excite_with_sGOATs, excite_with_GOATs, excite_with_iGOATs

In [None]:
def featurize_theta(obs_action):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is 
    transformed to sin(phi) and cos(phi) for comparison in the loss."""

    feat_obs_action = np.stack([np.sin(obs_action[..., 0] * np.pi), np.cos(obs_action[..., 0] * np.pi)], axis=-1)
    feat_obs_action = np.concatenate([feat_obs_action, obs_action[..., 1:]], axis=-1)
    
    return feat_obs_action

In [None]:
batch_size = 1
tau = 2e-2

env = Pendulum(
    batch_size=batch_size,
    tau=tau,
    max_torque=8
)

In [None]:
# all_observations = []
# all_actions = []

# all_observations, all_actions = excite_with_sGOATs(
#     n_amplitudes=600,
#     n_amplitude_groups=6,
#     reuse_observations=True,
#     all_observations=all_observations,
#     all_actions=all_actions,
#     env=env,
#     bounds_duration=(1,50),
#     population_size=50,
#     n_generations=100,
#     # n_support_points=1600,
#     featurize=featurize_theta,
#     seed=0,
#     verbose=True
# )

# sgoats_observations = np.concatenate(all_observations)
# sgoats_actions = np.concatenate(all_actions)

# np.save("results/obs_sGOATS.npy", np.stack(sgoats_observations))
# np.save("results/act_sGOATS.npy", np.stack(sgoats_actions))


sgoats_observations = np.load("results/obs_sGOATs.npy")
sgoats_actions = np.load("results/act_sGOATs.npy")

In [None]:
print("sgoats actions.shape:", sgoats_actions.shape)
print("sgoats observations.shape:", sgoats_observations.shape)

fig, axs = plot_sequence(
    observations=sgoats_observations,
    actions=sgoats_actions[:-1, ...],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);

plt.savefig("results/plots/sGOATS_example_result.pdf")

plt.show()

In [None]:
number_of_singletons = 0
for idx in range(1, sgoats_actions.shape[0] - 1):
    if sgoats_actions[idx-1] != sgoats_actions[idx] and sgoats_actions[idx] != sgoats_actions[idx+1]:
        number_of_singletons +=1

print("number of singleton inputs: ", number_of_singletons)

In [None]:
number_of_cons_inputs = np.zeros(600)

j = 0
number_of_cons_inputs[j] += 1
for idx in range(sgoats_actions.shape[0] - 1):
    if sgoats_actions[idx] == sgoats_actions[idx+1]:
        number_of_cons_inputs[j] += 1
    else:
        number_of_cons_inputs[j] += 1
        j+=1

In [None]:
np.sum(number_of_cons_inputs == 50)

In [None]:
plt.hist(number_of_cons_inputs, bins=50)

## GOATS implementation:

In [None]:
# goats_observations, goats_actions = excite_with_GOATs(
#     n_amplitudes=600,
#     env=env,
#     bounds_duration=(1,50),
#     population_size=50,
#     n_generations=100,
#     featurize=featurize_theta,
#     seed=0,
#     verbose=True
# )

In [None]:
# print("goats actions.shape:", goats_actions.shape)
# print("goats observations.shape:", goats_observations.shape)

# fig, axs = plot_sequence(
#     observations=goats_observations,
#     actions=goats_actions,
#     tau=tau,
#     obs_labels=[r"$\theta$", r"$\omega$"],
#     action_labels=[r"$u$"],
# );

# plt.savefig("results/plots/GOATS_example_result.pdf")

# plt.show()

- I suspect the optimization problem to be way too hard?
- unable to stabilize

## iGOATS implementation:

In [None]:
# h = 2
# a = 2

# igoats_actions = []
# igoats_observations = []

# igoats_observations, igoats_actions = excite_with_iGOATs(
#     n_timesteps=15000,
#     env=env,
#     actions=igoats_actions,
#     observations=igoats_observations,
#     h=h,
#     a=a,
#     bounds_amplitude=[-1, 1],
#     bounds_duration=[1, 50],
#     population_size=20,
#     n_generations=20,
#     mean=np.hstack([np.zeros(h), np.ones(h) * 25]),
#     sigma=2.0,
#     featurize=featurize_theta
# )

# np.save("results/obs_iGOATS.npy", np.stack(igoats_observations))
# np.save("results/act_iGOATS.npy", np.stack(igoats_actions))


igoats_observations = np.load("results/obs_iGOATS.npy")
igoats_actions = np.load("results/act_iGOATS.npy")

In [None]:
print("goats actions.shape:", igoats_actions.shape)
print("goats observations.shape:", igoats_observations.shape)

fig, axs = plot_sequence(
    observations=igoats_observations,
    actions=igoats_actions,
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);

plt.savefig("results/plots/iGOATS_example_result.pdf")

plt.show()

## Comparison:

#### Qualitative:

- Note that sGOATS uses **explicit Euler**, while DMPE uses **Tsit5** as ODE solver in the simulation
- both algorithms consider $\mathbf{u}_k$ and $\mathbf{y}_k$, first only look at observation distribution

In [None]:
mpl.rcParams.update({'font.size': 20})

fig, ax = plt.subplots(1, 3, figsize=(18,6), sharey=True)
for (idx, observations), name in zip(enumerate([dmpe_observations, sgoats_observations, igoats_observations]), ["dmpe", "sgoats", "igoats"]):
    ax[idx].scatter(observations[..., 0], observations[..., 1], s=1)
    ax[idx].grid()
    ax[idx].title.set_text(name + " observations, timeseries")

fig.tight_layout()

# plt.savefig("results/plots/comparison_observation_space.pdf")
plt.show()

In [None]:
from exciting_exciting_systems.utils.density_estimation import update_density_estimate_multiple_observations

In [None]:
density_estimate = DensityEstimate(
    p=jnp.zeros([points_per_dim**2, 1]),
    x_g=build_grid_2d(low=-1, high=1, points_per_dim=points_per_dim),
    bandwidth=jnp.array([bandwidth]),
    n_observations=jnp.array([0])
)

dmpe_density_estimate = update_density_estimate_multiple_observations(
    density_estimate, dmpe_observations,
)

sgoats_density_estimate = update_density_estimate_multiple_observations(
    density_estimate, sgoats_observations,
)

igoats_density_estimate = update_density_estimate_multiple_observations(
    density_estimate, igoats_observations,
)

In [None]:
mpl.rcParams.update({'font.size': 20, 'figure.autolayout': True})
#mpl.rcParams.update()


fig, axs, cax = exciting_exciting_systems.evaluation.plotting_utils.plot_2d_kde_as_contourf(
    dmpe_density_estimate.p, dmpe_density_estimate.x_g, [r"$\theta$", r"$\omega$"]
)
# plt.savefig("results/plots/dmpe_example_kde.pdf")
plt.show()

fig, axs, cax = exciting_exciting_systems.evaluation.plotting_utils.plot_2d_kde_as_contourf(
    sgoats_density_estimate.p, sgoats_density_estimate.x_g, [r"$\theta$", r"$\omega$"]
)
# plt.savefig("results/plots/sgoats_example_kde.pdf")

fig, axs, cax = exciting_exciting_systems.evaluation.plotting_utils.plot_2d_kde_as_contourf(
    igoats_density_estimate.p, igoats_density_estimate.x_g, [r"$\theta$", r"$\omega$"]
)
# plt.savefig("results/plots/igoats_example_kde.pdf")

- joint distributions $\mathbf{u}_k$ and $\mathbf{y}_k$

In [None]:
points_per_dim = 50
n_grid_points=points_per_dim**3
density_estimate = DensityEstimate(
    p=jnp.zeros([n_grid_points, 1]),
    x_g=exciting_exciting_systems.utils.density_estimation.build_grid_3d(
        low=-1,
        high=1,
        points_per_dim=points_per_dim
    ),
    bandwidth=jnp.array([bandwidth]),
    n_observations=jnp.array([0])
)


dmpe_density_estimate = update_density_estimate_multiple_observations(
    density_estimate, jnp.concatenate([dmpe_observations[0:-1, :], dmpe_actions], axis=-1),
)

sgoats_density_estimate = update_density_estimate_multiple_observations(
    density_estimate, jnp.concatenate([sgoats_observations, sgoats_actions], axis=-1),
)

igoats_density_estimate = update_density_estimate_multiple_observations(
    density_estimate, jnp.concatenate([igoats_observations[0:-1, :], igoats_actions], axis=-1),
)

In [None]:
x_plot = dmpe_density_estimate.x_g.reshape((points_per_dim, points_per_dim, points_per_dim, 3))
ims = []
for i in range(points_per_dim):
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    axs[0].contourf(
        dmpe_density_estimate.p.reshape((points_per_dim, points_per_dim, points_per_dim))[:, :, i],
        antialiased=False,
        levels=100,
        alpha=0.9,
        cmap=plt.cm.coolwarm
    )
    axs[1].contourf(
        sgoats_density_estimate.p.reshape((points_per_dim, points_per_dim, points_per_dim))[:, :, i],
        antialiased=False,
        levels=100,
        alpha=0.9,
        cmap=plt.cm.coolwarm
    )

    axs[0].title.set_text("dmpe: u_k = " + str(jnp.linspace(-1, 1, points_per_dim)[i]))
    axs[1].title.set_text("sgoats: u_k = " + str(jnp.linspace(-1, 1, points_per_dim)[i]))
    plt.show()

#### Quantitative:

- metrics used for comparison:
  - Jensen Shannon divergence (JSD) **<- optimization metric for dmpe optimization**
  - MC unifrom sampling distribution approximation (MCUDSA) **<- optimization metric for sGOATs**
  - Audze-Eglais (AE)
- metric used for iGOATs optimization:
  - Maximum nearest neighbor sequence (MNNS) (without penalty, only for MISO systems?)

In [None]:
import exciting_exciting_systems

In [None]:
from exciting_exciting_systems.utils.metrics import JSDLoss
from exciting_exciting_systems.related_work.np_reimpl.metrics import (
    MC_uniform_sampling_distribution_approximation, audze_eglais
)
from exciting_exciting_systems.related_work.excitation_utils import latin_hypercube_sampling

**JSD:**

In [None]:
dmpe_jsd_loss = JSDLoss(
    p=dmpe_density_estimate.p / jnp.sum(dmpe_density_estimate.p),
    q=exciter.target_distribution / jnp.sum(exciter.target_distribution),
)
print("dmpe jsd loss: ", dmpe_jsd_loss)

sgoats_jsd_loss = JSDLoss(
    p=sgoats_density_estimate.p / jnp.sum(sgoats_density_estimate.p),
    q=exciter.target_distribution / jnp.sum(exciter.target_distribution),
)
print("sgoats jsd loss: ", sgoats_jsd_loss)

igoats_jsd_loss = JSDLoss(
    p=igoats_density_estimate.p / jnp.sum(igoats_density_estimate.p),
    q=exciter.target_distribution / jnp.sum(exciter.target_distribution),
)
print("igoats jsd loss: ", igoats_jsd_loss)

**MCUDSA:**

In [None]:
support_points = latin_hypercube_sampling(d=3, n=20**3)

dmpe_mcudsa_loss = MC_uniform_sampling_distribution_approximation(
    data_points=np.concatenate([dmpe_observations[0:-1, :], dmpe_actions], axis=-1),
    support_points=support_points
)
print("dmpe mcudsa loss: ", dmpe_mcudsa_loss)

sgoats_mcudsa_loss = MC_uniform_sampling_distribution_approximation(
    data_points=np.concatenate([sgoats_observations, sgoats_actions], axis=-1),
    support_points=support_points
)
print("sgoats mcudsa loss: ", sgoats_mcudsa_loss)

igoats_mcudsa_loss = MC_uniform_sampling_distribution_approximation(
    data_points=np.concatenate([igoats_observations[0:-1, :], igoats_actions], axis=-1),
    support_points=support_points
)
print("igoats mcudsa loss: ", igoats_mcudsa_loss)

In [None]:
support_points = latin_hypercube_sampling(d=3, n=20**3)

dmpe_mcudsa_loss = MC_uniform_sampling_distribution_approximation(
    data_points=featurize_theta(np.concatenate([dmpe_observations[0:-1, :], dmpe_actions], axis=-1)),
    support_points=featurize_theta(support_points)
)
print("dmpe mcudsa loss: ", dmpe_mcudsa_loss)

sgoats_mcudsa_loss = MC_uniform_sampling_distribution_approximation(
    data_points=featurize_theta(np.concatenate([sgoats_observations, sgoats_actions], axis=-1)),
    support_points=featurize_theta(support_points)
)
print("sgoats mcudsa loss: ", sgoats_mcudsa_loss)

igoats_mcudsa_loss = MC_uniform_sampling_distribution_approximation(
    data_points=featurize_theta(np.concatenate([igoats_observations[0:-1, :], igoats_actions], axis=-1)),
    support_points=featurize_theta(support_points)
)
print("igoats mcudsa loss: ", igoats_mcudsa_loss)

**AE:**

In [None]:
dmpe_ae_loss = audze_eglais(np.concatenate([dmpe_observations[0:-1, :], dmpe_actions], axis=-1))
print("dmpe ae loss: ", dmpe_ae_loss)

sgoats_ae_loss = audze_eglais(np.concatenate([sgoats_observations, sgoats_actions], axis=-1))
print("sgoats ae loss: ", sgoats_ae_loss)

igoats_ae_loss = audze_eglais(np.concatenate([igoats_observations[0:-1, :], igoats_actions], axis=-1))
print("igoats ae loss: ", igoats_ae_loss)

In [None]:
dmpe_ae_loss = audze_eglais(featurize_theta(np.concatenate([dmpe_observations[0:-1, :], dmpe_actions], axis=-1)))
print("dmpe ae loss: ", dmpe_ae_loss)

sgoats_ae_loss = audze_eglais(featurize_theta(np.concatenate([sgoats_observations, sgoats_actions], axis=-1)))
print("sgoats ae loss: ", sgoats_ae_loss)

igoats_ae_loss = audze_eglais(featurize_theta(np.concatenate([igoats_observations[0:-1, :], igoats_actions], axis=-1)))
print("igoats ae loss: ", igoats_ae_loss)

## dev compression

observe:
- sign changes in $\theta$
- strong curvature (change in $\theta$)

In [None]:
from exciting_exciting_systems.related_work.excitation_utils import compress_datapoints

In [None]:
N = 1000
%lprun -f compress_datapoints compress_datapoints(sgoats_observations[:N], N_c=100, feature_dimension=1)

In [None]:
%timeit compressed_data, indices = compress_datapoints(sgoats_observations[:N], N_c=100, feature_dimension=1)

In [None]:
compressed_data, indices = compress_datapoints(sgoats_observations[:N], N_c=100, feature_dimension=1)

In [None]:
plt.plot(np.linspace(0, N-1, N), sgoats_observations[:N, 1])
plt.plot(np.linspace(0, N-1, N)[indices], compressed_data[..., 1], 'r.')
plt.show()

In [None]:
plt.plot(np.linspace(0, N-1, N), sgoats_observations[:N, 0])
plt.plot(np.linspace(0, N-1, N)[indices], compressed_data[..., 0], 'r.')

#### KDE for feature space:

In [None]:
points_per_dim = 50
n_grid_points=points_per_dim**3
density_estimate = DensityEstimate(
    p=jnp.zeros([n_grid_points, 1]),
    x_g=exciting_exciting_systems.utils.density_estimation.build_grid_3d(
        low=-1,
        high=1,
        points_per_dim=points_per_dim
    ),
    bandwidth=jnp.array([bandwidth]),
    n_observations=jnp.array([0])
)


dmpe_density_estimate = update_density_estimate_multiple_observations(
    density_estimate, featurize_theta(dmpe_observations),
)

sgoats_density_estimate = update_density_estimate_multiple_observations(
    density_estimate, featurize_theta(sgoats_observations),
)

igoats_density_estimate = update_density_estimate_multiple_observations(
    density_estimate, featurize_theta(igoats_observations),
)

In [None]:
x_plot = dmpe_density_estimate.x_g.reshape((points_per_dim, points_per_dim, points_per_dim, 3))
ims = []
for i in range(points_per_dim):
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    axs[0].contourf(
        dmpe_density_estimate.p.reshape((points_per_dim, points_per_dim, points_per_dim))[:, :, i],
        antialiased=False,
        levels=100,
        alpha=0.9,
        cmap=plt.cm.coolwarm
    )
    axs[1].contourf(
        sgoats_density_estimate.p.reshape((points_per_dim, points_per_dim, points_per_dim))[:, :, i],
        antialiased=False,
        levels=100,
        alpha=0.9,
        cmap=plt.cm.coolwarm
    )

    axs[0].title.set_text("dmpe: omega = " + str(jnp.linspace(-1, 1, points_per_dim)[i]))
    axs[1].title.set_text("sgoats: omega = " + str(jnp.linspace(-1, 1, points_per_dim)[i]))
    plt.show()