In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import pathlib

from functools import partial
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pickle
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import rc
# rc('font',**{'family':'serif','serif':['Helvetica']})
mpl.rcParams['text.usetex'] = True
mpl.rcParams.update({'font.size': 10 * 2.54})
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}\usepackage{amsmath}"

import jax
import jax.numpy as jnp
import jax_dataclasses as jdc
from jax.tree_util import tree_flatten, tree_unflatten

# jax.config.update('jax_platform_name', 'cpu')

# jax.config.update("jax_debug_nans", True)

gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])

import diffrax

In [None]:
from exciting_exciting_systems.evaluation.experiment_utils import extract_metrics_over_timesteps, extract_metrics_over_timesteps_via_interpolation
from exciting_exciting_systems.evaluation.plotting_utils import plot_metrics_by_sequence_length_for_all_algos
from exciting_exciting_systems.evaluation.experiment_utils import get_experiment_ids

from exciting_exciting_systems.utils.density_estimation import select_bandwidth
from exciting_exciting_systems.evaluation.experiment_utils import default_jsd, default_ae, default_mcudsa, default_ksfc

In [None]:
full_column_width = 18.2
half_colmun_width = 8.89

def plot_metrics_by_sequence_length_for_all_algos(data_per_algo, lengths, algo_names, use_log=False):
    assert len(data_per_algo) == len(algo_names), "Mismatch in number of algo results and number of algo names"

    metric_keys = data_per_algo[0].keys()

    fig, axs = plt.subplots(4, figsize=(half_colmun_width, 14), sharex=True) # figsize=(19, 18)
    colors = plt.rcParams["axes.prop_cycle"]()

    for algo_name, data in zip(algo_names, data_per_algo):
        c = next(colors)["color"]
        if c == '#d62728':
            c = next(colors)["color"]

        for metric_idx, metric_key in enumerate(metric_keys):
            mean = jnp.nanmean(jnp.log(data[metric_key]), axis=0) if use_log else jnp.nanmean(data[metric_key], axis=0)
            std = jnp.nanstd(jnp.log(data[metric_key]), axis=0) if use_log else jnp.nanstd(data[metric_key], axis=0)

            axs[metric_idx].plot(
                lengths,
                mean,  # jnp.log(mean) if use_log else mean,
                label=algo_name,
                color=c,
            )
            axs[metric_idx].fill_between(
                lengths,
                mean - std,  # jnp.log(mean - std) if use_log else mean - std,
                mean + std,  # jnp.log(mean + std) if use_log else mean + std,
                color=c,
                alpha=0.1,
            )
            # axs[metric_idx].set_ylabel(("log " if use_log else "") + metric_key)

    for idx, metric_key in enumerate(metric_keys):
        axs[idx].set_ylabel(f"$\mathcal{{L}}_\mathrm{{{metric_key.upper()}}}$")

    axs[-1].set_xlabel("$k$")
    axs[-1].set_xlim(lengths[0], lengths[-1])
    [ax.grid(True) for ax in axs]
    axs[0].legend()
    plt.tight_layout()

    [ax.tick_params(axis="y", direction='in') for ax in axs]
    [ax.tick_params(axis="x", direction='in') for ax in axs]

    # fig.align_ylabels(axs)

    plt.subplots_adjust(hspace=0.05)

    return fig

In [None]:
lengths = jnp.linspace(1000, 15000, 15, dtype=jnp.int32)
lengths

In [None]:
def extract_results(lengths, raw_results_path, algo_names, interpolate_to_lengths, system_name, metrics=None, extra_folders=None):

    all_results_by_metric = {}
    
    for (algo_name, use_interpolation) in zip(algo_names, interpolate_to_lengths):
        full_results_path = raw_results_path / pathlib.Path(algo_name) / pathlib.Path(system_name)
        full_results_path = full_results_path / pathlib.Path(extra_folders) if extra_folders is not None else full_results_path

        print("Extract results for", algo_name, "\n at", full_results_path)

        if not use_interpolation:
            all_results_by_metric[algo_name] = extract_metrics_over_timesteps(
                experiment_ids=get_experiment_ids(full_results_path),
                results_path=full_results_path,
                lengths=lengths,
                metrics=metrics,
            )
        else:
            all_results_by_metric[algo_name] = extract_metrics_over_timesteps_via_interpolation(
                experiment_ids=get_experiment_ids(full_results_path),
                results_path=full_results_path,
                target_lengths=lengths,
                metrics=metrics,
            )
        print("\n")
    return all_results_by_metric

## fluid_tank:

In [None]:
# system_name = "fluid_tank"

# all_fluid_tank_results_by_metric = extract_results(
#     lengths=lengths,
#     raw_results_path=pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/"),
#     algo_names=["dmpe", "sgoats", "perfect_model_dmpe", "igoats"],
#     interpolate_to_lengths=[False, True, False, True],
#     system_name=system_name,
#     extra_folders=None,
#     metrics={
#         "jsd": partial(default_jsd, points_per_dim=50, bandwidth=select_bandwidth(2, 2, 50, 0.3).item()),
#         "ae": default_ae,
#         "mcudsa": partial(default_mcudsa, points_per_dim=50),
#         "ksfc": partial(default_ksfc, points_per_dim=50, eps=1e-6),
#     }
# )
# with open("results/fluid_tank_results.pickle", "wb") as handle:
#     pickle.dump(all_fluid_tank_results_by_metric, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
system_name = "fluid_tank"

with open("results/fluid_tank_results.pickle", 'rb') as handle:
    all_fluid_tank_results_by_metric = pickle.load(handle)

all_fluid_tank_results_by_metric.keys()

In [None]:
# system_name = "fluid_tank"

# test_metrics = extract_results(
#     lengths=lengths,
#     raw_results_path=pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/"),
#     algo_names=["sgoats"],
#     system_name=system_name,
#     extra_folders="old",
#     metrics={
#         "jsd": partial(default_jsd, points_per_dim=50, bandwidth=select_bandwidth(2, 2, 50, 0.3)),
#         "ae": default_ae,
#         "mcudsa": partial(default_mcudsa, points_per_dim=50),
#         "ksfc": partial(default_ksfc, points_per_dim=50, eps=1e-6),
#     }
# )

In [None]:
# pm_dmpe_results_by_metric = all_fluid_tank_results_by_metric["perfect_model_dmpe"]
# dmpe_results_by_metric = all_fluid_tank_results_by_metric["dmpe"]
# sgoats_results_by_metric = all_fluid_tank_results_by_metric["sgoats"] 
# igoats_results_by_metric = all_fluid_tank_results_by_metric["igoats"] 

# plot_metrics_by_sequence_length_for_all_algos(
#     data_per_algo=[pm_dmpe_results_by_metric, dmpe_results_by_metric, sgoats_results_by_metric, igoats_results_by_metric, test_metrics["sgoats"]],
#     lengths=lengths,
#     algo_names=["$\mathrm{PM-DMPE}$", "$\mathrm{DMPE}$", "$\mathrm{sGOATS}$", "$\mathrm{iGOATS}$", "$\mathrm{sGOATS}_\mathrm{old}$"],
#     use_log=True,
# );
# plt.savefig(f"N_loss_test_{system_name}.pdf")

In [None]:
pm_dmpe_results_by_metric = all_fluid_tank_results_by_metric["perfect_model_dmpe"]
dmpe_results_by_metric = all_fluid_tank_results_by_metric["dmpe"]
sgoats_results_by_metric = all_fluid_tank_results_by_metric["sgoats"]["interp"]
igoats_results_by_metric = all_fluid_tank_results_by_metric["igoats"]["interp"]

In [None]:
plot_metrics_by_sequence_length_for_all_algos(
    data_per_algo=[pm_dmpe_results_by_metric, dmpe_results_by_metric, sgoats_results_by_metric, igoats_results_by_metric],
    lengths=lengths,
    algo_names=["$\mathrm{PM-DMPE}$", "$\mathrm{DMPE}$", "$\mathrm{sGOATS}$", "$\mathrm{iGOATS}$"],
    use_log=True,
);
plt.savefig(f"metrics_per_sequence_length_{system_name}.pdf")

## pendulum:

In [None]:
# system_name = "pendulum"

# all_pendulum_results_by_metric = extract_results(
#     lengths=lengths,
#     raw_results_path=pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/"),
#     algo_names=["dmpe", "sgoats", "perfect_model_dmpe", "igoats"],
#     interpolate_to_lengths=[False, True, False, True],
#     system_name=system_name,
#     extra_folders=None,
#     metrics={
#         "jsd": partial(default_jsd, points_per_dim=50, bandwidth=select_bandwidth(2, 3, 50, 0.3)),
#         "ae": default_ae,
#         "mcudsa": partial(default_mcudsa, points_per_dim=50),
#         "ksfc": partial(default_ksfc, points_per_dim=50, eps=1e-6),
#     }
# )
# with open("results/pendulum_results.pickle", "wb") as handle:
#     pickle.dump(all_pendulum_results_by_metric, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
system_name = "pendulum"

with open("results/pendulum_results.pickle", 'rb') as handle:
    all_pendulum_results_by_metric = pickle.load(handle)

all_pendulum_results_by_metric.keys()

In [None]:
# system_name = "pendulum"

# test_metrics = extract_results(
#     lengths=lengths,
#     raw_results_path=pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/"),
#     algo_names=["sgoats"],
#     system_name=system_name,
#     extra_folders="old",
#     metrics={
#         "jsd": partial(default_jsd, points_per_dim=50, bandwidth=select_bandwidth(2, 3, 50, 0.3)),
#         "ae": default_ae,
#         "mcudsa": partial(default_mcudsa, points_per_dim=50),
#         "ksfc": partial(default_ksfc, points_per_dim=50, eps=1e-6),
#     }
# )

In [None]:
# pm_dmpe_results_by_metric = all_pendulum_results_by_metric["perfect_model_dmpe"]
# dmpe_results_by_metric = all_pendulum_results_by_metric["dmpe"]
# sgoats_results_by_metric = all_pendulum_results_by_metric["sgoats"] 
# igoats_results_by_metric = all_pendulum_results_by_metric["igoats"] 

# plot_metrics_by_sequence_length_for_all_algos(
#     data_per_algo=[pm_dmpe_results_by_metric, dmpe_results_by_metric, sgoats_results_by_metric, igoats_results_by_metric, test_metrics["sgoats"]],
#     lengths=lengths,
#     algo_names=["$\mathrm{PM-DMPE}$", "$\mathrm{DMPE}$", "$\mathrm{sGOATS}$", "$\mathrm{iGOATS}$", "$\mathrm{sGOATS}_\mathrm{old}$"],
#     use_log=True,
# );
# plt.savefig(f"N_loss_test_{system_name}.pdf")

In [None]:
pm_dmpe_results_by_metric = all_pendulum_results_by_metric["perfect_model_dmpe"]
dmpe_results_by_metric = all_pendulum_results_by_metric["dmpe"]
sgoats_results_by_metric = all_pendulum_results_by_metric["sgoats"]["interp"]
igoats_results_by_metric = all_pendulum_results_by_metric["igoats"]["interp"]

In [None]:
plot_metrics_by_sequence_length_for_all_algos(
    data_per_algo=[pm_dmpe_results_by_metric, dmpe_results_by_metric, sgoats_results_by_metric, igoats_results_by_metric],
    lengths=lengths,
    algo_names=["$\mathrm{PM-DMPE}$", "$\mathrm{DMPE}$", "$\mathrm{sGOATS}$", "$\mathrm{iGOATS}$"],
    use_log=True,
);
plt.savefig(f"metrics_per_sequence_length_{system_name}.pdf")

## cart pole:

In [None]:
# system_name = "cart_pole"

# all_cart_pole_results_by_metric = extract_results(
#     lengths=lengths,
#     raw_results_path=pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/"),
#     algo_names=["dmpe", "sgoats", "perfect_model_dmpe", "igoats"],
#     interpolate_to_lengths=[False, True, False, True],
#     system_name=system_name,
#     extra_folders=None,
#     metrics={
#         "jsd": partial(default_jsd, points_per_dim=20, bandwidth=select_bandwidth(2, 5, 20, 0.1)),
#         "ae": default_ae,
#         "mcudsa": partial(default_mcudsa, points_per_dim=20),
#         "ksfc": partial(default_ksfc, points_per_dim=20, variance=0.1, eps=1e-6),
#     }
# )
# with open("results/cart_pole_results.pickle", "wb") as handle:
#     pickle.dump(all_cart_pole_results_by_metric, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
system_name = "cart_pole"
with open("results/cart_pole_results.pickle", 'rb') as handle:
    all_cart_pole_results_by_metric = pickle.load(handle)

In [None]:
# system_name = "cart_pole"

# test_metrics = extract_results(
#     lengths=lengths,
#     raw_results_path=pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/"),
#     algo_names=["sgoats"],
#     system_name=system_name,
#     extra_folders="old",
#     metrics={
#         "jsd": partial(default_jsd, points_per_dim=20, bandwidth=select_bandwidth(2, 5, 20, 0.1)),
#         "ae": default_ae,
#         "mcudsa": partial(default_mcudsa, points_per_dim=20),
#         "ksfc": partial(default_ksfc, points_per_dim=20, variance=0.1, eps=1e-6),
#     }
# )

In [None]:
# pm_dmpe_results_by_metric = all_cart_pole_results_by_metric["perfect_model_dmpe"]
# dmpe_results_by_metric = all_cart_pole_results_by_metric["dmpe"]
# sgoats_results_by_metric = all_cart_pole_results_by_metric["sgoats"] 
# igoats_results_by_metric = all_cart_pole_results_by_metric["igoats"] 

# plot_metrics_by_sequence_length_for_all_algos(
#     data_per_algo=[pm_dmpe_results_by_metric, dmpe_results_by_metric, sgoats_results_by_metric, igoats_results_by_metric, test_metrics["sgoats"]],
#     lengths=lengths,
#     algo_names=["$\mathrm{PM-DMPE}$", "$\mathrm{DMPE}$", "$\mathrm{sGOATS}$", "$\mathrm{iGOATS}$", "$\mathrm{sGOATS}_\mathrm{old}$"],
#     use_log=True,
# );
# plt.savefig(f"N_loss_test_{system_name}.pdf")

In [None]:
all_cart_pole_results_by_metric

In [None]:
pm_dmpe_results_by_metric = all_cart_pole_results_by_metric["perfect_model_dmpe"]
dmpe_results_by_metric = all_cart_pole_results_by_metric["dmpe"]
sgoats_results_by_metric = all_cart_pole_results_by_metric["sgoats"]["interp"]
igoats_results_by_metric = all_cart_pole_results_by_metric["igoats"]["interp"]

In [None]:
plot_metrics_by_sequence_length_for_all_algos(
    data_per_algo=[pm_dmpe_results_by_metric, dmpe_results_by_metric, sgoats_results_by_metric, igoats_results_by_metric],
    lengths=lengths,
    algo_names=["$\mathrm{PM-DMPE}$", "$\mathrm{DMPE}$", "$\mathrm{sGOATS}$", "$\mathrm{iGOATS}$"],
    use_log=True,
);
plt.savefig(f"metrics_per_sequence_length_{system_name}.pdf")