In [1]:
import numpy as np
import matplotlib.pyplot as plt
import random
import os
import torch 
import vtk
import pyvista
pyvista.global_theme.notebook = True
pyvista.set_jupyter_backend('static')
# pyvista.set_jupyter_backend('trame')

seed = 160318
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
import time
from utils import *
from utils_training import *
import prepare_data
from utils_compare_methods import *
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
from prepare_data import rotate, outside_ball
import pandas as pd
import gc
from utils_plot import * 
from scipy.spatial.distance import directed_hausdorff

sns.set_theme("paper", rc={"xtick.bottom": True, "ytick.left": True}, font_scale=1.1)
colors = sns.color_palette("mako").as_hex()
my_cmap = sns.color_palette("viridis", as_cmap=True)

%load_ext autoreload
%autoreload 2


In [None]:
save_figs = True
small_data = True
data = DataLoader(small_data)
agent = Agent(
    data,
    initial_lr=5e-4,
    l2_lambda=1e-4,
    level=0.5,
    pad_prop=0.1,
    batch_size=8,
    n_modes=10,
    width=20,
    activation="gelu",
)

model = agent.model
device = agent.device

models_repo = "./models"
images_repo = "../images"
results_repo = "../results"
best_model = torch.load(f"{models_repo}/best_model.pkl")
model.load_state_dict(best_model["model_state_dict"])
model.eval()

print(f"Best epoch = {best_model['epoch']}")
if not (os.path.exists(f"{images_repo}/")) and save_figs:
    os.makedirs(f"{images_repo}/")
if not (os.path.exists(f"{results_repo}/")):
    os.makedirs(f"{results_repo}/")

# Compare the errors of $\phi$-FEM, standard FEM and FNO with respect to a reference fine standard FEM solution

In [5]:
def error_L2(uh, u_ex, degree_raise=0):
    degree = uh.function_space.ufl_element().degree
    family = uh.function_space.ufl_element().family_name
    mesh = uh.function_space.mesh
    W = dolfinx.fem.functionspace(mesh, (family, degree + degree_raise, (2,)))

    u_W = dolfinx.fem.Function(W)
    u_W.interpolate(uh)
    u_ex_W = dolfinx.fem.Function(W)
    u_ex_W.interpolate(u_ex)

    e_W = dolfinx.fem.Function(W)
    e_W.x.array[:] = u_W.x.array - u_ex_W.x.array

    # Integrate the error
    error = dolfinx.fem.form(ufl.inner(e_W, e_W) * ufl.dx)
    error_local = dolfinx.fem.assemble_scalar(error)
    error_global = mesh.comm.allreduce(error_local, op=MPI.SUM)
    norm = dolfinx.fem.form(ufl.inner(u_ex_W, u_ex_W) * ufl.dx)
    norm_local = dolfinx.fem.assemble_scalar(norm)
    norm_global = mesh.comm.allreduce(norm_local, op=MPI.SUM)
    return np.sqrt(error_global), np.sqrt(error_global / norm_global)


def non_matching_interpolation(uh, V_to, padding=1e-14):
    u_to = dolfinx.fem.Function(V_to)

    u1_2_u2_nmm_data = dolfinx.fem.create_nonmatching_meshes_interpolation_data(
        u_to.function_space.mesh,
        u_to.function_space.element,
        uh.function_space.mesh,
        padding=padding,
    )
    u_to.interpolate(uh, nmm_interpolation_data=u1_2_u2_nmm_data)
    u_to.x.scatter_forward()

    return u_to


def compute_domain(X):
    loss = Loss()
    domain_tmp = (X[:, 0, :, :] <= 3e-16).to(device)
    neighborhood = loss.neighborhood_6(domain_tmp).to(device)
    domain = ((neighborhood.int() + domain_tmp.int()) != 0).to(device)
    return domain


In [6]:
def compare_std_phi_fem_and_fno(param, Plot=False, screenshot=False):
    standard_solver = StandardFEMSolver(params=param)
    solver = PhiFemSolver(nb_cell=32 - 1, params=param)

    param_ = param[0]
    gamma_G = param_[0]
    params_holes = np.array(param_[1:]).reshape((-1, 3))[:]

    G = generate_G_numpy(gamma_G, 64)
    phi = generate_phi_numpy(params_holes, 64)
    X = generate_manual_new_data_numpy(phi, G).to(device)

    x_normed = data.x_normalizer.encode(X)
    Y_normed = model(x_normed)
    start_call = time.time()
    Y_normed = model(x_normed)
    Y = data.y_normalizer.decode(Y_normed)

    domain = compute_domain(X)
    domains_tmp = domain.cpu().detach().numpy().flatten()
    domains_nan = domain.cpu().detach().numpy().copy().flatten().astype(float)
    domains_nan[np.where(domains_tmp == False)] = np.nan
    domains_nan = np.reshape(domains_nan, domain.shape)
    domains_nan = torch.tensor(domains_nan).to(X.device)

    avg_low_displacement = torch.nanmean(
        (Y * domains_nan[:, None, :, :])[:, :, :, 0], dim=2
    )[:, :, None, None]
    end_call = time.time()
    U_pred = (Y - avg_low_displacement).cpu().detach().numpy()[0, :, :, :]
    predicted_solution = U_pred.transpose((2, 1, 0))
    time_fno = end_call - start_call

    predicted_sol_fenics = convert_numpy_matrix_to_fenicsx(predicted_solution, 64)
    predicted_sol_fenics.x.scatter_forward()

    u_ref, V_ref = standard_solver.solve_one(0, 0.007, reference_fem=True)
    u_phi_fem, V_phi_fem, time_phi_fem = solver.solve_one(0)
    u_std, V_std, time_std_fem = standard_solver.solve_one(0, 0.035)

    predicted_sol_fenics_proj_V_ref = non_matching_interpolation(
        predicted_sol_fenics, V_ref
    )
    predicted_sol_fenics_proj_V_ref.x.scatter_forward()
    l2_error_fno, rl2_error_fno = error_L2(predicted_sol_fenics_proj_V_ref, u_ref)

    u_std_fem_proj_V_ref = non_matching_interpolation(u_std, V_ref, 1e-3)
    u_std_fem_proj_V_ref.x.scatter_forward()
    l2_error_std_fem, rl2_error_std_fem = error_L2(u_std_fem_proj_V_ref, u_ref)

    u_phi_fem_proj_V_ref = non_matching_interpolation(u_phi_fem, V_ref)
    u_phi_fem_proj_V_ref.x.scatter_forward()
    l2_error_phi_fem, rl2_error_phi_fem = error_L2(u_phi_fem_proj_V_ref, u_ref)

    if Plot:

        labels = [
            r"$\boldsymbol{u}_{ref}$",
            r"$\bar{L_2}(\boldsymbol{u}_{ref},\boldsymbol{u}_{std})=$"
            f"{rl2_error_std_fem:.3e}",
            r"$\bar{L_2}(\boldsymbol{u}_{ref},\boldsymbol{u}_{\varphi})=$"
            + f"{rl2_error_phi_fem:.3e}",
            r"$\bar{L_2}(\boldsymbol{u}_{ref},\boldsymbol{u}_{\theta})=$"
            + f"{rl2_error_fno:.3e}",
        ]

        plot_2D_vector_function_list_v2(
            [
                u_ref,
                u_std_fem_proj_V_ref,
                u_phi_fem_proj_V_ref,
                predicted_sol_fenics_proj_V_ref,
            ],
            labels,
            screenshot=screenshot,
        )

    return (
        l2_error_phi_fem,
        l2_error_std_fem,
        l2_error_fno,
        rl2_error_phi_fem,
        rl2_error_std_fem,
        rl2_error_fno,
        time_fno,
        time_std_fem,
        time_phi_fem,
    )

In [None]:
save_list_error = True
params = np.load("../../data_test/params.npy")
indices = list(range(0, 300))
L2_error_fno, L2_error_phi_fem, L2_error_std_fem = [], [], []
rL2_error_fno, rL2_error_phi_fem, rL2_error_std_fem = [], [], []
Times_fno, Times_std_fem, Times_phi_fem = [], [], []

for index in indices:
    print(f"Iter : {index+1}/{len(indices)}")

    if index % 50 == 0:
        Plot = True
        screenshot = False  # f"{images_repo}/output_{index}.png"
    else:
        Plot = False
        screenshot = False

    (
        l2_error_phi_fem,
        l2_error_std_fem,
        l2_error_fno,
        rl2_error_phi_fem,
        rl2_error_std_fem,
        rl2_error_fno,
        time_fno,
        time_std_fem,
        time_phi_fem,
    ) = compare_std_phi_fem_and_fno(
        np.array([params[index]]), Plot=Plot, screenshot=screenshot
    )

    L2_error_phi_fem.append(l2_error_phi_fem)
    L2_error_fno.append(l2_error_fno)
    L2_error_std_fem.append(l2_error_std_fem)
    rL2_error_phi_fem.append(rl2_error_phi_fem)
    rL2_error_fno.append(rl2_error_fno)
    rL2_error_std_fem.append(rl2_error_std_fem)
    Times_fno.append(time_fno)
    Times_std_fem.append(time_std_fem)
    Times_phi_fem.append(time_phi_fem)
    if save_list_error:
        np.save(
            f"{results_repo}/L2_error_phi_fem.npy",
            np.array([L2_error_phi_fem]),
        )
        np.save(
            f"{results_repo}/rL2_error_phi_fem.npy",
            np.array([rL2_error_phi_fem]),
        )
       
        np.save(
            f"{results_repo}/L2_error_std_fem.npy",
            np.array([L2_error_std_fem]),
        )
        np.save(
            f"{results_repo}/rL2_error_std_fem.npy",
            np.array([rL2_error_std_fem]),
        )
       
        np.save(f"{results_repo}/L2_error_fno.npy", np.array([L2_error_fno]))
        np.save(f"{results_repo}/rL2_error_fno.npy", np.array([rL2_error_fno]))
        np.save(f"{results_repo}/Times_fno.npy", np.array([Times_fno]))
        np.save(
            f"{results_repo}/Times_std_fem.npy",
            np.array([Times_std_fem]),
        )
        np.save(
            f"{results_repo}/Times_phi_fem.npy",
            np.array([Times_phi_fem]),
        )

In [None]:
rL2_error_fno = np.load(f"{results_repo}/rL2_error_fno.npy")
params = np.load("../../data_test/params.npy")
sorted_errors_indices = np.argsort(rL2_error_fno.flatten())
index = sorted_errors_indices[len(sorted_errors_indices) // 2]
compare_std_phi_fem_and_fno(
    np.array([params[index]]),
    Plot=True,
    screenshot=f"{images_repo}/output_median_index_L2.png",
)

In [None]:
rL2_error_phi_fem = np.load(
    f"{results_repo}/rL2_error_phi_fem.npy",
)
rL2_error_std_fem = np.load(f"{results_repo}/rL2_error_std_fem.npy")
rL2_error_fno = np.load(f"{results_repo}/rL2_error_fno.npy")
error_tab = []
error_tab.append(rL2_error_phi_fem)
error_tab.append(rL2_error_std_fem)
error_tab.append(rL2_error_fno)

abs_str = [r"$\varphi$-FEM", "Standard FEM", r"$\varphi$-FEM-FNO"]
errors = np.array(error_tab[:])
errors = np.reshape(errors, (3, np.shape(errors)[-1]))
print(np.shape(errors))
dataframe = pd.DataFrame(errors.transpose(), columns=abs_str)

palette = sns.cubehelix_palette(n_colors=2, start=0.25, rot=-0.25, gamma=0.5)
palette = palette.as_hex()
palette.append("#b22222")
palette = sns.color_palette(palette)

plt.figure(figsize=(6, 4))
sns.boxplot(
    data=dataframe,
    palette=palette,
    flierprops={"marker": "x", "markerfacecolor": "black"},
)
plt.yscale("log")
# plt.xlabel("Method", fontsize=16)
plt.ylabel("Relative $L^2$ error", fontsize=16)
plt.grid(axis="y", visible=True, which="both")
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)

plt.tight_layout()
if save_figs:
    plt.savefig(f"{images_repo}/boxplots_new_data_compare_method_relative_L2.pdf")
plt.show()

In [None]:
rL2_error_phi_fem = np.load(
    f"{results_repo}/rL2_error_phi_fem.npy",
).flatten()[1:]
rL2_error_std_fem = np.load(f"{results_repo}/rL2_error_std_fem.npy").flatten()[1:]
rL2_error_fno = np.load(f"{results_repo}/rL2_error_fno.npy").flatten()[1:]

Times_phi_fem = np.load(
    f"{results_repo}/Times_phi_fem.npy",
).flatten()[1:]
Times_std_fem = np.load(f"{results_repo}/Times_std_fem.npy").flatten()[1:]
Times_fno = np.load(f"{results_repo}/Times_fno.npy").flatten()[1:]


mean_time_fno = np.mean(Times_fno)
mean_time_std_fem = np.mean(Times_std_fem)
mean_time_phi_fem = np.mean(Times_phi_fem)

std_time_fno = np.std(Times_fno)
std_time_std_fem = np.std(Times_std_fem)
std_time_phi_fem = np.std(Times_phi_fem)


mean_error_fno = np.mean(rL2_error_fno)
mean_error_std_fem = np.mean(rL2_error_std_fem)
mean_error_phi_fem = np.mean(rL2_error_phi_fem)

std_error_fno = np.std(rL2_error_fno)
std_error_std_fem = np.std(rL2_error_std_fem)
std_error_phi_fem = np.std(rL2_error_phi_fem)

palette = sns.color_palette("Paired")


fig, ax = plt.subplots(1, 1, figsize=(6, 4))
ax.plot(
    mean_time_std_fem,
    mean_error_std_fem,
    "2",
    c=palette[7],
    markersize=12,
    label="Std-FEM",
)
confidence_ellipse(
    Times_std_fem,
    rL2_error_std_fem,
    ax,
    alpha=0.5,
    facecolor=palette[7],
    edgecolor=palette[7],
)

ax.loglog(
    mean_time_phi_fem,
    mean_error_phi_fem,
    "1",
    c=palette[1],
    markersize=12,
    label=r"$\varphi$-FEM",
)
confidence_ellipse(
    Times_phi_fem,
    rL2_error_phi_fem,
    ax,
    alpha=0.5,
    facecolor=palette[1],
    edgecolor=palette[1],
)

ax.plot(
    mean_time_fno,
    mean_error_fno,
    "x",
    markersize=8,
    label=r"$\varphi$-FEM-FNO",
    color=palette[5],
)
confidence_ellipse(
    Times_fno,
    rL2_error_fno,
    ax,
    alpha=0.5,
    facecolor=palette[5],
    edgecolor=palette[5],
)

ax.legend(fontsize=14, loc="lower left")  # , ncol=2)
ax.set_xlabel("Computation time (s)", fontsize=16)
ax.set_ylabel(r"Relative $L^2$ error", fontsize=16)
ax.set_xscale("log")
ax.set_yscale("log")
plt.grid(axis="y", visible=True, which="both")
plt.grid(axis="x", visible=True, which="major")

plt.tight_layout()
plt.savefig(f"./{images_repo}/relative_L2_vs_time.pdf")
plt.show()