In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
# plt.rcParams['text.usetex'] = True
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

import exciting_exciting_systems as eesys
from exciting_exciting_systems.models import NeuralEulerODEPendulum
from exciting_exciting_systems.models.model_utils import simulate_ahead, simulate_ahead_with_env
from exciting_exciting_systems.optimization import loss_function, optimize, soft_penalty
from exciting_exciting_systems.models.model_training import make_step, dataloader, load_single_batch

from exciting_exciting_systems.utils.density_estimation import update_kde_grid, update_kde_grid_multiple_observations
from exciting_exciting_systems.utils.metrics import JSDLoss
from exciting_exciting_systems.utils.signals import generate_constant_action, aprbs
from exciting_exciting_systems.evaluation.plotting_utils import (
    plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction, plot_model_performance
)

---

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=23) # 21)

data_key, model_key, loader_key, key = jax.random.split(key, 4)
data_rng = PRNGSequence(data_key)

In [None]:
batch_size = 1
tau = 2e-2 # 5e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    tau=tau
)

### Test simulation:

- starting from the intial state/obs ($\mathbf{x}_0$ / $\mathbf{y}_0$)
- apply $N = 999$ actions $\mathbf{u}_0 \dots \mathbf{u}_N$ (**here**: random APRBS actions)
- which results in the state trajectory $\mathbf{x}_0 ... \mathbf{x}_N+1$ with $N+1 = 1000$ elements

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)
n_steps = 999

actions = aprbs(n_steps, batch_size, 200, 500, next(data_rng))

In [None]:
observations = jax.vmap(simulate_ahead_with_env, in_axes=(None, 0, 0, 0, 0, 0, 0))(
    env,
    obs,
    state,
    actions,
    env.env_state_normalizer,
    env.action_normalizer,
    env.static_params
)

print("actions.shape:", actions.shape)
print("observations.shape:", observations.shape)

print(" \n One of the trajectories:")
fig, axs = plot_sequence(
    observations=observations[0, ...],
    actions=actions[0, ...],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

## Build an algorithm that simultaneously learns the model and optimizes its trajectory:

In [None]:
from exciting_exciting_systems.algorithms import excite, fit

In [None]:
def featurize_theta(obs):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is 
    transformed to sin(phi) and cos(phi) for comparison in the loss."""
    feat_obs = jnp.stack([jnp.sin(obs[..., 0] * jnp.pi), jnp.cos(obs[..., 0] * jnp.pi), obs[..., 1]], axis=-1)
    return feat_obs

In [None]:
# parameters
bandwidth = 0.1
n_prediction_steps = 50
# n_opt_steps = 5  # TODO: Not yet implemented

n_train_steps = 1
training_batch_size = 32
sequence_length = n_prediction_steps

n_timesteps = 5_000

In [None]:
x_g = eesys.utils.density_estimation.build_grid_2d(
    low=env.env_observation_space.low,
    high=env.env_observation_space.high,
    points_per_dim=100
)
n_grid_points = x_g.shape[0]

In [None]:
target_distribution = jnp.ones(shape=(n_grid_points, 1))
target_distribution *= 1 / (env.env_observation_space.high - env.env_observation_space.low)**2

obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)

observations = jnp.zeros((n_timesteps, env.env_observation_space.shape[-1]))
observations = observations.at[0].set(obs[0])
actions = jnp.zeros((n_timesteps-1, env.action_space.shape[-1]))

p_est = jnp.zeros([batch_size, n_grid_points, 1])

proposed_actions = aprbs(n_prediction_steps, batch_size, 20, 50, next(data_rng))

model = NeuralEulerODEPendulum(
    obs_dim=env.env_observation_space.shape[-1],
    action_dim=env.action_space.shape[-1],
    width_size=128,
    depth=3,
    key=model_key
)
lr = 1e-4
solver_model = optax.adabelief(lr)
opt_state_model = solver_model.init(eqx.filter(model, eqx.is_inexact_array))

start_n_measurments = jnp.array([0])
grad_loss_function = jax.grad(loss_function, argnums=(3))  # derivatve w.r.t. the actions

solver_prediction = optax.adabelief(learning_rate=1e-1)

In [None]:
from exciting_exciting_systems.algorithms import excite_and_fit

In [None]:
observations, actions, model = excite_and_fit(
    n_timesteps,
    env,
    grad_loss_function,
    obs,
    state,
    proposed_actions,
    p_est,
    x_g,
    bandwidth,
    tau,
    target_distribution,
    n_prediction_steps,
    training_batch_size,
    model,
    n_train_steps,
    sequence_length,
    observations,
    actions,
    featurize_theta,
    solver_prediction,
    solver_model,
    opt_state_model,
    loader_key
)

In [None]:
fig, axs = plot_sequence(
    observations,
    actions,
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.plot()

In [None]:
fig, axs = plot_model_performance(
    model=model,
    true_observations=observations[700:1000],
    actions=actions[700:999],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.plot()

In [None]:
p_est = jnp.zeros([n_grid_points, 1])
p_est = update_kde_grid_multiple_observations(p_est, x_g, observations, n_observations=0, bandwidth=bandwidth)
fig, axs, cax = eesys.evaluation.plotting_utils.plot_2d_kde_as_contourf(p_est, x_g, [r"$\theta$", r"$\omega$"])
# fig.savefig("excited_pendulum_kde_contourf.png")

In [None]:
p_est = jnp.zeros([n_grid_points, 1])
p_est = update_kde_grid_multiple_observations(p_est, x_g, observations, n_observations=0, bandwidth=bandwidth)
fig, axs = eesys.evaluation.plotting_utils.plot_2d_kde_as_surface(
    p_est, x_g, [r"$\theta$", r"$\omega$"]
)
fig.suptitle("Vanilla KDE")
# fig.savefig("excited_pendulum_kde_surface.png")
plt.show()

fig, axs = eesys.evaluation.plotting_utils.plot_2d_kde_as_surface(
    jnp.abs(p_est - target_distribution), x_g, [r"$\theta$", r"$\omega$"]
)
fig.suptitle("Difference")

plt.show()

fig, axs, cax = eesys.evaluation.plotting_utils.plot_2d_kde_as_contourf(
    jnp.abs(p_est - target_distribution), x_g, [r"$\theta$", r"$\omega$"]
)
plt.colorbar(cax)
fig.suptitle("Abs Difference")

plt.show()

### Timing the algorithm:

In [None]:
# parameters
bandwidth = 0.1
n_prediction_steps = 50
# n_opt_steps = 5  # TODO: Not yet implemented

n_train_steps = 1
training_batch_size = 32
sequence_length = n_prediction_steps

n_timesteps = 5_000

In [None]:
x_g = eesys.utils.density_estimation.build_grid_2d(
    low=env.env_observation_space.low,
    high=env.env_observation_space.high,
    points_per_dim=100
)
n_grid_points = x_g.shape[0]

In [None]:
target_distribution = jnp.ones(shape=(n_grid_points, 1))
target_distribution *= 1 / (env.env_observation_space.high - env.env_observation_space.low)**2

obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)

observations = jnp.zeros((n_timesteps, env.env_observation_space.shape[-1]))
observations = observations.at[0].set(obs[0])
actions = jnp.zeros((n_timesteps, env.action_space.shape[-1]))

p_est = jnp.zeros([batch_size, n_grid_points, 1])

proposed_actions = aprbs(n_prediction_steps, batch_size, 20, 50, next(data_rng))

model = NeuralEulerODEPendulum(
    obs_dim=env.env_observation_space.shape[-1],
    action_dim=env.action_space.shape[-1],
    width_size=128,
    depth=3,
    key=model_key
)
lr = 1e-4
solver_model = optax.adabelief(lr)
opt_state_model = solver_model.init(eqx.filter(model, eqx.is_inexact_array))

start_n_measurments = jnp.array([0])
grad_loss_function = jax.grad(loss_function, argnums=(3))  # derivatve w.r.t. the actions

solver_prediction = optax.adabelief(learning_rate=1e-1)

In [None]:
%lprun -f excite_and_fit excite_and_fit(1000,env,grad_loss_function,obs,state,proposed_actions,p_est,x_g,bandwidth,tau,target_distribution,n_prediction_steps,training_batch_size,model,n_train_steps,sequence_length,observations,actions,featurize_theta,solver_prediction,solver_model,opt_state_model,loader_key)