# Performance across metrics

This notebook shows what a difference it can make when comparing different metrics

In [None]:
import apebench
import jax
import jax.numpy as jnp
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats
import pandas as pd

RMSE is always a good choice because it is consistent with the L2 function norm!

In [None]:
REPORT_METRICS = ",".join(
    [
        # Normalized metrics, always a good starting point
        "mean_nRMSE",
        "mean_nMAE",
        "mean_nMSE",
        # Absolute metrics, be careful!
        "mean_RMSE",
        "mean_MAE",
        "mean_MSE",
        # Symmetric metrics: bounded range, nice!
        "mean_sRMSE",
        "mean_sMSE",
        "mean_sMAE",
        # Frequency range of the bandlimited initial condition
        "mean_fourier_nRMSE;0;5;0",
        "mean_fourier_RMSE;0;5;0",
        # Frequency range slightly above the bandlimited initial condition, caution!
        # Cannot use normalized metric because the reference is zero! Linear PDEs on
        # periodic BCs remain bandlimited
        "mean_fourier_RMSE;6;10;0",
        # All frequencies beyond
        "mean_fourier_RMSE;11;81;0",
        # How do the derivatives match?
        "mean_fourier_RMSE;0;5;1",
        "mean_fourier_RMSE;0;5;2",
        "mean_fourier_RMSE;0;81;1",
        "mean_fourier_RMSE;0;81;2",
        # Sobolev
        "mean_H1_RMSE",
    ]
)
CONFIGS = [
    {
        "scenario": "diff_adv",
        "task": "predict",
        "net": "Res;26;8;relu",  # 32'943 params, 16 receptive field per direction
        "start_seed": s,
        "num_seeds": 10,
        "report_metrics": REPORT_METRICS,
    }
    for s in [
        0,
    ]  # 10, 20, 30, 40]
    for (num_training_steps, unrolled_steps) in [
        (10_0, 5),
        # (50_000, 1),
        # (10_000, 1),
    ]
]

In [None]:
advection_scenario = apebench.scenarios.difficulty.Advection(
    report_metrics=REPORT_METRICS,
)

In [None]:
data, trained_nets = advection_scenario(
    task_config="predict",
    network_config="Res;26;8;relu",  # 32'943 params, 16 receptive field per direction
    train_config="one",  # one-step supervised learning
    num_seeds=10,
)

In [None]:
loss_data = apebench.melt_loss(data)
metric_data = apebench.melt_metrics(data, metric_name=REPORT_METRICS.split(","))

In [None]:
metric_data_further_flattened = pd.melt(
    metric_data,
    id_vars=apebench._utils.BASE_NAMES
    + [
        "time_step",
    ],
    value_vars=REPORT_METRICS.split(","),
    var_name="metric_name",
    value_name="metric_value",
)

In [None]:
# Lesson one: Know whether your metric is consistent with a norm or not!
sns.lineplot(
    metric_data_further_flattened,
    x="time_step",
    y="metric_value",
    hue="metric_name",
    hue_order=["mean_nRMSE", "mean_nMSE", "mean_nMAE"],
    estimator="median",
    errorbar=("pi", 50),
)
plt.ylim(0, 1)

In [None]:
# Lesson 2: Metric orders of magnitude can vary, so better use a normalized
# version
sns.lineplot(
    metric_data_further_flattened,
    x="time_step",
    y="metric_value",
    hue="metric_name",
    hue_order=["mean_RMSE", "mean_MSE", "mean_MAE"],
    estimator="median",
    errorbar=("pi", 50),
)

In [None]:
# Symmetric is also an option
sns.lineplot(
    metric_data_further_flattened,
    x="time_step",
    y="metric_value",
    hue="metric_name",
    hue_order=["mean_sRMSE", "mean_sMSE", "mean_sMAE"],
    estimator="median",
    errorbar=("pi", 50),
)

In [None]:
sns.lineplot(
    metric_data_further_flattened,
    x="time_step",
    y="metric_value",
    hue="metric_name",
    hue_order=[
        "mean_RMSE",
        "mean_fourier_RMSE;0;5;0",
        "mean_fourier_RMSE;6;10;0",
        "mean_fourier_RMSE;11;81;0",
    ],
    estimator="median",
    errorbar=("pi", 50),
)

In [None]:
sns.lineplot(
    metric_data_further_flattened,
    x="time_step",
    y="metric_value",
    hue="metric_name",
    hue_order=[
        "mean_fourier_RMSE;0;5;1",
        "mean_fourier_RMSE;0;5;2",
        "mean_fourier_RMSE;0;81;1",
        "mean_fourier_RMSE;0;81;2",
    ],
    estimator="median",
    errorbar=("pi", 50),
)
plt.yscale("log")