In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs
from metrics import JSDLoss, KLDLoss, kullback_leibler_divergence, jensen_shannon_divergence
from density_estimation import gaussian_kernel, update_kde_grid

---

In [None]:
def build_grid(space, points_per_dim):
    x1, x2 = [
        jnp.linspace(space.low, space.high, points_per_dim),
        jnp.linspace(space.low, space.high, points_per_dim)
    ]
    x = jnp.meshgrid(*[x1, x2])
    x = jnp.stack([x for x in x], axis=-1)
    x = x.reshape(-1, 2)
    return x

In [None]:
rng = PRNGSequence(0)

In [None]:
env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=1,
    tau=1e-3
)

x = build_grid(env.env_observation_space, points_per_dim=100)
p_est = jnp.zeros([1, x.shape[0], 1])
bandwidth=0.015

In [None]:
n_steps = 999

observations = []

obs, state = env.reset()
observations.append(obs)

for n in range(n_steps):
    action = env.action_space.sample(next(rng))
    p_est = update_kde_grid(p_est, x[None, ...], measurement=obs[:, None, :], n_measurements=n, bandwidth=bandwidth)
    
    obs, reward, terminated, truncated, state = env.step(action, state)
    observations.append(obs)

In [None]:
observations = jnp.squeeze(jnp.array(observations))

In [None]:
observations.shape

In [None]:
plt.plot(observations[:, 0], label="theta")
plt.plot(observations[:, 1], label="omega")
plt.legend()
plt.grid()
plt.show()

In [None]:
plt.plot(observations[:, 0], observations[:, 1], 'r.')
plt.xlabel("theta")
plt.ylabel("omega")
plt.grid()

In [None]:
fig, ax = plt.subplots(
    figsize=(8, 8)
)

grid_len_per_dim = int(np.sqrt(x.shape[0]))
x_plot = x.reshape((grid_len_per_dim, grid_len_per_dim, 2))

cax = ax.contourf(
    x_plot[..., 0],
    x_plot[..., 1],
    p_est[0, ...].reshape(x_plot.shape[:-1]),
    antialiased=False,
    levels=100,
    alpha=0.9,
    cmap=plt.cm.coolwarm
)
fig.colorbar(cax)