In [None]:
import numpy as np
import matplotlib.pyplot as plt
import random
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf

seed = 2023
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
tf.experimental.numpy.random.seed(seed)
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
os.environ["TF_DETERMINISTIC_OPS"] = "1"
# Set a fixed value for the hash seed
os.environ["PYTHONHASHSEED"] = str(seed)
import dolfin as df
import time
from utils import *
import prepare_data
from utils_compare_methods import *
from utils_training import *
import seaborn as sns
import generate_domains as gd

from mpl_toolkits.axes_grid1 import make_axes_locatable
import pandas as pd

sns.set_theme()
sns.set_context("paper")
colors = sns.color_palette("mako").as_hex()
my_cmap = sns.color_palette("viridis", as_cmap=True)

%load_ext autoreload
%autoreload 2

In [None]:
save_figs = True

small_data = False
level = 2
data = DataLoader(small_data)
agent = Agent(data, small_data)

if not (os.path.exists(f"./images/")) and save_figs:
    os.makedirs(f"./images/")

file = open("./models/best_model/best_epoch.txt")
for y in file.read().split(" "):
    if y.isdigit():
        best_epoch = int(y)
epochs = [50, 100, 250, 500, 1000, 1500, 1890, best_epoch]
print(len(epochs))
print(epochs)

indices = list(range(0, len(epochs)))
size_per_fig = 4
nb_rows = 2
size_row = int(len(indices) / nb_rows)


In [None]:
nb_epochs = best_epoch
agent.model.load_weights(f"./models/model_{nb_epochs}/model_weights")


# Compare the errors of $\phi$-FEM, standard FEM and FNO with respect to a reference fine standard FEM solution

In [None]:
def compare_std_phi_fem_and_fno(
    param, phi, coeffs, Plot=False, epoch=epochs[-1]
):
    agent.model.load_weights(f"./models/model_{epoch}/model_weights")

    solveur_standard = StandardFEMSolver(params=param, phi_vector=phi)
    solver = PhiFemSolver_error(nb_cell=64 - 1, params=param, coeffs=coeffs, sigma_D=1.0)
    mu0, mu1, sigma, alpha, beta = param[0]

    u_ref, V_ref, dx_ref = solveur_standard.solve_one(0, 0.002, Plot, True)
    u_phi_fem, V_phi_fem, time_phi, phi_64, _h_phi = solver.solve_one(0)
    u_std, time_std, _h_std = solveur_standard.solve_one(
        0, 0.023, Plot, False
    )

    F = generate_F_numpy(mu0, mu1, sigma, 64) / data.max_norm_F
    G = generate_G_numpy(alpha, beta, 64)
    X = generate_manual_new_data_numpy(phi_64, F, G)
    start_call = time.time()
    Y = agent.model.call(X)
    end_call = time.time()
    time_fno = end_call - start_call
    predicted_solution = X[:, :, :, 1] * Y[:, :, :, 0] + X[:, :, :, 2]
    predicted_solution = np.reshape(predicted_solution, (64, 64))
    predicted_solution_fenics = convert_numpy_matrix_to_fenics(
        predicted_solution, 64, 1
    )

    u_phi_fem_proj = df.project(
        u_phi_fem,
        V_ref,
        solver_type="gmres",
        preconditioner_type="hypre_amg",
    )

    l2_error_phi_fem = (
        df.assemble((((u_ref - u_phi_fem_proj)) ** 2) * dx_ref) ** (0.5)
    ) / (df.assemble((((u_ref)) ** 2) * dx_ref) ** (0.5))

    u_std_fem_proj = df.project(
        u_std, V_ref, solver_type="gmres", preconditioner_type="hypre_amg"
    )
    l2_error_std_fem = (
        df.assemble((((u_ref - u_std_fem_proj)) ** 2) * dx_ref) ** (0.5)
    ) / (df.assemble((((u_ref)) ** 2) * dx_ref) ** (0.5))

    predicted_solution_fenics_proj_V_ref = df.project(
        predicted_solution_fenics,
        V_ref,
        solver_type="gmres",
        preconditioner_type="hypre_amg",
    )

    l2_error_fno = (
        df.assemble((((u_ref - predicted_solution_fenics_proj_V_ref)) ** 2) * dx_ref)
        ** (0.5)
    ) / (df.assemble((((u_ref)) ** 2) * dx_ref) ** (0.5))

    if Plot:
        plt.figure(figsize=(16, 4))

        ax1 = plt.subplot(141)
        img = df.plot(u_ref, mode="color", cmap=my_cmap)
        divider = make_axes_locatable(ax1)
        cax = divider.append_axes("bottom", size="5%", pad=0.3)
        ax1.grid(False)
        ax1.set_xlim(0,1)
        ax1.set_ylim(0,1)
        ax1.set_title("Reference solution", fontsize=15)
        plt.colorbar(img, cax=cax, orientation="horizontal")

        ax2 = plt.subplot(142)
        img = df.plot(u_phi_fem_proj, mode="color", cmap=my_cmap)
        divider = make_axes_locatable(ax2)
        cax = divider.append_axes("bottom", size="5%", pad=0.3)
        ax2.grid(False)
        ax2.set_xlim(0,1)
        ax2.set_ylim(0,1)
        ax2.set_title(
            "$\phi$-FEM solution \n$L^2$ relative error : "
            + f"{l2_error_phi_fem:.5f}",
            fontsize=15,
        )
        plt.colorbar(img, cax=cax, orientation="horizontal")

        ax3 = plt.subplot(143)
        img = df.plot(u_std_fem_proj, mode="color", cmap=my_cmap)
        divider = make_axes_locatable(ax3)
        cax = divider.append_axes("bottom", size="5%", pad=0.3)
        ax3.grid(False)
        ax3.set_xlim(0,1)
        ax3.set_ylim(0,1)
        ax3.set_title(
            "Standard FEM solution \n$L^2$ relative error : "
            + f"{l2_error_std_fem:.5f}",
            fontsize=15,
        )
        plt.colorbar(img, cax=cax, orientation="horizontal")

        ax4 = plt.subplot(144)
        img = df.plot(
            predicted_solution_fenics_proj_V_ref, mode="color", cmap=my_cmap
        )
        divider = make_axes_locatable(ax4)
        cax = divider.append_axes("bottom", size="5%", pad=0.3)
        ax4.grid(False)
        ax4.set_xlim(0,1)
        ax4.set_ylim(0,1)
        ax4.set_title(
            "Predicted solution \n$L^2$ relative error : "
            + f"{l2_error_fno:.5f}",
            fontsize=15,
        )
        plt.colorbar(img, cax=cax, orientation="horizontal")

        plt.tight_layout(pad=1.3)
        if Plot and save_figs:
            plt.savefig(f"./images/example_output_FEMs_FNO.png")
        plt.show()

    print(f"{l2_error_fno=}")
    print(f"{l2_error_phi_fem=}")
    print(f"{l2_error_std_fem=}")

    print(f"{_h_std=}")
    print(f"{_h_phi=}")

    return (
        l2_error_phi_fem,
        l2_error_std_fem,
        l2_error_fno,
        time_phi,
        time_std,
        time_fno,
    )


In [None]:
save_list_error = True
save_figs = False
L2_error_fno, L2_error_phi_fem, L2_error_std_fem = [], [], []
Times_phi, Times_std, Times_fno = [], [], []
nb_data, nb_vert = 30, 256
n_mode = 4

if not (os.path.exists(f"./data_domains_{nb_data}")):
    gd.generate_multiple_domains(
        nb_data=nb_data,
        nb_vert=2 * nb_vert - 1,
        seed=seed,
        n_mode=n_mode,
        save=True,
    )


phi = np.load(f"./data_domains_{nb_data}_{n_mode}/level_sets_{nb_data}.npy")
domains = np.load(f"./data_domains_{nb_data}_{n_mode}/domains_{nb_data}.npy")

if not (os.path.exists(f"./data_compare_methods")):
    os.makedirs(f"./data_compare_methods")
    F, G, params = create_params(domains)
    np.save("./data_compare_methods/params.npy", params)
    
params = np.load("./data_compare_methods/params.npy")
domains = np.load(f"./data_domains_{nb_data}_{n_mode}/domains_{nb_data}.npy")
coeffs = np.load(f"./data_domains_{nb_data}_{n_mode}/params_{nb_data}.npy")

indices = list(range(0, len(phi)))
epoch = best_epoch
for index in indices:
    print(f"index = {index+1}/{len(indices)}")
    (
        l2_error_phi_fem,
        l2_error_std_fem,
        l2_error_fno,
        time_phi,
        time_std,
        time_fno,
    ) = compare_std_phi_fem_and_fno(
        np.array([params[index]]),
        phi[index],
        coeffs[index],
        Plot=True,
        epoch=epoch,
    )

    L2_error_phi_fem.append(l2_error_phi_fem)
    Times_phi.append(time_phi)
    L2_error_fno.append(l2_error_fno)
    Times_fno.append(time_fno)

    L2_error_std_fem.append(l2_error_std_fem)
    Times_std.append(time_std)

    if save_list_error:
        np.save(
            "./data_compare_methods/L2_error_phi_fem.npy",
            np.array([L2_error_phi_fem]),
        )
        np.save(
            "./data_compare_methods/L2_error_std_fem.npy",
            np.array([L2_error_std_fem]),
        )
        np.save(
            "./data_compare_methods/L2_error_fno.npy", np.array([L2_error_fno])
        )
        np.save(
            "./data_compare_methods/Times_phi_fem.npy",
            np.array([Times_phi]),
        )
        np.save(
            "./data_compare_methods/Times_std_fem.npy",
            np.array([Times_std]),
        )
        np.save(
            "./data_compare_methods/Times_fno.npy", np.array([Times_fno])
        )

In [None]:
L2_error_phi_fem = np.array([L2_error_phi_fem])
L2_error_std_fem = np.array([L2_error_std_fem])
L2_error_fno = np.array([L2_error_fno])

print(f"{np.shape(L2_error_phi_fem)=}")
print(f"{np.shape(L2_error_std_fem)=}")
print(f"{np.shape(L2_error_fno)=}")

In [None]:
L2_error_phi_fem = np.load("./data_compare_methods/L2_error_phi_fem.npy")
L2_error_std_fem = np.load("./data_compare_methods/L2_error_std_fem.npy")
L2_error_fno = np.load("./data_compare_methods/L2_error_fno.npy")

In [None]:
print(f"{np.shape(L2_error_phi_fem)=}")
print(f"{np.shape(L2_error_std_fem)=}")
print(f"{np.shape(L2_error_fno)=}")

print(f"{(L2_error_phi_fem)=}")
print(f"{(L2_error_std_fem)=}")
print(f"{(L2_error_fno)=}")

In [None]:
error_tab = []
error_tab.append(L2_error_phi_fem)
error_tab.append(L2_error_std_fem)
error_tab.append(L2_error_fno)

abs_str = ["$\phi$-FEM", "Standard FEM", "FNO " + str(epoch) + " epochs"]
errors = np.array(error_tab[:])
print(np.shape(errors))
errors = np.reshape(errors, (np.shape(errors)[0],np.shape(errors)[-1]))
dataframe = pd.DataFrame(errors.transpose(), columns=abs_str)

sns.set(font_scale=1.1)

plt.figure(figsize=(6, 4))
sns.boxplot(data=dataframe, palette="ch:s=.0,rot=0.0,dark=0.5")
plt.yscale("log")
plt.xlabel("Method", fontsize=16)
plt.ylabel("Relative $L^2$ error", fontsize=16)
plt.tight_layout()
if save_figs:
    plt.savefig(f"./images/boxplots_new_data_compare_method.png")
plt.show()

# Compare the total computation times

In [None]:
L2_error_phi_fem_array = np.load("./data_compare_methods/L2_error_phi_fem.npy")
L2_error_std_fem_array = np.load("./data_compare_methods/L2_error_std_fem.npy")
L2_error_fno_array = np.load("./data_compare_methods/L2_error_fno.npy")

mean_errors_L2_phi_fem = np.mean(L2_error_phi_fem_array, axis=1)
mean_errors_L2_std_fem = np.mean(L2_error_std_fem_array, axis=1)
mean_errors_L2_FNO = np.mean(L2_error_fno_array, axis=1)

std_errors_L2_phi_fem = np.std(L2_error_phi_fem_array, axis=1)
std_errors_L2_std_fem = np.std(L2_error_std_fem_array, axis=1)
std_errors_L2_FNO = np.std(L2_error_fno_array, axis=1)

min_errors_L2_phi_fem = np.min(L2_error_phi_fem_array, axis=1)
min_errors_L2_std_fem = np.min(L2_error_std_fem_array, axis=1)
min_errors_L2_FNO = np.min(L2_error_fno_array, axis=1)

max_errors_L2_phi_fem = np.max(L2_error_phi_fem_array, axis=1)
max_errors_L2_std_fem = np.max(L2_error_std_fem_array, axis=1)
max_errors_L2_FNO = np.max(L2_error_fno_array, axis=1)

Total_time_phi_fem_array = np.load('./data_compare_methods/Times_phi_fem.npy')
Total_time_std_fem_array = np.load('./data_compare_methods/Times_std_fem.npy')
Total_time_FNO_array = np.load('./data_compare_methods/Times_fno.npy')


mean_times_L2_phi_fem = np.mean(Total_time_phi_fem_array, axis=1)
mean_times_L2_std_fem = np.mean(Total_time_std_fem_array, axis=1)
mean_times_L2_FNO = np.mean(Total_time_FNO_array, axis=1)

std_times_L2_phi_fem = np.std(Total_time_phi_fem_array, axis=1)
std_times_L2_std_fem = np.std(Total_time_std_fem_array, axis=1)
std_times_L2_FNO = np.std(Total_time_FNO_array, axis=1)

min_times_L2_phi_fem = np.min(Total_time_phi_fem_array, axis=1)
min_times_L2_std_fem = np.min(Total_time_std_fem_array, axis=1)
min_times_L2_FNO = np.min(Total_time_FNO_array, axis=1)

max_times_L2_phi_fem = np.max(Total_time_phi_fem_array, axis=1)
max_times_L2_std_fem = np.max(Total_time_std_fem_array, axis=1)
max_times_L2_FNO = np.max(Total_time_FNO_array, axis=1)

fig, ax = plt.subplots(1, 1, figsize=(6, 4))

ax.plot(
    mean_times_L2_FNO[:],
    mean_errors_L2_FNO[:],
    "x",
    markersize=8,
    label="FNO",
    color=sns.color_palette("tab10")[4],
)
confidence_ellipse(
    Total_time_FNO_array[:],
    L2_error_fno_array[:],
    ax,
    alpha=0.8,
    facecolor=sns.color_palette("pastel")[4],
    edgecolor=sns.color_palette("pastel")[4],
)


ax.plot(
    mean_times_L2_std_fem[:],
    mean_errors_L2_std_fem[:],
    "x",
    markersize=8,
    label="Std FEM",
    color=sns.color_palette("tab10")[3],
)
confidence_ellipse(
    Total_time_std_fem_array[:],
    L2_error_std_fem_array[:],
    ax,
    alpha=0.8,
    facecolor=sns.color_palette("pastel")[3],
    edgecolor=sns.color_palette("pastel")[3],
)

ax.plot(
    mean_times_L2_phi_fem[:],
    mean_errors_L2_phi_fem[:],
    "x",
    markersize=8,
    label=r"$\phi$-FEM",
    color=sns.color_palette("tab10")[0],
)
confidence_ellipse(
    Total_time_phi_fem_array[:],
    L2_error_phi_fem_array[:],
    ax,
    alpha=0.8,
    facecolor=sns.color_palette("pastel")[0],
    edgecolor=sns.color_palette("pastel")[0],
)

ax.legend(fontsize=16, loc="upper right", ncol=2)
ax.set_xlabel("Computation time (s)", fontsize=16)
ax.set_ylabel("Relative $L^2$ error", fontsize=16)
ax.set_yscale("log")
ax.set_xscale("log")
ax.set_ylim(1e-5, 1e-1)
# ax.set_xlim(1e-2, 4e0)
plt.tight_layout()
if save_figs:
    plt.savefig(f"./images/error_time_3_methods.png")
plt.show()