In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['text.usetex'] = True
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

from density_estimation import update_kde_grid, update_kde_grid_multiple_observations
from metrics import JSDLoss
from model_utils import simulate_ahead
from plotting_utils import plot_sequence
from signals import generate_constant_action, aprbs
from optimization_utils import loss_function, optimize

---

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=21)

data_key, model_key, key = jax.random.split(key, 3)
data_rng = PRNGSequence(data_key)

In [None]:
batch_size = 1
tau = 1e-1 # 1e-3

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    tau=tau
)

model = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    tau=tau
)

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)
n_steps = 1_000

actions = aprbs(n_steps, batch_size, 200, 500, next(data_rng))
observations = simulate_ahead(
    model=model,
    n_steps=n_steps,
    obs=obs,
    state=state,
    actions=actions
)

In [None]:
plot_sequence(
    observations=observations,
    actions=actions,
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);

In [None]:
x1, x2 = [
    jnp.linspace(env.env_observation_space.low, env.env_observation_space.high, 100),
    jnp.linspace(env.env_observation_space.low, env.env_observation_space.high, 100)
]

x = jnp.meshgrid(*[x1, x2])
x = jnp.stack([x for x in x], axis=-1)
x = x.reshape(-1, 2)
n_grid_points = x.shape[0]

start_n_measurments = 0

bandwidth = 0.15
p_est = jnp.zeros([batch_size, n_grid_points, 1])

target_distribution = jnp.ones(shape=(batch_size, n_grid_points, 1))
target_distribution *= 1 / (env.env_observation_space.high - env.env_observation_space.low)**2

n_prediction_steps = 100
n_time_steps = 100

In [None]:
grad_loss_function = jax.grad(loss_function, argnums=(0))

In [None]:
observations = []
actions = []

obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)

for k in tqdm(range(n_time_steps)):
    proposed_actions = aprbs(n_steps, batch_size, 100, 200, next(data_rng))
    proposed_actions = optimize(
        grad_loss_function=grad_loss_function,
        proposed_actions=proposed_actions,
        model=model,
        init_obs=obs,
        init_state=state,
        n_steps=n_prediction_steps,
        p_est=p_est,
        x=x,
        start_n_measurments=start_n_measurments,
        bandwidth=bandwidth,
        target_distribution=target_distribution    
    )

    p_est = update_kde_grid(
        kde_grid=p_est,
        x_eval=x,
        observation=obs,
        n_observations=k,
        bandwidth=bandwidth
    )
    start_n_measurments += 1

    action = proposed_actions[:, 0, :]

    actions.append(action)
    observations.append(obs)

    obs, _, _, _, state = model.step(action, state)

In [None]:
plot_sequence(
    observations=jnp.stack(observations, axis=1),
    actions=jnp.stack(actions, axis=1),
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);

- look at the results of the optimization in each step :)
# **GPU!!!!**

In [None]:
p_est = update_kde_grid_multiple_observations(p_est, x, jnp.stack(observations, axis=1), n_observations=0, bandwidth=bandwidth)

In [None]:
fig, ax = plt.subplots(
    figsize=(6, 6)
)

grid_len_per_dim = int(np.sqrt(x.shape[0]))
x_plot = x.reshape((grid_len_per_dim, grid_len_per_dim, 2))

cax = ax.contourf(
    x_plot[..., 0],
    x_plot[..., 1],
    p_est[0, ...].reshape(x_plot.shape[:-1]),
    antialiased=False,
    levels=30,
    alpha=0.9,
    cmap=plt.cm.coolwarm
)
ax.set_xlabel(r"$\theta$")
ax.set_ylabel(r"$\omega$")
# fig.colorbar(cax)

In [None]:
# fig.savefig("excited_pendulum_stepwise.png")