In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['text.usetex'] = True
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}"
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp
# jax.config.update("jax_enable_x64", True)
gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
from exciting_exciting_systems.utils.density_estimation import (
    select_bandwidth)

In [None]:
select_bandwidth(
    delta_x=2,
    dim=1,
    n_g=50,
    percentage=0.3,
)

In [None]:
@jax.jit
def gaussian_kernel(x: jnp.ndarray, bandwidth: float) -> jnp.ndarray:
    """Evaluates the Gaussian RBF kernel at x with given bandwidth. This can take arbitrary
    dimensions for 'x' and will compute the output by broadcasting. The last dimension of
    the input needs to be the dimension of the data which is reduced.
    """
    data_dim = x.shape[-1]
    factor = bandwidth**data_dim * jnp.power(2 * jnp.pi, data_dim / 2)
    return 1 / factor * jnp.exp(-jnp.linalg.norm(x, axis=-1) ** 2 / (2 * bandwidth**2))

In [None]:
x_g = jnp.linspace(-1, 1, 10)[:, None]

x = jnp.linspace(-1.5, 1.5, 1000)[:, None]
y = gaussian_kernel(x, bandwidth=1)

In [None]:
x = jnp.linspace(0, .2, 1000)
for a in np.arange(0.05, 0.5, 0.05):
    plt.plot(x, jnp.sqrt(-1 * jnp.abs(x)**2 / (2 * jnp.log(a))))

plt.show()

In [None]:
def calc_bw(delta_x, d, n_g, a):
    return delta_x * jnp.sqrt(d) / (n_g * jnp.sqrt(-2 * jnp.log(a)))

In [None]:
h = calc_bw(
    x=jnp.sqrt(5) / 10,
    a=0.3
)
h

In [None]:
h = calc_bw(
    x=jnp.sqrt(5) / 5,
    a=0.3
)
h

In [None]:
calc_bw(delta_x=2, d=5, n_g=10, a=0.3)

In [None]:
test(jnp.array([0.2])[:, None] * jnp.sqrt(5), h)

In [None]:
test(jnp.array([0.2])[:, None], h)

In [None]:
0.025 * np.sqrt(2)

In [None]:
plt.plot(x, y)
plt.plot(x_g, jnp.zeros(x_g.shape), "r.")

In [None]:
from exciting_exciting_systems.utils.density_estimation import build_grid

In [None]:
x_g = build_grid(dim=2, low=-1, high=1, points_per_dim=10)

In [None]:
plt.scatter(x_g[:, 0], x_g[:, 1])

In [None]:
np.sqrt(0.04**2 + 0.04**2 + 0.04**2 + 0.04**2)

In [None]:
2 / 20 * np.sqrt(5)

In [None]:
gaussian_kernel(x=jnp.array([0, 0.5, 1, -1])[:, None], bandwidth=0.05)

In [None]:
gaussian_kernel(x=jnp.array([[0.04, 0.5, 1, -1], [0, 0.5, 1, -1]]).T, bandwidth=0.05) * 2* jnp.pi * 0.05**2

In [None]:
def test(x, bandwidth):
    return jnp.exp(-jnp.linalg.norm(x, axis=-1) ** 2 / (2 * bandwidth**2))

In [None]:
test(x=jnp.array([[1, 0.05, 0.05 / 2, -1], [0, 0.0, 0.05 / 2, -1]]).T, bandwidth=0.05)

In [None]:
test(x=jnp.array([0.025, 0.04, 1, -1])[:, None], bandwidth=0.025)

In [None]:
plt.plot(jnp.arange(0.025, 0.5, 0.025), test(x=jnp.array([2/20 * jnp.sqrt(5)])[:, None], bandwidth=jnp.arange(0.025, 0.5, 0.025)))

In [None]:
test(x=jnp.array([2/20 * jnp.sqrt(5)])[:, None], bandwidth=0.14)

In [None]:
test(x=jnp.array([[0.05], [0.05], [0.05], [0.05], [0.05]]).T, bandwidth=0.1)

In [None]:
np.linalg.norm(np.ones(5) * 0.05)