In [1]:
import numpy as np
import matplotlib.pyplot as plt
import random
import os
import torch 
import vtk
import pyvista
pyvista.global_theme.notebook = True
# pyvista.set_jupyter_backend('static')
pyvista.set_jupyter_backend('trame')
import generate_data as gen_data

seed = 231024
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
import time
from utils import *
from utils_training import * 
import prepare_data
from utils_compare_methods import *
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pandas as pd
import gc
from utils_plot import * 

sns.set_theme("paper", rc={"xtick.bottom": True, "ytick.left": True}, font_scale=1.1)
colors = sns.color_palette("mako").as_hex()
my_cmap = sns.color_palette("viridis", as_cmap=True)

%load_ext autoreload
%autoreload 2


In [None]:
save_figs = True
small_data = False
data = DataLoader(small_data)
agent = Agent(data)

images_repo = "../images"
models_repo = "./models"
results_repo = "../results"
model_fno = agent.model
model_iphi = agent.model_iphi
device = agent.device
best_model = torch.load(f"{models_repo}/best_model.pkl")
model_fno.load_state_dict(best_model["model_state_dict"])
model_iphi.load_state_dict(best_model["model_iphi_state_dict"])
model_fno.eval()
model_iphi.eval()

print(f"Best epoch std = {best_model['epoch']}")


if not (os.path.exists(f"{images_repo}/")) and save_figs:
    os.makedirs(f"{images_repo}/")
if not (os.path.exists(f"{results_repo}/")):
    os.makedirs(f"{results_repo}/")

# Compare the errors of $\phi$-FEM, standard FEM and FNO with respect to a reference fine standard FEM solution

In [3]:
def error_L2(uh, u_ex, degree_raise=3):
    degree = uh.function_space.ufl_element().degree
    family = uh.function_space.ufl_element().family_name
    mesh = uh.function_space.mesh
    W = dolfinx.fem.functionspace(mesh, (family, degree + degree_raise))

    u_W = dolfinx.fem.Function(W)
    u_W.interpolate(uh)
    u_ex_W = dolfinx.fem.Function(W)
    u_ex_W.interpolate(u_ex)

    e_W = dolfinx.fem.Function(W)
    e_W.x.array[:] = u_W.x.array - u_ex_W.x.array

    # Integrate the error
    error = dolfinx.fem.form(ufl.inner(e_W, e_W) * ufl.dx)
    error_local = dolfinx.fem.assemble_scalar(error)
    error_global = mesh.comm.allreduce(error_local, op=MPI.SUM)
    norm = dolfinx.fem.form(ufl.inner(u_ex_W, u_ex_W) * ufl.dx)
    norm_local = dolfinx.fem.assemble_scalar(norm)
    norm_global = mesh.comm.allreduce(norm_local, op=MPI.SUM)
    return np.sqrt(error_global / norm_global)


def non_matching_interpolation(uh, V_to, padding=1e-14):
    u_to = dolfinx.fem.Function(V_to)

    u1_2_u2_nmm_data = dolfinx.fem.create_nonmatching_meshes_interpolation_data(
        u_to.function_space.mesh,
        u_to.function_space.element,
        uh.function_space.mesh,
        padding=padding,
    )
    u_to.interpolate(uh, nmm_interpolation_data=u1_2_u2_nmm_data)
    u_to.x.scatter_forward()

    return u_to


def convert_numpy_matrix_to_fenicsx_given_V(U, X, V):
    dof_coords = V.tabulate_dof_coordinates()
    dof_coords = dof_coords.reshape((-1, 3))[:, :2]
    sorted_indices = np.lexsort((dof_coords[:, 0], dof_coords[:, 1]))
    sorted_indices = sorted_indices.astype(np.int32)

    sorted_indices_in = np.lexsort((X[:, 0], X[:, 1]))
    sorted_indices_in = sorted_indices_in.astype(np.int32)

    X_fenicsx = dolfinx.fem.Function(V)
    X_fenicsx.vector.setValuesLocal(sorted_indices, U.flatten()[sorted_indices_in])
    X_fenicsx.vector.ghostUpdate(
        addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD
    )

    return X_fenicsx

# Geo-FNO test data creation

In [None]:
params = np.load("../../data_test/params.npy")
print(params.shape)
solver = gen_data.StdFemSolver(nb_cell=64 - 1, params=params)
U, XY, F, G, Phi, HH = solver.solve_several()


np.save("../../data_test/XY_geo_fno.npy", XY)
np.save("../../data_test/Phi_geo_fno.npy", Phi)
np.save("../../data_test/F_geo_fno.npy", F)
np.save("../../data_test/G_geo_fno.npy", G)
np.save("../../data_test/U_geo_fno.npy", U)

# We construct all the test dataset and predict the solutions once

In [None]:
XY = np.load("../../data_test/XY_geo_fno.npy")
Phi = np.load("../../data_test/Phi_geo_fno.npy")
F = np.load("../../data_test/F_geo_fno.npy")
G = np.load("../../data_test/G_geo_fno.npy")
params = np.load("../../data_test/params.npy")
F = F[:, :, None]
G = G[:, :, None]

xy, f, g = (
    torch.tensor(XY, dtype=torch.float).to(device),
    torch.tensor(F, dtype=torch.float).to(device),
    torch.tensor(G, dtype=torch.float).to(device),
)

x = torch.cat([xy, f, g], dim=-1)

y_pred_1 = model_fno(x[:100, :, :], iphi=model_iphi).cpu().detach().numpy()
y_pred_2 = model_fno(x[100:200, :, :], iphi=model_iphi).cpu().detach().numpy()
y_pred_3 = model_fno(x[200:, :, :], iphi=model_iphi).cpu().detach().numpy()

y_pred = torch.concat(
    [torch.tensor(y_pred_1), torch.tensor(y_pred_2), torch.tensor(y_pred_3)]
).to(device)
print(y_pred.shape)  # should be (300,2600,1)
y_pred = y_pred.cpu().detach().numpy()

In [6]:
def compute_geo_fno_error(param, x, y_pred, index):
    standard_solver = StandardFEMSolver(params=param)
    u_ref, V_ref, dx_ref, h_ref = standard_solver.solve_one(
        0, 0.005, reference_fem=True
    )

    # mesh creation to go from point cloud to mesh --> can compute error with fenicsx
    in_points = x.cpu().detach().numpy()[index, :, :2]
    mesh_creator = points2mesh()
    mesh = mesh_creator.create_standard_mesh(in_points)
    V = dolfinx.fem.functionspace(mesh, ("CG", 1))

    start = time.time()
    prediction = model_fno(x[index, None, :, :], iphi=model_iphi).cpu().detach().numpy()
    end = time.time()
    # prediction = y_pred[index, :, 0]
    predicted_sol_fenics = convert_numpy_matrix_to_fenicsx_given_V(
        prediction, in_points, V
    )
    predicted_sol_fenics_proj_V_ref = non_matching_interpolation(
        predicted_sol_fenics, V_ref
    )
    l2_error_geo_fno = error_L2(predicted_sol_fenics_proj_V_ref, u_ref)

    print(f"{l2_error_geo_fno=:.3e}")

    return l2_error_geo_fno, end - start

In [None]:
save_list_error = True

params = np.load("../../data_test/params.npy")

indices = list(range(0, 300))
L2_error_geo_fno = []
Times_geo_fno = []
for index in indices[:]:
    print(f"Iter : {index+1}/{len(indices)}")
    l2_error_geo_fno, time_geo_fno = compute_geo_fno_error(
        np.array([params[index]]), x, y_pred, index
    )
    L2_error_geo_fno.append(l2_error_geo_fno)
    Times_geo_fno.append(time_geo_fno)
    if save_list_error:
        np.save(f"{results_repo}/L2_error_geo_fno.npy", np.array([L2_error_geo_fno]))
        np.save(f"{results_repo}/Times_geo_fno.npy", np.array([Times_geo_fno]))
L2_error_geo_fno = np.array([L2_error_geo_fno])
Times_geo_fno = np.array([Times_geo_fno])

In [None]:
L2_error_fno = np.load(f"{results_repo}/L2_error_phi_fem_fno.npy")
L2_error_phi_fem = np.load(f"{results_repo}/L2_error_phi_fem.npy")
L2_error_std_fem = np.load(f"{results_repo}/L2_error_std_fem.npy")
L2_error_std_fno = np.load(f"{results_repo}/L2_error_std_fem_fno.npy")
L2_error_fno_pred_u = np.load(f"{results_repo}/L2_error_phi_fem_fno_2.npy")
L2_error_geo_fno = np.load(f"{results_repo}/L2_error_geo_fno.npy")

error_tab = []
error_tab.append(L2_error_phi_fem)
error_tab.append(L2_error_std_fem)
error_tab.append(L2_error_fno)
error_tab.append(L2_error_fno_pred_u)
error_tab.append(L2_error_std_fno)
error_tab.append(L2_error_geo_fno)
abs_str = [
    r"$\varphi$-FEM",
    "\nStd FEM",
    r"$\varphi$-FEM-FNO",
    "\n" + r"$\varphi$-FEM-FNO 2",
    "Std-FEM-FNO",
    "\nGeo-FNO",
]
errors = np.array(error_tab[:])
errors = np.reshape(errors, (6, np.shape(errors)[-1]))
print(np.shape(errors))


palette = sns.cubehelix_palette(n_colors=5, start=0.25, rot=-0.25, gamma=0.5)
palette = palette.as_hex()
palette.insert(-3, "#b22222")
palette = sns.color_palette(palette)

dataframe = pd.DataFrame(errors.transpose(), columns=abs_str)

plt.figure(figsize=(6, 4))
sns.boxplot(
    data=dataframe,
    palette=palette,
    flierprops={"marker": "x", "markerfacecolor": "black"},
)
plt.yscale("log")
# plt.xlabel("Method", fontsize=16)
plt.ylabel("Relative $L^2$ error", fontsize=16)
plt.grid(axis="y", visible=True, which="both")
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)

plt.tight_layout()
if save_figs:
    plt.savefig(f"{images_repo}/boxplots_new_data_compare_all_methods.pdf")
plt.show()

In [None]:
dataframe.describe()