In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
# plt.rcParams['text.usetex'] = True
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

import exciting_exciting_systems as eesys
from exciting_exciting_systems.models import NeuralEulerODEPendulum
from exciting_exciting_systems.models.model_utils import simulate_ahead, simulate_ahead_with_env
from exciting_exciting_systems.optimization import loss_function, optimize, soft_penalty
from exciting_exciting_systems.models.model_training import make_step, dataloader

from exciting_exciting_systems.utils.density_estimation import update_kde_grid, update_kde_grid_multiple_observations
from exciting_exciting_systems.utils.metrics import JSDLoss
from exciting_exciting_systems.utils.signals import generate_constant_action, aprbs
from exciting_exciting_systems.evaluation.plotting_utils import (
    plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction, plot_model_performance
)

---

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=21)

data_key, model_key, loader_key, key = jax.random.split(key, 4)
data_rng = PRNGSequence(data_key)

In [None]:
batch_size = 1
tau = 1e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    tau=tau
)

### Test simulation:

- starting from the intial state/obs ($\mathbf{x}_0$ / $\mathbf{y}_0$)
- apply $N = 999$ actions $\mathbf{u}_0 \dots \mathbf{u}_N$ (**here**: random APRBS actions)
- which results in the state trajectory $\mathbf{x}_0 ... \mathbf{x}_N+1$ with $N+1 = 1000$ elements

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)
n_steps = 999

actions = aprbs(n_steps, batch_size, 200, 500, next(data_rng))

In [None]:
observations = jax.vmap(simulate_ahead_with_env, in_axes=(None, 0, 0, 0, 0, 0, 0))(
    env,
    obs,
    state,
    actions,
    env.env_state_normalizer,
    env.action_normalizer,
    env.static_params
)

print("actions.shape:", actions.shape)
print("observations.shape:", observations.shape)

print(" \n One of the trajectories:")
fig, axs = plot_sequence(
    observations=observations[0, ...],
    actions=actions[0, ...],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

## Build an algorithm that simultaneously learns the model and optimizes its trajectory:

In [None]:
def featurize_theta(obs):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is 
    transformed to sin(phi) and cos(phi) for comparison in the loss."""
    feat_obs = jnp.stack([jnp.sin(obs[..., 0] * jnp.pi), jnp.cos(obs[..., 0] * jnp.pi), obs[..., 1]], axis=-1)
    return feat_obs

In [None]:
# parameters
bandwidth = 0.05
n_prediction_steps = 50
# n_opt_steps = 5  # TODO: Not yet implemented

n_train_steps = 1
training_batch_size = 32
sequence_length = n_prediction_steps

n_timesteps = 5_000

In [None]:
x_g = eesys.utils.density_estimation.build_grid_2d(
    low=env.env_observation_space.low,
    high=env.env_observation_space.high,
    points_per_dim=40
)
n_grid_points = x_g.shape[0]

In [None]:
target_distribution = jnp.ones(shape=(n_grid_points, 1))
target_distribution *= 1 / (env.env_observation_space.high - env.env_observation_space.low)**2

obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)

memory = {
    "observations": [obs],
    "actions": []
}

p_est = jnp.zeros([batch_size, n_grid_points, 1])

proposed_actions = aprbs(n_prediction_steps, batch_size, 20, 50, next(data_rng))

model = NeuralEulerODEPendulum(
    obs_dim=env.env_observation_space.shape[-1],
    action_dim=env.action_space.shape[-1],
    width_size=128,
    depth=3,
    key=model_key
)
lr = 1e-3
optim = optax.adabelief(lr)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

start_n_measurments = jnp.array([0])
grad_loss_function = jax.grad(loss_function, argnums=(3))  # derivatve w.r.t. the actions

In [None]:
for k in tqdm(range(n_timesteps)):
    proposed_actions, final_loss = optimize(
        grad_loss_function=grad_loss_function,
        proposed_actions=proposed_actions,
        model=model,
        init_obs=obs,
        init_state=state,
        p_est=p_est,
        x=x_g,
        start_n_measurments=start_n_measurments,
        bandwidth=bandwidth,
        tau=tau,
        target_distribution=target_distribution
    )

    if k % 100 == 0 and k > 0:
        print(f"Loss iteration: {k}: {final_loss[0]}")
        fig, axs = plot_sequence_and_prediction(
            observations=jnp.stack(memory["observations"], axis=1)[0, :],
            actions=jnp.stack(memory["actions"], axis=1)[0, :],
            tau=tau,
            obs_labels=[r"$\theta$", r"$\omega$"],
            actions_labels=[r"$u$"],
            model=env,
            init_obs=obs[0, :],
            init_state=state[0, :],
            proposed_actions=proposed_actions[0, :]
        )
        plt.show()

    # update grid KDE with x_k
    p_est = jax.vmap(update_kde_grid, in_axes=[0, None, 0, None, None])(
        p_est, x_g, obs, k, bandwidth
    )
    start_n_measurments = start_n_measurments + 1

    # apply u_k = \hat{u}_{k+1} and go to x_{k+1}
    action = proposed_actions[:, 0, :]
    obs, _, _, _, state = env.step(action, state)
    proposed_actions = proposed_actions.at[:, :-1, :].set(proposed_actions[:, 1:, :])

    memory["actions"].append(action)  # store u_k
    memory["observations"].append(obs)  # store x_{k+1}

    # update the model
    if k > n_prediction_steps:
        for (i, (observations_batch, actions_batch)) in zip(
            range(n_train_steps), dataloader(memory, training_batch_size, sequence_length, key=loader_key)
        ):
            observations_batch = observations_batch[:, :, 0, :]
            actions_batch = actions_batch[:, :-1, 0, :]
            
            model_training_loss, model, opt_state = make_step(
                model,
                observations_batch,
                actions_batch,
                tau,
                opt_state,
                featurize_theta,
                optim
            )
    if (k % 10) == 0 and k > n_prediction_steps:
        print(f"iter: {k}, Loss: {model_training_loss}")        

In [None]:
fig, axs = plot_model_performance(
    model=model,
    true_observations=jnp.stack(memory["observations"])[700:1700, 0],
    actions=jnp.stack(memory["actions"])[700:1699, 0],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.plot()