In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import time
import matplotlib.pyplot as plt
import numpy as np

In [None]:
import jax
import jax.numpy as jnp
jax.config.update('jax_platform_name', 'gpu')

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_exciting_systems

from exciting_exciting_systems.utils.signals import aprbs
from exciting_exciting_systems.evaluation.plotting_utils import plot_sequence

from exciting_exciting_systems.models.model_utils import simulate_ahead_with_env

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=9)

data_key, model_key, loader_key, key = jax.random.split(key, 4)
data_rng = PRNGSequence(data_key)

# Usage Presentation

## Exciting Environments:

- [Available on github](https://github.com/ExcitingSystems/exciting-environments)
- Uses [jax](https://github.com/google/jax) to provide simulators for physical systems that are ...
  - ... differentiable
  - ... vectorizable
  - ... just-in-time compilable

In [None]:
import exciting_environments as excenvs

In [None]:
batch_size = 10
tau = 2e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    action_constraints={"torque": 8},
    static_params={"g": 9.81, "l": 1, "m": 1},
    solver=diffrax.Tsit5(),
    tau=tau,
)

In [None]:
obs, state = env.reset()
n_steps = 999

actions = aprbs(n_steps, batch_size, 5, 10, next(data_rng))

In [None]:
observations, _, _, _ = env.vmap_sim_ahead(
    state, actions[..., 0] * env.env_properties.action_constraints.torque, env.tau, env.tau
)

print("actions.shape:", actions.shape)
print("observations.shape:", observations.shape)

print(" \n One of the trajectories:")
fig, axs = plot_sequence(
    observations=observations[0, ...],
    actions=actions[0, ...],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

#### Speed:

In [None]:
batch_size = 1
tau = 2e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    action_constraints={"torque": 8},
    static_params={"g": 9.81, "l": 1, "m": 1},
    solver=diffrax.Euler(),
    tau=tau,
)

obs, state = env.reset()
obs = obs[0]

n_steps = 999

In [None]:
for _ in range(10):
    actions = aprbs(n_steps, batch_size, 5, 10, next(data_rng))[0]

    start = time.time()
    observations = simulate_ahead_with_env(env, obs, state, actions)
    end = time.time()
    print(end - start)

In [None]:
from exciting_exciting_systems.related_work.np_reimpl.pendulum import Pendulum
from exciting_exciting_systems.related_work.np_reimpl.env_utils import simulate_ahead_with_env as np_simulate_ahead_with_env

In [None]:
batch_size = 1
tau = 2e-2

np_env = Pendulum(
    batch_size=batch_size,
    tau=tau,
    max_torque=8
)

actions = np.array(actions)

obs, env_state = np_env.reset()
obs = obs.astype(np.float32)
env_state = env_state.astype(np.float32)

In [None]:
for _ in range(10):
    actions = np.array(aprbs(n_steps, batch_size, 5, 10, next(data_rng)))
    
    start = time.time()
    observations, _ = np_simulate_ahead_with_env(
        np_env,
        obs,
        env_state,
        actions,
    )
    end = time.time()
    print(end - start)

#### Automatic Differentation:

In [None]:
batch_size = 20
tau = 2e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    action_constraints={"torque": 8},
    static_params={"g": 9.81, "l": 1, "m": 1},
    solver=diffrax.Euler(),
    tau=tau,
)

obs, state = env.reset()
n_steps = 999

actions = aprbs(n_steps, batch_size, 5, 10, next(data_rng))

In [None]:
def loss_function(
    target_observations, actions, env, obs, state
):
    observations = jax.vmap(simulate_ahead_with_env, in_axes=(None, 0, 0, 0))(
        env, obs, state, actions
    )
    return jnp.mean((observations - target_observations)**2)

In [None]:
grad_function = jax.grad(loss_function, argnums=[1])

In [None]:
grads = grad_function(
    jnp.ones((20, 1000, 2)) * 0.5,
    actions,
    env=env,
    obs=obs,
    state=state
)[0]
print("actions.shape", actions.shape)

print("grads.shape", grads.shape)

## Differentiable Model Predictive Excitation (DMPE):

- excite systems by simultaneous identification and input optimization

In [None]:
batch_size = 1
tau = 2e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    action_constraints={"torque": 8},
    static_params={"g": 9.81, "l": 1, "m": 1},
    solver=diffrax.Tsit5(),
    tau=tau,
)

In [None]:
def featurize_theta(obs):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is 
    transformed to sin(phi) and cos(phi) for comparison in the loss."""
    feat_obs = jnp.stack([jnp.sin(obs[..., 0] * jnp.pi), jnp.cos(obs[..., 0] * jnp.pi), obs[..., 1]], axis=-1)
    return feat_obs

In [None]:
bandwidth = 0.05
n_prediction_steps = 50

dim_obs_space = 2
dim_action_space = 1

dim = dim_obs_space + dim_action_space
points_per_dim = 50
n_grid_points=points_per_dim**dim

n_timesteps = 15_000

In [None]:
obs, state = env.reset()
obs = obs[0]

observations = jnp.zeros((n_timesteps, dim_obs_space))
observations = observations.at[0].set(obs)
actions = jnp.zeros((n_timesteps-1, dim_action_space))

proposed_actions = aprbs(n_prediction_steps, batch_size, 1, 10, next(data_rng))[0]

In [None]:
from exciting_exciting_systems.models.model_training import ModelTrainer
from exciting_exciting_systems.excitation import loss_function, Exciter
from exciting_exciting_systems.models import NeuralEulerODEPendulum
from exciting_exciting_systems.utils.density_estimation import (
    DensityEstimate, build_grid_2d
)

In [None]:
exciter = Exciter(
    grad_loss_function=jax.grad(loss_function, argnums=(3)),
    excitation_optimizer=optax.adabelief(1e-1),
    tau=tau,
    target_distribution=jnp.ones(shape=(n_grid_points, 1)) * 1 / (1 - (-1))**dim
)

model_trainer = ModelTrainer(
    start_learning=n_prediction_steps,
    training_batch_size=128,
    n_train_steps=1,
    sequence_length=n_prediction_steps,
    featurize=featurize_theta,
    model_optimizer=optax.adabelief(1e-4),
    tau=tau
)

density_estimate = DensityEstimate(
    p=jnp.zeros([n_grid_points, 1]),
    x_g=exciting_exciting_systems.utils.density_estimation.build_grid_3d(
        low=-1,
        high=1,
        points_per_dim=points_per_dim
    ),
    bandwidth=jnp.array([bandwidth]),
    n_observations=jnp.array([0])
)


model = NeuralEulerODEPendulum(
    obs_dim=dim_obs_space,
    action_dim=dim_action_space,
    width_size=128,
    depth=3,
    key=model_key
)
opt_state_model = model_trainer.model_optimizer.init(eqx.filter(model, eqx.is_inexact_array))

In [None]:
from exciting_exciting_systems.algorithms import excite_and_fit

In [None]:
dmpe_observations, dmpe_actions, model, density_estimate = excite_and_fit(
    n_timesteps=n_timesteps,
    env=env,
    model=model,
    obs=obs,
    state=state,
    proposed_actions=proposed_actions,
    exciter=exciter,
    model_trainer=model_trainer,
    density_estimate=density_estimate,
    observations=observations,
    actions=actions,
    opt_state_model=opt_state_model,
    loader_key=loader_key,
    plot_every=2500,
)

In [None]:
fig, axs = plot_sequence(
    dmpe_observations,
    dmpe_actions,
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.plot()

## sGOATs reimplementation:

In [None]:
from exciting_exciting_systems.related_work.np_reimpl.pendulum import Pendulum
from exciting_exciting_systems.related_work.algorithms import excite_with_sGOATs

In [None]:
def featurize_theta(obs_action):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is 
    transformed to sin(phi) and cos(phi) for comparison in the loss."""

    feat_obs_action = np.stack([np.sin(obs_action[..., 0] * np.pi), np.cos(obs_action[..., 0] * np.pi)], axis=-1)
    feat_obs_action = np.concatenate([feat_obs_action, obs_action[..., 1:]], axis=-1)
    
    return feat_obs_action

In [None]:
env = Pendulum(
    batch_size=batch_size,
    tau=tau,
    max_torque=8
)

In [None]:
# optimization routine, commented out because time-intensive computation

# all_observations = []
# all_actions = []

# all_observations, all_actions = excite_with_sGOATs(
#     n_amplitudes=600,
#     n_amplitude_groups=6,
#     reuse_observations=True,
#     all_observations=all_observations,
#     all_actions=all_actions,
#     env=env,
#     bounds_duration=(1,50),
#     population_size=20,
#     n_generations=50,
#     n_support_points=1600,
#     featurize=featurize_theta,
#     seed=0,
#     verbose=True
# )

# sgoats_observations = np.concatenate(all_observations)
# sgoats_actions = np.concatenate(all_actions)

sgoats_observations = np.load("results/obs_sGOATs.npy")
sgoats_actions = np.load("results/act_sGOATs.npy")

In [None]:
print("sgoats actions.shape:", sgoats_actions.shape)
print("sgoats observations.shape:", sgoats_observations.shape)

fig, axs = plot_sequence(
    observations=sgoats_observations,
    actions=sgoats_actions[:-1, ...],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

## Comparison:

#### Qualitative:

- Note that sGOATs uses **explicit Euler**, while DMPE uses **Tsit5** as ODE solver in the simulation
- both algorithms consider $\mathbf{u}_k$ and $\mathbf{y}_k$, first only look at observation distribution

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12,6), sharey=True)
for (idx, observations), name in zip(enumerate([dmpe_observations, sgoats_observations]), ["dmpe", "sgoats"]):
    ax[idx].scatter(observations[..., 0], observations[..., 1], s=1)
    ax[idx].grid()
    ax[idx].title.set_text(name + " observations, timeseries")

fig.tight_layout()
plt.show()

In [None]:
from exciting_exciting_systems.utils.density_estimation import update_density_estimate_multiple_observations

In [None]:
density_estimate = DensityEstimate(
    p=jnp.zeros([points_per_dim**2, 1]),
    x_g=build_grid_2d(low=-1, high=1, points_per_dim=points_per_dim),
    bandwidth=jnp.array([bandwidth]),
    n_observations=jnp.array([0])
)

dmpe_density_estimate = update_density_estimate_multiple_observations(
    density_estimate, dmpe_observations,
)

sgoats_density_estimate = update_density_estimate_multiple_observations(
    density_estimate, sgoats_observations,
)

In [None]:
fig, axs, cax = exciting_exciting_systems.evaluation.plotting_utils.plot_2d_kde_as_contourf(
    dmpe_density_estimate.p, dmpe_density_estimate.x_g, [r"$\theta$", r"$\omega$"]
)
plt.show()

fig, axs, cax = exciting_exciting_systems.evaluation.plotting_utils.plot_2d_kde_as_contourf(
    sgoats_density_estimate.p, sgoats_density_estimate.x_g, [r"$\theta$", r"$\omega$"]
)

- joint distributions $\mathbf{u}_k$ and $\mathbf{y}_k$

In [None]:
points_per_dim = 50
n_grid_points=points_per_dim**3
density_estimate = DensityEstimate(
    p=jnp.zeros([n_grid_points, 1]),
    x_g=exciting_exciting_systems.utils.density_estimation.build_grid_3d(
        low=-1,
        high=1,
        points_per_dim=points_per_dim
    ),
    bandwidth=jnp.array([bandwidth]),
    n_observations=jnp.array([0])
)


dmpe_density_estimate = update_density_estimate_multiple_observations(
    density_estimate, jnp.concatenate([dmpe_observations[0:-1, :], dmpe_actions], axis=-1),
)

sgoats_density_estimate = update_density_estimate_multiple_observations(
    density_estimate, jnp.concatenate([sgoats_observations, sgoats_actions], axis=-1),
)

In [None]:
x_plot = dmpe_density_estimate.x_g.reshape((points_per_dim, points_per_dim, points_per_dim, 3))
ims = []
for i in range(points_per_dim):
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    axs[0].contourf(
        dmpe_density_estimate.p.reshape((points_per_dim, points_per_dim, points_per_dim))[:, :, i],
        antialiased=False,
        levels=100,
        alpha=0.9,
        cmap=plt.cm.coolwarm
    )
    axs[1].contourf(
        sgoats_density_estimate.p.reshape((points_per_dim, points_per_dim, points_per_dim))[:, :, i],
        antialiased=False,
        levels=100,
        alpha=0.9,
        cmap=plt.cm.coolwarm
    )

    axs[0].title.set_text("dmpe: u_k = " + str(jnp.linspace(-1, 1, points_per_dim)[i]))
    axs[1].title.set_text("goats: u_k = " + str(jnp.linspace(-1, 1, points_per_dim)[i]))
    plt.show()

#### Quantitative:

- metrics used for comparison:
  - Jensen Shannon divergence (JSD) **<- optimization metric for dmpe optimization**
  - MC unifrom sampling distribution approximation (MCUDSA) **<- optimization metric for sGOATs**
  - Audze-Eglais (AE)
- metric used for iGOATs optimization:
  - Maximum nearest neighbor sequence (MNNS) (without penalty, only for MISO systems?)

In [None]:
import exciting_exciting_systems

In [None]:
from exciting_exciting_systems.utils.metrics import JSDLoss
from exciting_exciting_systems.related_work.np_reimpl.metrics import (
    MC_uniform_sampling_distribution_approximation, audze_eglais
)
from exciting_exciting_systems.related_work.excitation_utils import latin_hypercube_sampling

**JSD:**

In [None]:
dmpe_jsd_loss = JSDLoss(
    p=dmpe_density_estimate.p / jnp.sum(dmpe_density_estimate.p),
    q=exciter.target_distribution / jnp.sum(exciter.target_distribution),
)
print("dmpe jsd loss: ", dmpe_jsd_loss)

sgoats_jsd_loss = JSDLoss(
    p=sgoats_density_estimate.p / jnp.sum(sgoats_density_estimate.p),
    q=exciter.target_distribution / jnp.sum(exciter.target_distribution),
)
print("sgoats jsd loss: ", sgoats_jsd_loss)

**MCUDSA:**

In [None]:
support_points = latin_hypercube_sampling(d=3, n=20**3)

dmpe_mcudsa_loss = MC_uniform_sampling_distribution_approximation(
    data_points=np.concatenate([dmpe_observations[0:-1, :], dmpe_actions], axis=-1),
    support_points=support_points
)
print("dmpe mcudsa loss: ", dmpe_mcudsa_loss)

sgoats_mcudsa_loss = MC_uniform_sampling_distribution_approximation(
    data_points=np.concatenate([sgoats_observations, sgoats_actions], axis=-1),
    support_points=support_points
)
print("sgoats mcudsa loss: ", sgoats_mcudsa_loss)

**AE:**

In [None]:
dmpe_ae_loss = audze_eglais(np.concatenate([dmpe_observations[0:-1, :], dmpe_actions], axis=-1))
print("dmpe ae loss: ", dmpe_ae_loss)

sgoats_ae_loss = audze_eglais(np.concatenate([sgoats_observations, sgoats_actions], axis=-1))
print("sgoats ae loss: ", sgoats_ae_loss)

# Questions:

- how fast are your GOATs implementations roughly? How long did the optimization take and why is that so important?
  - i.e. could the optimization not just run through the night?
- Is your code fully MATLAB? Are you using the "standard" GA or did you implement one yourself?
- What about the compression? Why is it actually necessary?
- Why is Audze-Eglais not used in the Evaluation?
- I found that if the bounds of the input space are large I get stability issues with GOATs and sGOATs. Did you observe something similar?
- I did not really understand the penalty term in the MNNS loss function for iGOATs. Can you explain your idea?

# Feedback:

What do you think about my approach?
- What are possible vulnerablilities?

In [None]:
# np.save("results/obs_dmpe.npy", np.stack(dmpe_observations))
# np.save("results/act_dmpe.npy", np.stack(dmpe_actions))

In [None]:
# np.save("results/obs_sGOATs.npy", np.stack(sgoats_observations))
# np.save("results/act_sGOATs.npy", np.stack(sgoats_actions))