In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import pathlib
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

from functools import partial
from typing import Callable

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
# plt.rcParams['text.usetex'] = True
import matplotlib as mpl
mpl.rcParams['text.usetex'] = True
mpl.rcParams.update({'font.size': 10})
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}\usepackage{amsmath}"
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp
import jax_dataclasses as jdc
from jax.tree_util import tree_flatten, tree_unflatten

# jax.config.update('jax_platform_name', 'cpu')
# jax.config.update("jax_debug_nans", True)
# gpus = jax.devices()
# jax.config.update("jax_default_device", gpus[0])
# jax.config.update("jax_enable_x64", True)

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

import exciting_exciting_systems
from exciting_exciting_systems.models import NeuralEulerODE
from exciting_exciting_systems.models.model_utils import simulate_ahead_with_env
from exciting_exciting_systems.models.model_training import ModelTrainer
from exciting_exciting_systems.excitation import loss_function, Exciter

from exciting_exciting_systems.utils.density_estimation import (
    select_bandwidth, update_density_estimate_single_observation, update_density_estimate_multiple_observations, DensityEstimate
)
from exciting_exciting_systems.utils.signals import aprbs
from exciting_exciting_systems.evaluation.plotting_utils import (
    plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction, plot_model_performance, plot_2d_kde_as_contourf
)
from exciting_exciting_systems.models.model_utils import ModelEnvWrapperFluidTank

---

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=33) # 21)

data_key, model_key, loader_key, key = jax.random.split(key, 4)
data_rng = PRNGSequence(data_key)

In [None]:
env_params = dict(
    batch_size=1,
    tau=5,
    max_height=3,
    max_inflow=0.2,
    base_area=jnp.pi,
    orifice_area=jnp.pi * 0.1**2,
    c_d=0.6,
    g=9.81,
    env_solver=diffrax.Tsit5(),
)
env = excenvs.make(
    "FluidTank-v0",
    physical_constraints=dict(height=env_params["max_height"]),
    action_constraints=dict(inflow=env_params["max_inflow"]),
    static_params=dict(
        base_area=env_params["base_area"],
        orifice_area=env_params["orifice_area"],
        c_d=env_params["c_d"],
        g=env_params["g"],
    ),
    tau=env_params["tau"],
    solver=env_params["env_solver"],
)


In [None]:
obs, state = env.reset()
obs = obs[0]
n_steps = 999

actions = aprbs(n_steps, 1, 10, 100, next(data_rng))[0]

In [None]:
observations, state = simulate_ahead_with_env(env, obs, state, actions)

print("actions.shape:", actions.shape)
print("observations.shape:", observations.shape)

print(" \n One of the trajectories:")
fig, axs = plot_sequence(
    observations=observations,
    actions=actions,
    tau=env.tau,
    obs_labels=[r"$h$"],
    action_labels=[r"$u$"],
);
plt.show()

In [None]:
# sgoats
seed=0


env_params = dict(
    batch_size=1,
    tau=5,
    max_height=3,
    max_inflow=0.2,
    base_area=jnp.pi,
    orifice_area=jnp.pi * 0.1**2,
    c_d=0.6,
    g=9.81,
    env_solver=diffrax.Tsit5(),
)
env = excenvs.make(
    "FluidTank-v0",
    physical_constraints=dict(height=env_params["max_height"]),
    action_constraints=dict(inflow=env_params["max_inflow"]),
    static_params=dict(
        base_area=env_params["base_area"],
        orifice_area=env_params["orifice_area"],
        c_d=env_params["c_d"],
        g=env_params["g"],
    ),
    tau=env_params["tau"],
    solver=env_params["env_solver"],
)
alg_params = dict(
    n_amplitudes=779,
    n_amplitude_groups=41,
    reuse_observations=True,
    bounds_duration=(5, 50),
    population_size=50,
    n_generations=25,
    compress_data=True,
    compression_target_N=500,
    compression_dist_th=0.1,
    compression_feature_dim=-2,
    rho_obs=1e3,
    rho_act=1e3,
    featurize=lambda x: x,
)
exp_params = dict(
    seed=int(seed),
    alg_params=alg_params,
    env_params=env_params,
)

# setup PRNG
rng = np.random.default_rng(seed=seed)

# run excitation algorithm
observations, actions = exciting_exciting_systems.related_work.algorithms.excite_with_sGOATS(
    n_amplitudes=alg_params["n_amplitudes"],
    n_amplitude_groups=alg_params["n_amplitude_groups"],
    reuse_observations=alg_params["reuse_observations"],
    env=env,
    bounds_duration=alg_params["bounds_duration"],
    population_size=alg_params["population_size"],
    n_generations=alg_params["n_generations"],
    featurize=alg_params["featurize"],
    compress_data=alg_params["compress_data"],
    compression_target_N=alg_params["compression_target_N"],
    compression_dist_th=alg_params["compression_dist_th"],
    compression_feat_dim=alg_params["compression_feature_dim"],
    rho_obs=alg_params["rho_obs"],
    rho_act=alg_params["rho_act"],
    rng=np.random.default_rng(seed=exp_params["seed"]),
    verbose=False,
    plot_every_subsequence=True,
)

In [None]:
# pm-dmpe
env_params = dict(
    batch_size=1,
    tau=5,
    max_height=3,
    max_inflow=0.2,
    base_area=jnp.pi,
    orifice_area=jnp.pi * 0.1**2,
    c_d=0.6,
    g=9.81,
    env_solver=diffrax.Tsit5(),
)
env = excenvs.make(
    "FluidTank-v0",
    physical_constraints=dict(height=env_params["max_height"]),
    action_constraints=dict(inflow=env_params["max_inflow"]),
    static_params=dict(
        base_area=env_params["base_area"],
        orifice_area=env_params["orifice_area"],
        c_d=env_params["c_d"],
        g=env_params["g"],
    ),
    tau=env_params["tau"],
    solver=env_params["env_solver"],
)

alg_params = dict(
    bandwidth=None,
    n_prediction_steps=10,
    points_per_dim=50,
    action_lr=1e-1,
    n_opt_steps=10,
    rho_obs=1,
    rho_act=1,
    penalty_order=1,
    clip_action=True,
)
alg_params["bandwidth"] = select_bandwidth(
    delta_x=2,
    dim=env.physical_state_dim + env.action_dim,
    n_g=alg_params["points_per_dim"],
    percentage=0.3,
)

exp_params = dict(
    seed=None,
    n_timesteps=15_000,
    model_class=None,
    env_params=env_params,
    alg_params=alg_params,
    model_trainer_params=None,
    model_params=None,
    model_env_wrapper=ModelEnvWrapperFluidTank,
)

seed=4

In [None]:
exp_params["seed"] = int(seed)

# setup PRNG
key = jax.random.PRNGKey(seed=exp_params["seed"])
data_key, _, _, expl_key, key = jax.random.split(key, 5)
data_rng = PRNGSequence(data_key)

# initial guess
proposed_actions = aprbs(exp_params["alg_params"]["n_prediction_steps"], env.batch_size, 1, 10, next(data_rng))[0]

# run excitation algorithm
observations, actions, model, density_estimate, losses, proposed_actions = excite_with_dmpe(
    env,
    exp_params,
    proposed_actions,
    None,
    expl_key,
    500
)

In [None]:
# dmpe
env_params = dict(
    batch_size=1,
    tau=5,
    max_height=3,
    max_inflow=0.2,
    base_area=jnp.pi,
    orifice_area=jnp.pi * 0.1**2,
    c_d=0.6,
    g=9.81,
    env_solver=diffrax.Tsit5(),
)
env = excenvs.make(
    "FluidTank-v0",
    physical_constraints=dict(height=env_params["max_height"]),
    action_constraints=dict(inflow=env_params["max_inflow"]),
    static_params=dict(
        base_area=env_params["base_area"],
        orifice_area=env_params["orifice_area"],
        c_d=env_params["c_d"],
        g=env_params["g"],
    ),
    tau=env_params["tau"],
    solver=env_params["env_solver"],
)

alg_params = dict(
    bandwidth=None,
    n_prediction_steps=10,
    points_per_dim=50,
    action_lr=1e-1,
    n_opt_steps=10,
    rho_obs=1,
    rho_act=1,
    penalty_order=1,
    clip_action=True,
)
alg_params["bandwidth"] = select_bandwidth(
    delta_x=2,
    dim=env.physical_state_dim + env.action_dim,
    n_g=alg_params["points_per_dim"],
    percentage=0.3,
)

model_trainer_params = dict(
    start_learning=alg_params["n_prediction_steps"],
    training_batch_size=128,
    n_train_steps=1,
    sequence_length=alg_params["n_prediction_steps"],
    featurize=lambda x: x,
    model_lr=1e-4,
)
model_params = dict(obs_dim=env.physical_state_dim, action_dim=env.action_dim, width_size=128, depth=3, key=None)

exp_params = dict(
    seed=None,
    n_timesteps=1,
    model_class=NeuralEulerODE,
    env_params=env_params,
    alg_params=alg_params,
    model_trainer_params=model_trainer_params,
    model_params=model_params,
)

seed = 194

In [None]:
from exciting_exciting_systems.algorithms import excite_with_dmpe

In [None]:
key = jax.random.PRNGKey(seed=seed)
data_key, model_key, loader_key, expl_key, key = jax.random.split(key, 5)
data_rng = PRNGSequence(data_key)
exp_params["model_params"]["key"] = model_key

# initial guess
proposed_actions = aprbs(exp_params["alg_params"]["n_prediction_steps"], env.batch_size, 1, 10, next(data_rng))[0]

print(proposed_actions[0])

# run excitation algorithm
observations, actions, model, density_estimate, losses, proposed_actions = excite_with_dmpe(
    env, exp_params, proposed_actions, loader_key, expl_key, 1000
)

In [None]:
obs, state = env.reset()

In [None]:
obs

In [None]:
env.step(state, proposed_actions[0], env.env_properties)

In [None]:
proposed_actions

In [None]:
from exciting_exciting_systems.utils.metrics import JSDLoss
from exciting_exciting_systems.related_work.np_reimpl.metrics import (
    MC_uniform_sampling_distribution_approximation, audze_eglais
)
from exciting_exciting_systems.related_work.excitation_utils import latin_hypercube_sampling

In [None]:
dim = 2
points_per_dim = 30

target_distribution = jnp.ones(shape=(points_per_dim**dim, 1))[None]

dmpe_density_estimate = DensityEstimate.from_dataset(dmpe_observations, dmpe_actions, points_per_dim=points_per_dim, bandwidth=0.05)
sgoats_density_estimate = DensityEstimate.from_dataset(sgoats_observations, sgoats_actions[:-1, :], points_per_dim=points_per_dim, bandwidth=0.05)

dmpe_jsd_loss = JSDLoss(
    p=dmpe_density_estimate.p / jnp.sum(dmpe_density_estimate.p),
    q=target_distribution / jnp.sum(target_distribution),
)
print("dmpe jsd loss: ", dmpe_jsd_loss)

sgoats_jsd_loss = JSDLoss(
    p=sgoats_density_estimate.p / jnp.sum(sgoats_density_estimate.p),
    q=target_distribution / jnp.sum(target_distribution),
)
print("sgoats jsd loss: ", sgoats_jsd_loss)

In [None]:
support_points = latin_hypercube_sampling(d=dim, n=30**dim, rng=np.random.default_rng())

dmpe_mcudsa_loss = MC_uniform_sampling_distribution_approximation(
    data_points=np.concatenate([dmpe_observations[:-1, :], dmpe_actions], axis=-1),
    support_points=support_points
)
print("dmpe mcudsa loss: ", dmpe_mcudsa_loss)

sgoats_mcudsa_loss = MC_uniform_sampling_distribution_approximation(
    data_points=np.concatenate([sgoats_observations, sgoats_actions], axis=-1),
    support_points=support_points
)
print("sgoats mcudsa loss: ", sgoats_mcudsa_loss)

In [None]:
dmpe_ae_loss = audze_eglais(np.concatenate([dmpe_observations[:-1, :], dmpe_actions], axis=-1))
print("dmpe ae loss: ", dmpe_ae_loss)

sgoats_ae_loss = audze_eglais(np.concatenate([sgoats_observations, sgoats_actions], axis=-1))
print("sgoats ae loss: ", sgoats_ae_loss)