In [1]:
import numpy as np
import matplotlib.pyplot as plt
import random
import os
import torch 
import vtk
import pyvista
pyvista.global_theme.notebook = True
pyvista.set_jupyter_backend('static')
# pyvista.set_jupyter_backend('trame')

seed = 160318
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
import time
from utils import *
from utils_training import *
import prepare_data
from utils_compare_methods import *
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
from prepare_data import rotate, outside_ball, create_parameters
import pandas as pd
import gc
from utils_plot import * 
sns.set_theme("paper", rc={"xtick.bottom": True, "ytick.left": True}, font_scale=1.1)
colors = sns.color_palette("mako").as_hex()
my_cmap = sns.color_palette("viridis", as_cmap=True)

%load_ext autoreload
%autoreload 2


In [None]:
save_figs = True
small_data = False
data = DataLoader(small_data)
agent = Agent(data, level=1)

model = agent.model
device = agent.device

models_repo = "./models"
images_repo = "../images"
results_repo = "../results"

best_model = torch.load(f"{models_repo}/best_model.pkl")
model.load_state_dict(best_model["model_state_dict"])
model.eval()

print(f"Best epoch = {best_model['epoch']}")
if not (os.path.exists(f"{images_repo}/")) and save_figs:
    os.makedirs(f"{images_repo}/")

if not (os.path.exists(f"{results_repo}/")):
    os.makedirs(f"{results_repo}/")

# Compare the errors of $\phi$-FEM, standard FEM and FNO with respect to a reference fine standard FEM solution

In [3]:
def error_L2(uh, u_ex, degree_raise=3):
    degree = uh.function_space.ufl_element().degree
    family = uh.function_space.ufl_element().family_name
    mesh = uh.function_space.mesh
    W = dolfinx.fem.functionspace(mesh, (family, degree + degree_raise))

    u_W = dolfinx.fem.Function(W)
    u_W.interpolate(uh)
    u_ex_W = dolfinx.fem.Function(W)
    u_ex_W.interpolate(u_ex)

    e_W = dolfinx.fem.Function(W)
    e_W.x.array[:] = u_W.x.array - u_ex_W.x.array

    # Integrate the error
    error = dolfinx.fem.form(ufl.inner(e_W, e_W) * ufl.dx)
    error_local = dolfinx.fem.assemble_scalar(error)
    error_global = mesh.comm.allreduce(error_local, op=MPI.SUM)
    norm = dolfinx.fem.form(ufl.inner(u_ex_W, u_ex_W) * ufl.dx)
    norm_local = dolfinx.fem.assemble_scalar(norm)
    norm_global = mesh.comm.allreduce(norm_local, op=MPI.SUM)
    return np.sqrt(error_global / norm_global)


def non_matching_interpolation(uh, V_to, padding=1e-14):
    u_to = dolfinx.fem.Function(V_to)

    u1_2_u2_nmm_data = dolfinx.fem.create_nonmatching_meshes_interpolation_data(
        u_to.function_space.mesh,
        u_to.function_space.element,
        uh.function_space.mesh,
        padding=padding,
    )
    u_to.interpolate(uh, nmm_interpolation_data=u1_2_u2_nmm_data)
    u_to.x.scatter_forward()

    return u_to


def compute_domain(X):
    loss = Loss()
    domain_tmp = (X[:, 1, :, :] <= 3e-16).to(device)
    neighborhood = loss.neighborhood_6(domain_tmp).to(device)
    domain = ((neighborhood.int() + domain_tmp.int()) != 0).to(device)
    return domain

In [4]:
def compare_std_phi_fem_and_fno(param, Plot=False, screenshot=False):
    standard_solver = StandardFEMSolver(params=param)
    solver = PhiFemSolver_error(nb_cell=64 - 1, params=param)
    mu0, mu1, sigma_x, sigma_y, amplitude, x_0, y_0, lx, ly, theta, alpha, beta = param[
        0
    ]

    u_ref, V_ref, dx_ref, h_ref = standard_solver.solve_one(
        0, 0.005, reference_fem=True
    )
    (u_phi_fem, _h_phi, list_times_phi_fem) = solver.solve_one(0)

    # list_times_phi_fem = [self.init_time, cell_selection, submesh_construction, ghost_cell_selection, interp_f_g, resolution_time]
    (u_std, _h_std, list_times_std_fem) = standard_solver.solve_one(
        0, 0.0224, reference_fem=False
    )
    # list_times_std_fem = [ls_constr, construct_background_mesh, interp_mesh, remesh_time, trunc_mesh, read_fenics, inter_f_g, locte_bc_nodes, solve]

    F = generate_F_numpy(mu0, mu1, sigma_x, sigma_y, amplitude, 64)
    G = generate_G_numpy(alpha, beta, 64)
    phi = generate_phi_numpy(x_0, y_0, lx, ly, theta, 64)
    X = generate_manual_new_data_numpy(F, phi, G).to(device)

    # start phi-FEM-FNO
    x_normed = data.x_normalizer.encode(X)
    start = time.time()
    Y_normed = model(x_normed)
    end = time.time()
    time_fno = end - start
    Y = data.y_normalizer.decode(Y_normed)
    predicted_solution = (
        (X[:, 1, :, :] * Y[:, 0, :, :] + X[:, 2, :, :]).cpu().detach().numpy()
    )[0, :, :]
    predicted_sol_fenics = convert_numpy_matrix_to_fenicsx(predicted_solution, 64, 1)
    predicted_sol_fenics.x.scatter_forward()
    predicted_sol_fenics_proj_V_ref = non_matching_interpolation(
        predicted_sol_fenics, V_ref
    )
    l2_error_fno = error_L2(predicted_sol_fenics_proj_V_ref, u_ref)
    # end phi-FEM-FNO

    u_std_fem_proj_V_ref = non_matching_interpolation(u_std, V_ref, 1e-3)
    l2_error_std_fem = error_L2(u_std_fem_proj_V_ref, u_ref)

    u_phi_fem_proj_V_ref = non_matching_interpolation(u_phi_fem, V_ref)
    l2_error_phi_fem = error_L2(u_phi_fem_proj_V_ref, u_ref)

    diff_std = dolfinx.fem.Function(V_ref)
    diff_phi = dolfinx.fem.Function(V_ref)
    diff_fno = dolfinx.fem.Function(V_ref)

    diff_std.x.array[:] = np.absolute(
        u_std_fem_proj_V_ref.x.array[:] - u_ref.x.array[:]
    )
    diff_phi.x.array[:] = np.absolute(
        u_phi_fem_proj_V_ref.x.array[:] - u_ref.x.array[:]
    )
    diff_fno.x.array[:] = np.absolute(
        predicted_sol_fenics_proj_V_ref.x.array[:] - u_ref.x.array[:]
    )

    if Plot:
        labels = [
            r"$u_{ref}$",
            r"$E(u_{ref}, u_{std}) =$" + f"{l2_error_std_fem:.3e}",
            r"$E(u_{ref}, u_{\varphi}) =$" + f"{l2_error_phi_fem:.3e}",
            r"$E(u_{ref}, u_{\theta}) =$" + f"{l2_error_fno:.3e}",
        ]
        plot_scalar_functions_list(
            [
                u_ref,
                diff_std,
                diff_phi,
                diff_fno,
            ],
            labels=labels,
            screenshot=screenshot,
            height=400,
        )
    print(f"{l2_error_fno=:.3e}")
    print(f"{l2_error_phi_fem=:.3e}")
    print(f"{l2_error_std_fem=:.3e}")
    print(f"{_h_std=:.6f}")
    print(f"{_h_phi=:.6f}")

    return (
        l2_error_phi_fem,
        l2_error_std_fem,
        l2_error_fno,
        np.sum(list_times_phi_fem),
        np.sum(list_times_std_fem),
        time_fno,
    )

In [None]:
save_list_error = True

params = np.load("../../data_test/params.npy")
indices = list(range(0, 300))
L2_error_fno, L2_error_phi_fem, L2_error_std_fem = [], [], []
Times_fno, Times_phi_fem, Times_std_fem = [], [], []

for index in indices:
    print(f"Iter : {index+1}/{len(indices)}")
    (
        l2_error_phi_fem,
        l2_error_std_fem,
        l2_error_fno,
        time_phi_fem,
        time_std_fem,
        time_fno,
    ) = compare_std_phi_fem_and_fno(np.array([params[index]]), Plot=(index % 50) == 0)

    L2_error_phi_fem.append(l2_error_phi_fem)
    L2_error_fno.append(l2_error_fno)
    L2_error_std_fem.append(l2_error_std_fem)
    Times_phi_fem.append(time_phi_fem)
    Times_fno.append(time_fno)
    Times_std_fem.append(time_std_fem)

    if save_list_error:
        if not os.path.exists(f"{results_repo}/"):
            os.makedirs(f"{results_repo}/")

        np.save(
            f"{results_repo}/L2_error_phi_fem.npy",
            np.array([L2_error_phi_fem]),
        )
        np.save(
            f"{results_repo}/L2_error_std_fem.npy",
            np.array([L2_error_std_fem]),
        )
        np.save(f"{results_repo}/L2_error_phi_fem_fno.npy", np.array([L2_error_fno]))
        np.save(
            f"{results_repo}/Times_phi_fem.npy",
            np.array([Times_phi_fem]),
        )
        np.save(
            f"{results_repo}/Times_std_fem.npy",
            np.array([Times_std_fem]),
        )
        np.save(f"{results_repo}/Times_phi_fem_fno.npy", np.array([Times_fno]))

In [None]:
L2_error_phi_fem = np.load(f"{results_repo}/L2_error_phi_fem.npy")
L2_error_std_fem = np.load(f"{results_repo}/L2_error_std_fem.npy")
L2_error_fno = np.load(f"{results_repo}/L2_error_phi_fem_fno.npy")

error_tab = []
error_tab.append(L2_error_phi_fem)
error_tab.append(L2_error_std_fem)
error_tab.append(L2_error_fno)

abs_str = [
    r"$\varphi$-FEM",
    "Standard FEM",
    r"$\varphi$-FEM-FNO",
]
errors = np.array(error_tab[:])
errors = np.reshape(errors, (3, np.shape(errors)[-1]))
print(np.shape(errors))
dataframe = pd.DataFrame(errors.transpose(), columns=abs_str)


palette = sns.cubehelix_palette(n_colors=2, start=0.25, rot=-0.25, gamma=0.5)
palette = palette.as_hex()
palette.append("#b22222")
palette = sns.color_palette(palette)

plt.figure(figsize=(6, 4))
sns.boxplot(
    data=dataframe,
    palette=palette,
    flierprops={"marker": "x", "markerfacecolor": "black"},
)
plt.yscale("log")
plt.xlabel("Method", fontsize=16)
plt.ylabel("Relative $L^2$ error", fontsize=16)
plt.grid(axis="y", visible=True, which="both")

plt.tight_layout()
if save_figs:
    plt.savefig(f"{images_repo}/boxplots_new_data_compare_method.pdf")
plt.show()

In [None]:
dataframe.describe()

In [None]:
L2_error_fno = np.load(f"{results_repo}/L2_error_phi_fem_fno.npy")
params = np.load("../../data_test/params.npy")
sorted_errors_indices = np.argsort(L2_error_fno.flatten())
index = sorted_errors_indices[len(sorted_errors_indices) // 2]
(
    l2_error_phi_fem,
    l2_error_std_fem,
    l2_error_fno,
    time_phi_fem,
    time_std_fem,
    time_fno,
) = compare_std_phi_fem_and_fno(
    np.array([params[index]]),
    Plot=True,
    screenshot=f"{images_repo}/output_median_index.png",
)