# Fourier Neural Operators and their active modes

Loosely speaking, the kernel size of convolutional architectures are the active
modes of an FNO.

In [None]:
import apebench
import jax
import jax.numpy as jnp
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
CONFIGS = [
    {
        "scenario": "diff_adv",
        "task": "predict",
        "net": f"FNO;{fno_modes};18;4;gelu",
        "train": "one",
        "start_seed": 0,
        "num_seeds": 20,
    }
    for fno_modes in [1, 4, 5, 6, 7, 10, 40, 80]
]

In [None]:
(
    metric_df,
    loss_df,
    sample_rollout_df,
    network_weights_list,
) = apebench.run_study_convenience(
    CONFIGS,
    "fno_and_active_modes/",
    do_loss=True,
    do_metrics=True,
    do_sample_rollouts=True,
)

In [None]:
sns.lineplot(loss_df, x="update_step", y="train_loss", hue="net", errorbar=None)
plt.yscale("log")

In [None]:
sns.lineplot(metric_df, x="time_step", y="mean_nRMSE", hue="net", errorbar=None)

### Relate to Spectra

In [None]:
adv_scenario = apebench.scenarios.scenario_dict["diff_adv"]()

In [None]:
test_trj_set = adv_scenario.get_test_data()
# (num_test_samples, test_temporal_horizon+1, num_channels, num_points)
test_trj_set.shape

In [None]:
# Produce the spatial magnitude spectrum for all samples across all time steps
test_trj_set_spectrum = jax.vmap(
    jax.vmap(lambda u: apebench.exponax.get_spectrum(u, power=False))
)(test_trj_set)

# (num_test_samples, test_temporal_horizon+1, num_channels=1, num_points//2+1)
test_trj_set_spectrum.shape

In [None]:
# Let's visualize the spectrum for the zeroth sample at some points in time
SAMPLE_IDX = 0
TIME_STEP_IDX = [0, 1, 2, 5, 10, 50, 100, 200]

for time_step_idx in TIME_STEP_IDX:
    plt.semilogy(
        # Need to index the zeroth channel dimension
        test_trj_set_spectrum[SAMPLE_IDX, time_step_idx, 0],
        label=f"[t]={time_step_idx}",
    )

plt.legend()
plt.grid()

## Extend to more scenarios

In [None]:
CONFIGS = [
    {
        "scenario": scenario,
        "task": "predict",
        "net": f"FNO;{fno_modes};18;4;gelu",
        "train": "one",
        "start_seed": 0,
        "num_seeds": 20,
    }
    for scenario in ["diff_adv", "diff_burgers", "diff_ks"]
    for fno_modes in [1, 4, 5, 6, 7, 10, 40, 80]
]

In [None]:
(
    metric_df,
    loss_df,
    sample_rollout_df,
    network_weights_list,
) = apebench.run_study_convenience(
    CONFIGS,
    "fno_and_active_modes/",
    do_loss=True,
    do_metrics=True,
    do_sample_rollouts=True,
)

In [None]:
loss_facet = sns.relplot(
    data=loss_df,
    x="update_step",
    y="train_loss",
    hue="net",
    col="scenario",
    kind="line",
    estimator="median",
    errorbar=("pi", 50),
)

for ax in loss_facet.axes.flat:
    ax.set_yscale("log")
    ax.grid()

In [None]:
metric_facet = sns.relplot(
    data=metric_df,
    x="time_step",
    y="mean_nRMSE",
    hue="net",
    col="scenario",
    kind="line",
    estimator="median",
    errorbar=("pi", 50),
)

for ax in metric_facet.axes.flat:
    ax.set_ylim(-0.1, 1.1)
    ax.grid()

## Relate to the Spectra

In [None]:
burger_scenario = apebench.scenarios.scenario_dict["diff_burgers"]()
ks_scenario = apebench.scenarios.scenario_dict["diff_ks"]()

In [None]:
burger_test_trj_set = burger_scenario.get_test_data()
ks_test_trj_set = ks_scenario.get_test_data()

burger_test_trj_set_spectrum = jax.vmap(
    jax.vmap(lambda u: apebench.exponax.get_spectrum(u, power=False))
)(burger_test_trj_set)

ks_test_trj_set_spectrum = jax.vmap(
    jax.vmap(lambda u: apebench.exponax.get_spectrum(u, power=False))
)(ks_test_trj_set)

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(15, 5))

for time_step_idx in TIME_STEP_IDX:
    axs[0].semilogy(
        test_trj_set_spectrum[SAMPLE_IDX, time_step_idx, 0],
        label=f"[t]={time_step_idx}",
    )
    axs[0].set_title("Advection")

    axs[1].semilogy(
        burger_test_trj_set_spectrum[SAMPLE_IDX, time_step_idx, 0],
        label=f"[t]={time_step_idx}",
    )
    axs[1].set_title("Burgers")

    axs[2].semilogy(
        ks_test_trj_set_spectrum[SAMPLE_IDX, time_step_idx, 0],
        label=f"[t]={time_step_idx}",
    )
    axs[2].set_title("Kuramoto-Sivashinsky")

# Have joint legend
fig.legend(*axs[0].get_legend_handles_labels(), loc="center right")