In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs
from metrics import JSDLoss, KLDLoss, kullback_leibler_divergence, jensen_shannon_divergence
from density_estimation import gaussian_kernel, update_kde_grid

---

### dummy example:

In [None]:
@jax.jit
def lotka_volterra_ODE(t, x):
    alpha = 0.1  # Reproduction rate of prey
    beta = 0.02  # Predation rate
    delta = 0.01  # Reproduction rate of predator
    gamma = 0.1  # Death rate of predator
    
    x, y =  x[:, 0], x[:, 1]
    dxdt = alpha * x - beta * x * y
    dydt = delta * x * y - gamma * y
    return jnp.stack([dxdt,dydt], axis=1)

@partial(jax.jit, static_argnums=(0,))
def forward_euler(ode, state, tau):
    next_state = state + tau * ode(None, state)
    return next_state

In [None]:
rng = PRNGSequence(42)

In [None]:
x1, x2 = [jnp.linspace(0, 15, 100), jnp.linspace(0, 15, 100)]
x = jnp.meshgrid(*[x1, x2])
x = jnp.stack([x for x in x], axis=-1)
x = x.reshape(-1, 2)

In [None]:
episode_len = 10_001
batch_size = 100
n_features = 2
bandwidth = 0.15
tau = 5e-2
t0, tf = 0, (episode_len-1)*tau
t = jnp.arange(t0, tf + tau, tau) 
    
states = []
p_est = jnp.zeros([batch_size, x.shape[0], 1])

state = 10 * jax.random.uniform(next(rng), (batch_size, n_features))

for i in range(episode_len):
    states.append(state)
    p_est = update_kde_grid(p_est, x[None, ...], measurement=state[:, None, :], n_measurements=i, bandwidth=bandwidth)

    
    next_state = forward_euler(ode=lotka_volterra_ODE, state=state, tau=tau)
    
    state = next_state

states = jnp.array(states).swapaxes(0, 1)  # put batch_size first

In [None]:
states.shape

In [None]:
p_est.shape

In [None]:
fig, ax = plt.subplots(
    figsize=(12, 12)
)

grid_len_per_dim = int(np.sqrt(x.shape[0]))
x_plot = x.reshape((grid_len_per_dim, grid_len_per_dim, 2))

cax = ax.contourf(
    x_plot[..., 0],
    x_plot[..., 1],
    p_est[0, ...].reshape(x_plot.shape[:-1]),
    antialiased=False,
    levels=100,
    alpha=0.9
)
fig.colorbar(cax)

In [None]:
plt.plot(t, states[0,:,:])
plt.grid()

In [None]:
for idx in range(states.shape[0]):
    plt.plot(states[idx, :, 0], states[idx, :, 1])
    plt.grid()
    plt.show()

---

In [None]:
loss = kullback_leibler_divergence(
    p=jnp.ones([10, 100, 1]) * 0.1,
    q=jnp.ones([10, 100, 1]) * 0.1
)
loss

In [None]:
x = jnp.linspace(-1, 1, 100)[:, None]

for bandwidth in jnp.arange(0.05, 0.4, 0.05):
    y = gaussian_kernel(x, bandwidth=bandwidth)
    plt.plot(x, y, label=str(bandwidth))

plt.legend()
plt.grid()