In [1]:
import numpy as np
import matplotlib.pyplot as plt
import random
import os
import torch 
import vtk
import pyvista
pyvista.global_theme.notebook = True
pyvista.set_jupyter_backend('static')
# pyvista.set_jupyter_backend('trame')

seed = 160318
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
import time
from utils import *
from utils_training import *
import prepare_data
from utils_compare_methods import *
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
from prepare_data import rotate, outside_ball, create_parameters
import pandas as pd
import gc
from utils_plot import * 
sns.set_theme("paper", rc={"xtick.bottom": True, "ytick.left": True}, font_scale=1.1)
colors = sns.color_palette("mako").as_hex()
my_cmap = sns.color_palette("viridis", as_cmap=True)

%load_ext autoreload
%autoreload 2


In [None]:
save_figs = True
small_data = False
data = DataLoader(small_data)
agent = Agent(data, level=1)

model = agent.model
device = agent.device

models_repo = "./models"
images_repo = "../images"
results_repo = "../results"

best_model = torch.load(f"{models_repo}/best_model.pkl")
model.load_state_dict(best_model["model_state_dict"])
model.eval()

print(f"Best epoch = {best_model['epoch']}")
if not (os.path.exists(f"{images_repo}/")) and save_figs:
    os.makedirs(f"{images_repo}/")

if not (os.path.exists(f"{results_repo}/")):
    os.makedirs(f"{results_repo}/")

# Compare the errors of $\phi$-FEM, standard FEM and FNO with respect to a reference fine standard FEM solution

In [3]:
def error_L2(uh, u_ex, degree_raise=3):
    degree = uh.function_space.ufl_element().degree
    family = uh.function_space.ufl_element().family_name
    mesh = uh.function_space.mesh
    W = dolfinx.fem.functionspace(mesh, (family, degree + degree_raise))

    u_W = dolfinx.fem.Function(W)
    u_W.interpolate(uh)
    u_ex_W = dolfinx.fem.Function(W)
    u_ex_W.interpolate(u_ex)

    e_W = dolfinx.fem.Function(W)
    e_W.x.array[:] = u_W.x.array - u_ex_W.x.array

    # Integrate the error
    error = dolfinx.fem.form(ufl.inner(e_W, e_W) * ufl.dx)
    error_local = dolfinx.fem.assemble_scalar(error)
    error_global = mesh.comm.allreduce(error_local, op=MPI.SUM)
    norm = dolfinx.fem.form(ufl.inner(u_ex_W, u_ex_W) * ufl.dx)
    norm_local = dolfinx.fem.assemble_scalar(norm)
    norm_global = mesh.comm.allreduce(norm_local, op=MPI.SUM)
    return np.sqrt(error_global / norm_global)


def non_matching_interpolation(uh, V_to, padding=1e-14):
    u_to = dolfinx.fem.Function(V_to)

    u1_2_u2_nmm_data = dolfinx.fem.create_nonmatching_meshes_interpolation_data(
        u_to.function_space.mesh,
        u_to.function_space.element,
        uh.function_space.mesh,
        padding=padding,
    )
    u_to.interpolate(uh, nmm_interpolation_data=u1_2_u2_nmm_data)
    u_to.x.scatter_forward()

    return u_to


def compute_domain(X):
    loss = Loss()
    domain_tmp = (X[:, 1, :, :] <= 3e-16).to(device)
    neighborhood = loss.neighborhood_6(domain_tmp).to(device)
    domain = ((neighborhood.int() + domain_tmp.int()) != 0).to(device)
    return domain

In [4]:
def compare_error_unet(param, Plot=False, screenshot=False):
    standard_solver = StandardFEMSolver(params=param)
    mu0, mu1, sigma_x, sigma_y, amplitude, x_0, y_0, lx, ly, theta, alpha, beta = param[
        0
    ]

    u_ref, V_ref, dx_ref, h_ref = standard_solver.solve_one(
        0, 0.005, reference_fem=True
    )

    F = generate_F_numpy(mu0, mu1, sigma_x, sigma_y, amplitude, 64)
    G = generate_G_numpy(alpha, beta, 64)
    phi = generate_phi_numpy(x_0, y_0, lx, ly, theta, 64)
    X = generate_manual_new_data_numpy(F, phi, G).to(device)

    # start phi-FEM-FNO
    x_normed = data.x_normalizer.encode(X)
    start = time.time()
    Y_normed = model(x_normed)
    end = time.time()
    Y = data.y_normalizer.decode(Y_normed)
    predicted_solution = (
        (X[:, 1, :, :] * Y[:, 0, :, :] + X[:, 2, :, :]).cpu().detach().numpy()
    )[0, :, :]
    predicted_sol_fenics = convert_numpy_matrix_to_fenicsx(predicted_solution, 64, 1)
    predicted_sol_fenics.x.scatter_forward()
    predicted_sol_fenics_proj_V_ref = non_matching_interpolation(
        predicted_sol_fenics, V_ref
    )
    l2_error_unet = error_L2(predicted_sol_fenics_proj_V_ref, u_ref)
    # end phi-FEM-FNO

    print(f"{l2_error_unet=:.3e}")

    return l2_error_unet, end - start

In [None]:
save_list_error = True

params = np.load("../../data_test/params.npy")
indices = list(range(0, 300))
L2_error_unet = []
Times_unet = []
for index in indices:
    print(f"Iter : {index+1}/{len(indices)}")
    l2_error_unet, time_unet = compare_error_unet(np.array([params[index]]), Plot=False)

    L2_error_unet.append(l2_error_unet)
    Times_unet.append(time_unet)
    if save_list_error:
        np.save(f"{results_repo}/L2_error_phi_fem_unet.npy", np.array([L2_error_unet]))
        np.save(f"{results_repo}/Times_phi_fem_unet.npy", np.array([Times_unet]))

In [None]:
L2_error_fno = np.load(f"{results_repo}/L2_error_phi_fem_fno.npy")
L2_error_phi_fem = np.load(f"{results_repo}/L2_error_phi_fem.npy")
L2_error_std_fem = np.load(f"{results_repo}/L2_error_std_fem.npy")
L2_error_std_fno = np.load(f"{results_repo}/L2_error_std_fem_fno.npy")
L2_error_fno_pred_u = np.load(f"{results_repo}/L2_error_phi_fem_fno_2.npy")
L2_error_geo_fno = np.load(f"{results_repo}/L2_error_geo_fno.npy")
L2_error_phi_fem_unet = np.load(f"{results_repo}/L2_error_phi_fem_unet.npy")

error_tab = []
error_tab.append(L2_error_phi_fem)
error_tab.append(L2_error_std_fem)
error_tab.append(L2_error_fno)
error_tab.append(L2_error_fno_pred_u)
error_tab.append(L2_error_std_fno)
error_tab.append(L2_error_geo_fno)
error_tab.append(L2_error_phi_fem_unet)
abs_str = [
    r"$\varphi$-FEM",
    "\nStd FEM",
    r"$\varphi$-FEM-FNO",
    "\n" + r"$\varphi$-FEM-FNO 2",
    "Std-FEM-FNO",
    "\nGeo-FNO",
    r"$\varphi$-FEM-UNET",
]
errors = np.array(error_tab[:])
errors = np.reshape(errors, (7, np.shape(errors)[-1]))
print(np.shape(errors))


palette = sns.cubehelix_palette(n_colors=6, start=0.25, rot=-0.25, gamma=0.5)
palette = palette.as_hex()
palette.insert(-4, "#b22222")
palette = sns.color_palette(palette)

dataframe = pd.DataFrame(errors.transpose(), columns=abs_str)

plt.figure(figsize=(6, 4))
sns.boxplot(
    data=dataframe,
    palette=palette,
    flierprops={"marker": "x", "markerfacecolor": "black"},
)
plt.yscale("log")
# plt.xlabel("Method", fontsize=16)
plt.ylabel("Relative $L^2$ error", fontsize=16)
plt.grid(axis="y", visible=True, which="both")
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)

plt.tight_layout()
if save_figs:
    plt.savefig(f"{images_repo}/boxplots_new_data_compare_all_methods_unet.pdf")
plt.show()

In [None]:
dataframe.describe()

In [22]:
L2_error_geo_fno_array = np.load(f"./{results_repo}/L2_error_geo_fno.npy")[
    :, 1:101
].flatten()
Time_geo_fno_array = np.load(f"./{results_repo}/Times_geo_fno.npy")[:, 1:101].flatten()

L2_error_fno_std_array = np.load(f"./{results_repo}/L2_error_std_fem_fno.npy")[
    :, 1:101
].flatten()
Time_fno_std_array = np.load(f"./{results_repo}/Times_std_fem_fno.npy")[
    :, 1:101
].flatten()

L2_error_fno_pred_u_array = np.load(f"./{results_repo}/L2_error_phi_fem_fno_2.npy")[
    :, 1:101
].flatten()
Time_fno_pred_u_array = np.load(f"./{results_repo}/Times_phi_fem_fno_2.npy")[
    :, 1:101
].flatten()

L2_error_phi_fem_unet_array = np.load(f"./{results_repo}/L2_error_phi_fem_unet.npy")[
    :, 1:101
].flatten()
Time_phi_fem_unet_array = np.load(f"./{results_repo}/Times_phi_fem_unet.npy")[
    :, 1:101
].flatten()

L2_error_phi_fem_array = np.load(f"./{results_repo}/L2_error_phi_fem.npy")[
    :, 1:101
].flatten()
Time_phi_fem_array = np.load(f"./{results_repo}/Times_phi_fem.npy")[:, 1:101].flatten()

L2_error_std_fem_array = np.load(f"./{results_repo}/L2_error_std_fem.npy")[
    :, 1:101
].flatten()
Time_std_fem_array = np.load(f"./{results_repo}/Times_std_fem.npy")[:, 1:101].flatten()

L2_error_fno_array = np.load(f"./{results_repo}/L2_error_phi_fem_fno.npy")[
    :, 1:101
].flatten()
Time_fno_array = np.load(f"./{results_repo}/Times_phi_fem_fno.npy")[:, 1:101].flatten()

In [None]:
mean_errors_L2_phi_fem = np.mean(L2_error_phi_fem_array)
mean_errors_L2_std_fem = np.mean(L2_error_std_fem_array)
mean_errors_L2_FNO = np.mean(L2_error_fno_array)
mean_errors_L2_fno_pred_u = np.mean(L2_error_fno_pred_u_array)
mean_errors_L2_fno_std = np.mean(L2_error_fno_std_array)
mean_errors_L2_geo_fno = np.mean(L2_error_geo_fno_array)
mean_errors_L2_phi_fem_unet = np.mean(L2_error_phi_fem_unet_array)

std_errors_L2_phi_fem = np.std(L2_error_phi_fem_array)
std_errors_L2_std_fem = np.std(L2_error_std_fem_array)
std_errors_L2_FNO = np.std(L2_error_fno_array)
std_errors_L2_fno_pred_u = np.std(L2_error_fno_pred_u_array)
std_errors_L2_fno_std = np.std(L2_error_fno_std_array)
std_errors_L2_geo_fno = np.std(L2_error_geo_fno_array)
std_errors_L2_phi_fem_unet = np.std(L2_error_phi_fem_unet_array)

mean_times_phi_fem = np.mean(Time_phi_fem_array)
mean_times_std_fem = np.mean(Time_std_fem_array)
mean_times_FNO = np.mean(Time_fno_array)
mean_times_fno_pred_u = np.mean(Time_fno_pred_u_array)
mean_times_fno_std = np.mean(Time_fno_std_array)
mean_times_geo_fno = np.mean(Time_geo_fno_array)
mean_times_phi_fem_unet = np.mean(Time_phi_fem_unet_array)

std_times_phi_fem = np.std(Time_phi_fem_array)
std_times_std_fem = np.std(Time_std_fem_array)
std_times_FNO = np.std(Time_fno_array)
std_times_fno_pred_u = np.std(Time_fno_pred_u_array)
std_times_fno_std = np.std(Time_fno_std_array)
std_times_geo_fno = np.std(Time_geo_fno_array)
std_times_phi_fem_unet = np.std(Time_phi_fem_unet_array)

palette = sns.color_palette("Paired")
fig, ax = plt.subplots(1, 1, figsize=(6, 4))

# phi-FEM (start)
ax.plot(
    mean_times_phi_fem,
    mean_errors_L2_phi_fem,
    "1",
    markersize=12,
    label=r"$\varphi$-FEM",
    color=palette[1],
)
confidence_ellipse(
    Time_phi_fem_array,
    L2_error_phi_fem_array,
    ax,
    alpha=0.5,
    facecolor=palette[1],
    edgecolor=palette[1],
)
# phi-FEM (end)

# std-FEM (start)
ax.plot(
    mean_times_std_fem,
    mean_errors_L2_std_fem,
    "2",
    markersize=12,
    label=r"Std-FEM",
    color=palette[7],
)
confidence_ellipse(
    Time_std_fem_array,
    L2_error_std_fem_array,
    ax,
    alpha=0.5,
    facecolor=palette[7],
    edgecolor=palette[7],
)
# std-FEM (end)


# phi-FEM-FNO (start)
ax.plot(
    mean_times_FNO,
    mean_errors_L2_FNO,
    "x",
    markersize=8,
    label=r"$\varphi$-FEM-FNO",
    color=palette[5],
)
confidence_ellipse(
    Time_fno_array,
    L2_error_fno_array,
    ax,
    alpha=0.5,
    facecolor=palette[5],
    edgecolor=palette[5],
)
# phi-FEM-FNO (end)

# phi-FEM-FNO-2 (start)
ax.plot(
    mean_times_fno_pred_u,
    mean_errors_L2_fno_pred_u,
    "*",
    markersize=8,
    label=r"$\varphi$-FEM-FNO 2",
    color="k",  # palette[3],
)
confidence_ellipse(
    Time_fno_pred_u_array,
    L2_error_fno_pred_u_array,
    ax,
    alpha=0.3,
    facecolor="k",  # palette[3],
    edgecolor="k",  # palette[3],
)
# phi-FEM-FNO-2 (end)

# Std-FEM-FNO (start)
ax.plot(
    mean_times_fno_std,
    mean_errors_L2_fno_std,
    "+",
    markersize=8,
    label=r"Std-FEM-FNO",
    color=palette[9],
)
confidence_ellipse(
    Time_fno_std_array,
    L2_error_fno_std_array,
    ax,
    alpha=0.4,
    facecolor=palette[9],
    edgecolor=palette[9],
)
# Std-FEM-FNO (end)

# Geo-FNO (start)
ax.plot(
    mean_times_geo_fno,
    mean_errors_L2_geo_fno,
    "3",
    markersize=12,
    label=r"Geo-FNO",
    color=palette[11],
)
confidence_ellipse(
    Time_geo_fno_array,
    L2_error_geo_fno_array,
    ax,
    alpha=0.5,
    facecolor=palette[11],
    edgecolor=palette[11],
)
# Geo-FNO (end)

# phi-FEM-UNET (start)
ax.plot(
    mean_times_phi_fem_unet,
    mean_errors_L2_phi_fem_unet,
    "4",
    markersize=12,
    label=r"$\varphi$-FEM-UNET",
    color=palette[3],
)
confidence_ellipse(
    Time_phi_fem_unet_array,
    L2_error_phi_fem_unet_array,
    ax,
    alpha=0.5,
    facecolor=palette[3],
    edgecolor=palette[3],
)
# phi-FEM-UNET (end)

ax.legend(fontsize=12, loc="upper right", ncol=2)
ax.set_xlabel("Computation time (s)", fontsize=16)
ax.set_ylabel("Relative $L^2$ error", fontsize=16)
plt.grid(axis="y", visible=True, which="both")
plt.grid(axis="x", visible=True, which="major")
plt.xscale("log")
plt.yscale("log")
plt.tight_layout()
plt.savefig(f"./{images_repo}/error_time_all_methods.pdf")
plt.show()