In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import pathlib
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

from functools import partial
from typing import Callable

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
# plt.rcParams['text.usetex'] = True
import matplotlib as mpl
mpl.rcParams['text.usetex'] = True
mpl.rcParams.update({'font.size': 10})
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}\usepackage{amsmath}"
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp
import jax_dataclasses as jdc
from jax.tree_util import tree_flatten, tree_unflatten

# jax.config.update('jax_platform_name', 'gpu')
gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])
# jax.config.update("jax_enable_x64", True)

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

import exciting_exciting_systems
from exciting_exciting_systems.models import NeuralEulerODE
from exciting_exciting_systems.models.model_utils import simulate_ahead_with_env
from exciting_exciting_systems.models.model_training import ModelTrainer
from exciting_exciting_systems.excitation import loss_function, Exciter

from exciting_exciting_systems.utils.density_estimation import (
    update_density_estimate_single_observation, update_density_estimate_multiple_observations, DensityEstimate
)
from exciting_exciting_systems.utils.signals import aprbs
from exciting_exciting_systems.evaluation.plotting_utils import (
    plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction, plot_model_performance, plot_2d_kde_as_contourf
)

---

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=33) # 21)

data_key, model_key, loader_key, key = jax.random.split(key, 4)
data_rng = PRNGSequence(data_key)

In [None]:
batch_size = 1
tau = 5e-1 # 5e-2

env = excenvs.make(
    "FluidTank-v0",
    tau=tau
)

In [None]:
obs, state = env.reset()
obs = obs[0]
n_steps = 999

actions = aprbs(n_steps, batch_size, 10, 100, next(data_rng))[0]

In [None]:
observations, state = simulate_ahead_with_env(env, obs, state, actions)

print("actions.shape:", actions.shape)
print("observations.shape:", observations.shape)

print(" \n One of the trajectories:")
fig, axs = plot_sequence(
    observations=observations,
    actions=actions,
    tau=tau,
    obs_labels=[r"$h$"],
    action_labels=[r"$u$"],
);
plt.show()

In [None]:
from exciting_exciting_systems.algorithms import excite_with_dmpe

In [None]:
from exciting_exciting_systems.evaluation.experiment_utils import (
    get_experiment_ids, load_experiment_results
)

In [None]:
idx = 0

NaN_indices = np.array([5, 61, 65, 69], dtype=int)

dmpe_results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/dmpe/fluid_tank/")
dmpe_exp_ids = np.array(get_experiment_ids(dmpe_results_path))
NaN_cases = dmpe_exp_ids[NaN_indices].tolist()

params, observations, actions, model = load_experiment_results(NaN_cases[idx], dmpe_results_path, NeuralEulerODE)

In [None]:
idx = 57

params, observations, actions, model = load_experiment_results(dmpe_exp_ids[idx], dmpe_results_path, NeuralEulerODE)

In [None]:
plot_sequence(observations, actions, env.tau, env.obs_description, env.action_description)

In [None]:
plot_sequence_and_prediction(
    observations=observations,
    actions=actions,
    tau=env.tau,
    obs_labels=env.obs_description,
    actions_labels=[r"$u$"],
    model=model,
    init_obs=observations[-1],
    proposed_actions=-jnp.ones(100)[..., None],
)

In [None]:
env_params = dict(
    batch_size=1,
    tau=5e-1,
    max_height=3,
    max_inflow=0.2,
    base_area=jnp.pi,
    orifice_area=jnp.pi * 0.1**2,
    c_d=0.6,
    g=9.81,
    env_solver=diffrax.Euler(),
)
env = excenvs.make(
    "FluidTank-v0",
    physical_constraints=dict(height=env_params["max_height"]),
    action_constraints=dict(inflow=env_params["max_inflow"]),
    static_params=dict(
        base_area=env_params["base_area"],
        orifice_area=env_params["orifice_area"],
        c_d=env_params["c_d"],
        g=env_params["g"],
    ),
    tau=env_params["tau"],
    solver=env_params["env_solver"],
)

alg_params = dict(
    bandwidth=0.025,
    n_prediction_steps=100,
    points_per_dim=50,
    action_lr=1e-1,
    n_opt_steps=10,
    rho_obs=1,
    rho_act=1,
    penalty_order=1,
    clip_action=True,
)

model_trainer_params = dict(
    start_learning=alg_params["n_prediction_steps"],
    training_batch_size=128,
    n_train_steps=1,
    sequence_length=alg_params["n_prediction_steps"],
    featurize=lambda x: x,
    model_lr=1e-4,
)
model_params = dict(obs_dim=env.physical_state_dim, action_dim=env.action_dim, width_size=128, depth=3, key=None)

exp_params = dict(
    seed=None,
    n_timesteps=5_000,
    model_class=NeuralEulerODE,
    env_params=env_params,
    alg_params=alg_params,
    model_trainer_params=model_trainer_params,
    model_params=model_params,
)
seeds = list(np.arange(1, 101))

seed = seeds[idx]

exp_params["seed"] = int(seed)

# setup PRNG
key = jax.random.PRNGKey(seed=exp_params["seed"])
data_key, model_key, loader_key, expl_key, key = jax.random.split(key, 5)
data_rng = PRNGSequence(data_key)
exp_params["model_params"]["key"] = model_key


In [None]:
model = exp_params["model_class"](**exp_params["model_params"])

In [None]:
def plot_model_trajectory(model, obs, actions, tau, obs_labels, action_labels):

    pred_observations = exciting_exciting_systems.models.model_utils.simulate_ahead(model, obs, actions, tau)

    fig, axs = plot_sequence(
        observations=pred_observations,
        actions=actions,
        tau=tau,
        obs_labels=obs_labels,
        action_labels=action_labels
    )

In [None]:
obs, state = env.reset()
obs = obs[0]
 
plot_model_trajectory(
    model,
    obs=obs,
    actions=+jnp.ones(10)[..., None],
    tau=env.tau,
    obs_labels=env.obs_description,
    action_labels=env.action_description,
)

In [None]:
obs, state = env.reset()
obs = obs[0]
 
plot_model_trajectory(
    exciting_exciting_systems.models.model_utils.ModelEnvWrapper(env),
    obs=obs,
    actions=jnp.ones(10)[..., None],
    tau=env.tau,
    obs_labels=env.obs_description,
    action_labels=env.action_description,
)

In [None]:
observations, actions, model, density_estimate, losses, proposed_actions = excite_with_dmpe(
    env, params, proposed_actions, loader_key, expl_key, plot_every=100
)

In [None]:
from copy import deepcopy
from exciting_exciting_systems.models.model_utils import simulate_ahead

In [None]:
from exciting_exciting_systems.excitation.excitation_utils import optimize_actions

In [None]:
init_obs = observations[~jnp.isnan(observations)][-1][None]
init_state = env.State(physical_state=env.PhysicalState(height=(obs + 1) * 1.5), PRNGKey=0, optional=env.Optional(0))
action = actions[~jnp.isnan(observations)[:-1]][-1][None]


k = observations.shape[0]

prpsed_actions = deepcopy(proposed_actions)
dnsty_estimate = deepcopy(density_estimate)

# env.step(state, action, env.env_properties)

In [None]:
prpsed_actions.shape

In [None]:
obs = observations[~jnp.isnan(observations)][-1][None]
state = env.State(physical_state=env.PhysicalState(height=(obs + 1) * 1.5), PRNGKey=0, optional=env.Optional(0))

In [None]:
obs

In [None]:
test_key = jax.random.PRNGKey(seed=2)
test_rng = PRNGSequence(test_key)

exciter = Exciter(
    loss_function=loss_function,
    grad_loss_function=jax.value_and_grad(loss_function, argnums=(2)),
    excitation_optimizer=optax.adabelief(1e-3),
    tau=tau,
    n_opt_steps=100,
    target_distribution=jnp.ones(shape=(50**2, 1)) * 1 / (1 - (-1))**2,
    rho_obs=1,
    rho_act=1,
    penalty_order=1,
    clip_action=True
)

In [None]:
new_observations = []
        
prpsed_actions, loss = optimize_actions(#, in_axes=(None, None, 0, None, None, 0, None, None, None, None, None, None, None))(
    exciter.loss_function,
    exciter.grad_loss_function,
    prpsed_actions,
    model,
    exciter.excitation_optimizer,
    obs,
    dnsty_estimate,
    exciter.n_opt_steps,
    exciter.tau,
    exciter.target_distribution,
    exciter.rho_obs,
    exciter.rho_act,
    exciter.penalty_order
)

action = prpsed_actions[0, :]

action = jax.lax.cond(
    exciter.clip_action,
    jnp.clip,
    lambda action, min_val, max_val: action,
    action,
    -1,
    1,
)

next_proposed_actions = prpsed_actions.at[:-1, :].set(prpsed_actions[1:, :])

new_proposed_actions = jax.random.uniform(key=test_key, minval=-1, maxval=1)
next_proposed_actions = next_proposed_actions.at[-1, :].set(new_proposed_action)



dnsty_estimate = update_density_estimate_single_observation(
    dnsty_estimate, jnp.concatenate([obs, action], axis=-1)
)

prpsed_actions = next_proposed_actions

obs, _, _, _, state = env.step(state, action, env.env_properties)

new_observations.append(obs)

obs

In [None]:
env_params = dict(batch_size=1, tau=5e-1, env_solver=diffrax.Euler())
env = excenvs.make(
    "FluidTank-v0",
    tau=env_params["tau"],
    solver=env_params["env_solver"]
)


# alg_params = dict(
#     bandwidth=0.1, n_prediction_steps=100, points_per_dim=50, action_lr=1e-1, n_opt_steps=10, rho_obs=1, rho_act=1
# )

alg_params = dict(
    bandwidth=0.05, n_prediction_steps=100, points_per_dim=50, action_lr=1e-3, n_opt_steps=100, rho_obs=1, rho_act=1, penalty_order=1
)

model_trainer_params = dict(
    start_learning=alg_params["n_prediction_steps"],
    training_batch_size=128,
    n_train_steps=1,
    sequence_length=alg_params["n_prediction_steps"],
    featurize=lambda obs: obs,
    model_lr=1e-4,
)
model_params = dict(obs_dim=env.physical_state_dim, action_dim=env.action_dim, width_size=128, depth=3, key=None)

exp_params = dict(
    seed=seed,
    n_timesteps=5_000,
    model_class=NeuralEulerODE,
    env_params=env_params,
    alg_params=alg_params,
    model_trainer_params=model_trainer_params,
)

key = jax.random.PRNGKey(seed=exp_params["seed"])
data_key, model_key, loader_key, expl_key, key = jax.random.split(key, 5)
data_rng = PRNGSequence(data_key)

model_params["key"] = model_key
exp_params["model_params"] = model_params

# initial guess
proposed_actions = aprbs(alg_params["n_prediction_steps"], env.batch_size, 1, 10, next(data_rng))[0]

# run excitation algorithm
dmpe_observations, dmpe_actions, model, density_estimate, losses, proposed_actions = excite_with_dmpe(
    env, exp_params, proposed_actions, loader_key, expl_key, plot_every=1000
)

In [None]:
fig, axs = plot_sequence(
    observations=dmpe_observations,
    actions=dmpe_actions,
    tau=tau,
    obs_labels=[r"$h$"],
    action_labels=[r"$q_{in}$"],
);
plt.show()

In [None]:
def plot_observations(observations, actions, tau, obs_labels, action_labels, fig=None, axs=None, dotted=False):    
    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(19, 9.5))
    
    t = jnp.linspace(0, observations.shape[0] - 1, observations.shape[0]) * tau
    
    for observation_idx in range(observations.shape[-1]):
        axs[0].plot(
            t,
            jnp.squeeze(observations[..., observation_idx]),
            "." if dotted else "-",
            markersize=1,
            label=obs_labels[observation_idx],
        )
    axs[0].title.set_text("observations, timeseries")
    axs[0].legend()
    axs[0].set_ylabel(r"$\bm{x}$")
    axs[0].set_xlabel("$t$ in seconds")
    
    axs[1].scatter(jnp.squeeze(actions[..., 0]), jnp.squeeze(observations[:-1, 0]), s=1)
    axs[1].title.set_text("observation $\\times$ action plane")
    axs[1].set_ylabel(obs_labels[0])
    axs[1].set_xlabel(action_labels[0])
    
    # for action_idx in range(actions.shape[-1]):
    #     axs[2].plot(t[:-1], jnp.squeeze(actions[..., action_idx]), label=action_labels[action_idx])
    
    # axs[2].title.set_text("actions, timeseries")
    # axs[2].legend()
    # axs[2].set_ylabel(r"$\bm{u}$")
    # axs[2].set_xlabel(r"$t$ in seconds")
    
    for ax in axs:
        ax.grid(True)
    fig.tight_layout()
    return fig, axs

In [None]:
fig, axs = plot_observations(
    observations=dmpe_observations,
    actions=dmpe_actions,
    tau=tau,
    obs_labels=[r"$h$"],
    action_labels=[r"$q_{in}$"],
);
plt.show()

In [None]:
density_estimate = DensityEstimate.from_dataset(dmpe_observations, dmpe_actions, bandwidth=0.05)
fig, axs, cax = plot_2d_kde_as_contourf(density_estimate.p, density_estimate.x_g, observation_labels=["$h$", "$u$"])

In [None]:
from exciting_exciting_systems.related_work.algorithms import excite_with_sGOATS, excite_with_GOATS, excite_with_iGOATS

In [None]:
# current version is a bit sketchy, because it jit-compiles the episode rollout for all input sequence lengths that are presented....
# but this is wayy faster than just computing it as is

sgoats_observations, sgoats_actions = excite_with_sGOATS(
    n_amplitudes=200,
    n_amplitude_groups=5,
    reuse_observations=True,
    env=env,
    bounds_duration=(1, 50),
    population_size=50,
    n_generations=100,
    featurize=lambda obs: obs,
    rng=np.random.default_rng(seed=seed),
    verbose=True
)

In [None]:
fig, axs = plot_sequence(
    observations=sgoats_observations,
    actions=sgoats_actions[:-1, :],
    tau=tau,
    obs_labels=[r"$h$"],
    action_labels=[r"$u$"],
);
plt.show()

In [None]:
sgoats_density_estimate = DensityEstimate.from_dataset(sgoats_observations, sgoats_actions[:-1, :], bandwidth=0.05)
fig, axs, cax = plot_2d_kde_as_contourf(sgoats_density_estimate.p, sgoats_density_estimate.x_g, observation_labels=["$h$", "$u$"])

metrics:

In [None]:
from exciting_exciting_systems.utils.metrics import JSDLoss
from exciting_exciting_systems.related_work.np_reimpl.metrics import (
    MC_uniform_sampling_distribution_approximation, audze_eglais
)
from exciting_exciting_systems.related_work.excitation_utils import latin_hypercube_sampling

In [None]:
dim = 2
points_per_dim = 30

target_distribution = jnp.ones(shape=(points_per_dim**dim, 1))[None]

dmpe_density_estimate = DensityEstimate.from_dataset(dmpe_observations, dmpe_actions, points_per_dim=points_per_dim, bandwidth=0.05)
sgoats_density_estimate = DensityEstimate.from_dataset(sgoats_observations, sgoats_actions[:-1, :], points_per_dim=points_per_dim, bandwidth=0.05)

dmpe_jsd_loss = JSDLoss(
    p=dmpe_density_estimate.p / jnp.sum(dmpe_density_estimate.p),
    q=target_distribution / jnp.sum(target_distribution),
)
print("dmpe jsd loss: ", dmpe_jsd_loss)

sgoats_jsd_loss = JSDLoss(
    p=sgoats_density_estimate.p / jnp.sum(sgoats_density_estimate.p),
    q=target_distribution / jnp.sum(target_distribution),
)
print("sgoats jsd loss: ", sgoats_jsd_loss)

In [None]:
support_points = latin_hypercube_sampling(d=dim, n=30**dim, rng=np.random.default_rng())

dmpe_mcudsa_loss = MC_uniform_sampling_distribution_approximation(
    data_points=np.concatenate([dmpe_observations[:-1, :], dmpe_actions], axis=-1),
    support_points=support_points
)
print("dmpe mcudsa loss: ", dmpe_mcudsa_loss)

sgoats_mcudsa_loss = MC_uniform_sampling_distribution_approximation(
    data_points=np.concatenate([sgoats_observations, sgoats_actions], axis=-1),
    support_points=support_points
)
print("sgoats mcudsa loss: ", sgoats_mcudsa_loss)

In [None]:
dmpe_ae_loss = audze_eglais(np.concatenate([dmpe_observations[:-1, :], dmpe_actions], axis=-1))
print("dmpe ae loss: ", dmpe_ae_loss)

sgoats_ae_loss = audze_eglais(np.concatenate([sgoats_observations, sgoats_actions], axis=-1))
print("sgoats ae loss: ", sgoats_ae_loss)

In [None]:
raise

# Rest:

In [None]:
bandwidth = 0.1
n_prediction_steps = 50

dim = 2
points_per_dim = 50
n_grid_points=points_per_dim**dim

n_timesteps = 5_000 #15_000

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)

observations = jnp.zeros((n_timesteps, env.env_observation_space.shape[-1]))
observations = observations.at[0].set(obs[0])
actions = jnp.zeros((n_timesteps-1, env.action_space.shape[-1]))

proposed_actions = aprbs(n_prediction_steps, batch_size, 1, 10, next(data_rng))
# proposed_actions = (proposed_actions + 1) / 2  # remap from (-1, 1) to (0, 1)

In [None]:
exciter = Exciter(
    grad_loss_function=jax.grad(loss_function, argnums=(3)),
    excitation_optimizer=optax.adabelief(1e-1),
    tau=tau,
    target_distribution=jnp.ones(shape=(n_grid_points, 1)) * 1 / (env.env_observation_space.high - env.env_observation_space.low)**dim  # transposed pdf_vectorfield instead as a test
)

model_trainer = ModelTrainer(
    start_learning=n_prediction_steps,
    training_batch_size=32,
    n_train_steps=2,
    sequence_length=n_prediction_steps,
    featurize=lambda obs: obs,
    model_optimizer=optax.adabelief(1e-4),
    tau=tau
)

density_estimate = DensityEstimate(
    p=jnp.zeros([batch_size, n_grid_points, 1]),
    x_g=eesys.utils.density_estimation.build_grid_2d(
        low=env.env_observation_space.low,
        high=env.env_observation_space.high,
        points_per_dim=points_per_dim
    ),
    bandwidth=jnp.array([bandwidth]),
    n_observations=jnp.array([0])
)

model = NeuralEulerODE(
    obs_dim=env.env_observation_space.shape[-1],
    action_dim=env.action_space.shape[-1],
    width_size=128,
    depth=3,
    key=model_key
)

opt_state_model = model_trainer.model_optimizer.init(eqx.filter(model, eqx.is_inexact_array))

In [None]:
from exciting_exciting_systems.algorithms import excite_and_fit

In [None]:
# with jax.profiler.trace("/tmp/jax-trace", create_perfetto_trace=True):
#     observations, actions, model, density_estimate = excite_and_fit(
#         n_timesteps=n_timesteps,
#         env=env,
#         model=model,
#         obs=obs,
#         state=state,
#         proposed_actions=proposed_actions,
#         exciter=exciter,
#         model_trainer=model_trainer,
#         density_estimate=density_estimate,
#         observations=observations,
#         actions=actions,
#         opt_state_model=opt_state_model,
#         loader_key=loader_key
#     )
#     observations.block_until_ready()

In [None]:
observations, actions, model, density_estimate = excite_and_fit(
    n_timesteps=n_timesteps,
    env=env,
    model=model,
    obs=obs,
    state=state,
    proposed_actions=proposed_actions,
    exciter=exciter,
    model_trainer=model_trainer,
    density_estimate=density_estimate,
    observations=observations,
    actions=actions,
    opt_state_model=opt_state_model,
    loader_key=loader_key
)

In [None]:
fig, axs = plot_sequence(
    observations,
    actions,
    tau=tau,
    obs_labels=[r"$h$"],
    action_labels=[r"$q_{in}$"]
);
plt.plot()

In [None]:
fig, axs = plot_model_performance(
    model=model,
    true_observations=observations,
    actions=actions,
    tau=tau,
    obs_labels=[r"$h$"],
    action_labels=[r"$q_{in}$"]
);
plt.plot()

In [None]:
fig, axs, cax = eesys.evaluation.plotting_utils.plot_2d_kde_as_contourf(
    density_estimate.p, density_estimate.x_g, [r"$h$", r"$q_{in}$"]
)

In [None]:
jnp.max(density_estimate.p)

In [None]:
grid = eesys.utils.density_estimation.build_grid_2d(-1, 1, 50)
df_dt = jax.vmap(model.func)(grid[:, 0], grid[:, 1])

In [None]:
df_dt.shape

In [None]:
fig, axs = plt.subplots(figsize=(12,12))
axs.quiver(grid[:, 0], grid[:, 1], df_dt, np.zeros(df_dt.shape))
axs.axis('equal')
axs.set_xlabel(r"$h$")
axs.set_ylabel(r"$q_{in}$")
fig.show()

In [None]:
fig, axs, cax = eesys.evaluation.plotting_utils.plot_2d_kde_as_contourf(
   jnp.max(jnp.abs(df_dt)) - jnp.abs(df_dt), grid, [r"$h$", r"$q_{in}$"]
)

fig, axs = eesys.evaluation.plotting_utils.plot_2d_kde_as_surface(
    jnp.max(jnp.abs(df_dt)) - jnp.abs(df_dt), grid, [r"$h$", r"$q_{in}$"]
)

In [None]:
fig, axs, cax = eesys.evaluation.plotting_utils.plot_2d_kde_as_contourf(
    density_estimate.p, density_estimate.x_g, [r"$h$", r"$q_{in}$"]
)

fig, axs = eesys.evaluation.plotting_utils.plot_2d_kde_as_surface(
    density_estimate.p, density_estimate.x_g, [r"$h$", r"$q_{in}$"]
)

- somehow compare this vector field to the kde?
- target distribution could be changed depending on this vector field
- The best way I see is to assume a histogram and normalize the values so that the histogram volume equals to $1$

In [None]:
exciter.target_distribution

In [None]:
unnormalized_values = jnp.max(jnp.abs(df_dt)) - jnp.abs(df_dt)

In [None]:
full_area = 1
area_of_each_gridpoint = full_area / 2500

In [None]:
normalization_factor = area_of_each_gridpoint * jnp.sum(unnormalized_values)

In [None]:
pdf_vectorfield = 1 / normalization_factor * unnormalized_values

In [None]:
fig, axs, cax = eesys.evaluation.plotting_utils.plot_2d_kde_as_contourf(
    pdf_vectorfield, grid, [r"$h$", r"$q_{in}$"]
)

fig, axs = eesys.evaluation.plotting_utils.plot_2d_kde_as_surface(
    pdf_vectorfield, grid, [r"$h$", r"$q_{in}$"]
)

## debug stuff

In [None]:
test_density_estimate = jax.vmap(
    update_density_estimate_single_observation,
    in_axes=(DensityEstimate(0, None, None, None), None),
    out_axes=DensityEstimate(0, None, None, None)
)(density_estimate, jnp.stack([0.2, 0.2]))

In [None]:
for i in range(400):
    test_density_estimate = jax.vmap(
        update_density_estimate_single_observation,
        in_axes=(DensityEstimate(0, None, None, None), None),
        out_axes=DensityEstimate(0, None, None, None)
    )(test_density_estimate, jnp.stack([0.2, 0.2]))

In [None]:
test_density_estimate.x_g.shape

In [None]:
fig, axs, cax = eesys.evaluation.plotting_utils.plot_2d_kde_as_contourf(
    test_density_estimate.p, test_density_estimate.x_g, [r"$h$", r"$q_{in}$"]
)

In [None]:
exciter.target_distribution.shape
density_estimate.p.shape

In [None]:
eesys.utils.metrics.JSDLoss(density_estimate.p[0], exciter.target_distribution)

In [None]:
eesys.utils.metrics.JSDLoss(test_density_estimate.p[0], exciter.target_distribution)

In [None]:
jnp.max(density_estimate.p[0])

In [None]:
fig, axs = eesys.evaluation.plotting_utils.plot_2d_kde_as_surface(
    density_estimate.p, density_estimate.x_g, [r"$h$", r"$q_{in}$"]
)