In [None]:
import numpy as np
import matplotlib.pyplot as plt
import random
import os
import torch 

seed = 2023
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
import dolfin as df
import time
from utils import *
from utils_training import *
import prepare_data
from utils_compare_methods import *
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pandas as pd
import gc

sns.set_theme()
sns.set_context("paper")
sns.set(rc={"xtick.bottom": True, "ytick.left": True})
colors = sns.color_palette("mako").as_hex()
my_cmap = sns.color_palette("viridis", as_cmap=True)

%load_ext autoreload
%autoreload 2


In [None]:
save_figs = True
small_data = False
data = DataLoader(small_data)
agent = Agent(
    data,
    level=2,
    relative=True,
    squared=False,
    initial_lr=5e-3,
    n_modes=10,
    width=20,
    batch_size=32,
    pad_prop=0.05,
    pad_mode="reflect",
    l2_lambda=1e-3,
)
model = agent.model
device = agent.device

models_repo = "./models"
images_repo = "./images"
best_model = torch.load(f"{models_repo}/best_model.pkl")
model.load_state_dict(best_model["model_state_dict"])
model.eval()

print(f"Best epoch = {best_model['epoch']}")
if not (os.path.exists(f"{images_repo}/")) and save_figs:
    os.makedirs(f"{images_repo}/")

# Compare the errors of $\phi$-FEM, standard FEM and FNO with respect to a reference fine standard FEM solution

In [None]:
def compare_std_phi_fem_and_fno(param, coeffs_ls, Plot=False):
    standard_solver = StandardFEMSolver(params=param, coeffs_ls=coeffs_ls)
    solver = PhiFemSolver_error(nb_cell=64 - 1, params=param, coeffs_ls=coeffs_ls)
    mu0, mu1, sigma_x, sigma_y, amplitude, alpha, beta = param[0]

    u_ref, V_ref, dx_ref = standard_solver.solve_one(0, 0.002, reference_fem=True)
    (
        u_phi_fem,
        V_phi_fem,
        dx_phi_fem,
        _h_phi,
        cell_selection_phi,
        submesh_construction_phi,
        ghost_cell_selection_phi,
        resolution_time_phi,
    ) = solver.solve_one(0)

    (
        u_std,
        _h_std,
        construction_time_standard,
        resolution_time_standard,
    ) = standard_solver.solve_one(0, 0.0224, reference_fem=False)

    F = generate_F_numpy(mu0, mu1, sigma_x, sigma_y, amplitude, 64)
    G = generate_G_numpy(alpha, beta, 64)
    phi = generate_phi_numpy(coeffs_ls, 0.4, 64)
    X = generate_manual_new_data_numpy(F, phi, G).to(device)
    x_normed = data.x_normalizer.encode(X)
    start_call = time.time()
    Y_normed = model(x_normed)
    end_call = time.time()
    time_fno = end_call - start_call
    Y = data.y_normalizer.decode(Y_normed)
    predicted_solution = (
        (X[:, 1, :, :] * Y[:, 0, :, :] + X[:, 2, :, :]).cpu().detach().numpy()
    )
    predicted_solution = np.reshape(predicted_solution, (64, 64))
    predicted_sol_fenics = convert_numpy_matrix_to_fenics(predicted_solution, 64, 1)

    u_phi_fem_proj = df.project(
        u_phi_fem,
        V_ref,
        solver_type="gmres",
        preconditioner_type="hypre_amg",
    )

    l2_error_phi_fem = (
        df.assemble((((u_ref - u_phi_fem_proj)) ** 2) * dx_ref) ** (0.5)
    ) / (df.assemble((((u_ref)) ** 2) * dx_ref) ** (0.5))

    u_std_fem_proj = df.project(
        u_std, V_ref, solver_type="gmres", preconditioner_type="hypre_amg"
    )
    l2_error_std_fem = (
        df.assemble((((u_ref - u_std_fem_proj)) ** 2) * dx_ref) ** (0.5)
    ) / (df.assemble((((u_ref)) ** 2) * dx_ref) ** (0.5))

    predicted_sol_fenics_proj_V_ref = df.project(
        predicted_sol_fenics,
        V_ref,
        solver_type="gmres",
        preconditioner_type="hypre_amg",
    )

    l2_error_fno = (
        df.assemble((((u_ref - predicted_sol_fenics_proj_V_ref)) ** 2) * dx_ref)
        ** (0.5)
    ) / (df.assemble((((u_ref)) ** 2) * dx_ref) ** (0.5))

    if Plot:
        plt.figure(figsize=(16, 4))

        ax1 = plt.subplot(141)
        img = df.plot(u_ref, mode="color", cmap=my_cmap)
        divider = make_axes_locatable(ax1)
        cax = divider.append_axes("bottom", size="5%", pad=0.3)
        ax1.grid(False)
        ax1.set_xlim(0, 1)
        ax1.set_ylim(0, 1)
        ax1.set_title("Reference solution", fontsize=15)
        plt.colorbar(img, cax=cax, orientation="horizontal")

        ax2 = plt.subplot(142)
        img = df.plot(u_phi_fem_proj, mode="color", cmap=my_cmap)
        divider = make_axes_locatable(ax2)
        cax = divider.append_axes("bottom", size="5%", pad=0.3)
        ax2.set_xlim(0, 1)
        ax2.set_ylim(0, 1)
        ax2.grid(False)
        ax2.set_title(
            "$\phi$-FEM solution \n$L^2$ relative error : " + f"{l2_error_phi_fem:.5f}",
            fontsize=15,
        )
        plt.colorbar(img, cax=cax, orientation="horizontal")

        ax3 = plt.subplot(143)
        img = df.plot(u_std_fem_proj, mode="color", cmap=my_cmap)
        divider = make_axes_locatable(ax3)
        cax = divider.append_axes("bottom", size="5%", pad=0.3)
        ax3.grid(False)
        ax3.set_xlim(0, 1)
        ax3.set_ylim(0, 1)
        ax3.set_title(
            "Standard FEM solution \n$L^2$ relative error : "
            + f"{l2_error_std_fem:.5f}",
            fontsize=15,
        )
        plt.colorbar(img, cax=cax, orientation="horizontal")

        ax4 = plt.subplot(144)
        img = df.plot(predicted_sol_fenics_proj_V_ref, mode="color", cmap=my_cmap)
        divider = make_axes_locatable(ax4)
        cax = divider.append_axes("bottom", size="5%", pad=0.3)
        ax4.grid(False)
        ax4.set_xlim(0, 1)
        ax4.set_ylim(0, 1)
        ax4.set_title(
            "Predicted solution \n$L^2$ relative error : " + f"{l2_error_fno:.5f}",
            fontsize=15,
        )
        plt.colorbar(img, cax=cax, orientation="horizontal")
        # if Plot and save_figs:
        #     plt.savefig(f"{images_repo}/example_output_FEMs_FNO.png")
        plt.show()

    print(f"{l2_error_fno=}")
    print(f"{l2_error_phi_fem=}")
    print(f"{l2_error_std_fem=}")

    print(f"{_h_std=}")
    print(f"{_h_phi=}")

    return (
        l2_error_phi_fem,
        l2_error_std_fem,
        l2_error_fno,
        cell_selection_phi,
        submesh_construction_phi,
        ghost_cell_selection_phi,
        resolution_time_phi,
        construction_time_standard,
        resolution_time_standard,
        time_fno,
    )

In [None]:
F, phi, G, params, coeffs_ls = create_FG_numpy(
    3, 64, n_mode=3, seed=250124, compare_methods=True
)


i = 0
for i in range(len(params)):
    (
        l2_error_phi_fem,
        l2_error_std_fem,
        l2_error_fno,
        cell_selection_phi,
        submesh_construction_phi,
        ghost_cell_selection_phi,
        resolution_time_phi,
        construction_time_standard,
        resolution_time_standard,
        time_fno,
    ) = compare_std_phi_fem_and_fno(
        np.array([params[i]]), np.array([coeffs_ls[i]]), True
    )
    print(f"{l2_error_fno=}")

In [None]:
save_list_error = True
L2_error_fno, L2_error_phi_fem, L2_error_std_fem = [], [], []
Time_phi, Time_std, Time_fno = [], [], []

F, phi, G, params, coeffs_ls = create_FG_numpy(
    300, 64, n_mode=3, seed=250124, compare_methods=True
)
indices = list(range(0, len(params)))
for index in indices:
    print(f"Iter : {index+1}/{len(params)}")
    (
        l2_error_phi_fem,
        l2_error_std_fem,
        l2_error_fno,
        cell_selection_phi,
        submesh_construction_phi,
        ghost_cell_selection_phi,
        resolution_time_phi,
        construction_time_standard,
        resolution_time_standard,
        time_fno,
    ) = compare_std_phi_fem_and_fno(
        np.array([params[index]]), np.array([coeffs_ls[index]]), Plot=True
    )

    L2_error_phi_fem.append(l2_error_phi_fem)
    Time_phi.append(
        [
            cell_selection_phi,
            submesh_construction_phi,
            ghost_cell_selection_phi,
            resolution_time_phi,
        ]
    )
    L2_error_fno.append(l2_error_fno)
    Time_fno.append(time_fno)

    L2_error_std_fem.append(l2_error_std_fem)
    Time_std.append([construction_time_standard, resolution_time_standard])

    if save_list_error:
        np.save(
            f"{models_repo}/L2_error_phi_fem.npy",
            np.array([L2_error_phi_fem]),
        )
        np.save(
            f"{models_repo}/L2_error_std_fem.npy",
            np.array([L2_error_std_fem]),
        )
        np.save(f"{models_repo}/L2_error_fno.npy", np.array([L2_error_fno]))

In [None]:
if not os.path.exists("./data_test_compare_methods/"):
    os.makedirs("./data_test_compare_methods")
np.save("./data_test_compare_methods/F.npy", F)
np.save("./data_test_compare_methods/Phi.npy", phi)
np.save("./data_test_compare_methods/G.npy", G)
np.save("./data_test_compare_methods/params.npy", params)

In [None]:
L2_error_phi_fem = np.load(f"{models_repo}/L2_error_phi_fem.npy")
L2_error_std_fem = np.load(f"{models_repo}/L2_error_std_fem.npy")
L2_error_fno = np.load(f"{models_repo}/L2_error_fno.npy")

In [None]:
error_tab = []
error_tab.append(L2_error_phi_fem)
error_tab.append(L2_error_std_fem)
error_tab.append(L2_error_fno)

abs_str = [r"$\phi$-FEM", "Standard FEM", r"$\phi$-FEM-FNO "]
errors = np.array(error_tab[:])
errors = np.reshape(errors, (3, np.shape(errors)[-1]))
print(np.shape(errors))
dataframe = pd.DataFrame(errors.transpose(), columns=abs_str)

sns.set(font_scale=1.1)

plt.figure(figsize=(6, 4))
sns.boxplot(
    data=dataframe,
    palette="ch:s=.0,rot=0.0,dark=0.5",
    flierprops={"marker": "x", "markerfacecolor": "black"},
)
plt.yscale("log")
plt.xlabel("Method", fontsize=16)
plt.ylabel("Relative $L^2$ error", fontsize=16)
plt.grid(axis="y", visible=True, which="both")

plt.tight_layout()
if save_figs:
    plt.savefig(f"{images_repo}/boxplots_new_data_compare_method.pdf")
plt.show()

# Compare the total computation times

In [None]:
def compare_std_phi_fem_and_fno_size_error(
    param, coeffs_ls, size_std, size_phi_fem, u_ref, V_ref, dx_ref
):
    solver = PhiFemSolver_error(
        nb_cell=size_phi_fem - 1, params=param, coeffs_ls=coeffs_ls
    )
    (
        u_phi_fem,
        V_phi_fem,
        dx_phi_fem,
        _h_phi,
        cell_selection_phi,
        submesh_construction_phi,
        ghost_cell_selection_phi,
        resolution_time_phi,
    ) = solver.solve_one(0)
    u_phi_fem_proj = df.project(
        u_phi_fem, V_ref, solver_type="gmres", preconditioner_type="hypre_amg"
    )
    l2_error_phi_fem = (
        df.assemble((((u_ref - u_phi_fem_proj)) ** 2) * dx_ref) ** (0.5)
    ) / (df.assemble((((u_ref)) ** 2) * dx_ref) ** (0.5))

    standard_solver = StandardFEMSolver(params=param, coeffs_ls=coeffs_ls)
    (
        u_std,
        _h_std,
        mesh_times_std,
        resolution_time_standard,
    ) = standard_solver.solve_one(0, size_std, size_phi_fem, False)
    u_std_fem_proj = df.project(
        u_std, V_ref, solver_type="gmres", preconditioner_type="hypre_amg"
    )
    l2_error_std_fem = (
        df.assemble((((u_ref - u_std_fem_proj)) ** 2) * dx_ref) ** (0.5)
    ) / (df.assemble((((u_ref)) ** 2) * dx_ref) ** (0.5))

    mu0, mu1, sigma_x, sigma_y, amplitude, alpha, beta = param[0]

    F = generate_F_numpy(mu0, mu1, sigma_x, sigma_y, amplitude, 64)
    G = generate_G_numpy(alpha, beta, 64)
    phi = generate_phi_numpy(coeffs_ls, threshold=0.4, nb_vert=64)
    X = generate_manual_new_data_numpy(F, phi, G).to(device)
    x_normed = data.x_normalizer.encode(X)
    Y_normed = model(x_normed)
    start_call = time.time()
    Y_normed = model(x_normed)
    end_call = time.time()
    time_fno = end_call - start_call
    Y = data.y_normalizer.decode(Y_normed)
    predicted_solution = (
        (X[:, 1, :, :] * Y[:, 0, :, :] + X[:, 2, :, :]).cpu().detach().numpy()
    )
    predicted_solution = np.reshape(predicted_solution, (64, 64))
    predicted_solution_fenics = convert_numpy_matrix_to_fenics(
        predicted_solution, 64, 1
    )
    predicted_solution_fenics_proj_V_ref = df.project(
        predicted_solution_fenics,
        V_ref,
        solver_type="gmres",
        preconditioner_type="hypre_amg",
    )

    l2_error_fno = (
        df.assemble((((u_ref - predicted_solution_fenics_proj_V_ref)) ** 2) * dx_ref)
        ** (0.5)
    ) / (df.assemble((((u_ref)) ** 2) * dx_ref) ** (0.5))

    return (
        l2_error_phi_fem,
        l2_error_std_fem,
        l2_error_fno,
        cell_selection_phi,
        submesh_construction_phi,
        ghost_cell_selection_phi,
        resolution_time_phi,
        mesh_times_std,
        resolution_time_standard,
        time_fno,
    )

In [None]:
seed = 260124
F, phi, G, params, coeffs_ls = create_FG_numpy(
    20, 64, n_mode=3, seed=seed, compare_methods=True
)
if not os.path.exists("./data_test_compare_methods/"):
    os.makedirs("./data_test_compare_methods")
np.save("./data_test_compare_methods/F_convergence.npy", F)
np.save("./data_test_compare_methods/Phi_convergence.npy", phi)
np.save("./data_test_compare_methods/G_convergence.npy", G)
np.save("./data_test_compare_methods/params_convergence.npy", params)

In [None]:
sizes_phi_fem = [16, 32, 64, 128, 256]
sizes_std = []
for size in sizes_phi_fem:
    mesh_tmp = df.UnitSquareMesh(size - 1, size - 1)
    sizes_std.append(mesh_tmp.hmax())
print(
    sizes_std
)  # [0.09428090415820647, 0.045619792334616084, 0.02244783432338254, 0.01113553986120561, 0.005545935538718157]

In [None]:
(
    L2_error_phi_fem_array,
    L2_error_std_fem_array,
    L2_error_fno_array,
    Cell_selection_phi_array,
    Submesh_construction_phi_array,
    Ghost_cell_selection_phi_array,
    Resolution_time_phi_array,
    Construction_time_standard_array,
    Resolution_time_standard_array,
    Time_fno_array,
) = (
    [],
    [],
    [],
    [],
    [],
    [],
    [],
    [],
    [],
    [],
)

index = 0
u_refs, V_refs, dx_refs = [], [], []

for index in range(len(params)):
    standard_solver = StandardFEMSolver(
        params=np.array([params[index]]), coeffs_ls=np.array([coeffs_ls[index]])
    )
    u_ref, V_ref, dx_ref = standard_solver.solve_one(0, 0.003, reference_fem=True)
    u_refs.append(u_ref)
    V_refs.append(V_ref)
    dx_refs.append(dx_ref)
    sizes_phi_fem = [16, 32, 64, 128, 256]
    sizes_std = [
        0.09428090415820647,
        0.045619792334616084,
        0.02244783432338254,
        0.01113553986120561,
        0.005545935538718157,
    ]

    (
        L2_error_phi_fem,
        L2_error_std_fem,
        L2_error_fno,
        Cell_selection_phi,
        Submesh_construction_phi,
        Ghost_cell_selection_phi,
        Resolution_time_phi,
        Construction_time_standard,
        Resolution_time_standard,
        Time_fno,
    ) = (
        [],
        [],
        [],
        [],
        [],
        [],
        [],
        [],
        [],
        [],
    )

    for i in range(len(sizes_phi_fem)):
        print(f"Param : {index}/{len(params)}    {sizes_phi_fem[i]=}")
        (
            l2_error_phi_fem,
            l2_error_std_fem,
            l2_error_fno,
            cell_selection_phi,
            submesh_construction_phi,
            ghost_cell_selection_phi,
            resolution_time_phi,
            construction_time_standard,
            resolution_time_standard,
            time_fno,
        ) = compare_std_phi_fem_and_fno_size_error(
            np.array([params[index]]),
            np.array([coeffs_ls[index]]),
            sizes_std[i],
            sizes_phi_fem[i],
            u_ref,
            V_ref,
            dx_ref,
        )

        print(f"{l2_error_phi_fem=:.3e}")
        print(f"{l2_error_std_fem=:.3e}")
        print(f"{l2_error_fno=:.3e}")

        L2_error_phi_fem.append(l2_error_phi_fem)
        L2_error_std_fem.append(l2_error_std_fem)
        L2_error_fno.append(l2_error_fno)
        Cell_selection_phi.append(cell_selection_phi)
        Submesh_construction_phi.append(submesh_construction_phi)
        Ghost_cell_selection_phi.append(ghost_cell_selection_phi)
        Resolution_time_phi.append(resolution_time_phi)
        Construction_time_standard.append(construction_time_standard)
        Resolution_time_standard.append(resolution_time_standard)
        Time_fno.append(time_fno)

    L2_error_phi_fem_array.append(L2_error_phi_fem)
    L2_error_std_fem_array.append(L2_error_std_fem)
    L2_error_fno_array.append(L2_error_fno)
    Cell_selection_phi_array.append(Cell_selection_phi)
    Submesh_construction_phi_array.append(Submesh_construction_phi)
    Ghost_cell_selection_phi_array.append(Ghost_cell_selection_phi)
    Resolution_time_phi_array.append(Resolution_time_phi)
    Construction_time_standard_array.append(Construction_time_standard)
    Resolution_time_standard_array.append(Resolution_time_standard)
    Time_fno_array.append(Time_fno)

In [None]:
L2_error_phi_fem_array = np.array(L2_error_phi_fem_array)
L2_error_std_fem_array = np.array(L2_error_std_fem_array)
L2_error_fno_array = np.array(L2_error_fno_array)
Cell_selection_phi_array = np.array(Cell_selection_phi_array)
Submesh_construction_phi_array = np.array(Submesh_construction_phi_array)
Ghost_cell_selection_phi_array = np.array(Ghost_cell_selection_phi_array)
Resolution_time_phi_array = np.array(Resolution_time_phi_array)
Construction_time_standard_array = np.array(Construction_time_standard_array)
Resolution_time_standard_array = np.array(Resolution_time_standard_array)
Time_fno_array = np.array(Time_fno_array)


if not os.path.exists("./compare_methods/"):
    os.makedirs("./compare_methods")

np.save("./compare_methods/L2_error_phi_fem_array.npy", L2_error_phi_fem_array)
np.save("./compare_methods/L2_error_std_fem_array.npy", L2_error_std_fem_array)
np.save("./compare_methods/L2_error_fno_array.npy", L2_error_fno_array)
np.save("./compare_methods/Cell_selection_phi_array.npy", Cell_selection_phi_array)
np.save(
    "./compare_methods/Submesh_construction_phi_array.npy",
    Submesh_construction_phi_array,
)
np.save(
    "./compare_methods/Ghost_cell_selection_phi_array.npy",
    Ghost_cell_selection_phi_array,
)
np.save("./compare_methods/Resolution_time_phi_array.npy", Resolution_time_phi_array)
np.save(
    "./compare_methods/Construction_time_standard_array.npy",
    Construction_time_standard_array,
)
np.save(
    "./compare_methods/Resolution_time_standard_array.npy",
    Resolution_time_standard_array,
)
np.save("./compare_methods/Time_fno_array.npy", Time_fno_array)

In [None]:
L2_error_phi_fem_array = np.load("./compare_methods/L2_error_phi_fem_array.npy")
L2_error_std_fem_array = np.load("./compare_methods/L2_error_std_fem_array.npy")
L2_error_fno_array = np.load("./compare_methods/L2_error_fno_array.npy")
Cell_selection_phi_array = np.load("./compare_methods/Cell_selection_phi_array.npy")
Submesh_construction_phi_array = np.load(
    "./compare_methods/Submesh_construction_phi_array.npy"
)
Ghost_cell_selection_phi_array = np.load(
    "./compare_methods/Ghost_cell_selection_phi_array.npy"
)
Resolution_time_phi_array = np.load("./compare_methods/Resolution_time_phi_array.npy")
Construction_time_standard_array = np.load(
    "./compare_methods/Construction_time_standard_array.npy"
)
Resolution_time_standard_array = np.load(
    "./compare_methods/Resolution_time_standard_array.npy"
)
Time_fno_array = np.load("./compare_methods/Time_fno_array.npy")

In [None]:
sizes_phi_fem_h = []
for size in sizes_phi_fem:
    mesh_macro = df.UnitSquareMesh(size - 1, size - 1)
    h_macro = mesh_macro.hmax()
    sizes_phi_fem_h.append(h_macro)

size_fno = df.UnitSquareMesh(64 - 1, 64 - 1).hmax()

In [None]:
mean_errors_L2_phi_fem = np.mean(L2_error_phi_fem_array, axis=0)
mean_errors_L2_std_fem = np.mean(L2_error_std_fem_array, axis=0)
mean_errors_L2_FNO = np.mean(L2_error_fno_array, axis=0)

std_errors_L2_phi_fem = np.std(L2_error_phi_fem_array, axis=0)
std_errors_L2_std_fem = np.std(L2_error_std_fem_array, axis=0)
std_errors_L2_FNO = np.std(L2_error_fno_array, axis=0)

min_errors_L2_phi_fem = np.min(L2_error_phi_fem_array, axis=0)
min_errors_L2_std_fem = np.min(L2_error_std_fem_array, axis=0)
min_errors_L2_FNO = np.min(L2_error_fno_array, axis=0)

max_errors_L2_phi_fem = np.max(L2_error_phi_fem_array, axis=0)
max_errors_L2_std_fem = np.max(L2_error_std_fem_array, axis=0)
max_errors_L2_FNO = np.max(L2_error_fno_array, axis=0)

palette = sns.cubehelix_palette(n_colors=4, start=0.25, rot=-0.25, gamma=0.5)
palette = palette.as_hex()
palette.insert(-1, "#b22222")
palette = sns.color_palette(palette)

center = [size_fno, mean_errors_L2_FNO[-1]]

plt.figure(figsize=(6, 4))
plt.loglog(sizes_std, mean_errors_L2_std_fem, "-+", markersize=8, label="Std-FEM")
ci_error_std_FEM = std_errors_L2_std_fem
plt.fill_between(
    sizes_std,
    mean_errors_L2_std_fem - ci_error_std_FEM,
    mean_errors_L2_std_fem + ci_error_std_FEM,
    alpha=0.2,
)
plt.loglog(
    sizes_phi_fem_h,
    mean_errors_L2_phi_fem,
    "-+",
    markersize=8,
    label=r"$\phi$-FEM",
)
ci_error_phi_fem = std_errors_L2_phi_fem

plt.fill_between(
    sizes_phi_fem_h,
    mean_errors_L2_phi_fem - ci_error_phi_fem,
    mean_errors_L2_phi_fem + ci_error_phi_fem,
    alpha=0.2,
)
plt.plot(
    size_fno,
    mean_errors_L2_FNO[-1],
    "x",
    markersize=8,
    label=r"$\phi$-FEM-FNO",
    color=palette[-2],
)

ci = [
    std_errors_L2_FNO[-1],
]
plt.errorbar(
    size_fno,
    mean_errors_L2_FNO[-1],
    yerr=ci[-1],
    ecolor=palette[-2],
    capsize=7,
    elinewidth=1.5,
    capthick=1.5,
    fmt="none",
    alpha=0.7,
)

plt.legend(fontsize=12, loc="lower right", ncol=1)
plt.xlabel("$h$", fontsize=16)
plt.ylabel("Relative $L^2$ error", fontsize=16)
plt.grid(axis="y", visible=True, which="both")
plt.ylim(1.3e-5, 5.2e-2)
plt.tight_layout()
plt.savefig(f"./{images_repo}/error_size_3_methods.pdf")
plt.show()

In [None]:
Total_time_phi_fem_array = np.zeros((Cell_selection_phi_array.shape))
Total_time_std_fem_array = np.zeros((Construction_time_standard_array.shape))

Total_time_phi_fem_array = (
    Cell_selection_phi_array
    + Submesh_construction_phi_array
    + Ghost_cell_selection_phi_array
    + Resolution_time_phi_array
)
Total_time_std_fem_array = (
    np.sum(Construction_time_standard_array[:, :, :-3], axis=2)
    + Resolution_time_standard_array
)  # removes : total construction time, fenics read time and conversion time

In [None]:
mean_times_L2_phi_fem = np.mean(Total_time_phi_fem_array, axis=0)
mean_times_L2_std_fem = np.mean(Total_time_std_fem_array, axis=0)
mean_times_L2_FNO = np.mean(Time_fno_array, axis=0)

std_times_L2_phi_fem = np.std(Total_time_phi_fem_array, axis=0)
std_times_L2_std_fem = np.std(Total_time_std_fem_array, axis=0)
std_times_L2_FNO = np.std(Time_fno_array, axis=0)

min_times_L2_phi_fem = np.min(Total_time_phi_fem_array, axis=0)
min_times_L2_std_fem = np.min(Total_time_std_fem_array, axis=0)
min_times_L2_FNO = np.min(Time_fno_array, axis=0)

max_times_L2_phi_fem = np.max(Total_time_phi_fem_array, axis=0)
max_times_L2_std_fem = np.max(Total_time_std_fem_array, axis=0)
max_times_L2_FNO = np.max(Time_fno_array, axis=0)

fig, ax = plt.subplots(1, 1, figsize=(6, 4))
ci_error_std_FEM = std_errors_L2_std_fem
ci_time_std_FEM = std_times_L2_std_fem

ax.loglog(
    mean_times_L2_std_fem,
    mean_errors_L2_std_fem,
    "-+",
    markersize=8,
    label="Std-FEM",
)
ax.fill_between(
    mean_times_L2_std_fem,
    mean_errors_L2_std_fem - ci_error_std_FEM,
    mean_errors_L2_std_fem + ci_error_std_FEM,
    alpha=0.2,
)
ax.errorbar(
    mean_times_L2_std_fem,
    mean_errors_L2_std_fem,
    xerr=ci_time_std_FEM,
    ecolor=sns.color_palette()[0],
    capsize=7,
    elinewidth=1.5,
    capthick=1.5,
    fmt="none",
    alpha=0.7,
)

ax.loglog(
    mean_times_L2_phi_fem,
    mean_errors_L2_phi_fem,
    "-+",
    markersize=8,
    label=r"$\phi$-FEM",
)

ci_error_phi_FEM = std_errors_L2_phi_fem
ci_time_phi_FEM = std_times_L2_phi_fem

ax.fill_between(
    mean_times_L2_phi_fem,
    mean_errors_L2_phi_fem - ci_error_phi_FEM,
    mean_errors_L2_phi_fem + ci_error_phi_FEM,
    alpha=0.2,
)
ax.errorbar(
    mean_times_L2_phi_fem,
    mean_errors_L2_phi_fem,
    xerr=ci_time_phi_FEM,
    ecolor=sns.color_palette()[1],
    capsize=7,
    elinewidth=1.5,
    capthick=1.5,
    fmt="none",
    alpha=0.7,
)

ax.plot(
    mean_times_L2_FNO[-1],
    mean_errors_L2_FNO[-1],
    "x",
    markersize=8,
    label=r"$\phi$-FEM-FNO",
    color=palette[-2],
)
confidence_ellipse(
    Time_fno_array[:, -1],
    L2_error_fno_array[:, -1],
    ax,
    alpha=0.8,
    facecolor=sns.color_palette("pastel")[3],
    edgecolor=sns.color_palette("pastel")[3],
)

ax.legend(fontsize=12, loc="lower left", ncol=2)
ax.set_xlabel("Computation time (s)", fontsize=16)
ax.set_ylabel("Relative $L^2$ error", fontsize=16)
plt.grid(axis="y", visible=True, which="both")
plt.ylim(1.3e-5, 5.2e-2)
plt.grid(axis="x", visible=True, which="major")

plt.tight_layout()
plt.savefig(f"./{images_repo}/error_time_3_methods.pdf")
plt.show()