In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

from functools import partial
import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['text.usetex'] = True
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}"
import plotly.express as px


In [None]:
import jax
import jax.numpy as jnp
# jax.config.update("jax_enable_x64", False)
gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

import dmpe
from dmpe.models import NeuralEulerODEPendulum, NeuralODEPendulum
from dmpe.models.model_utils import simulate_ahead_with_env
from dmpe.models.model_training import ModelTrainer
from dmpe.excitation import loss_function, Exciter

from dmpe.utils.density_estimation import (
    update_density_estimate_single_observation, update_density_estimate_multiple_observations, DensityEstimate, select_bandwidth
)
from dmpe.utils.signals import aprbs
from dmpe.evaluation.plotting_utils import (
    plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction, plot_model_performance
)

In [None]:
from dmpe.models.model_utils import ModelEnvWrapperPendulum

---

In [None]:
from dmpe.algorithms import excite_with_dmpe, default_dmpe, default_dmpe_parameterization

In [None]:
env_params = dict(batch_size=1, tau=2e-2, max_torque=5, g=9.81, l=1, m=1, env_solver=diffrax.Tsit5())
env = excenvs.make(
    env_id="Pendulum-v0",
    batch_size=env_params["batch_size"],
    action_constraints={"torque": env_params["max_torque"]},
    static_params={"g": env_params["g"], "l": env_params["l"], "m": env_params["m"]},
    solver=env_params["env_solver"],
    tau=env_params["tau"],
)

In [None]:
def featurize_theta(obs):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is
    transformed to sin(phi) and cos(phi) for comparison in the loss."""
    feat_obs = jnp.stack([jnp.sin(obs[..., 0] * jnp.pi), jnp.cos(obs[..., 0] * jnp.pi), obs[..., 1]], axis=-1)
    return feat_obs

default_dmpe(env, seed=0, featurize=featurize_theta, model_class=NeuralEulerODEPendulum, plot_every=200)

In [None]:
env_params = dict(batch_size=1, tau=2e-2, max_torque=5, g=9.81, l=1, m=1, env_solver=diffrax.Tsit5())
env = excenvs.make(
    env_id="Pendulum-v0",
    batch_size=env_params["batch_size"],
    action_constraints={"torque": env_params["max_torque"]},
    static_params={"g": env_params["g"], "l": env_params["l"], "m": env_params["m"]},
    solver=env_params["env_solver"],
    tau=env_params["tau"],
)
alg_params = dict(
    bandwidth=None,
    n_prediction_steps=50,
    points_per_dim=50,
    action_lr=1e-1,
    n_opt_steps=10,
    rho_obs=1,
    rho_act=1,
    penalty_order=2,
    clip_action=True,
    n_starts=5,
    reuse_proposed_actions=True,
)
alg_params["bandwidth"] = float(
    select_bandwidth(
        delta_x=2,
        dim=env.physical_state_dim + env.action_dim,
        n_g=alg_params["points_per_dim"],
        percentage=0.3,
    )
)

exp_params = dict(
    seed=None,
    n_timesteps=15_000,
    model_class=None,
    env_params=env_params,
    alg_params=alg_params,
    model_trainer_params=None,
    model_params=None,
    model_env_wrapper=ModelEnvWrapperPendulum,
)
seed = 0

exp_params["seed"] = int(seed)

# setup PRNG
key = jax.random.PRNGKey(seed=exp_params["seed"])
data_key, _, _, expl_key, key = jax.random.split(key, 5)
data_rng = PRNGSequence(data_key)

# initial guess
proposed_actions = aprbs(exp_params["alg_params"]["n_prediction_steps"], env.batch_size, 1, 10, next(data_rng))[0]

# run excitation algorithm
observations, actions, model, density_estimate, losses, proposed_actions = excite_with_dmpe(
    env,
    exp_params,
    proposed_actions,
    None,
    expl_key,
)

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=22)#8)

data_key, model_key, loader_key, expl_key, key = jax.random.split(key, 5)
data_rng = PRNGSequence(data_key)

In [None]:
env_params = dict(batch_size=1, tau=2e-2, max_torque=5, g=9.81, l=1, m=1, env_solver=diffrax.Tsit5())
env = excenvs.make(
    env_id="Pendulum-v0",
    batch_size=env_params["batch_size"],
    action_constraints={"torque": env_params["max_torque"]},
    static_params={"g": env_params["g"], "l": env_params["l"], "m": env_params["m"]},
    solver=env_params["env_solver"],
    tau=env_params["tau"],
)

### Test simulation:

- starting from the intial state/obs ($\mathbf{x}_0$ / $\mathbf{y}_0$)
- apply $N = 999$ actions $\mathbf{u}_0 \dots \mathbf{u}_N$ (**here**: random APRBS actions)
- which results in the state trajectory $\mathbf{x}_0 ... \mathbf{x}_N+1$ with $N+1 = 1000$ elements

In [None]:
obs, state = env.reset()
obs = obs[0]

n_steps = 4000

# actions = aprbs(n_steps, batch_size, 1, 10, next(data_rng))[0]


actions = jnp.ones((n_steps, 1))
actions = actions.at[30:].set(-1)

actions = actions.at[100:].set(0)

In [None]:
observations, _ = simulate_ahead_with_env(env, obs, state, actions)

print("actions.shape:", actions.shape)
print("observations.shape:", observations.shape)

print(" \n One of the trajectories:")
fig, axs = plot_sequence(
    observations=observations,
    actions=actions,
    tau=env.tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.savefig("Pendulum_pushup_tsit5.pdf")
plt.show()

In [None]:
observations[-5:-1]

In [None]:
from dmpe.algorithms import excite_with_dmpe
from dmpe.utils.density_estimation import select_bandwidth

In [None]:
seed=22

In [None]:
def featurize_theta(obs):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is 
    transformed to sin(phi) and cos(phi) for comparison in the loss."""
    feat_obs = jnp.stack([jnp.sin(obs[..., 0] * jnp.pi), jnp.cos(obs[..., 0] * jnp.pi), obs[..., 1]], axis=-1)
    return feat_obs

In [None]:
def featurize_theta(obs):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is
    transformed to sin(phi) and cos(phi) for comparison in the loss."""
    feat_obs = jnp.stack([jnp.sin(obs[..., 0] * jnp.pi), jnp.cos(obs[..., 0] * jnp.pi), obs[..., 1]], axis=-1)
    return feat_obs

env_params = dict(batch_size=1, tau=2e-2, max_torque=5, g=9.81, l=1, m=1, env_solver=diffrax.Tsit5())
env = excenvs.make(
    env_id="Pendulum-v0",
    batch_size=env_params["batch_size"],
    action_constraints={"torque": env_params["max_torque"]},
    static_params={"g": env_params["g"], "l": env_params["l"], "m": env_params["m"]},
    solver=env_params["env_solver"],
    tau=env_params["tau"],
)
alg_params = dict(
    bandwidth=None,
    n_prediction_steps=50,
    points_per_dim=50,
    action_lr=1e-1,
    n_opt_steps=10,
    rho_obs=1,
    rho_act=1,
    penalty_order=2,
    clip_action=True,
)
alg_params["bandwidth"] = select_bandwidth(
    delta_x=2,
    dim=env.physical_state_dim + env.action_dim,
    n_g=alg_params["points_per_dim"],
    percentage=0.3,
)

model_trainer_params = dict(
    start_learning=alg_params["n_prediction_steps"],
    training_batch_size=128,
    n_train_steps=1,
    sequence_length=alg_params["n_prediction_steps"],
    featurize=featurize_theta,
    model_lr=1e-4,
)
model_params = dict(obs_dim=env.physical_state_dim, action_dim=env.action_dim, width_size=128, depth=3, key=None)

exp_params = dict(
    seed=None,
    n_timesteps=15_000,
    model_class=NeuralEulerODEPendulum,
    env_params=env_params,
    alg_params=alg_params,
    model_trainer_params=model_trainer_params,
    model_params=model_params,
)

key = jax.random.PRNGKey(seed=seed)
data_key, model_key, loader_key, expl_key, key = jax.random.split(key, 5)
data_rng = PRNGSequence(data_key)

model_params["key"] = model_key
exp_params["model_params"] = model_params

# initial guess
proposed_actions = aprbs(alg_params["n_prediction_steps"], env.batch_size, 1, 10, next(data_rng))[0]

# run excitation algorithm
observations, actions, model, density_estimate, losses, proposed_actions = excite_with_dmpe(
    env, exp_params, proposed_actions, loader_key, expl_key, plot_every=100
)

In [None]:
bandwidth = 0.05
n_prediction_steps = 50

dim_obs_space = 2
dim_action_space = 1

dim = dim_obs_space + dim_action_space
points_per_dim = 50
n_grid_points=points_per_dim**dim

n_timesteps = 15_000

In [None]:
obs, state = env.reset()
obs = obs[0]

observations = jnp.zeros((n_timesteps, dim_obs_space))
observations = observations.at[0].set(obs)
actions = jnp.zeros((n_timesteps-1, dim_action_space))

proposed_actions = aprbs(n_prediction_steps, batch_size, 1, 10, next(data_rng))[0]

In [None]:
exciter = Exciter(
    loss_function=loss_function,
    grad_loss_function=jax.value_and_grad(loss_function, argnums=(2)),
    excitation_optimizer=optax.lbfgs(),
    tau=tau,
    n_opt_steps=10,
    target_distribution=jnp.ones(shape=(n_grid_points, 1)) * 1 / (1 - (-1))**dim,
    rho_obs=1,
    rho_act=1
)

model_trainer = ModelTrainer(
    start_learning=n_prediction_steps,
    training_batch_size=128,
    n_train_steps=1,
    sequence_length=n_prediction_steps,
    featurize=featurize_theta,
    model_optimizer=optax.adabelief(1e-4),
    tau=tau
)

# density_estimate = DensityEstimate(
#     p=jnp.zeros([batch_size, n_grid_points, 1]),
#     x_g=eesys.utils.density_estimation.build_grid_2d(
#         low=env.env_observation_space.low,
#         high=env.env_observation_space.high,
#         points_per_dim=points_per_dim
#     ),
#     bandwidth=jnp.array([bandwidth]),
#     n_observations=jnp.array([0])
# )

density_estimate = DensityEstimate(
    p=jnp.zeros([n_grid_points, 1]),
    x_g=dmpe.utils.density_estimation.build_grid_3d(
        low=-1,
        high=1,
        points_per_dim=points_per_dim
    ),
    bandwidth=jnp.array([bandwidth]),
    n_observations=jnp.array([0])
)

# model = NeuralODEPendulum(
#     solver=diffrax.Euler(),
#     obs_dim=dim_obs_space,
#     action_dim=dim_action_space,
#     width_size=128,
#     depth=3,
#     key=model_key
# )

model = NeuralEulerODEPendulum(
    obs_dim=dim_obs_space,
    action_dim=dim_action_space,
    width_size=128,
    depth=3,
    key=model_key
)

opt_state_model = model_trainer.model_optimizer.init(eqx.filter(model, eqx.is_inexact_array))

In [None]:
from dmpe.algorithms import excite_and_fit

In [None]:
observations, actions, model, density_estimate, losses, proposed_actions = excite_and_fit(
    n_timesteps=n_timesteps,
    env=env,
    model=model,
    obs=obs,
    state=state,
    proposed_actions=proposed_actions,
    exciter=exciter,
    model_trainer=model_trainer,
    density_estimate=density_estimate,
    observations=observations,
    actions=actions,
    opt_state_model=opt_state_model,
    loader_key=loader_key,
    expl_key=expl_key,
    plot_every=250,
)

In [None]:
fig, axs = plot_sequence(
    observations,
    actions,
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.savefig("results/plots/dmpe_example_result.pdf")

plt.plot()

In [None]:
raise

---

In [None]:
if dim == 2:
    fig, axs = eesys.evaluation.plotting_utils.plot_2d_kde_as_surface(
        density_estimate.p, density_estimate.x_g, [r"$\theta$", r"$\omega$"]
    )
    fig.suptitle("Vanilla KDE")
    # fig.savefig("excited_pendulum_kde_surface.png")
    plt.show()
    
    fig, axs = eesys.evaluation.plotting_utils.plot_2d_kde_as_surface(
        jnp.abs(density_estimate.p - exciter.target_distribution), density_estimate.x_g, [r"$\theta$", r"$\omega$"]
    )
    fig.suptitle("Difference")
    
    plt.show()
    
    fig, axs, cax = eesys.evaluation.plotting_utils.plot_2d_kde_as_contourf(
        jnp.abs(density_estimate.p - exciter.target_distribution), density_estimate.x_g, [r"$\theta$", r"$\omega$"]
    )
    plt.colorbar(cax)
    fig.suptitle("Abs Difference")
    
    plt.show()

In [None]:
raise

---
### Look at the actions:

In [None]:
def build_grid_3d(low, high, points_per_dim):
    x1, x2, x3 = [
        jnp.linspace(low, high, points_per_dim),
        jnp.linspace(low, high, points_per_dim),
        jnp.linspace(low, high, points_per_dim)
    ]

    x_g = jnp.meshgrid(*[x1, x2, x3])
    x_g = jnp.stack([_x for _x in x_g], axis=-1)
    x_g = x_g.reshape(-1, 3)

    assert x_g.shape[0] == points_per_dim**3
    return x_g

In [None]:
jnp.concatenate([observations[0:-1, :], actions], axis=-1).shape

In [None]:
points_per_dim = 40
n_grid_points=points_per_dim**3
density_estimate = DensityEstimate(
    p=jnp.zeros([batch_size, n_grid_points, 1]),
    x_g=build_grid_3d(-1, 1, points_per_dim),
    bandwidth=jnp.array([bandwidth]),
    n_observations=jnp.array([0])
)

density_estimate = jax.vmap(
    update_density_estimate_multiple_observations,
    in_axes=(DensityEstimate(0, None, None, None), 0),
    out_axes=(DensityEstimate(0, None, None, None))
)(
    density_estimate,
    jnp.concatenate([observations[0:-1, :], actions], axis=-1)[None],
)

In [None]:
density_estimate.p.shape

In [None]:
x_plot = density_estimate.x_g.reshape((points_per_dim, points_per_dim, points_per_dim, 3))

fig, axs = plt.subplots(
    figsize=(6, 6)
)

ims = []
for i in range(points_per_dim):
    # fig, axs = plt.subplots(
    #     figsize=(6, 6)
    # )
    cax = axs.contourf(
        # x_plot[:, :, 0, :-1][..., 0],
        #x_plot[:, :, 0, :-1][..., 1],
        density_estimate.p[0].reshape((points_per_dim, points_per_dim, points_per_dim))[:, :, i],
        #jnp.sum(density_estimate.p[0].reshape((points_per_dim, points_per_dim, points_per_dim)), axis=-1),
        antialiased=False,
        levels=100,
        alpha=0.9,
        cmap=plt.cm.coolwarm
    )
    ims.append([cax])
    # plt.title(jnp.linspace(-1, 1, points_per_dim)[i])
    # plt.show()

In [None]:
import matplotlib.animation as animation

In [None]:
ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True,
                                repeat_delay=1000)

In [None]:
writer = animation.PillowWriter(fps=5,
                                metadata=dict(artist='Me'),
                                bitrate=1800)
ani.save('opt_wrt_obs_and_act.gif', writer=writer)