In [None]:
import numpy as np
import matplotlib.pyplot as plt
import random
import os
import torch 

seed = 2023
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
import dolfin as df
import time
from utils import *
from utils_training import *
import prepare_data
from utils_compare_methods import *
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
from prepare_data import rotate, outside_ball
import pandas as pd
import gc

sns.set_theme()
sns.set_context("paper")
colors = sns.color_palette("mako").as_hex()
my_cmap = sns.color_palette("viridis", as_cmap=True)

%load_ext autoreload
%autoreload 2


In [None]:
save_figs = True
small_data = False
data = DataLoader(small_data)
agent = Agent(
    data,
    level=0,
    relative=True,
    squared=False,
    initial_lr=5e-3,
    n_modes=10,
    width=20,
    batch_size=32,
    pad_prop=0.05,
    pad_mode="reflect",
    l2_lambda=1e-3,
)

model = agent.model
device = agent.device

models_repo = "./models_l2"
images_repo = "./images"
best_model = torch.load(f"{models_repo}/best_model.pkl")
model.load_state_dict(best_model["model_state_dict"])
model.eval()

print(f"Best epoch = {best_model['epoch']}")
if not (os.path.exists(f"{images_repo}/")) and save_figs:
    os.makedirs(f"{images_repo}/")

# Compare the errors of $\phi$-FEM, standard FEM and FNO with respect to a reference fine standard FEM solution

In [None]:
def compare_std_phi_fem_and_fno(param, Plot=False):
    standard_solver = StandardFEMSolver(params=param)

    mu0, mu1, sigma_x, sigma_y, amplitude, x_0, y_0, lx, ly, theta, alpha, beta = param[
        0
    ]

    u_ref, V_ref, dx_ref = standard_solver.solve_one(0, 0.002, reference_fem=True)

    F = generate_F_numpy(mu0, mu1, sigma_x, sigma_y, amplitude, 64)
    G = generate_G_numpy(alpha, beta, 64)
    phi = generate_phi_numpy(x_0, y_0, lx, ly, theta, 64)
    X = generate_manual_new_data_numpy(F, phi, G).to(device)
    x_normed = data.x_normalizer.encode(X)
    start_call = time.time()
    Y_normed = model(x_normed)
    end_call = time.time()
    time_fno = end_call - start_call
    Y = data.y_normalizer.decode(Y_normed)
    predicted_solution = (Y[:, 0, :, :]).cpu().detach().numpy()
    predicted_solution = np.reshape(predicted_solution, (64, 64))
    predicted_sol_fenics = convert_numpy_matrix_to_fenics(predicted_solution, 64, 1)
    predicted_sol_fenics_proj_V_ref = df.project(
        predicted_sol_fenics,
        V_ref,
        solver_type="gmres",
        preconditioner_type="hypre_amg",
    )

    l2_error_fno = (
        df.assemble((((u_ref - predicted_sol_fenics_proj_V_ref)) ** 2) * dx_ref)
        ** (0.5)
    ) / (df.assemble((((u_ref)) ** 2) * dx_ref) ** (0.5))

    if Plot:
        plt.figure(figsize=(8, 4))

        ax1 = plt.subplot(121)
        img = df.plot(u_ref, mode="color", cmap=my_cmap)
        divider = make_axes_locatable(ax1)
        cax = divider.append_axes("bottom", size="5%", pad=0.3)
        ax1.grid(False)
        ax1.set_xlim(0, 1)
        ax1.set_ylim(0, 1)
        ax1.set_title("Reference solution", fontsize=15)
        plt.colorbar(img, cax=cax, orientation="horizontal")

        ax2 = plt.subplot(122)
        img = df.plot(predicted_sol_fenics_proj_V_ref, mode="color", cmap=my_cmap)
        divider = make_axes_locatable(ax2)
        cax = divider.append_axes("bottom", size="5%", pad=0.3)
        ax2.grid(False)
        ax2.set_xlim(0, 1)
        ax2.set_ylim(0, 1)
        ax2.set_title(
            "Predicted solution \n$L^2$ relative error : " + f"{l2_error_fno:.5f}",
            fontsize=15,
        )
        plt.colorbar(img, cax=cax, orientation="horizontal")
        if Plot and save_figs:
            plt.savefig(f"{images_repo}/example_output_FEMs_FNO.png")
        plt.show()

    print(f"{l2_error_fno=}")

    return l2_error_fno, time_fno

In [None]:
F, phi, G, params = create_FG_numpy(10, 64)
# params = np.array(
#     [[0.5, 0.4, 0.3, 0.25, 22.0, 0.51, 0.4, 0.32, 0.25, np.pi / 3.0, 0.60, -0.4]]
# )
i = 0

for i in range(len(params)):
    (
        l2_error_fno,
        time_fno,
    ) = compare_std_phi_fem_and_fno(np.array([params[i]]), True)

In [None]:
save_list_error = True
L2_error_fno = []
Time_fno = []

params = np.load("../main/data_test_compare_methods/params.npy")

indices = list(range(0, len(params)))
for index in indices:
    print(f"Iter : {index+1}/{len(params)}")
    (
        l2_error_fno,
        time_fno,
    ) = compare_std_phi_fem_and_fno(np.array([params[index]]), Plot=False)

    L2_error_fno.append(l2_error_fno)
    Time_fno.append(time_fno)

    if save_list_error:
        np.save(f"{models_repo}/L2_error_fno_std.npy", np.array([L2_error_fno]))

In [None]:
models_repo_phi_fem = "../main/models_H2"
models_repo_phi_fem_fno = "../main/compare_losses"
L2_error_phi_fem = np.load(f"{models_repo_phi_fem}/L2_error_phi_fem.npy")
L2_error_std_fem = np.load(f"{models_repo_phi_fem}/L2_error_std_fem.npy")
L2_error_fno = np.load(f"{models_repo}/L2_error_fno.npy")
L2_error_fno_std = np.load(f"{models_repo}/L2_error_fno_std.npy")

In [None]:
error_tab = []
error_tab.append(L2_error_phi_fem)
error_tab.append(L2_error_std_fem)
error_tab.append(L2_error_fno)
error_tab.append(L2_error_fno_std)

abs_str = [r"$\phi$-FEM", "Standard FEM", r"$\phi$-FEM-FNO", "Std-FEM-FNO"]
errors = np.array(error_tab[:])
errors = np.reshape(errors, (4, np.shape(errors)[-1]))
print(np.shape(errors))
dataframe = pd.DataFrame(errors.transpose(), columns=abs_str)

sns.set(font_scale=1.1)

plt.figure(figsize=(6, 4))
sns.boxplot(
    data=dataframe,
    palette="ch:s=.0,rot=0.0,dark=0.5",
    flierprops={"marker": "x", "markerfacecolor": "black"},
)
plt.yscale("log")
plt.xlabel("Method", fontsize=16)
plt.ylabel("Relative $L^2$ error", fontsize=16)
plt.grid(axis="y", visible=True, which="both")

plt.tight_layout()
if save_figs:
    plt.savefig(f"{images_repo}/boxplots_new_data_compare_method.png")
plt.show()