In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
# plt.rcParams['text.usetex'] = True
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

import exciting_exciting_systems as eesys
from exciting_exciting_systems.models import NeuralEulerODEPendulum
from exciting_exciting_systems.models.model_utils import simulate_ahead, simulate_ahead_with_env
from exciting_exciting_systems.optimization import loss_function, optimize, soft_penalty
from exciting_exciting_systems.models.model_training import make_step, dataloader, load_single_batch

from exciting_exciting_systems.utils.density_estimation import (
    DensityEstimate, update_density_estimate, update_density_estimate_multiple_observations
)
from exciting_exciting_systems.utils.metrics import JSDLoss
from exciting_exciting_systems.utils.signals import generate_constant_action, aprbs
from exciting_exciting_systems.evaluation.plotting_utils import (
    plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction, plot_model_performance
)

---

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=213) # 21)

data_key, model_key, loader_key, key = jax.random.split(key, 4)
data_rng = PRNGSequence(data_key)

In [None]:
batch_size = 1
tau = 2e-2 # 5e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    tau=tau
)

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)
n_steps = 999

actions = aprbs(n_steps, batch_size, 1, 10, next(data_rng))

In [None]:
observations = jax.vmap(simulate_ahead_with_env, in_axes=(None, 0, 0, 0, 0, 0, 0))(
    env,
    obs,
    state,
    actions,
    env.env_state_normalizer,
    env.action_normalizer,
    env.static_params
)

print("actions.shape:", actions.shape)
print("observations.shape:", observations.shape)

print(" \n One of the trajectories:")
fig, axs = plot_sequence(
    observations=observations[0, ...],
    actions=actions[0, ...],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

### Rebuild density estimatior

In [None]:
# x_g = eesys.utils.density_estimation.build_grid_2d(
#     low=env.env_observation_space.low,
#     high=env.env_observation_space.high,
#     points_per_dim=100
# )
# n_grid_points = x_g.shape[0]

# bandwidth = 0.1

# p_est = jnp.zeros([n_grid_points, 1])

# for i in range(100):
#     start = time.time()
#     p_est = update_kde_grid_multiple_observations(p_est, x_g, observations[0, ...], n_observations=0, bandwidth=bandwidth)
#     end = time.time()
#     print("computation time: ", end - start)

# fig, axs, cax = eesys.evaluation.plotting_utils.plot_2d_kde_as_contourf(p_est, x_g, [r"$\theta$", r"$\omega$"])

In [None]:
density_estimate = DensityEstimate(
    p=jnp.zeros([batch_size, 100**2, 1]),
    x_g=eesys.utils.density_estimation.build_grid_2d(
        low=env.env_observation_space.low,
        high=env.env_observation_space.high,
        points_per_dim=100
    ),
    bandwidth=jnp.array([0.1]),
    n_observations=jnp.array([0])
)
print("before_update: ", density_estimate)

for i in range(10):
    start = time.time()
    density_estimate_after = jax.vmap(
        update_density_estimate_multiple_observations,
        in_axes=(DensityEstimate(0, None, None, None), 0),
        out_axes=(DensityEstimate(0, None, None, None))
    )(
        density_estimate,
        observations,
    )
    end = time.time()
    print("computation time: ", end - start)
print("after_update: ", density_estimate_after)

fig, axs, cax = eesys.evaluation.plotting_utils.plot_2d_kde_as_contourf(
    density_estimate_after.p[0, ...], 
    density_estimate_after.x_g, [r"$\theta$", r"$\omega$"])

In [None]:
observations.shape

In [None]:
density_estimate = DensityEstimate(
    p=jnp.zeros([batch_size, 100**2, 1]),
    x_g=eesys.utils.density_estimation.build_grid_2d(
        low=env.env_observation_space.low,
        high=env.env_observation_space.high,
        points_per_dim=100
    ),
    bandwidth=jnp.array([0.1]),
    n_observations=jnp.array([0])
)
print("before_update: ", density_estimate)

for i in range(1000):
    start = time.time()
    density_estimate = jax.vmap(
        update_density_estimate,
        in_axes=(DensityEstimate(0, None, None, None), 0),
        out_axes=(DensityEstimate(0, None, None, None))
    )(
        density_estimate,
        observations[:, i, :],
    )
    end = time.time()

    if i < 10:
        print("computation time: ", end - start)
print("after_update: ", density_estimate)

fig, axs, cax = eesys.evaluation.plotting_utils.plot_2d_kde_as_contourf(
    density_estimate.p[0, ...], 
    density_estimate.x_g, [r"$\theta$", r"$\omega$"])

### Rebuild Exciter

In [None]:
from typing import Callable

In [None]:
from exciting_exciting_systems.utils.density_estimation import DensityEstimate, update_density_estimate

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)


density_estimate = DensityEstimate(
    p=jnp.zeros([batch_size, 100**2, 1]),
    x_g=eesys.utils.density_estimation.build_grid_2d(
        low=env.env_observation_space.low,
        high=env.env_observation_space.high,
        points_per_dim=100
    ),
    bandwidth=jnp.array([0.1]),
    n_observations=jnp.array([0])
)

n_timesteps = 5000
observations = jnp.zeros((n_timesteps, env.env_observation_space.shape[-1]))
observations = observations.at[0].set(obs[0])
actions = jnp.zeros((n_timesteps-1, env.action_space.shape[-1]))

In [None]:
exciter = Exciter(
    grad_loss_function=jax.grad(eesys.optimization.loss_function, argnums=(3)),
    excitation_solver=optax.adabelief(1e-1),
    tau=tau,
    target_distribution=jnp.ones(shape=(100**2, 1)) * 1 / (env.env_observation_space.high - env.env_observation_space.low)**2
)

proposed_actions = aprbs(50, batch_size, 1, 10, next(data_rng))

In [None]:
for k in tqdm(range(n_timesteps)):
    action, proposed_actions, density_estimate = exciter.choose_action(
        obs, state, env, density_estimate, proposed_actions
    )

    obs, state, actions, observations = interact_and_observe(
        env, jnp.array([k]), action, obs, state, actions, observations
    )

In [None]:
fig, axs = plot_sequence(
    observations=observations,
    actions=actions,
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

In [None]:
fig, axs, cax = eesys.evaluation.plotting_utils.plot_2d_kde_as_contourf(
    density_estimate.p, density_estimate.x_g, [r"$\theta$", r"$\omega$"]
)

---
---
---
---
---
---
---
---
---

In [None]:
grad_loss_function = jax.grad(loss_function, argnums=(3))
excitation_solver = optax.adabelief(1e-1)
target_distribution = jnp.ones(shape=(100**2, 1)) * 1 / (env.env_observation_space.high - env.env_observation_space.low)**2

In [None]:
for k in tqdm(range(n_timesteps)):
    obs, state, actions, observations, proposed_actions, density_estimate = excite(
        env,
        actions,
        observations,
        grad_loss_function,
        proposed_actions,
        env,
        excitation_solver,
        obs,
        state,
        density_estimate,
        tau,
        target_distribution
    )


In [None]:
fig, axs = plot_sequence(
    observations=observations,
    actions=actions,
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()