In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['text.usetex'] = True
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}"
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp
# jax.config.update("jax_enable_x64", True)
gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

import exciting_exciting_systems
from exciting_exciting_systems.models import NeuralEulerODEPendulum, NeuralODEPendulum, NeuralEulerODE
from exciting_exciting_systems.models.model_utils import simulate_ahead_with_env
from exciting_exciting_systems.models.model_training import ModelTrainer
from exciting_exciting_systems.excitation import loss_function, Exciter

from exciting_exciting_systems.utils.density_estimation import (
    update_density_estimate_single_observation, update_density_estimate_multiple_observations, DensityEstimate
)
from exciting_exciting_systems.utils.signals import aprbs
from exciting_exciting_systems.evaluation.plotting_utils import (
    plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction, plot_model_performance
)
from exciting_exciting_systems.evaluation.experiment_utils import (
    get_experiment_ids, load_experiment_results, quick_eval_pendulum, evaluate_experiment_metrics, evaluate_algorithm_metrics, evaluate_metrics
)

---

In [None]:
def identity(x):
    return x

def featurize_theta(obs_action):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is 
    transformed to sin(phi) and cos(phi) for comparison in the loss."""

    feat_obs_action = np.stack([np.sin(obs_action[..., 0] * np.pi), np.cos(obs_action[..., 0] * np.pi)], axis=-1)
    feat_obs_action = np.concatenate([feat_obs_action, obs_action[..., 1:]], axis=-1)
    
    return feat_obs_action

## DMPE quick experiment eval:

In [None]:
batch_size = 1
tau = 2e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    action_constraints={"torque": 5},
    static_params={"g": 9.81, "l": 1, "m": 1},
    solver=diffrax.Euler(),
    tau=tau,
)

In [None]:
batch_size = 1
tau = 5e-1

env = excenvs.make(
    "FluidTank-v0",
    solver=diffrax.Euler(),
    tau=tau,
)

In [None]:
results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/perfect_model_dmpe/fluid_tank/15k")

for exp_idx, identifier in enumerate(get_experiment_ids(results_path)):
    print(exp_idx)
    quick_eval_pendulum(env, identifier, results_path, None)

In [None]:
results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/perfect_model_dmpe/fluid_tank")

for exp_idx, identifier in enumerate(get_experiment_ids(results_path)):
    print(exp_idx)
    quick_eval_pendulum(env, identifier, results_path, None)

In [None]:
results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/dmpe/fluid_tank")

for exp_idx, identifier in enumerate(get_experiment_ids(results_path)):
    print(exp_idx)
    quick_eval_pendulum(env, identifier, results_path, None)

---

In [None]:
params, observations, actions, model = load_experiment_results(get_experiment_ids(results_path)[0], results_path, None)

In [None]:
plot_sequence(observations, actions, env.tau, env.obs_description, env.action_description)

In [None]:
density_est = DensityEstimate.from_dataset(observations, actions, points_per_dim=50, bandwidth=0.01)
exciting_exciting_systems.evaluation.plotting_utils.plot_2d_kde_as_contourf(
    density_est.p, density_est.x_g, [env.obs_description[0], env.action_description[0]]
)
plt.scatter(density_est.x_g[:, 0], density_est.x_g[:, 1], s=1)

In [None]:
density_est = DensityEstimate.from_dataset(observations, actions, points_per_dim=50, bandwidth=0.05)
exciting_exciting_systems.evaluation.plotting_utils.plot_2d_kde_as_contourf(
    density_est.p, density_est.x_g, [env.obs_description[0], env.action_description[0]]
)

## GOATS quick experiment eval:

In [None]:
results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/sgoats/fluid_tank")

for identifier in get_experiment_ids(results_path):
    quick_eval_pendulum(env, identifier, results_path, None)

In [None]:
results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/igoats/fluid_tank")

for identifier in get_experiment_ids(results_path):
    quick_eval_pendulum(env, identifier, results_path, None)