In [None]:
import numpy as np
import matplotlib.pyplot as plt
import random
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf

seed = 2023
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
tf.experimental.numpy.random.seed(seed)
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
os.environ["TF_DETERMINISTIC_OPS"] = "1"
# Set a fixed value for the hash seed
os.environ["PYTHONHASHSEED"] = str(seed)
import dolfin as df
import time
from utils import *
from utils_training import *
import prepare_data
from utils_compare_methods import *
import seaborn as sns
import domain_generator as dga
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pandas as pd

sns.set_theme()
sns.set_context("paper")
colors = sns.color_palette("mako").as_hex()
my_cmap = sns.color_palette("viridis", as_cmap=True)

In [None]:
nb_data_origin = 2811
nb_data_used = 1500
save_figs = True

small_data = False  
if not (small_data):
    nb_data_used = nb_data_origin
level = 2
data = DataLoader(small_data)
agent = Agent(data, small_data)

if not (os.path.exists(f"./images_{nb_data_used}/")) and save_figs:
    os.makedirs(f"./images_{nb_data_used}/")

epochs = [100, 200, 500, 750, 1000, 1250, 1500, 2000]
print(len(epochs))
print(epochs)

indices = list(range(0, len(epochs)))
size_per_fig = 4
nb_rows = 2
size_row = int(len(indices) / nb_rows)

In [None]:
nb_epochs = 2000
agent.model.load_weights(
    f"./models_{nb_data_used}/model_{nb_epochs}/model_weights"
)

In [None]:
def create_params(domains):
    # method to create params for f respecting the constraint of create_FG_numpy() on the domain
    nb_dofs = np.shape(domains)[-1]
    xy = np.linspace(0.0, 1.0, nb_dofs)
    XX, YY = np.meshgrid(xy, xy)
    XX = np.reshape(XX, [-1])
    YY = np.reshape(YY, [-1])
    XXYY = np.stack([XX, YY])
    nb_data = np.shape(domains)[0]
    mu0 = np.random.uniform(0.2, 0.8, size=[nb_data, 1])
    mu1 = np.random.uniform(0.2, 0.8, size=[nb_data, 1])
    sigma = np.random.uniform(0.1, 0.5, size=[nb_data, 1])

    F = call_F(np, XXYY, mu0, mu1, sigma)
    F = np.reshape(F, [nb_data, nb_dofs, nb_dofs])

    for i in range(np.shape(domains)[0]):
        domain = domains[i, :, :]
        f = F[i]
        new_gen = 0
        while np.max(f * domain) < 80.0:
            __mu0 = np.random.uniform(0.2, 0.8, size=[1, 1])[0]
            __mu1 = np.random.uniform(0.2, 0.8, size=[1, 1])[0]
            __sigma = np.random.uniform(0.1, 0.5, size=[1, 1])[0]
            mu0[i][0] = __mu0
            mu1[i][0] = __mu1
            sigma[i][0] = __sigma
            f = call_F(np, XXYY, __mu0, __mu1, __sigma)
            f = np.reshape(f, [nb_dofs, nb_dofs])
            F[i] = f
            new_gen += 1
            print(f"{new_gen=}")

    xy = np.linspace(0.0, 1.0, nb_dofs)
    XX, YY = np.meshgrid(xy, xy)
    XX = np.reshape(XX, [-1])
    YY = np.reshape(YY, [-1])
    XXYY = np.stack([XX, YY])

    F = call_F(np, XXYY, mu0, mu1, sigma)
    F = np.reshape(F, [nb_data, nb_dofs, nb_dofs])

    params = np.concatenate([mu0, mu1, sigma], axis=1)
    return params


def create_multiple_params_unique_domain(domain, nb_data):
    nb_dofs = np.shape(domain)[-1]
    xy = np.linspace(0.0, 1.0, nb_dofs)
    XX, YY = np.meshgrid(xy, xy)
    XX = np.reshape(XX, [-1])
    YY = np.reshape(YY, [-1])
    XXYY = np.stack([XX, YY])
    mu0 = np.random.uniform(0.2, 0.8, size=[nb_data, 1])
    mu1 = np.random.uniform(0.2, 0.8, size=[nb_data, 1])
    sigma = np.random.uniform(0.1, 0.5, size=[nb_data, 1])

    F = call_F(np, XXYY, mu0, mu1, sigma)
    F = np.reshape(F, [nb_data, nb_dofs, nb_dofs])

    for i in range(nb_data):
        f = F[i]
        new_gen = 0
        while np.max(f * domain) < 80.0:
            __mu0 = np.random.uniform(0.2, 0.8, size=[1, 1])[0]
            __mu1 = np.random.uniform(0.2, 0.8, size=[1, 1])[0]
            __sigma = np.random.uniform(0.1, 0.5, size=[1, 1])[0]
            mu0[i][0] = __mu0
            mu1[i][0] = __mu1
            sigma[i][0] = __sigma
            f = call_F(np, XXYY, __mu0, __mu1, __sigma)
            f = np.reshape(f, [nb_dofs, nb_dofs])
            F[i] = f
            new_gen += 1
            print(f"{new_gen=}")

    xy = np.linspace(0.0, 1.0, nb_dofs)
    XX, YY = np.meshgrid(xy, xy)
    XX = np.reshape(XX, [-1])
    YY = np.reshape(YY, [-1])
    XXYY = np.stack([XX, YY])

    F = call_F(np, XXYY, mu0, mu1, sigma)
    F = np.reshape(F, [nb_data, nb_dofs, nb_dofs])

    params = np.concatenate([mu0, mu1, sigma], axis=1)
    return params


def compare_std_phi_fem_and_fno(param, phi, domain, Plot=False):
    solveur_standard = StandardFEMSolver(params=param, phi_vector=phi)
    solver = PhiFemSolver_error(nb_cell=64 - 1, params=param, phi_vector=phi)

    L2_error_fno, L2_error_phi_fem, L2_error_std_fem = [], [], []
    Temps_phi, Temps_std, Temps_fno = [], [], []
    for i in range(len(param)):
        u_ex, V_ex, dx_ex = solveur_standard.solve_one(i, None, True)

        u_phi_fem, V_phi_fem, temps_phi, phi_64, _h_phi = solver.solve_one(i)
        u_phi_fem_proj = df.project(
            u_phi_fem,
            V_ex,
            solver_type="gmres",
            preconditioner_type="hypre_amg",
        )
        l2_error_phi_fem = (
            df.assemble((((u_ex - u_phi_fem_proj)) ** 2) * dx_ex) ** (0.5)
        ) / (df.assemble((((u_ex)) ** 2) * dx_ex) ** (0.5))

        L2_error_phi_fem.append(l2_error_phi_fem)
        Temps_phi.append(temps_phi)
        u_std, temps_std, _h_std = solveur_standard.solve_one(i, 64, True)
        u_std_fem_proj = df.project(
            u_std, V_ex, solver_type="gmres", preconditioner_type="hypre_amg"
        )
        l2_error_std_fem = (
            df.assemble((((u_ex - u_std_fem_proj)) ** 2) * dx_ex) ** (0.5)
        ) / (df.assemble((((u_ex)) ** 2) * dx_ex) ** (0.5))
        L2_error_std_fem.append(l2_error_std_fem)
        Temps_std.append(temps_std)

        mu0, mu1, sigma = param[i]

        F = generate_F_numpy(mu0, mu1, sigma, 64) / data.max_norm_F
        X = generate_manual_new_data_numpy(phi_64, F)
        start_call = time.time()
        Y = agent.model.call(X)
        end_call = time.time()
        temps_fno = end_call - start_call
        solution_predite = X[:, :, :, 1] * Y[:, :, :, 0]  # * X[:, :, :, -2]
        solution_predite = np.reshape(solution_predite, (64, 64))
        sol_predite_fenics = convert_numpy_matrix_to_fenics(
            solution_predite, 64, 1
        )
        sol_predite_fenics_proj_V_ex = df.project(
            sol_predite_fenics,
            V_ex,
            solver_type="gmres",
            preconditioner_type="hypre_amg",
        )

        l2_error_fno = (
            df.assemble((((u_ex - sol_predite_fenics_proj_V_ex)) ** 2) * dx_ex)
            ** (0.5)
        ) / (df.assemble((((u_ex)) ** 2) * dx_ex) ** (0.5))
        L2_error_fno.append(l2_error_fno)
        Temps_fno.append(temps_fno)
        if Plot:
            plt.figure(figsize=(20, 5.5))
            plt.subplot(1, 4, 1)
            p = df.plot(u_ex, mode="color", vmin=0.0, cmap=my_cmap)
            plt.colorbar(p)
            plt.title("Exact solution", fontsize="14")
            plt.subplot(1, 4, 2)
            p = df.plot(u_phi_fem_proj, mode="color", vmin=0.0, cmap=my_cmap)
            plt.colorbar(p)
            plt.title(
                "$\phi$-FEM solution \n$L^2$ relative error : "
                + f"{l2_error_phi_fem:.5f}",
                fontsize="14",
            )
            plt.subplot(1, 4, 3)
            p = df.plot(u_std_fem_proj, mode="color", vmin=0.0, cmap=my_cmap)
            plt.colorbar(p)
            plt.title(
                "Standard FEM solution \n$L^2$ relative error : "
                + f"{l2_error_std_fem:.5f}",
                fontsize="14",
            )
            plt.subplot(1, 4, 4)
            p = df.plot(
                sol_predite_fenics_proj_V_ex,
                mode="color",
                vmin=0.0,
                cmap=my_cmap,
            )
            plt.colorbar(p)
            plt.title(
                "Predicted solution \n$L^2$ relative error : "
                + f"{l2_error_fno:.5f}",
                fontsize="14",
            )
            plt.tight_layout()
            if Plot and save_figs:
                plt.savefig(
                    f"./images_{nb_data_used}/example_output_FEMs_FNO.png"
                )
            plt.show()
    return (
        L2_error_phi_fem,
        L2_error_std_fem,
        L2_error_fno,
        Temps_phi,
        Temps_std,
        Temps_fno,
    )

In [None]:
# (mu0, mu1, sigma, x_0, y_0, lx, ly, theta) = [0.33092311 0.61749416 0.19945455 0.56279884 0.48285894 0.25675309
#  0.37431148 0.5453267 ]
save_figs = True
phi = np.array(
    [
        np.load(
            "../shape_generation/compare_methods/test_level_set_10_1023.npy"
        )[5]
    ]
)
domains = np.array(
    [
        np.load(
            "../shape_generation/compare_methods/original_domains_10_1023.npy"
        )[5]
    ]
)

# params = create_params(domains=domains)
# print(params)
params = create_multiple_params_unique_domain(domains[0], 10)
print(f"{params=}")

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(phi[0, :, :], cmap=my_cmap, origin="lower")
plt.subplot(1, 2, 2)
plt.imshow(domains[0, :, :], cmap=my_cmap, origin="lower")
plt.tight_layout()
plt.show()

(
    l2_error_phi_fem,
    l2_error_std_fem,
    l2_error_fno,
    temps_phi,
    temps_std,
    temps_fno,
) = compare_std_phi_fem_and_fno(params, phi[0], domains[0], True)
print(f"{l2_error_fno=}")

In [None]:
def compute_errors_multiple_epochs(epochs, params, phi, domains):
    u_exs, V_exs, dx_exs = [], [], []
    L2_error_phi_fem, L2_error_std_fem = [], []
    Temps_phi, Temps_std = [], []

    Phi_64 = []
    for i in range(len(params)):
        solveur_standard = StandardFEMSolver(
            params=[params[i]], phi_vector=phi[i]
        )
        solver = PhiFemSolver_error(
            nb_cell=64 - 1, params=[params[i]], phi_vector=phi[i]
        )
        u_ex, V_ex, dx_ex = solveur_standard.solve_one(0, None, True)
        u_exs.append(u_ex)
        V_exs.append(V_ex)
        dx_exs.append(dx_ex)

        u_phi_fem, V_phi_fem, temps_phi, phi_64, _h_phi = solver.solve_one(0)
        u_phi_fem_proj = df.project(
            u_phi_fem,
            V_ex,
            solver_type="gmres",
            preconditioner_type="hypre_amg",
        )
        l2_error_phi_fem = (
            df.assemble((((u_ex - u_phi_fem_proj)) ** 2) * dx_ex) ** (0.5)
        ) / (df.assemble((((u_ex)) ** 2) * dx_ex) ** (0.5))

        u_std, temps_std, _h_std = solveur_standard.solve_one(0, 64, True)
        u_std_fem_proj = df.project(
            u_std, V_ex, solver_type="gmres", preconditioner_type="hypre_amg"
        )
        l2_error_std_fem = (
            df.assemble((((u_ex - u_std_fem_proj)) ** 2) * dx_ex) ** (0.5)
        ) / (df.assemble((((u_ex)) ** 2) * dx_ex) ** (0.5))

        L2_error_phi_fem.append(l2_error_phi_fem)
        L2_error_std_fem.append(l2_error_std_fem)
        Temps_phi.append(temps_phi)
        Temps_std.append(temps_std)
        Phi_64.append(phi_64)

    Phi_64 = np.array(Phi_64)
    errors_fno, times_fno = [], []

    for j in epochs:
        print(f"Epoch : {j}")
        agent.model.load_weights(
            f"./models_{nb_data_used}/model_{j}/model_weights"
        )
        L2_error_fno = []
        Temps_fno = []
        for i in range(len(params)):
            mu0, mu1, sigma = params[i]
            phi_i = Phi_64[i]
            F = generate_F_numpy(mu0, mu1, sigma, 64) / data.max_norm_F
            X = generate_manual_new_data_numpy(phi_i, F)
            start_call = time.time()
            Y = agent.model.call(X)
            end_call = time.time()
            temps_fno = end_call - start_call
            solution_predite = (
                X[:, :, :, 1] * Y[:, :, :, 0]
            )  # * X[:, :, :, -2]
            solution_predite = np.reshape(solution_predite, (64, 64))
            sol_predite_fenics = convert_numpy_matrix_to_fenics(
                solution_predite, 64, 1
            )
            sol_predite_fenics_proj_V_ex = df.project(
                sol_predite_fenics,
                V_exs[i],
                solver_type="gmres",
                preconditioner_type="hypre_amg",
            )

            l2_error_fno = (
                df.assemble(
                    (((u_exs[i] - sol_predite_fenics_proj_V_ex)) ** 2)
                    * dx_exs[i]
                )
                ** (0.5)
            ) / (df.assemble((((u_exs[i])) ** 2) * dx_exs[i]) ** (0.5))
            L2_error_fno.append(l2_error_fno)
            Temps_fno.append(end_call - start_call)
        errors_fno.append(L2_error_fno)
        times_fno.append(Temps_fno)

    return (
        L2_error_phi_fem,
        L2_error_std_fem,
        errors_fno,
        Temps_phi,
        Temps_std,
        times_fno,
    )

In [None]:
nb_new_data = 100
size_fine_exact_mesh = 127
# tmp_phi = dga.create_level_set(
#     127,
#     127,
#     n_mode=3,
#     batch_size=nb_new_data,
#     threshold=0.4,
#     seed=seed,
#     save=True,
# )

In [None]:
phi = np.load(f"./absolute/test_level_set_{nb_new_data}_127_3.npy")
domains = np.load(f"./absolute/original_domains_{nb_new_data}_127_3.npy")
params = create_params(domains=domains)
# print(params)
params = create_params(domains)

plt.figure()
plt.imshow(domains[0], origin="lower", cmap=my_cmap)
plt.show()
print(params)

In [None]:
(
    L2_error_phi_fem,
    L2_error_std_fem,
    errors_fno,
    Temps_phi,
    Temps_std,
    times_fno,
) = compute_errors_multiple_epochs(epochs, params, phi, domains)

In [None]:
plt.figure()
plt.semilogy(L2_error_phi_fem, label="phiFEM")
plt.semilogy(L2_error_std_fem, label="StdFEM")
for i in range(len(indices)):
    plt.semilogy(
        errors_fno[indices[i]], label=f"FNO {epochs[indices[i]]} epochs"
    )

plt.legend(ncol=2)
plt.savefig(
    f"./images_{nb_data_used}/plot_pas_joli_error_methods_new_data.png"
)
plt.show()

In [None]:
save_figs = True
colors = sns.color_palette("mako").as_hex()
j, k = 0, 0
fig, axes = plt.subplots(
    nb_rows,
    size_row,
    figsize=(size_per_fig * size_row, size_per_fig * nb_rows),
)

for i in range(len(indices)):
    if k == size_row:
        k = 0
        j += 1
    sns.histplot(
        data=errors_fno[indices[i]],
        kde=True,
        bins=20,
        color=colors[2],
        edgecolor="k",
        log_scale=True,
        label="FNO " + str(epochs[indices[i]]) + " epochs",
        stat="proportion",
        legend=True,
        ax=axes[j, k],
    )
    sns.histplot(
        data=L2_error_phi_fem,
        kde=True,
        bins=10,
        color=colors[3],
        edgecolor="k",
        log_scale=True,
        label="$\phi$-FEM",
        stat="proportion",
        legend=True,
        ax=axes[j, k],
    )
    sns.histplot(
        data=L2_error_std_fem,
        kde=True,
        bins=10,
        color=colors[1],
        edgecolor="k",
        log_scale=True,
        label="Standard FEM",
        stat="proportion",
        legend=True,
        ax=axes[j, k],
    )
    # axes[j, k].set_xlim(1e-4, 1e0)
    # axes[j, k].set_ylim(0.0, 0.30)
    axes[j, k].set_xlabel("$L^2$ relative error")
    axes[j, k].legend()
    k += 1

plt.tight_layout()
if save_figs:
    plt.savefig(
        f"./images_{nb_data_used}/histograms_new_data_compare_methods_L2.png"
    )
plt.show()

In [None]:
error_tab = []
error_tab.append(L2_error_phi_fem)
error_tab.append(L2_error_std_fem)

plot_errors_fno = [0, 1, 2, 4, 5, 6, 7]
new_fno_errors = [errors_fno[i] for i in plot_errors_fno]
plot_epochs = [epochs[i] for i in plot_errors_fno]
error_tab += new_fno_errors

abs_str = ["PhiFEM", "StdFEM"]
abs_str += [str(epochs[i]) for i in plot_errors_fno]
errors = np.array(error_tab[:])
print(np.shape(errors))
dataframe = pd.DataFrame(errors.transpose(), columns=abs_str)

plt.figure(figsize=(10, 5))

sns.boxplot(data=dataframe, palette="ch:s=.25,rot=-.25")
plt.yscale("log")
plt.xlabel("epoch")
plt.ylabel("error")
if save_figs:
    plt.savefig(
        f"./images_{nb_data_used}/boxplots_new_data_compare_method.png"
    )
plt.show()

In [None]:
ratio_time_error_fno = []

for i in range(len(times_fno)):
    tmp = [
        times_fno[i][j] / errors_fno[i][j] for j in range(len(times_fno[i]))
    ]
    ratio_time_error_fno.append(tmp)

ratio_time_error_phi_fem = [
    Temps_phi[j] / L2_error_phi_fem[j] for j in range(len(Temps_phi))
]
ratio_time_error_std_fem = [
    Temps_std[j] / L2_error_std_fem[j] for j in range(len(Temps_std))
]


ratio_error_time_fno = []
for i in range(len(errors_fno)):
    tmp = [
        errors_fno[i][j] / times_fno[i][j] for j in range(len(times_fno[i]))
    ]
    ratio_error_time_fno.append(tmp)
ratio_error_time_phi_fem = [
    L2_error_phi_fem[j] / Temps_phi[j] for j in range(len(Temps_phi))
]
ratio_error_time_std_fem = [
    L2_error_std_fem[j] / Temps_std[j] for j in range(len(Temps_std))
]

In [None]:
colors = sns.color_palette("mako").as_hex()
j, k = 0, 0
fig, axes = plt.subplots(
    nb_rows,
    size_row,
    figsize=(size_per_fig * size_row, size_per_fig * nb_rows),
)

for i in range(len(indices)):
    if k == size_row:
        k = 0
        j += 1
    sns.histplot(
        data=ratio_error_time_fno[indices[i]],
        kde=True,
        bins=15,
        color=colors[2],
        edgecolor="k",
        log_scale=True,
        label="FNO " + str(epochs[indices[i]]) + " epochs",
        stat="proportion",
        legend=True,
        ax=axes[j, k],
    )
    sns.histplot(
        data=ratio_error_time_phi_fem,
        kde=True,
        bins=10,
        color=colors[3],
        edgecolor="k",
        log_scale=True,
        label="$\phi$-FEM",
        stat="proportion",
        legend=True,
        ax=axes[j, k],
    )
    sns.histplot(
        data=ratio_error_time_std_fem,
        kde=True,
        bins=10,
        color=colors[1],
        edgecolor="k",
        log_scale=True,
        label="Standard FEM",
        stat="proportion",
        legend=True,
        ax=axes[j, k],
    )
    axes[j, k].set_xlabel("error/time")
    axes[j, k].legend()
    axes[j, k].set_xlim(1e-3, 1e3)
    # axes[j, k].set_ylim(0.0, 0.50)
    k += 1

plt.tight_layout()
if save_figs:
    plt.savefig(
        f"./images_{nb_data_used}/histograms_new_data_compare_methods_error_time.png"
    )
plt.show()

In [None]:
colors = sns.color_palette("mako").as_hex()
j, k = 0, 0
fig, axes = plt.subplots(
    nb_rows,
    size_row,
    figsize=(size_per_fig * size_row, size_per_fig * nb_rows),
)

for i in range(len(indices)):
    if k == size_row:
        k = 0
        j += 1
    sns.histplot(
        data=ratio_time_error_fno[indices[i]],
        kde=True,
        bins=15,
        color=colors[2],
        edgecolor="k",
        log_scale=True,
        label="FNO " + str(epochs[indices[i]]) + " epochs",
        stat="proportion",
        legend=True,
        ax=axes[j, k],
    )
    sns.histplot(
        data=ratio_time_error_phi_fem,
        kde=True,
        bins=10,
        color=colors[3],
        edgecolor="k",
        log_scale=True,
        label="$\phi$-FEM",
        stat="proportion",
        legend=True,
        ax=axes[j, k],
    )
    sns.histplot(
        data=ratio_time_error_std_fem,
        kde=True,
        bins=10,
        color=colors[1],
        edgecolor="k",
        log_scale=True,
        label="Standard FEM",
        stat="proportion",
        legend=True,
        ax=axes[j, k],
    )
    axes[j, k].set_xlabel("time/error")
    axes[j, k].legend()
    # axes[j, k].set_xlim(1e-3, 0.3)
    # axes[j, k].set_ylim(0.0, 0.50)
    k += 1

plt.tight_layout()
if save_figs:
    plt.savefig(
        f"./images_{nb_data_used}/histograms_new_data_compare_methods_time_error.png"
    )
plt.show()

In [None]:
epochs_means = list(i * 50 for i in range(1, 41))

(
    L2_error_phi_fem,
    L2_error_std_fem,
    errors_fno,
    Temps_phi,
    Temps_std,
    times_fno,
) = compute_errors_multiple_epochs(epochs_means, params)

In [None]:
means = np.mean(errors, axis=1)
standard_deviation = np.std(errors, axis=1)
maxs = np.max(errors, axis=1)
mins = np.min(errors, axis=1)

plt.figure(figsize=(8, 3))
plt.subplot(1, 2, 1)
plt.semilogy(epochs_means, means, "-+", label="Mean")
plt.xlabel("Epochs")
plt.ylabel("$L^2$ relative error")
plt.semilogy(
    epochs_means, standard_deviation, "-+", label="Standard deviation"
)
plt.xlabel("Epochs")
plt.ylabel("$L^2$ relative error")
plt.grid(True, "both", "both")
plt.legend()

plt.subplot(1, 2, 2)
plt.semilogy(epochs_means, maxs, "-+", label="Maximum")
plt.semilogy(epochs_means, mins, "-+", label="Minimum")
plt.grid(True, "both", "both")
plt.legend()
plt.xlabel("Epochs")
plt.ylabel("$L^2$ relative error")
plt.tight_layout()


if save_figs:
    plt.savefig(
        f"./images_{nb_data_used}/min_mean_max_error_epochs_new_data_compare_methods.png"
    )
plt.show()