In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['text.usetex'] = True
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}"
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp
# jax.config.update("jax_enable_x64", True)
gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

import exciting_exciting_systems
from exciting_exciting_systems.models import NeuralEulerODEPendulum, NeuralODEPendulum, NeuralEulerODE
from exciting_exciting_systems.models.model_utils import simulate_ahead_with_env
from exciting_exciting_systems.models.model_training import ModelTrainer
from exciting_exciting_systems.excitation import loss_function, Exciter

from exciting_exciting_systems.utils.density_estimation import (
    update_density_estimate_single_observation, update_density_estimate_multiple_observations, DensityEstimate
)
from exciting_exciting_systems.utils.signals import aprbs
from exciting_exciting_systems.evaluation.plotting_utils import (
    plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction, plot_model_performance
)
from exciting_exciting_systems.evaluation.experiment_utils import (
    get_experiment_ids, load_experiment_results, quick_eval, evaluate_experiment_metrics, evaluate_algorithm_metrics, evaluate_metrics
)

---

In [None]:
def identity(x):
    return x

def featurize_theta(obs_action):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is 
    transformed to sin(phi) and cos(phi) for comparison in the loss."""

    feat_obs_action = np.stack([np.sin(obs_action[..., 0] * np.pi), np.cos(obs_action[..., 0] * np.pi)], axis=-1)
    feat_obs_action = np.concatenate([feat_obs_action, obs_action[..., 1:]], axis=-1)
    
    return feat_obs_action

In [None]:
batch_size = 1
tau = 5

env_params = dict(
    batch_size=batch_size,
    tau=tau,
    max_height=3,
    max_inflow=0.2,
    base_area=jnp.pi,
    orifice_area=jnp.pi * 0.1**2,
    c_d=0.6,
    g=9.81,
    env_solver=diffrax.Tsit5(),
)

In [None]:
batch_size = 1
tau = 2e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    action_constraints={"torque": 5},
    static_params={"g": 9.81, "l": 1, "m": 1},
    solver=diffrax.Tsit5(),
    tau=tau,
)

In [None]:
# env_params = dict(
#     batch_size=1,
#     tau=2e-2,
#     max_force=5,
#     static_params={
#         "mu_p": 0,
#         "mu_c": 0,
#         "l": 1,
#         "m_p": 1,
#         "m_c": 1,
#         "g": 9.81,
#     },
#     physical_constraints={
#         "deflection": 1,
#         "velocity": 10,
#         "theta": jnp.pi,
#         "omega": 10,
#     },
#     env_solver=diffrax.Tsit5(),
# )
# env = excenvs.make(
#     env_id="CartPole-v0",
#     batch_size=env_params["batch_size"],
#     action_constraints={"force": env_params["max_force"]},
#     physical_constraints=env_params["physical_constraints"],
#     static_params=env_params["static_params"],
#     solver=env_params["env_solver"],
#     tau=env_params["tau"],
# )

## DMPE quick experiment eval:

In [None]:
results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/igoats/fluid_tank")

for exp_idx, identifier in enumerate(get_experiment_ids(results_path)):
    print(exp_idx)
    quick_eval(env, identifier, results_path, None)

In [None]:
len(get_experiment_ids(results_path))

In [None]:
params, observations, actions, model = load_experiment_results(get_experiment_ids(results_path)[-1], results_path, None)

In [None]:
observations.shape

In [None]:
a = 3000
b = a+1000

plt.plot(observations[a:b, 0])
plt.show()
plt.plot(observations[a:b, 1])
plt.show()
plt.plot(observations[a:b, 2])
plt.show()
plt.plot(observations[a:b, 3])
plt.show()

plt.plot(actions[a:b, 0])
plt.show()

- how do you even evaluate the coverage for 5d?

In [None]:
results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/dmpe/pendulum")

for exp_idx, identifier in enumerate(get_experiment_ids(results_path)):
    print(exp_idx)
    quick_eval_pendulum(env, identifier, results_path, None)

---

In [None]:
params, observations, actions, model = load_experiment_results(get_experiment_ids(results_path)[0], results_path, None)

In [None]:
plot_sequence(observations, actions, env.tau, env.obs_description, env.action_description)

In [None]:
density_est = DensityEstimate.from_dataset(observations, actions, points_per_dim=50, bandwidth=0.01)
exciting_exciting_systems.evaluation.plotting_utils.plot_2d_kde_as_contourf(
    density_est.p, density_est.x_g, [env.obs_description[0], env.action_description[0]]
)
plt.scatter(density_est.x_g[:, 0], density_est.x_g[:, 1], s=1)

In [None]:
density_est = DensityEstimate.from_dataset(observations, actions, points_per_dim=50, bandwidth=0.05)
exciting_exciting_systems.evaluation.plotting_utils.plot_2d_kde_as_contourf(
    density_est.p, density_est.x_g, [env.obs_description[0], env.action_description[0]]
)

## GOATS quick experiment eval:

In [None]:
results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/sgoats/fluid_tank/")

for idx, identifier in enumerate(get_experiment_ids(results_path)):
    print(idx)
    quick_eval(env, identifier, results_path, None)

In [None]:
20 * 20 * tau

In [None]:
20 * 100 * tau

In [None]:
tau

In [None]:
params, observations, actions, model = load_experiment_results(get_experiment_ids(results_path)[1], results_path, None)
observations.shape

In [None]:
results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/igoats/fluid_tank/")

for idx, identifier in enumerate(get_experiment_ids(results_path)[-1:]):
    print(idx)
    quick_eval_pendulum(env, identifier, results_path, None)

## Checkout difference in support points in metric computation for sGOATS

In [None]:
from exciting_exciting_systems.evaluation.experiment_utils import extract_metrics_over_timesteps

In [None]:
results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/sgoats/pendulum")

for idx, identifier in enumerate(get_experiment_ids(results_path)):
    print(idx)
    quick_eval_pendulum(env, identifier, results_path, None)

In [None]:
params, observations, actions, _ = load_experiment_results(get_experiment_ids(results_path)[:1][0], results_path, None, to_array=False)

In [None]:
len(observations)

In [None]:
lengths = jnp.linspace(1000, 15000, 15, dtype=jnp.int32)
results_set_dist = extract_metrics_over_timesteps(
    experiment_ids=get_experiment_ids(results_path)[:1],
    results_path=results_path,
    lengths=lengths,
)

In [None]:
lengths = [len(subsequence) for subsequence in observations]
lengths = np.cumsum(lengths[:-1])
lengths

In [None]:
lengths

In [None]:
results_nonset_dist = extract_metrics_over_timesteps(
    experiment_ids=get_experiment_ids(results_path)[:1],
    results_path=results_path,
    lengths=lengths,
)

In [None]:
def plot_metrics_by_sequence_length_for_all_algos(data_per_algo, lengths, algo_names, use_log=False):
    assert len(data_per_algo) == len(algo_names), "Mismatch in number of algo results and number of algo names"

    metric_keys = data_per_algo[0].keys()

    fig, axs = plt.subplots(3, figsize=(19, 18), sharex=True)
    colors = plt.rcParams["axes.prop_cycle"]()

    for length, algo_name, data in zip(lengths, algo_names, data_per_algo):
        c = next(colors)["color"]

        for metric_idx, metric_key in enumerate(metric_keys):
            mean = jnp.nanmean(jnp.log(data[metric_key]), axis=0) if use_log else jnp.nanmean(data[metric_key], axis=0)
            std = jnp.nanstd(jnp.log(data[metric_key]), axis=0) if use_log else jnp.nanstd(data[metric_key], axis=0)

            axs[metric_idx].plot(
                length,
                mean,  # jnp.log(mean) if use_log else mean,
                label=algo_name,
                color=c,
            )
            axs[metric_idx].fill_between(
                length,
                mean - std,  # jnp.log(mean - std) if use_log else mean - std,
                mean + std,  # jnp.log(mean + std) if use_log else mean + std,
                color=c,
                alpha=0.1,
            )
            axs[metric_idx].set_ylabel(("log " if use_log else "") + metric_key)

    for idx, metric_key in enumerate(metric_keys):
        axs[idx].set_ylabel(f"$\mathcal{{L}}_\mathrm{{{metric_key.upper()}}}$")

    axs[-1].set_xlabel("$\mathrm{timesteps}$")
    axs[-1].set_xlim(lengths[0][0] - 100, lengths[0][-1] + 100)
    [ax.grid(True) for ax in axs]
    axs[0].legend()
    plt.tight_layout()

    return fig


In [None]:
plot_metrics_by_sequence_length_for_all_algos(
    [results_set_dist, results_nonset_dist],
    [jnp.linspace(1000, 15000, 15, dtype=jnp.int32), lengths],
    ["set", "nonset"],
)