In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['text.usetex'] = True
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}"
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp
# jax.config.update("jax_enable_x64", True)
gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

import dmpe
from dmpe.models import NeuralEulerODEPendulum, NeuralODEPendulum, NeuralEulerODE, NeuralEulerODECartpole
from dmpe.models.model_utils import simulate_ahead_with_env
from dmpe.models.model_training import ModelTrainer
from dmpe.excitation import loss_function, Exciter

from dmpe.utils.density_estimation import (
    update_density_estimate_single_observation, update_density_estimate_multiple_observations, DensityEstimate
)
from dmpe.utils.signals import aprbs
from dmpe.evaluation.plotting_utils import (
    plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction, plot_model_performance
)
from dmpe.evaluation.experiment_utils import (
    get_experiment_ids, load_experiment_results, quick_eval, evaluate_experiment_metrics, evaluate_algorithm_metrics, evaluate_metrics
)
from dmpe.utils.density_estimation import select_bandwidth
from dmpe.evaluation.experiment_utils import default_jsd, default_ae, default_mcudsa, default_ksfc

---

In [None]:
def identity(x):
    return x

def featurize_theta(obs_action):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is 
    transformed to sin(phi) and cos(phi) for comparison in the loss."""

    feat_obs_action = np.stack([np.sin(obs_action[..., 0] * np.pi), np.cos(obs_action[..., 0] * np.pi)], axis=-1)
    feat_obs_action = np.concatenate([feat_obs_action, obs_action[..., 1:]], axis=-1)
    
    return feat_obs_action

In [None]:
batch_size = 1
tau = 5

env_params = dict(
    batch_size=batch_size,
    tau=tau,
    max_height=3,
    max_inflow=0.2,
    base_area=jnp.pi,
    orifice_area=jnp.pi * 0.1**2,
    c_d=0.6,
    g=9.81,
    env_solver=diffrax.Tsit5(),
)
env = excenvs.make(
    "FluidTank-v0",
    physical_constraints=dict(height=env_params["max_height"]),
    action_constraints=dict(inflow=env_params["max_inflow"]),
    static_params=dict(
        base_area=env_params["base_area"],
        orifice_area=env_params["orifice_area"],
        c_d=env_params["c_d"],
        g=env_params["g"],
    ),
    tau=env_params["tau"],
    solver=env_params["env_solver"],
)

In [None]:
batch_size = 1
tau = 2e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    action_constraints={"torque": 5},
    static_params={"g": 9.81, "l": 1, "m": 1},
    solver=diffrax.Tsit5(),
    tau=tau,
)

In [None]:
env_params = dict(
    batch_size=1,
    tau=2e-2,
    max_force=10,
    static_params={
        "mu_p": 0.002,
        "mu_c": 0.5,
        "l": 0.5,
        "m_p": 0.1,
        "m_c": 1,
        "g": 9.81,
    },
    physical_constraints={
        "deflection": 2.4,
        "velocity": 8,
        "theta": jnp.pi,
        "omega": 8,
    },
    env_solver=diffrax.Tsit5(),
)
env = excenvs.make(
    env_id="CartPole-v0",
    batch_size=env_params["batch_size"],
    action_constraints={"force": env_params["max_force"]},
    physical_constraints=env_params["physical_constraints"],
    static_params=env_params["static_params"],
    solver=env_params["env_solver"],
    tau=env_params["tau"],
)

## quick experiment eval:

In [None]:
for algo in ["dmpe", "perfect_model_dmpe", "igoats", "sgoats"]:
    for env in ["fluid_tank", "pendulum", "cart_pole"]:
        results_path = pathlib.Path(f"/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/{algo}/{env}")
        print(algo, env, ":", len(get_experiment_ids(results_path)) == 30)

In [None]:
results_path_1 = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/igoats/fluid_tank/old")
results_path_2 = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/igoats/fluid_tank/")

for i in range(len(get_experiment_ids(results_path))):
    params_1, observations_1, actions_1, _ = load_experiment_results(get_experiment_ids(results_path_1)[i], results_path_1, None)
    params_2, observations_2, actions_2, _ = load_experiment_results(get_experiment_ids(results_path_2)[i], results_path_2, None)

    assert jnp.all(observations_1[:-1] == observations_2)
    assert jnp.all(actions_1 == actions_2)

print("Datasets are equal!")

In [None]:
params_2, observations_2, actions_2, _ = load_experiment_results(get_experiment_ids(results_path)[30], results_path, None, to_array=False)

In [None]:
for obs in observations_2:
    print(len(obs))

In [None]:
len(get_experiment_ids(results_path))

In [None]:
for 

In [None]:
# for i in range(-30, 0):
#     params, _, _, _ = load_experiment_results(get_experiment_ids(results_path)[i], results_path, None)
#     print(params["seed"])

In [None]:
from dmpe.evaluation.experiment_utils import default_jsd, default_ae, default_mcudsa, default_ksfc

In [None]:
test_loss = partial(default_ksfc, points_per_dim=20, eps=0.01)

In [None]:
params, observations, actions, _ = load_experiment_results(get_experiment_ids(results_path)[0], results_path, None)

In [None]:
test_loss(observations, actions)

In [None]:
raise

In [None]:
results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/sgoats/pendulum/test_N")

for exp_idx, identifier in enumerate(get_experiment_ids(results_path)[-30:]):
    print(exp_idx)
    quick_eval(env, identifier, results_path, None)

In [None]:
params, observations, actions, model = load_experiment_results(get_experiment_ids(results_path)[93], results_path, None)

In [None]:
params

In [None]:
plot_sequence(observations[:3000], actions[:2999], env.tau, env.obs_description, env.action_description)

In [None]:
from dmpe.utils.density_estimation import build_grid
from dmpe.utils.metrics import kiss_space_filling_cost

In [None]:
from dmpe.evaluation.experiment_utils import default_jsd

In [None]:
cov_factor = 0.05

In [None]:
for data_points in [
    build_grid(2, -1, 1, 100),
    jnp.concatenate([observations[:3000],actions[:3000]], axis=-1),
    np.random.uniform(-1, 1, size=(5000, 2)),
    np.random.uniform(-0.7, 0.7, size=(5000, 2)),
    jnp.ones((5000, 2)),
    np.random.normal(0, 0.1, size=(5000, 2)),
    np.random.normal(0, 0.1, size=(10000, 2)),
    np.zeros((5000, 2)),
    np.zeros((10000, 2))
]:
    value = kiss_space_filling_cost(data_points=data_points, support_points=support_points, covariance_matrix=jnp.eye(2) * cov_factor)
    print(value)

In [None]:
%debug

In [None]:
data_points = build_grid(2, -1, 1, 100)
support_points = build_grid(2, -1, 1, 100)
kiss_space_filling_cost(data_points=data_points, support_points=support_points, covariance_matrix=jnp.eye(2) * cov_factor)

In [None]:
data_points = jnp.concatenate([observations[:3000], actions[:3000]], axis=-1)
support_points = build_grid(2, -1, 1, 100)
kiss_space_filling_cost(data_points=data_points, support_points=support_points, covariance_matrix=jnp.eye(2) * cov_factor)

In [None]:
data_points = np.random.uniform(-1, 1, size=(5000, 2))
support_points = build_grid(2, -1, 1, 100)
kiss_space_filling_cost(data_points=data_points, support_points=support_points, covariance_matrix=jnp.eye(2) * cov_factor)

In [None]:
data_points = np.random.uniform(-0.7, 0.7, size=(5000, 2))
support_points = build_grid(2, -1, 1, 100)
kiss_space_filling_cost(data_points=data_points, support_points=support_points, covariance_matrix=jnp.eye(2) * cov_factor)

In [None]:
data_points = jnp.concatenate([observations[:1000], actions[:1000]], axis=-1)
support_points = build_grid(2, -1, 1, 100)
data_points = jnp.concatenate([data_points, jnp.ones((2000, 2))], axis=0)
kiss_space_filling_cost(data_points=data_points, support_points=support_points, covariance_matrix=jnp.eye(2) * cov_factor)

In [None]:
support_points = build_grid(2, -1, 1, 100)
data_points = jnp.ones((100, 2))
kiss_space_filling_cost(data_points=data_points, support_points=support_points, covariance_matrix=jnp.eye(2) * cov_factor)

In [None]:
data_points.shape

In [None]:
%debug

In [None]:
datapoints = 

In [None]:
a = 3000
b = a+1000

plt.plot(observations[a:b, 0])
plt.show()
plt.plot(observations[a:b, 1])
plt.show()
plt.plot(observations[a:b, 2])
plt.show()
plt.plot(observations[a:b, 3])
plt.show()

plt.plot(actions[a:b, 0])
plt.show()

- how do you even evaluate the coverage for 5d?

In [None]:
results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/dmpe/pendulum")

for exp_idx, identifier in enumerate(get_experiment_ids(results_path)):
    print(exp_idx)
    quick_eval_pendulum(env, identifier, results_path, None)

---

In [None]:
params, observations, actions, model = load_experiment_results(get_experiment_ids(results_path)[0], results_path, None)

In [None]:
plot_sequence(observations, actions, env.tau, env.obs_description, env.action_description)

In [None]:
density_est = DensityEstimate.from_dataset(observations, actions, points_per_dim=50, bandwidth=0.01)
dmpe.evaluation.plotting_utils.plot_2d_kde_as_contourf(
    density_est.p, density_est.x_g, [env.obs_description[0], env.action_description[0]]
)
plt.scatter(density_est.x_g[:, 0], density_est.x_g[:, 1], s=1)

In [None]:
density_est = DensityEstimate.from_dataset(observations, actions, points_per_dim=50, bandwidth=0.05)
dmpe.evaluation.plotting_utils.plot_2d_kde_as_contourf(
    density_est.p, density_est.x_g, [env.obs_description[0], env.action_description[0]]
)

## GOATS quick experiment eval:

In [None]:
results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/sgoats/fluid_tank/")

for idx, identifier in enumerate(get_experiment_ids(results_path)):
    print(idx)
    quick_eval(env, identifier, results_path, None)

In [None]:
20 * 20 * tau

In [None]:
20 * 100 * tau

In [None]:
tau

In [None]:
params, observations, actions, model = load_experiment_results(get_experiment_ids(results_path)[1], results_path, None)
observations.shape

In [None]:
results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/igoats/fluid_tank/")

for idx, identifier in enumerate(get_experiment_ids(results_path)[-1:]):
    print(idx)
    quick_eval_pendulum(env, identifier, results_path, None)

## Checkout difference in support points in metric computation for sGOATS

- checkout full numpy implementation. Is the jitting maybe a problem?

In [None]:
from dmpe.evaluation.experiment_utils import extract_metrics_over_timesteps, evaluate_experiment_metrics

In [None]:
results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/sgoats/cart_pole")

# for idx, identifier in enumerate(get_experiment_ids(results_path)[-5:]):
#     print(idx)
#     quick_eval(env, identifier, results_path, None)

In [None]:
params, observations, actions, _ = load_experiment_results(get_experiment_ids(results_path)[7], results_path, None, to_array=False)

In [None]:
len(observations)

In [None]:
from dmpe.related_work.np_reimpl.metric_utils import default_ae, default_mcudsa, default_ksfc

start = time.time()

evaluate_experiment_metrics(
    np.concatenate(observations),
    np.concatenate(actions),
    metrics={
        # "jsd": partial(default_jsd, points_per_dim=20, bandwidth=select_bandwidth(2, 5, 20, 0.1)),
        "ae": default_ae,
        "mcudsa": partial(default_mcudsa, points_per_dim=20),
        "ksfc": partial(default_ksfc, points_per_dim=20, variance=0.1, eps=1e-6),
    }
)

end = time.time()
print("Computation time numpy:", end - start)

In [None]:
from dmpe.evaluation.experiment_utils import default_jsd, default_ae, default_mcudsa, default_ksfc

start = time.time()

evaluate_experiment_metrics(
    jnp.array(np.concatenate(observations)),
    jnp.array(np.concatenate(actions)),
    metrics={
        "jsd": partial(default_jsd, points_per_dim=20, bandwidth=select_bandwidth(2, 5, 20, 0.1)),
        "ae": default_ae,
        "mcudsa": partial(default_mcudsa, points_per_dim=20),
        "ksfc": partial(default_ksfc, points_per_dim=20, variance=0.1, eps=1e-6),
    }
)

end = time.time()
print("Computation time jax:", end - start)

In [None]:
from dmpe.evaluation.experiment_utils import default_jsd, default_ae, default_mcudsa, default_ksfc

In [None]:
start = time.time()

lengths = jnp.linspace(1000, 15000, 15, dtype=jnp.int32)
results_set_dist = extract_metrics_over_timesteps(
    experiment_ids=get_experiment_ids(results_path)[:1],
    results_path=results_path,
    lengths=lengths,
    metrics={
        # "jsd": partial(default_jsd, points_per_dim=20, bandwidth=select_bandwidth(2, 5, 20, 0.1)),
        "ae": default_ae,
        "mcudsa": partial(default_mcudsa, points_per_dim=20),
        "ksfc": partial(default_ksfc, points_per_dim=20, variance=0.1, eps=1e-6),
    }
)

end = time.time()
print("Computation time jax:", end - start)

In [None]:
params, observations, actions, _ = load_experiment_results(get_experiment_ids(results_path)[0], results_path, None, to_array=False)
raw_lengths = [len(subsequence) for subsequence in observations]
raw_lengths = np.cumsum(raw_lengths[:-1])
raw_lengths

In [None]:
start = time.time()

results_nonset_dist = extract_metrics_over_timesteps(
    experiment_ids=get_experiment_ids(results_path)[:1],
    results_path=results_path,
    lengths=lengths,
    metrics={
        "jsd": partial(default_jsd, points_per_dim=20, bandwidth=select_bandwidth(2, 5, 20, 0.1)),
        "ae": default_ae,
        "mcudsa": partial(default_mcudsa, points_per_dim=20),
        "ksfc": partial(default_ksfc, points_per_dim=20, variance=0.1, eps=1e-6),
    }
)

end = time.time()
print("Computation time jax:", end - start)

In [None]:
lengths

In [None]:
raw_results = results_nonset_dist["jsd"][0]

In [None]:
interpolated_results = jnp.interp(
    x=lengths,
    xp=raw_lengths,
    fp=results_nonset_dist["jsd"][0],
)

In [None]:
plt.plot(lengths, interpolated_results, 'r.')
plt.plot(raw_lengths, raw_results, 'b.')

In [None]:
def plot_metrics_by_sequence_length_for_all_algos(data_per_algo, lengths, algo_names, use_log=False):
    assert len(data_per_algo) == len(algo_names), "Mismatch in number of algo results and number of algo names"

    metric_keys = data_per_algo[0].keys()

    fig, axs = plt.subplots(4, figsize=(19, 18), sharex=True)
    colors = plt.rcParams["axes.prop_cycle"]()

    for length, algo_name, data in zip(lengths, algo_names, data_per_algo):
        c = next(colors)["color"]

        for metric_idx, metric_key in enumerate(metric_keys):
            mean = jnp.nanmean(jnp.log(data[metric_key]), axis=0) if use_log else jnp.nanmean(data[metric_key], axis=0)
            std = jnp.nanstd(jnp.log(data[metric_key]), axis=0) if use_log else jnp.nanstd(data[metric_key], axis=0)

            axs[metric_idx].plot(
                length,
                mean,  # jnp.log(mean) if use_log else mean,
                label=algo_name,
                color=c,
            )
            axs[metric_idx].fill_between(
                length,
                mean - std,  # jnp.log(mean - std) if use_log else mean - std,
                mean + std,  # jnp.log(mean + std) if use_log else mean + std,
                color=c,
                alpha=0.1,
            )
            axs[metric_idx].set_ylabel(("log " if use_log else "") + metric_key)

    for idx, metric_key in enumerate(metric_keys):
        axs[idx].set_ylabel(f"$\mathcal{{L}}_\mathrm{{{metric_key.upper()}}}$")

    axs[-1].set_xlabel("$\mathrm{timesteps}$")
    axs[-1].set_xlim(lengths[0][0] - 100, lengths[0][-1] + 100)
    [ax.grid(True) for ax in axs]
    axs[0].legend()
    plt.tight_layout()

    return fig


In [None]:
plot_metrics_by_sequence_length_for_all_algos(
    [results_set_dist, results_nonset_dist],
    [jnp.linspace(1000, 15000, 15, dtype=jnp.int32), lengths],
    ["set", "nonset"],
    use_log=True
);
plt.savefig("test_interpolation.pdf")

## Check blockwise metrics

In [None]:
from dmpe.utils.metrics import blockwise_ksfc, kiss_space_filling_cost, blockwise_mcudsa, MC_uniform_sampling_distribution_approximation
from dmpe.utils.density_estimation import build_grid

In [None]:
results_path = pathlib.Path(f"/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/dmpe/cart_pole")
params, observations, actions, _ = load_experiment_results(get_experiment_ids(results_path)[0], results_path, None)

In [None]:
observations.shape

In [None]:
support_points = build_grid(4, -1, 1, 14)
support_points.shape

In [None]:
blockwise_value = blockwise_mcudsa(observations, support_points).item()
full_value = MC_uniform_sampling_distribution_approximation(observations, support_points).item()

blockwise_value == full_value

In [None]:
full_value/blockwise_value

In [None]:
blockwise_value = blockwise_ksfc(observations, support_points, variances=0.1, eps=1e-6).item()
full_value = kiss_space_filling_cost(observations, support_points, variances=0.1, eps=1e-6).item()

blockwise_value == full_value

In [None]:
full_value/blockwise_value

In [None]:
blockwise_value

In [None]:
full_value