In [1]:
import numpy as np
import matplotlib.pyplot as plt
import random
import os
import torch 
import vtk
import pyvista
pyvista.global_theme.notebook = True
# pyvista.set_jupyter_backend('static')
pyvista.set_jupyter_backend('trame')

seed = 231024
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
import time
from utils import *
from  utils_training import *
import prepare_data
from utils_compare_methods import *
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pandas as pd
import gc
from utils_plot import * 

sns.set_theme("paper", rc={"xtick.bottom": True, "ytick.left": True}, font_scale=1.1)
colors = sns.color_palette("mako").as_hex()
my_cmap = sns.color_palette("viridis", as_cmap=True)

%load_ext autoreload
%autoreload 2


In [None]:
save_figs = True
small_data = False
data = DataLoader(small_data)
agent = Agent(data, l2_lambda=1e-7, initial_lr=5e-4, level=1)

model = agent.model
device = agent.device
images_repo = "../images"
models_repo = "./models"
results_repo = "../results"

best_model = torch.load(f"{models_repo}/best_model.pkl")
model.load_state_dict(best_model["model_state_dict"])
model.eval()

print(f"Best epoch std = {best_model['epoch']}")
if not (os.path.exists(f"{images_repo}/")) and save_figs:
    os.makedirs(f"{images_repo}/")
if not (os.path.exists(f"{results_repo}/")):
    os.makedirs(f"{results_repo}/")

# Compare the errors of $\phi$-FEM, standard FEM and FNO with respect to a reference fine standard FEM solution

In [3]:
def error_L2(uh, u_ex, degree_raise=3):
    degree = uh.function_space.ufl_element().degree
    family = uh.function_space.ufl_element().family_name
    mesh = uh.function_space.mesh
    W = dolfinx.fem.functionspace(mesh, (family, degree + degree_raise))

    u_W = dolfinx.fem.Function(W)
    u_W.interpolate(uh)
    u_ex_W = dolfinx.fem.Function(W)
    u_ex_W.interpolate(u_ex)

    e_W = dolfinx.fem.Function(W)
    e_W.x.array[:] = u_W.x.array - u_ex_W.x.array

    # Integrate the error
    error = dolfinx.fem.form(ufl.inner(e_W, e_W) * ufl.dx)
    error_local = dolfinx.fem.assemble_scalar(error)
    error_global = mesh.comm.allreduce(error_local, op=MPI.SUM)
    norm = dolfinx.fem.form(ufl.inner(u_ex_W, u_ex_W) * ufl.dx)
    norm_local = dolfinx.fem.assemble_scalar(norm)
    norm_global = mesh.comm.allreduce(norm_local, op=MPI.SUM)
    return np.sqrt(error_global / norm_global)


def non_matching_interpolation(uh, V_to, padding=1e-14):
    u_to = dolfinx.fem.Function(V_to)

    u1_2_u2_nmm_data = dolfinx.fem.create_nonmatching_meshes_interpolation_data(
        u_to.function_space.mesh,
        u_to.function_space.element,
        uh.function_space.mesh,
        padding=padding,
    )
    u_to.interpolate(uh, nmm_interpolation_data=u1_2_u2_nmm_data)
    u_to.x.scatter_forward()

    return u_to

In [4]:
def compute_fno_error(param, Plot=False, screenshot=False):
    standard_solver = StandardFEMSolver(params=param)
    mu0, mu1, sigma_x, sigma_y, amplitude, x_0, y_0, lx, ly, theta, alpha, beta = param[
        0
    ]

    u_ref, V_ref, dx_ref, h_ref = standard_solver.solve_one(
        0, 0.005, reference_fem=True
    )

    F = generate_F_numpy(mu0, mu1, sigma_x, sigma_y, amplitude, 64)
    G = generate_G_numpy(alpha, beta, 64)
    phi = generate_phi_numpy(x_0, y_0, lx, ly, theta, 64)
    X = generate_manual_new_data_numpy(F, phi, G).to(device)

    # start std-FEM-FNO
    x_normed = data.x_normalizer.encode(X)
    start = time.time()
    Y_normed = model(x_normed)
    end = time.time()
    time_call = end - start
    Y = data.y_normalizer.decode(Y_normed)
    predicted_solution = (Y[:, 0, :, :]).cpu().detach().numpy()
    predicted_sol_fenics = convert_numpy_matrix_to_fenicsx(predicted_solution, 64, 1)
    predicted_sol_fenics_proj_V_ref = non_matching_interpolation(
        predicted_sol_fenics, V_ref
    )
    l2_error_fno = error_L2(predicted_sol_fenics_proj_V_ref, u_ref)
    # end std-FEM-FNO

    print(f"{l2_error_fno=:.3e}")

    return l2_error_fno, time_call

In [None]:
save_list_error = True

params = np.load("../../data_test/params.npy")

indices = list(range(0, 300))
L2_error_fno_std = []
Times_fno_std = []
for index in indices:
    print(f"Iter : {index+1}/{len(indices)}")
    l2_error_fno_std, time_fno_std = compute_fno_error(
        np.array([params[index]]), Plot=False
    )

    L2_error_fno_std.append(l2_error_fno_std)
    Times_fno_std.append(time_fno_std)
    if save_list_error:
        np.save(
            f"{results_repo}/L2_error_std_fem_fno.npy", np.array([L2_error_fno_std])
        )
        np.save(f"{results_repo}/Times_std_fem_fno.npy", np.array([Times_fno_std]))
L2_error_fno_std = np.array([L2_error_fno_std])
Times_fno_std = np.array([Times_fno_std])

In [None]:
L2_error_phi_fem_fno = np.load(f"{results_repo}/L2_error_phi_fem_fno.npy")
L2_error_phi_fem = np.load(f"{results_repo}/L2_error_phi_fem.npy")
L2_error_std_fem = np.load(f"{results_repo}/L2_error_std_fem.npy")
L2_error_std_fem_fno = np.load(f"{results_repo}/L2_error_std_fem_fno.npy")
L2_error_phi_fem_fno_2 = np.load(f"{results_repo}/L2_error_phi_fem_fno_2.npy")
error_tab = []
error_tab.append(L2_error_phi_fem)
error_tab.append(L2_error_std_fem)
error_tab.append(L2_error_phi_fem_fno)
error_tab.append(L2_error_phi_fem_fno_2)
error_tab.append(L2_error_std_fem_fno)
abs_str = [
    r"$\varphi$-FEM",
    "Standard FEM",
    r"$\varphi$-FEM-FNO",
    r"$\varphi$-FEM-FNO 2",
    "Std-FEM-FNO",
]
errors = np.array(error_tab[:])
errors = np.reshape(errors, (5, np.shape(errors)[-1]))
print(np.shape(errors))


palette = sns.cubehelix_palette(n_colors=4, start=0.25, rot=-0.25, gamma=0.5)
palette = palette.as_hex()
palette.insert(-2, "#b22222")
palette = sns.color_palette(palette)

dataframe = pd.DataFrame(errors.transpose(), columns=abs_str)

plt.figure(figsize=(6, 4))
sns.boxplot(
    data=dataframe,
    palette=palette,
    flierprops={"marker": "x", "markerfacecolor": "black"},
)
plt.yscale("log")
plt.xlabel("Method", fontsize=16)
plt.ylabel("Relative $L^2$ error", fontsize=16)
plt.grid(axis="y", visible=True, which="both")

plt.tight_layout()
if save_figs:
    plt.savefig(f"{images_repo}/boxplots_new_data_compare_5_methods.pdf")
plt.show()