Investigates 3(b) of the APEBench paper

In [None]:
import apebench
import jax
import jax.numpy as jnp
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats
import pandas as pd

In [None]:
CONFIGS = [
    {
        "scenario": "diff_adv",
        "task": "predict",
        "net": "Res;26;8;relu",  # 32'943 params, 16 receptive field per direction
        "train": f"sup;{unrolled_steps:02d}",
        "start_seed": s,
        "num_seeds": 10,
        "advection_gamma": 10.5,
        "optim_config": f"adam;{num_training_steps};warmup_cosine;0.0;1e-3;{num_training_steps//5}",
    }
    for s in [
        0,
    ]  # 10, 20, 30, 40]
    for (num_training_steps, unrolled_steps) in [
        (10_000, 5),
        (50_000, 1),
        (10_000, 1),
    ]
    # for unrolled_steps in [
    #     1,
    #     # 2,
    #     # 3,
    #     5,
    #     # 10,
    #     # 15,
    # ]
]

In [None]:
(
    metric_df,
    loss_df,
    sample_rollout_df,
    network_weights_list,
) = apebench.run_study_convenience(
    CONFIGS,
    "unrolled_training_for_high_difficulty/",
    do_loss=True,
    do_metrics=True,
    do_sample_rollouts=True,
)

In [None]:
loss_df = apebench.split_train(loss_df)
metric_df = apebench.split_train(metric_df)
loss_df = apebench.read_in_kwargs(loss_df)
metric_df = apebench.read_in_kwargs(metric_df)

In [None]:
def get_setup(df):
    df["num_training_steps"] = df["optim_config"].apply(lambda x: int(x.split(";")[1]))
    df["setup"] = df["rollout"].astype(str) + "-" + df["num_training_steps"].astype(str)
    return df

In [None]:
loss_df = get_setup(loss_df)
metric_df = get_setup(metric_df)

In [None]:
sns.lineplot(loss_df, x="update_step", y="train_loss", hue="setup")
plt.yscale("log")

In [None]:
sns.lineplot(metric_df, x="time_step", y="mean_nRMSE", hue="setup")

In [None]:
sns.lineplot(metric_df, x="time_step", y="mean_nRMSE", hue="setup")
plt.yscale("log")

In [None]:
sns.catplot(
    data=metric_df,
    x="time_step",
    order=[1, 2, 5],
    y="mean_nRMSE",
    hue="setup",
    kind="bar",
)

In [None]:
stats.ttest_ind(
    metric_df.query("setup == '1-10000' and time_step == 5")["mean_nRMSE"],
    metric_df.query("setup == '5-10000' and time_step == 5")["mean_nRMSE"],
).pvalue

In [None]:
# metric_df.groupby(["time_step"]).apply(
#     lambda x: x.assign(
#         p_1_short_over5=stats.ttest_ind(
#             x.query("setup == '1-10000'")["mean_nRMSE"],
#             x.query("setup == '5-10000'")["mean_nRMSE"]
#         ).pvalue
#     ),
#     include_groups=False
# )

In [None]:
p_5_better_than_1_10000 = metric_df.groupby(["time_step"]).apply(
    lambda x: stats.ttest_ind(
        x.query("setup == '5-10000'")["mean_nRMSE"],
        x.query("setup == '1-10000'")["mean_nRMSE"],
        alternative="less",
    ).pvalue,
    include_groups=False,
)

p_5_better_than_1_50000 = metric_df.groupby(["time_step"]).apply(
    lambda x: stats.ttest_ind(
        x.query("setup == '5-10000'")["mean_nRMSE"],
        x.query("setup == '1-50000'")["mean_nRMSE"],
        alternative="less",
    ).pvalue,
    include_groups=False,
)

p_1_50000_better_than_1_10000 = metric_df.groupby(["time_step"]).apply(
    lambda x: stats.ttest_ind(
        x.query("setup == '1-50000'")["mean_nRMSE"],
        x.query("setup == '1-10000'")["mean_nRMSE"],
        alternative="less",
    ).pvalue,
    include_groups=False,
)

In [None]:
plt.plot(p_5_better_than_1_10000, label="5-10000 better than 1-10000")
plt.plot(p_5_better_than_1_50000, label="5-10000 better than 1-50000")
plt.plot(p_1_50000_better_than_1_10000, label="1-50000 better than 1-10000")
plt.hlines(0.05, 0, 200, colors="red", linestyles="dashed", label="p=0.05")
plt.hlines(0.01, 0, 200, colors="red", linestyles="dotted", label="p=0.01")
# plt.ylim(-0.01, 0.11)
plt.yscale("log")
plt.legend()
plt.xlabel("Time step")
plt.ylabel("p-value")