In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import pathlib

import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import rc
# rc('font',**{'family':'serif','serif':['Helvetica']})
mpl.rcParams['text.usetex'] = True
mpl.rcParams.update({'font.size': 10 * 2.54})
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}\usepackage{amsmath}"

import jax
import jax.numpy as jnp
import jax_dataclasses as jdc
from jax.tree_util import tree_flatten, tree_unflatten

jax.config.update('jax_platform_name', 'cpu')

# gpus = jax.devices()
# jax.config.update("jax_default_device", gpus[0])

import diffrax

In [None]:
from exciting_exciting_systems.evaluation.experiment_utils import extract_metrics_over_timesteps
from exciting_exciting_systems.evaluation.plotting_utils import plot_metrics_by_sequence_length_for_all_algos
from exciting_exciting_systems.evaluation.experiment_utils import get_experiment_ids

In [None]:
lengths = jnp.arange(500, 5100, 500)
lengths

In [None]:
igoats_results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/igoats/fluid_tank/run1")
igoats_results_by_metric = extract_metrics_over_timesteps(
    experiment_ids=get_experiment_ids(igoats_results_path),
    results_path=igoats_results_path,
    lengths=lengths,
)

In [None]:
sgoats_results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/sgoats/fluid_tank/run1")
sgoats_results_by_metric = extract_metrics_over_timesteps(
    experiment_ids=get_experiment_ids(sgoats_results_path),
    results_path=sgoats_results_path,
    lengths=lengths,
)

In [None]:
dmpe_results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/dmpe/fluid_tank")
dmpe_results_by_metric = extract_metrics_over_timesteps(
    experiment_ids=get_experiment_ids(dmpe_results_path),
    results_path=dmpe_results_path,
    lengths=lengths,
)

In [None]:
pm_dmpe_results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/exciting-exciting-systems/eval/results/perfect_model_dmpe/fluid_tank")
pm_dmpe_results_by_metric = extract_metrics_over_timesteps(
    experiment_ids=get_experiment_ids(pm_dmpe_results_path),
    results_path=pm_dmpe_results_path,
    lengths=lengths,
)

plot results:

In [None]:
def plot_metrics_by_sequence_length_for_all_algos(data_per_algo, lengths, algo_names, use_log=False):
    assert len(data_per_algo) == len(algo_names), "Mismatch in number of algo results and number of algo names"

    metric_keys = data_per_algo[0].keys()

    fig, axs = plt.subplots(3, figsize=(19, 18), sharex=True)
    colors = plt.rcParams["axes.prop_cycle"]()

    for algo_name, data in zip(algo_names, data_per_algo):
        c = next(colors)["color"]

        for metric_idx, metric_key in enumerate(metric_keys):
            mean = jnp.nanmean(jnp.log(data[metric_key]), axis=0) if use_log else jnp.nanmean(data[metric_key], axis=0)
            std = jnp.nanstd(jnp.log(data[metric_key]), axis=0) if use_log else jnp.nanstd(data[metric_key], axis=0)

            axs[metric_idx].plot(
                lengths,
                mean,  # jnp.log(mean) if use_log else mean,
                label=algo_name,
                color=c,
            )
            axs[metric_idx].fill_between(
                lengths,
                mean - std,  # jnp.log(mean - std) if use_log else mean - std,
                mean + std,  # jnp.log(mean + std) if use_log else mean + std,
                color=c,
                alpha=0.1,
            )
            axs[metric_idx].set_ylabel(("log " if use_log else "") + metric_key)

    for idx, metric_key in enumerate(metric_keys):
        axs[idx].set_ylabel(f"$\mathcal{{L}}_\mathrm{{{metric_key.upper()}}}$")

    axs[-1].set_xlabel("$\mathrm{timesteps}$")
    axs[-1].set_xlim(lengths[0] - 100, lengths[-1] + 100)
    [ax.grid(True) for ax in axs]
    axs[0].legend()
    plt.tight_layout()

    return fig


In [None]:
plot_metrics_by_sequence_length_for_all_algos(
    data_per_algo=[pm_dmpe_results_by_metric, dmpe_results_by_metric, sgoats_results_by_metric, igoats_results_by_metric],
    lengths=lengths,
    algo_names=["$\mathrm{PM-DMPE}$", "$\mathrm{DMPE}$", "$\mathrm{sGOATS}$", "$\mathrm{iGOATS}$"],
    use_log=True,
);
plt.savefig(f"metrics_per_sequence_length.pdf")

In [None]:
dmpe_results_by_metric_wo_outlier = {key: jnp.delete(results_array, 5, axis=0) for key, results_array in dmpe_results_by_metric.items()}

In [None]:
plot_metrics_by_sequence_length_for_all_algos(
    data_per_algo=[pm_dmpe_results_by_metric, dmpe_results_by_metric_wo_outlier, sgoats_results_by_metric, igoats_results_by_metric],
    lengths=lengths,
    algo_names=["$\mathrm{PM-DMPE}$", "$\mathrm{DMPE \ without \ outlier} $", "$\mathrm{sGOATS}$", "$\mathrm{iGOATS}$"],
    use_log=True,
);
plt.savefig(f"metrics_per_sequence_length_without_outlier.pdf")