In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
from functools import partial

import pickle

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

from mpl_toolkits.axes_grid1.inset_locator import inset_axes

mpl.rcParams['text.usetex'] = True
mpl.rcParams.update({'font.size': 10 * 2.54})
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}\usepackage{amsmath}"
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp
# jax.config.update("jax_enable_x64", True)
gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

import exciting_exciting_systems
from exciting_exciting_systems.models import NeuralEulerODEPendulum, NeuralODEPendulum, NeuralEulerODE, NeuralEulerODECartpole
from exciting_exciting_systems.models.model_utils import simulate_ahead_with_env
from exciting_exciting_systems.models.model_training import ModelTrainer
from exciting_exciting_systems.excitation import loss_function, Exciter

from exciting_exciting_systems.utils.density_estimation import (
    update_density_estimate_single_observation, update_density_estimate_multiple_observations, DensityEstimate, select_bandwidth
)
from exciting_exciting_systems.utils.signals import aprbs
from exciting_exciting_systems.evaluation.plotting_utils import (
    plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction, plot_model_performance
)
from exciting_exciting_systems.evaluation.experiment_utils import (
    get_experiment_ids, load_experiment_results, quick_eval, evaluate_experiment_metrics, evaluate_algorithm_metrics, evaluate_metrics
)

from exciting_exciting_systems.evaluation.experiment_utils import extract_metrics_over_timesteps

---

- I think I can mostly reuse the code from the quantitative eval?

In [None]:
full_column_width = 18.2
half_colmun_width = 8.89

def plot_metrics_by_sequence_length_for_all_algos(data_per_algo, lengths, algo_names, use_log=False, show_legend=False, show_zoomed=False):
    assert len(data_per_algo) == len(algo_names), "Mismatch in number of algo results and number of algo names"

    metric_keys = data_per_algo[0].keys()

    fig, axs = plt.subplots(1, figsize=(half_colmun_width, 4), sharex=True) # figsize=(19, 18)
    colors = plt.rcParams["axes.prop_cycle"]()

    axs = [axs]
    if show_zoomed:
        inset_ax = axs[0].inset_axes(bounds=[0.55, 0.03, 0.3, 0.3])

    for algo_name, data in zip(algo_names, data_per_algo):
        c = next(colors)["color"]
        if c == '#d62728':
            c = next(colors)["color"]

        metric_key = list(metric_keys)[0]
        mean = jnp.nanmean(jnp.log(data[metric_key]), axis=0) if use_log else jnp.nanmean(data[metric_key], axis=0)
        std = jnp.nanstd(jnp.log(data[metric_key]), axis=0) if use_log else jnp.nanstd(data[metric_key], axis=0)

        axs[0].plot(
            lengths,
            mean,  # jnp.log(mean) if use_log else mean,
            label=algo_name,
            color=c,
        )
        axs[0].fill_between(
            lengths,
            mean - std,  # jnp.log(mean - std) if use_log else mean - std,
            mean + std,  # jnp.log(mean + std) if use_log else mean + std,
            color=c,
            alpha=0.1,
        )
        axs[0].set_ylabel(("log " if use_log else "") + metric_key)

        
        if show_zoomed:
            metric_key = list(metric_keys)[0]
            mean = jnp.nanmean(jnp.log(data[metric_key]), axis=0) if use_log else jnp.nanmean(data[metric_key], axis=0)
            std = jnp.nanstd(jnp.log(data[metric_key]), axis=0) if use_log else jnp.nanstd(data[metric_key], axis=0)
    
            inset_ax.plot(
                lengths,
                mean,  # jnp.log(mean) if use_log else mean,
                label=algo_name,
                color=c,
            )
            inset_ax.fill_between(
                lengths,
                mean - std,  # jnp.log(mean - std) if use_log else mean - std,
                mean + std,  # jnp.log(mean + std) if use_log else mean + std,
                color=c,
                alpha=0.1,
            )
    
            inset_ax.set_xlim(7_500, 10_000)
    
            inset_ax.set_xticks([])
            inset_ax.ticklabel_format(style='sci',scilimits=(-3,4),axis='both')

    if show_zoomed:
        axs[0].indicate_inset_zoom(inset_ax)

    for idx, metric_key in enumerate(metric_keys):
        axs[idx].set_ylabel(f"$\mathcal{{L}}_\mathrm{{{metric_key.upper()}}}$")

    axs[0].set_xlabel("$\mathrm{timesteps}$")
    axs[0].set_xlim(lengths[0] - 100, lengths[-1] + 100)
    axs[0].set_ylim(-0.2, 0.2)
    [ax.grid(True) for ax in axs]
    axs[0].legend()
    
    if not show_legend:
        axs[0].get_legend().remove()
    fig.tight_layout()
    
    return fig

In [None]:
lengths = jnp.linspace(0, 15000, 151, dtype=jnp.int32)
lengths

In [None]:
def extract_results(lengths, raw_results_path, algo_names, system_name, metrics=None, extra_folders=None):

    all_results_by_metric = {}
    
    for algo_name in algo_names:
        full_results_path = raw_results_path / pathlib.Path(algo_name) / pathlib.Path(system_name)
        full_results_path = full_results_path / pathlib.Path(extra_folders) if extra_folders is not None else full_results_path

        print("Extract results for", algo_name, "\n at", full_results_path)
        
        all_results_by_metric[algo_name] = extract_metrics_over_timesteps(
            experiment_ids=get_experiment_ids(full_results_path),
            results_path=full_results_path,
            lengths=lengths,
            metrics=metrics,
            slotted=True,
        )
        print("\n")
    return all_results_by_metric

## fluid_tank:

In [None]:
from exciting_exciting_systems.excitation.excitation_utils import soft_penalty

In [None]:
def soft_penalty_wrapper(observations, actions, a_max=1, penalty_order=1):
    # actions are ignored because they practically cannot be out of the constraints
    return soft_penalty(observations, a_max, penalty_order) / observations.shape[0]

In [None]:
# system_name = "fluid_tank"

# all_fluid_tank_results_by_metric = extract_results(
#     lengths=lengths,
#     raw_results_path=pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/"),
#     algo_names=["dmpe", "sgoats", "perfect_model_dmpe", "igoats"],
#     system_name=system_name,
#     extra_folders=None,
#     metrics={
#         "sc": soft_penalty_wrapper,
#     }
# )
# with open("results/fluid_tank_results_constraints.pickle", "wb") as handle:
#     pickle.dump(all_fluid_tank_results_by_metric, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
system_name = "fluid_tank"
with open("results/fluid_tank_results_constraints.pickle", 'rb') as handle:
    all_fluid_tank_results_by_metric = pickle.load(handle)

In [None]:
for algo in all_fluid_tank_results_by_metric.keys():
    results_for_algo = all_fluid_tank_results_by_metric[algo]
    print(algo)
    print("mean:", np.nanmean(results_for_algo["sc"]))
    print("std:", np.nanstd(results_for_algo["sc"]))
    print("\n")

In [None]:
pm_dmpe_results_by_metric = all_fluid_tank_results_by_metric["perfect_model_dmpe"]
dmpe_results_by_metric = all_fluid_tank_results_by_metric["dmpe"]
sgoats_results_by_metric = all_fluid_tank_results_by_metric["sgoats"]
igoats_results_by_metric = all_fluid_tank_results_by_metric["igoats"] 

plot_metrics_by_sequence_length_for_all_algos(
    data_per_algo=[pm_dmpe_results_by_metric, dmpe_results_by_metric, sgoats_results_by_metric, igoats_results_by_metric],
    lengths=lengths[1:],
    algo_names=["$\mathrm{PM-DMPE}$", "$\mathrm{DMPE}$", "$\mathrm{sGOATS}$", "$\mathrm{iGOATS}$"],
    use_log=False,
    show_legend=True,
);
plt.savefig(f"metrics_per_sequence_length_{system_name}_constraints.pdf")

## pendulum:

In [None]:
# system_name = "pendulum"

# all_pendulum_results_by_metric = extract_results(
#     lengths=lengths,
#     raw_results_path=pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/"),
#     algo_names=["dmpe", "sgoats", "perfect_model_dmpe", "igoats"],
#     system_name=system_name,
#     extra_folders=None,
#     metrics={
#         "sc": soft_penalty_wrapper,
#     }
# )
# with open("results/pendulum_results_constraints.pickle", "wb") as handle:
#     pickle.dump(all_pendulum_results_by_metric, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
system_name = "pendulum"
with open("results/pendulum_results_constraints.pickle", 'rb') as handle:
    all_pendulum_results_by_metric = pickle.load(handle)

In [None]:
for algo in all_pendulum_results_by_metric.keys():
    results_for_algo = all_pendulum_results_by_metric[algo]
    print(algo)
    print("mean:", np.nanmean(results_for_algo["sc"]))
    print("std:", np.nanstd(results_for_algo["sc"]))
    print("\n")

In [None]:
pm_dmpe_results_by_metric = all_pendulum_results_by_metric["perfect_model_dmpe"]
dmpe_results_by_metric = all_pendulum_results_by_metric["dmpe"]
sgoats_results_by_metric = all_pendulum_results_by_metric["sgoats"]
igoats_results_by_metric = all_pendulum_results_by_metric["igoats"] 

plot_metrics_by_sequence_length_for_all_algos(
    data_per_algo=[pm_dmpe_results_by_metric, dmpe_results_by_metric, sgoats_results_by_metric, igoats_results_by_metric],
    lengths=lengths[1:],
    algo_names=["$\mathrm{PM-DMPE}$", "$\mathrm{DMPE}$", "$\mathrm{sGOATS}$", "$\mathrm{iGOATS}$"],
    use_log=False,
    show_zoomed=True,
);
plt.savefig(f"metrics_per_sequence_length_{system_name}_constraints.pdf")

## Cart pole:

In [None]:
# system_name = "cart_pole"

# all_cart_pole_results_by_metric = extract_results(
#     lengths=lengths,
#     raw_results_path=pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/"),
#     algo_names=["dmpe", "sgoats", "perfect_model_dmpe", "igoats"],
#     system_name=system_name,
#     extra_folders=None,
#     metrics={
#         "sc": soft_penalty_wrapper,
#     }
# )
# with open("results/cart_pole_results_constraints.pickle", "wb") as handle:
#     pickle.dump(all_cart_pole_results_by_metric, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
system_name = "cart_pole"
with open("results/cart_pole_results_constraints.pickle", 'rb') as handle:
    all_cart_pole_results_by_metric = pickle.load(handle)

In [None]:
for algo in all_cart_pole_results_by_metric.keys():
    results_for_algo = all_cart_pole_results_by_metric[algo]
    print(algo)
    print("mean:", np.nanmean(results_for_algo["sc"]))
    print("std:", np.nanstd(results_for_algo["sc"]))
    print("\n")

In [None]:
pm_dmpe_results_by_metric = all_cart_pole_results_by_metric["perfect_model_dmpe"]
dmpe_results_by_metric = all_cart_pole_results_by_metric["dmpe"]
sgoats_results_by_metric = all_cart_pole_results_by_metric["sgoats"]
igoats_results_by_metric = all_cart_pole_results_by_metric["igoats"] 

plot_metrics_by_sequence_length_for_all_algos(
    data_per_algo=[pm_dmpe_results_by_metric, dmpe_results_by_metric, sgoats_results_by_metric, igoats_results_by_metric],
    lengths=lengths[1:],
    algo_names=["$\mathrm{PM-DMPE}$", "$\mathrm{DMPE}$", "$\mathrm{sGOATS}$", "$\mathrm{iGOATS}$"],
    use_log=False,
);
plt.savefig(f"metrics_per_sequence_length_{system_name}_constraints.pdf")

In [None]:
system_name

-> build the full plot in latex