In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
# plt.rcParams['text.usetex'] = True
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

import exciting_exciting_systems as eesys
from exciting_exciting_systems.models import NeuralEulerODEPendulum
from exciting_exciting_systems.models.model_utils import simulate_ahead_with_env
from exciting_exciting_systems.models.model_training import ModelTrainer
from exciting_exciting_systems.excitation import loss_function, Exciter

from exciting_exciting_systems.utils.density_estimation import (
    update_density_estimate_single_observation, update_density_estimate_multiple_observations, DensityEstimate
)
from exciting_exciting_systems.utils.signals import aprbs
from exciting_exciting_systems.evaluation.plotting_utils import (
    plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction, plot_model_performance
)

---

In [None]:
from exciting_exciting_systems.utils.metrics import JSDLoss

In [None]:
n_grid_points = 2500
points_per_dim = int(np.sqrt(n_grid_points))
dim = 2

bandwidth = 0.001

target = jnp.ones(shape=(n_grid_points, 1)) * 1 / (1 - (-1))**dim

In [None]:
density_estimate = DensityEstimate(
    p=jnp.zeros([1, n_grid_points, 1]),
    x_g=eesys.utils.density_estimation.build_grid_2d(
        low=-1,
        high=1,
        points_per_dim=points_per_dim
    ),
    bandwidth=jnp.array([bandwidth]),
    n_observations=jnp.array([0])
)

density_estimate_1 = update_density_estimate_single_observation(
    density_estimate, -jnp.ones((1,2))
)
density_estimate.p

delta_x_g = jnp.abs(density_estimate.x_g[0, 0] - density_estimate.x_g[1, 0])
delta_x_g**dim

In [None]:
from copy import deepcopy

In [None]:
density_estimate_2 = update_density_estimate_single_observation(
    density_estimate, jnp.ones((1,2))
)

In [None]:
fig, axs, cax = eesys.evaluation.plotting_utils.plot_2d_kde_as_contourf(
    density_estimate_1.p, density_estimate.x_g, [r"$\theta$", r"$\omega$"]
)

In [None]:
fig, axs, cax = eesys.evaluation.plotting_utils.plot_2d_kde_as_contourf(
    density_estimate_2.p, density_estimate.x_g, [r"$\theta$", r"$\omega$"]
)

In [None]:
def normalized_JSDLoss(estimate, target):
    return JSDLoss(
        p=estimate / (jnp.sum(estimate)),
        q=target / (jnp.sum(target))
    )

In [None]:
normalized_JSDLoss(
    estimate=density_estimate_1.p[0],
    target=target
)

- tested shifting so that all points have the same area within the constraints

I do not see major benefits in this approach. I guess we can put more value on the edge values.

In [None]:
points = jnp.linspace(-1, 1, points_per_dim + 1)
space_between_elements = points[1] - points[0]
points = points[:-1] + space_between_elements / 2

points