# Emulator Architectures across Difficulties for Diffusion PDE

This notebook trains a range of neural emulator architectures:

- Feed-Forward ConvNets (of varying depths)
- ResNet
- UNet
- Dilated ResNet
- FNO

for the diffusion PDE

$$
\frac{\partial u}{\partial t} = \nu \nabla^2 u
$$

using a "difficulty parameterization" as explained in the [APEBench
paper](https://arxiv.org/abs/2411.00180). We use the $\gamma_2$ (difficulty
associated with a PDE with a spatial derivative of order 2) defined by (in 1D)

$$
\gamma_2 = \frac{\nu \Delta t N 2}{L}
$$

Conviently, $\gamma_2 \in [0, 1]$ defines the stability region for the most
compact finite difference scheme of this equation, the [FTCS
method](https://en.wikipedia.org/wiki/FTCS_scheme). It thereby poses a minimum
requirement on the receptive field of the neural emulator as we shall see below.

In [None]:
import apebench
import seaborn as sns

In [None]:
CONFIGS = [
    {
        "scenario": "diff_diff",
        "task": "predict",
        "net": net,
        "train": "one",
        "start_seed": 0,
        "num_seeds": 20,
        "diffusion_gamma": diffusion_gamma,
    }
    for net in [
        *[f"Conv;34;{depth};relu" for depth in [0, 1, 2, 10]],
    ]
    for diffusion_gamma in [
        0.5,
        2.5,
        10.5,
    ]
]

In [None]:
(
    df_metric,
    df_loss,
    _,
    network_list,
) = apebench.run_study_convenience(
    CONFIGS,
    "difficulty_and_receptive_field_diffusion",
    do_loss=True,
)

In [None]:
def split_net(df):
    df["net_type"] = df["net"].apply(lambda x: x.split(";")[0])
    df["hidden_channels"] = df["net"].apply(
        lambda x: int(x.split(";")[1])
        if len(x.split(";")) == 4
        else int(x.split(";")[2])
    )
    df["conv_depth"] = df["net"].apply(
        lambda x: int(x.split(";")[2]) if x.split(";")[0].lower() == "conv" else 10
    )

    return df

In [None]:
df_metric = split_net(df_metric)
df_loss = split_net(df_loss)

In [None]:
facet = sns.relplot(
    df_loss,
    x="update_step",
    y="train_loss",
    style="conv_depth",
    style_order=[10, 2, 1, 0],
    col="diffusion_gamma",
    kind="line",
    estimator="median",
    errorbar=("pi", 50),
)
for ax in facet.axes.flat:
    ax.set_yscale("log")
    ax.grid(True)

In [None]:
facet = sns.relplot(
    df_metric,
    x="time_step",
    y="mean_nRMSE",
    style="conv_depth",
    style_order=[10, 2, 1, 0],
    col="diffusion_gamma",
    kind="line",
    estimator="median",
    errorbar=("pi", 50),
)
for ax in facet.axes.flat:
    ax.grid(True)
    ax.set_ylim(-0.05, 1.05)

Zooming in

In [None]:
facet = sns.relplot(
    df_metric,
    x="time_step",
    y="mean_nRMSE",
    style="conv_depth",
    style_order=[10, 2, 1, 0],
    col="diffusion_gamma",
    kind="line",
    estimator="median",
    errorbar=("pi", 50),
)
for ax in facet.axes.flat:
    ax.grid(True)
    ax.set_ylim(-0.001, 0.07)
    ax.set_xlim(-1, 25)

## Extended Study

In [None]:
CONFIGS_EXTENDED = [
    {
        "scenario": "diff_diff",
        "task": "predict",
        "net": net,
        "train": "one",
        "start_seed": 0,
        "num_seeds": 20,
        "diffusion_gamma": diffusion_gamma,
    }
    for net in [
        *[f"Conv;34;{depth};relu" for depth in [0, 1, 2, 10]],
        "UNet;12;2;relu",  # 27'193 params, 29 receptive field per direction
        "Res;26;8;relu",  # 32'943 params, 16 receptive field per direction
        "FNO;12;18;4;gelu",  # 32'527 params, inf receptive field per direction
        "Dil;2;32;2;relu",  # 31'777 params, 20 receptive field per direction
    ]
    for diffusion_gamma in [
        0.5,
        2.5,
        10.5,
    ]
]

In [None]:
(
    df_metric_extended,
    df_loss_extended,
    _,
    network_list,
) = apebench.run_study_convenience(
    CONFIGS_EXTENDED,
    "difficulty_and_receptive_field_diffusion",
    do_loss=True,
)

In [None]:
df_loss_extended = split_net(apebench.read_in_kwargs(df_loss_extended))
df_metric_extended = split_net(apebench.read_in_kwargs(df_metric_extended))

In [None]:
facet = sns.relplot(
    df_loss_extended,
    x="update_step",
    y="train_loss",
    hue="net_type",
    hue_order=["Conv", "Res", "UNet", "Dil", "FNO"],
    palette=["#377eb8", "#4daf4a", "#e41a1c", "#ff7f00", "#984ea3"],
    style="conv_depth",
    style_order=[10, 2, 1, 0],
    col="diffusion_gamma",
    kind="line",
    estimator="median",
    errorbar=("pi", 50),
)
for ax in facet.axes.flat:
    ax.set_yscale("log")
    ax.grid(True)

In [None]:
facet = sns.relplot(
    df_metric_extended,
    x="time_step",
    y="mean_nRMSE",
    hue="net_type",
    hue_order=["Conv", "Res", "UNet", "Dil", "FNO"],
    palette=["#377eb8", "#4daf4a", "#e41a1c", "#ff7f00", "#984ea3"],
    style="conv_depth",
    style_order=[10, 2, 1, 0],
    col="diffusion_gamma",
    kind="line",
    estimator="median",
    errorbar=("pi", 50),
)
for ax in facet.axes.flat:
    ax.grid(True)
    ax.set_ylim(-0.05, 1.05)

In [None]:
facet = sns.relplot(
    df_metric_extended,
    x="time_step",
    y="mean_nRMSE",
    hue="net_type",
    hue_order=["Conv", "Res", "UNet", "Dil", "FNO"],
    palette=["#377eb8", "#4daf4a", "#e41a1c", "#ff7f00", "#984ea3"],
    style="conv_depth",
    style_order=[10, 2, 1, 0],
    col="diffusion_gamma",
    kind="line",
    estimator="median",
    errorbar=("pi", 50),
)
for ax in facet.axes.flat:
    ax.grid(True)
    ax.set_ylim(-0.001, 0.07)
    ax.set_xlim(-1, 25)