# Statistical Analysis of emulator performance

This notebook compares a feed-forward ConvNet and a FNO for emulating the 1D
advection equation. A focus will be on using statistical hypothesis tests to
answer which model is better, under which conditions.

We will discuss the following:

TODO

In [None]:
import apebench
import seaborn as sns
import matplotlib.pyplot as plt
import equinox as eqx
import exponax as ex
import jax
from scipy import stats
import jax.numpy as jnp

In [None]:
CONFIGS = [
    {
        "scenario": "diff_adv",
        "task": "predict",
        "net": net,
        "train": "one",
        "start_seed": 0,
        "num_seeds": 20,
    }
    for net in [
        "Conv;34;10;relu",
        "FNO;12;18;4;gelu",
    ]
]

In [None]:
(
    df_metric,
    df_loss,
    _,
    network_list,
) = apebench.run_study_convenience(
    CONFIGS,
    "statistical_analysis",
    do_loss=True,
)

In [None]:
sns.lineplot(df_loss, x="update_step", y="train_loss", hue="net")
plt.yscale("log")

In [None]:
sns.lineplot(df_metric, x="time_step", y="mean_nRMSE", hue="net")

In [None]:
advection_scenario = apebench.scenarios.difficulty.Advection(
    advection_gamma=2.0, num_test_samples=300
)

In [None]:
fno_data, fno_models = advection_scenario(
    network_config="FNO;12;18;4;gelu",
    num_seeds=20,
)

In [None]:
fno_models

In [None]:
test_ic_set = advection_scenario.get_test_ic_set()
test_trj = advection_scenario.get_test_data()
test_trj_no_init = test_trj[:, 1:]

In [None]:
fno_rollout = eqx.filter_vmap(
    lambda m: jax.vmap(ex.rollout(m, advection_scenario.test_temporal_horizon))(
        test_ic_set
    )
)(fno_models)

In [None]:
fno_rollout.shape

In [None]:
test_trj.shape, test_trj_no_init.shape

In [None]:
metric_rollout = jax.vmap(
    jax.vmap(jax.vmap(ex.metrics.nRMSE)),
    in_axes=(0, None),
)(fno_rollout, test_trj_no_init)

In [None]:
metric_rollout.shape

In [None]:
plt.hist(metric_rollout[0, :, 0])

In [None]:
p_value_for_normality = jnp.array(
    [
        [
            stats.normaltest(metric_rollout[s, :, t]).pvalue
            for t in range(advection_scenario.test_temporal_horizon)
        ]
        for s in range(20)
    ]
)

For most seeds, most of the time snap shots, the distribution over the 30 test
samples is likely not normally distributed.

Well... actually, it is

In [None]:
plt.semilogy(p_value_for_normality[:20].T)
plt.hlines(0.05, 0, 200, colors="r", linestyles="--", linewidth=3)

In [None]:
mean_metric_rollout = jnp.mean(metric_rollout, axis=1)

In [None]:
p_value_for_normality_mean = jnp.array(
    [
        stats.shapiro(mean_metric_rollout[:, t]).pvalue
        for t in range(advection_scenario.test_temporal_horizon)
    ]
)

In [None]:
plt.plot(p_value_for_normality_mean)

# Hence
we might need a non-parametric test to compare the two models???