In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs
from models import NeuralEulerODE

from density_estimation import update_kde_grid, update_kde_grid_multiple_observations, update_kde_grid_multiple_observations_2
from metrics import JSDLoss

## Find the best actions for good coverage

Necessary steps:

- simulate a trajectory $\mathbf{x}_{k:k+N}$ with $N+1$ elements using $N$ actions $\mathbf{u}_{k:k+N-1}$ using the model
- evaluate the JSD of the trajectory to the target distribution
- optimize the trajectory
- take a single step in the environment

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=21)

data_key, model_key, key = jax.random.split(key, 3)
data_rng = PRNGSequence(data_key)

In [None]:
batch_size = 1
tau = 1e-3

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    tau=tau
)

model = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    tau=tau
)

In [None]:
@partial(jax.jit, static_argnums=(0, 1))
def gen_actions(n_steps, batch_size, key):
    actions = jax.random.uniform(key, shape=(batch_size, 1, 1), minval=-1, maxval=1)
    actions = jnp.repeat(actions, repeats=n_steps, axis=1)
    return actions

In [None]:
@partial(jax.jit, static_argnums=(0, 3))
def simulate_episode(env, obs, state, n_steps, actions):

    batch_size, obs_dim = obs.shape
    observations = jnp.zeros([batch_size, n_steps, obs_dim])
    observations = observations.at[:, 0, :].set(obs)

    def body_fun(n, carry):
        obs, state, observations = carry

        action = actions[:, n, :]
        obs, reward, terminated, truncated, state = env.step(action, state)
        observations = observations.at[:, n, :].set(obs)

        return (obs, state, observations)

    obs, state, observations = jax.lax.fori_loop(lower=1, upper=n_steps, body_fun=body_fun, init_val=(obs, state, observations))

    return observations


def plot_episode(observations, actions, max_n=2):
    n_plots = min(max_n, observations.shape[0])

    for idx in range(n_plots):
        plt.plot(observations[idx, :, 0], 'r.', label="theta")
        plt.plot(observations[idx, :, 1], 'b-', label="omega")
        plt.grid()
        plt.title("observations, timeseries")
        plt.legend()
        plt.show()
    
    for idx in range(n_plots):
        plt.plot(observations[idx, :, 0], observations[idx, :, 1], 'b.')
        plt.grid()
        plt.title("observations, together")
        plt.show()
    
    for idx in range(n_plots):
        plt.plot(actions[idx, :, 0])
        plt.grid()
        plt.title("actions, timeseries")
        plt.show()


def aprbs2(len, t_min, t_max, key):
    t = 0
    sig = []
    while t < len:
        steps_key, value_key, key = jax.random.split(key, 3)

        t_step = jax.random.randint(steps_key, shape=(1,), minval=t_min, maxval=t_max)
           
        sig.append(jnp.ones(t_step) * jax.random.uniform(value_key, shape=(1,), minval=-1, maxval=1))
        t += t_step.item()

    return jnp.hstack(sig)[:len]


def aprbs(n_steps, batch_size, t_min, t_max, key):
    actions = []
    for _ in range(batch_size):
        subkey, key = jax.random.split(key)
        actions.append(aprbs2(n_steps, t_min, t_max, subkey)[..., None])
    return jnp.stack(actions, axis=0)

In [None]:
@partial(jax.jit, static_argnums=(1, 4))
def loss_function(
    actions,
    model,
    init_obs,
    init_state,
    n_steps,
    p_est,
    x,
    start_n_measurments,
    bandwidth,
    target_distribution
):
    actions = jax.nn.tanh(actions)
    observations = simulate_episode(model, init_obs, init_state, n_steps, actions)

    p_est = update_kde_grid_multiple_observations_2(n_steps, p_est, x, observations, start_n_measurments, bandwidth)
    loss = JSDLoss(
        p=p_est,
        q=target_distribution
    )
    return loss

grad_loss_function = jax.grad(loss_function, argnums=(0))

def optimize(
    proposed_actions,
    model,
    init_obs,
    init_state,
    n_steps,
    p_est,
    x,
    start_n_measurments,
    bandwidth,
    target_distribution
):
    solver = optax.adabelief(learning_rate=1e-1)
    opt_state = solver.init(proposed_actions)

    for iter in tqdm(range(100)):
        grad = grad_loss_function(
            proposed_actions,
            model,
            init_obs,
            init_state,
            n_steps,
            p_est,
            x,
            start_n_measurments,
            bandwidth,
            target_distribution
        )
        updates, opt_state = solver.update(grad, opt_state, proposed_actions)
        proposed_actions = optax.apply_updates(proposed_actions, updates)

    return proposed_actions

In [None]:
x1, x2 = [
    jnp.linspace(env.env_observation_space.low, env.env_observation_space.high, 100),
    jnp.linspace(env.env_observation_space.low, env.env_observation_space.high, 100)
]

x = jnp.meshgrid(*[x1, x2])
x = jnp.stack([x for x in x], axis=-1)
x = x.reshape(-1, 2)
n_grid_points = x.shape[0]

start_n_measurments = 0

bandwidth = 0.15
p_est = jnp.zeros([batch_size, n_grid_points, 1])

target_distribution = jnp.ones(shape=(batch_size, n_grid_points, 1))
target_distribution *= 1 / (env.env_observation_space.high - env.env_observation_space.low)**2

n_steps = 1000

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)

actions = aprbs(n_steps, batch_size, 10, 20, next(data_rng))

plt.plot(jnp.squeeze(actions))
plt.show()

actions = optimize(
    actions,
    model,
    obs,
    state,
    n_steps,
    p_est,
    x,
    start_n_measurments,
    bandwidth,
    target_distribution
)

start_n_measurments += n_steps

In [None]:
loss = loss_function(
    actions,
    model,
    obs,
    state,
    n_steps,
    p_est,
    x,
    start_n_measurments,
    bandwidth,
    target_distribution
)
loss

In [None]:
observations = simulate_episode(model, obs, state, n_steps, jax.nn.tanh(actions))
plot_episode(observations, jax.nn.tanh(actions), max_n=1)
p_est = update_kde_grid_multiple_observations(n_steps, p_est, x, observations, start_n_measurments, bandwidth)

In [None]:
fig, ax = plt.subplots(
    figsize=(6, 6)
)

grid_len_per_dim = int(np.sqrt(x.shape[0]))
x_plot = x.reshape((grid_len_per_dim, grid_len_per_dim, 2))

cax = ax.contourf(
    x_plot[..., 0],
    x_plot[..., 1],
    p_est[0, ...].reshape(x_plot.shape[:-1]),
    antialiased=False,
    levels=30,
    alpha=0.9,
    cmap=plt.cm.coolwarm
)
ax.set_xlabel(r"$\theta$")
ax.set_ylabel(r"$\omega$")
# fig.colorbar(cax)

---

In [None]:
actions = aprbs(n_steps, batch_size, 200, 500, next(data_rng))
observations = simulate_episode(model, obs, state, n_steps, actions)
plot_episode(observations, actions, max_n=1)

p_est = update_kde_grid_multiple_observations(n_steps, p_est, x, observations, 0, bandwidth)


p_est_2 = update_kde_grid_multiple_observations_2(n_steps, p_est, x, observations, 0, bandwidth)

In [None]:
from density_estimation import update_kde_grid, update_kde_grid_multiple_observations, update_kde_grid_multiple_observations_2

In [None]:
start = time.time()
p_est = update_kde_grid_multiple_observations(n_steps, p_est, x, observations, 0, bandwidth)
end = time.time()
print(end - start)

start = time.time()
p_est_2 = update_kde_grid_multiple_observations_2(n_steps, p_est, x, observations, 0, bandwidth)
end = time.time()
print(end - start)

print(jnp.mean(jnp.abs(p_est_2 - p_est) ** 2))

In [None]:
p_est.shape

In [None]:
p_est_2.shape

In [None]:
jnp.mean(jnp.abs(p_est_2 - p_est))

In [None]:
jnp.min(p_est_2 - p_est)

In [None]:
fig, ax = plt.subplots(
    figsize=(8, 8)
)

grid_len_per_dim = int(np.sqrt(x.shape[0]))
x_plot = x.reshape((grid_len_per_dim, grid_len_per_dim, 2))

cax = ax.contourf(
    x_plot[..., 0],
    x_plot[..., 1],
    p_est[0, ...].reshape(x_plot.shape[:-1]),
    antialiased=False,
    levels=30,
    alpha=0.9,
    cmap=plt.cm.coolwarm
)

- build the cells above into a full algorithm!
- think about the differences between the update_kde implementations

...

In [None]:
mpe = ModelPredictiveExcitation()