In [None]:
import numpy as np
from training import *
from prepare_data import *
from utils import *
from utilities import *
from utils_compare_methods import * 
from generate_data import StdFemSolver
import time 
import matplotlib.pyplot as plt
import random
import os

seed = 2023
random.seed(seed)
np.random.seed(seed)
import dolfin as df
from utils import *
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pandas as pd
import torch
import gc

torch.manual_seed(seed)

sns.set_theme()
sns.set_context("paper")
sns.set(rc={"xtick.bottom": True, "ytick.left": True})
colors = sns.color_palette("mako").as_hex()
my_cmap = sns.color_palette("viridis", as_cmap=True)

%load_ext autoreload
%autoreload 2

In [None]:
data = DataLoader(False)

training_agent = Agent(data)
model_fno = training_agent.model

model_iphi = training_agent.model_iphi
device = training_agent.device
best_model = torch.load(f"models/best_model.pkl")
model_fno.load_state_dict(best_model["model_state_dict"])
model_iphi.load_state_dict(best_model["model_iphi_state_dict"])
model_fno.eval()
model_iphi.eval()

In [None]:
params = np.load("../main/data_test_compare_methods/params.npy")
print(params.shape)
solver = StdFemSolver(nb_cell=64 - 1, params=params)
U, XY, F, G, Phi, HH = solver.solve_several()

if not os.path.exists("./new_data_test"):
    os.makedirs("./new_data_test")
np.save("./new_data_test/XY_compare_methods.npy", XY)
np.save("./new_data_test/Phi_compare_methods.npy", Phi)
np.save("./new_data_test/F_compare_methods.npy", F)
np.save("./new_data_test/G_compare_methods.npy", G)

In [None]:
XY = np.load("./new_data_test/XY_compare_methods.npy")
Phi = np.load("./new_data_test/Phi_compare_methods.npy")
F = np.load("./new_data_test/F_compare_methods.npy")
G = np.load("./new_data_test/G_compare_methods.npy")
params = np.load("../main/data_test_compare_methods/params.npy")
F = F[:, :, None]
G = G[:, :, None]

In [None]:
xy, f, g = (
    torch.tensor(XY, dtype=torch.float).to(device),
    torch.tensor(F, dtype=torch.float).to(device),
    torch.tensor(G, dtype=torch.float).to(device),
)

x = torch.cat([xy, f, g], dim=-1)
y_pred_1 = model_fno(x[:100, :, :], iphi=model_iphi).cpu().detach().numpy()
y_pred_2 = model_fno(x[100:200, :, :], iphi=model_iphi).cpu().detach().numpy()
y_pred_3 = model_fno(x[200:, :, :], iphi=model_iphi).cpu().detach().numpy()

y_pred = torch.concat(
    [torch.tensor(y_pred_1), torch.tensor(y_pred_2), torch.tensor(y_pred_3)]
).to(device)
print(y_pred.shape)
loss = LpLoss(size_average=False)

plot = False

errors = []
standard_solver = StandardFEMSolver(params=params)
for i in range(len(params)):
    print(f"Params : {i+1}/{len(params)}")
    u_ref, V_ref, dx_ref = standard_solver.solve_one(i, 0.002, reference_fem=True)

    new_points = xy[i]
    x_i = x[i, :, :][None, :, :]
    y_true = np.array([u_ref(point) for point in new_points])
    y_true_torch = torch.tensor(y_true[None, :, None], dtype=torch.float).to(device)
    y_true = y_true[None, :]
    error = loss(y_pred[i, :, :][None, :, :], y_true_torch).item()
    errors.append(error)
    print(error)
    if i %100 == 0 : 
        plot = True 
    else : 
        plot = False 
    if plot:
        pred = y_pred[i].squeeze().detach().cpu().numpy()
        lims = dict(cmap="RdBu_r", vmin=y_true.min(), vmax=y_true.max())

        plt.figure(figsize=(15, 5))
        plt.subplot(1, 3, 1)
        plt.scatter(
            x_i[0, :, 0].cpu().detach().numpy(),
            x_i[0, :, 1].cpu().detach().numpy(),
            c=y_true,
            edgecolor="w",
            **lims,
        )
        plt.xlim(0, 1)
        plt.ylim(0, 1)
        plt.gca().set_aspect("equal")
        plt.colorbar(shrink=0.6)
        plt.title("Truth")
        plt.subplot(1, 3, 2)
        plt.scatter(
            x_i[0, :, 0].cpu().detach().numpy(),
            x_i[0, :, 1].cpu().detach().numpy(),
            c=pred,
            edgecolor="w",
            **lims,
        )
        plt.title("Pred")
        plt.xlim(0, 1)
        plt.ylim(0, 1)
        plt.gca().set_aspect("equal")
        plt.colorbar(shrink=0.6)
        plt.subplot(1, 3, 3)
        plt.scatter(
            x_i[0, :, 0].cpu().detach().numpy(),
            x_i[0, :, 1].cpu().detach().numpy(),
            c=y_true - pred,
            edgecolor="w",
            cmap="RdBu_r",
        )
        plt.title("Diff")
        plt.xlim(0, 1)
        plt.ylim(0, 1)
        plt.gca().set_aspect("equal")
        plt.colorbar(shrink=0.6)
        plt.show()
    np.save("./errors_geo_fno.npy", np.array(errors))

In [None]:
np.save("./errors_geo_fno.npy", np.array(errors))

In [None]:
models_repo_phi_fem = "../main/models_H2"
models_repo_phi_fem_fno = "../main/compare_losses"
L2_error_phi_fem = np.load(f"{models_repo_phi_fem}/L2_error_phi_fem.npy")
L2_error_std_fem = np.load(f"{models_repo_phi_fem}/L2_error_std_fem.npy")
L2_error_fno = np.load(f"{models_repo_phi_fem_fno}/L2_error_fno_H2.npy")
L2_error_fno_std = np.load(f"../main_standard_fem/models_H2/L2_error_fno_std.npy")
L2_error_geo_fno = np.load("./errors_geo_fno.npy")

print(L2_error_phi_fem.shape)
print(L2_error_std_fem.shape)
print(L2_error_fno.shape)
print(L2_error_fno_std.shape)
L2_error_geo_fno = np.reshape(L2_error_geo_fno, (1, 300))
print(L2_error_geo_fno.shape)

In [None]:
error_tab = []
error_tab.append(L2_error_phi_fem)
error_tab.append(L2_error_std_fem)
error_tab.append(L2_error_fno)
error_tab.append(L2_error_fno_std)
error_tab.append(L2_error_geo_fno)

abs_str = [r"$\phi$-FEM", "Std FEM", r"$\phi$-FEM-FNO", "Std-FEM-FNO", "Geo-FNO"]
errors = np.array(error_tab[:])
errors = np.reshape(errors, (5, np.shape(errors)[-1]))
print(np.shape(errors))

dataframe = pd.DataFrame(errors.transpose(), columns=abs_str)

sns.set(font_scale=1.0)

palette = sns.cubehelix_palette(n_colors=4, start=0.25, rot=-0.25, gamma=0.5)
palette = palette.as_hex()
palette.insert(-2, "#b22222")
palette = sns.color_palette(palette)


plt.figure(figsize=(6, 4))
sns.boxplot(
    data=dataframe,
    palette=palette,  # "ch:s=.0,rot=0.0,dark=0.5",
    flierprops={"marker": "x", "markerfacecolor": "black"},
)
plt.yscale("log")
plt.xlabel("Method", fontsize=16)
plt.ylabel("Relative $L^2$ error", fontsize=16)
plt.grid(axis="y", visible=True, which="both")

plt.tight_layout()
plt.savefig(f"./images/boxplots_new_data_compare_method_geo_fno.pdf")
plt.show()

In [None]:
params = np.load("../main/data_test_compare_methods/params_convergence.npy")
print(params.shape)
solver = StdFemSolver(nb_cell=64 - 1, params=params, compare_methods=True)
U, XY, F, G, Phi, hh = solver.solve_several()

if not os.path.exists("./data_compare_methods_convergence/"):
    os.makedirs("./data_compare_methods_convergence")
np.save("./data_compare_methods_convergence/XY_compare_methods_convergence.npy", XY)
np.save("./data_compare_methods_convergence/Phi_compare_methods_convergence.npy", Phi)
np.save("./data_compare_methods_convergence/F_compare_methods_convergence.npy", F)
np.save("./data_compare_methods_convergence/G_compare_methods_convergence.npy", G)
np.save("./data_compare_methods_convergence/hmax.npy", hh)

In [None]:
data = DataLoader(False)

XY = np.load("./data_compare_methods_convergence/XY_compare_methods_convergence.npy")
Phi = np.load("./data_compare_methods_convergence/Phi_compare_methods_convergence.npy")
F = np.load("./data_compare_methods_convergence/F_compare_methods_convergence.npy")
G = np.load("./data_compare_methods_convergence/G_compare_methods_convergence.npy")
hh = np.load("./data_compare_methods_convergence/hmax.npy")
params = np.load("../main/data_test_compare_methods/params_convergence.npy")


training_agent = Agent(data)
model_fno = training_agent.model

model_iphi = training_agent.model_iphi
device = training_agent.device
best_model = torch.load(f"models/best_model.pkl")
model_fno.load_state_dict(best_model["model_state_dict"])
model_iphi.load_state_dict(best_model["model_iphi_state_dict"])
model_fno.eval()
model_iphi.eval()
F = F[:, :, None]
G = G[:, :, None]
xy, f, g = (
    torch.tensor(XY, dtype=torch.float).to(device),
    torch.tensor(F, dtype=torch.float).to(device),
    torch.tensor(G, dtype=torch.float).to(device),
)

x = torch.cat([xy, f, g], dim=-1)

y_pred = model_fno(x[:, :, :], iphi=model_iphi)

loss = LpLoss(size_average=False)

errors = []
times = []
standard_solver = StandardFEMSolver(params=params)
for i in range(len(params)):
    print(f"Param : {i}/{len(params)}")
    u_ref, V_ref, dx_ref = standard_solver.solve_one(i, 0.001, reference_fem=True)

    new_points = xy[i]
    y_true = np.array([u_ref(point) for point in new_points])
    y_true_torch = torch.tensor(y_true[None, :, None], dtype=torch.float).to(device)
    y_true = y_true[None, :]
    start = time.time()
    y_pred_i = model_fno(x[i, :, :][None, :, :], iphi=model_iphi)
    end = time.time()
    error = loss(y_pred_i[0, :, :][None, :, :], y_true_torch).item()
    errors.append(error)
    times.append(end - start)
    np.save("./errors_geo_fno_convergence.npy", np.array(errors))
    np.save("./times_geo_fno_convergence.npy", np.array(times))

In [None]:
sizes_phi_fem = [16, 32, 64, 128, 256]
sizes_std = []
for size in sizes_phi_fem:
    mesh_tmp = df.UnitSquareMesh(size - 1, size - 1)
    sizes_std.append(mesh_tmp.hmax())
print(
    sizes_std
)  # [0.09428090415820647, 0.045619792334616084, 0.02244783432338254, 0.01113553986120561, 0.005545935538718157]

sizes_phi_fem_h = []
for size in sizes_phi_fem:
    mesh_macro = df.UnitSquareMesh(size - 1, size - 1)
    h_macro = mesh_macro.hmax()
    sizes_phi_fem_h.append(h_macro)

size_fno = df.UnitSquareMesh(64 - 1, 64 - 1).hmax()


L2_error_phi_fem_array = np.load("../main/compare_methods/L2_error_phi_fem_array.npy")
L2_error_std_fem_array = np.load("../main/compare_methods/L2_error_std_fem_array.npy")
L2_error_fno_array = np.load("../main/compare_methods/L2_error_fno_array.npy")
Cell_selection_phi_array = np.load(
    "../main/compare_methods/Cell_selection_phi_array.npy"
)
Submesh_construction_phi_array = np.load(
    "../main/compare_methods/Submesh_construction_phi_array.npy"
)
Ghost_cell_selection_phi_array = np.load(
    "../main/compare_methods/Ghost_cell_selection_phi_array.npy"
)
Resolution_time_phi_array = np.load(
    "../main/compare_methods/Resolution_time_phi_array.npy"
)
Construction_time_standard_array = np.load(
    "../main/compare_methods/Construction_time_standard_array.npy"
)
Resolution_time_standard_array = np.load(
    "../main/compare_methods/Resolution_time_standard_array.npy"
)
Time_fno_array = np.load("../main/compare_methods/Time_fno_array.npy")


L2_error_fno_std_array = np.load("../main/compare_methods/L2_error_fno_std_array.npy")
Time_fno_std_array = np.load("../main/compare_methods/Time_fno_std_array.npy")

L2_error_geo_fno_array = np.load("./errors_geo_fno_convergence.npy")
times_geo_fno_array = np.load("./times_geo_fno_convergence.npy")
hh = np.load("./data_compare_methods_convergence/hmax.npy")

In [None]:
L2_error_geo_fno_array = np.repeat(L2_error_geo_fno_array, 5, axis=0)
times_geo_fno_array = np.repeat(times_geo_fno_array, 5, axis=0)

In [None]:
L2_error_geo_fno_array = np.reshape(
    L2_error_geo_fno_array, L2_error_fno_std_array.shape
)
times_geo_fno_array = np.reshape(times_geo_fno_array, L2_error_fno_std_array.shape)

In [None]:
mean_errors_L2_phi_fem = np.mean(L2_error_phi_fem_array, axis=0)
mean_errors_L2_std_fem = np.mean(L2_error_std_fem_array, axis=0)
mean_errors_L2_FNO = np.mean(L2_error_fno_array, axis=0)
mean_errors_L2_FNO_std = np.mean(L2_error_fno_std_array, axis=0)
mean_errors_L2_geo_fno = np.mean(L2_error_geo_fno_array, axis=0)

std_errors_L2_phi_fem = np.std(L2_error_phi_fem_array, axis=0)
std_errors_L2_std_fem = np.std(L2_error_std_fem_array, axis=0)
std_errors_L2_FNO = np.std(L2_error_fno_array, axis=0)
std_errors_L2_FNO_std = np.std(L2_error_fno_std_array, axis=0)
std_errors_L2_geo_fno = np.std(L2_error_geo_fno_array, axis=0)

min_errors_L2_phi_fem = np.min(L2_error_phi_fem_array, axis=0)
min_errors_L2_std_fem = np.min(L2_error_std_fem_array, axis=0)
min_errors_L2_FNO = np.min(L2_error_fno_array, axis=0)
min_errors_L2_FNO_std = np.min(L2_error_fno_std_array, axis=0)
min_errors_L2_geo_fno = np.min(L2_error_geo_fno_array, axis=0)

max_errors_L2_phi_fem = np.max(L2_error_phi_fem_array, axis=0)
max_errors_L2_std_fem = np.max(L2_error_std_fem_array, axis=0)
max_errors_L2_FNO = np.max(L2_error_fno_array, axis=0)
max_errors_L2_FNO_std = np.max(L2_error_fno_std_array, axis=0)
max_errors_L2_geo_fno = np.max(L2_error_geo_fno_array, axis=0)

palette = sns.cubehelix_palette(n_colors=4, start=0.25, rot=-0.25, gamma=0.5)
palette = palette.as_hex()
palette.insert(-1, "#b22222")
palette = sns.color_palette(palette)

center = [size_fno, mean_errors_L2_FNO[0]]
center_fno_std = [size_fno, mean_errors_L2_FNO_std[0]]

plt.figure(figsize=(6, 4))
plt.loglog(sizes_std, mean_errors_L2_std_fem, "-+", markersize=8, label="Std-FEM")
ci_error_std_FEM = std_errors_L2_std_fem
plt.fill_between(
    sizes_std,
    mean_errors_L2_std_fem - ci_error_std_FEM,
    mean_errors_L2_std_fem + ci_error_std_FEM,
    alpha=0.2,
)
plt.loglog(
    sizes_phi_fem_h,
    mean_errors_L2_phi_fem,
    "-+",
    markersize=8,
    label=r"$\phi$-FEM",
)
ci_error_phi_fem = std_errors_L2_phi_fem

plt.fill_between(
    sizes_phi_fem_h,
    mean_errors_L2_phi_fem - ci_error_phi_fem,
    mean_errors_L2_phi_fem + ci_error_phi_fem,
    alpha=0.2,
)
plt.plot(
    size_fno,
    mean_errors_L2_FNO[-1],
    "x",
    markersize=8,
    label=r"$\phi$-FEM-FNO",
    color=palette[-2],
)

ci = [std_errors_L2_FNO[-1]]
plt.errorbar(
    size_fno,
    mean_errors_L2_FNO[-1],
    yerr=ci[-1],
    ecolor=palette[-2],
    capsize=7,
    elinewidth=1.5,
    capthick=1.5,
    fmt="none",
    alpha=0.7,
)

plt.plot(
    size_fno,
    mean_errors_L2_FNO_std[-1],
    "x",
    markersize=8,
    label=r"Std-FEM-FNO",
    color=sns.color_palette("tab10")[4],
)

ci = [std_errors_L2_FNO_std[-1]]
plt.errorbar(
    size_fno,
    mean_errors_L2_FNO_std[-1],
    yerr=ci[-1],
    ecolor=sns.color_palette("tab10")[4],
    capsize=7,
    elinewidth=1.5,
    capthick=1.5,
    fmt="none",
    alpha=0.7,
)


plt.plot(
    np.mean(hh),
    mean_errors_L2_geo_fno[-1],
    "x",
    markersize=8,
    label=r"Geo-FNO",
    color=sns.color_palette("tab10")[5],
)

ci = [std_errors_L2_geo_fno[-1]]

plt.errorbar(
    np.mean(hh),
    mean_errors_L2_geo_fno[-1],
    yerr=ci[-1],
    ecolor=sns.color_palette("tab10")[5],
    capsize=7,
    elinewidth=1.5,
    capthick=1.5,
    fmt="none",
    alpha=0.7,
)

plt.legend(fontsize=12, loc="lower right", ncol=1)
plt.xlabel("$h$", fontsize=16)
plt.ylabel("Relative $L^2$ error", fontsize=16)
plt.grid(axis="y", visible=True, which="both")
plt.ylim(1.3e-5, 7e-2)
plt.tight_layout()
plt.savefig(f"./error_size_5_methods.pdf")
plt.show()

In [None]:
Total_time_phi_fem_array = np.zeros((Cell_selection_phi_array.shape))
Total_time_std_fem_array = np.zeros((Construction_time_standard_array.shape))

Total_time_phi_fem_array = (
    Cell_selection_phi_array
    + Submesh_construction_phi_array
    + Ghost_cell_selection_phi_array
    + Resolution_time_phi_array
)
Total_time_std_fem_array = (
    np.sum(Construction_time_standard_array[:, :, :-3], axis=2)
    + Resolution_time_standard_array
)  # removes : total construction time, fenics read time and conversion time

In [None]:
mean_times_L2_phi_fem = np.mean(Total_time_phi_fem_array, axis=0)
mean_times_L2_std_fem = np.mean(Total_time_std_fem_array, axis=0)
mean_times_L2_FNO = np.mean(Time_fno_array, axis=0)
mean_times_L2_FNO_std = np.mean(Time_fno_std_array, axis=0)
mean_times_L2_geo_fno = np.mean(times_geo_fno_array, axis=0)

std_times_L2_phi_fem = np.std(Total_time_phi_fem_array, axis=0)
std_times_L2_std_fem = np.std(Total_time_std_fem_array, axis=0)
std_times_L2_FNO = np.std(Time_fno_array, axis=0)
std_times_L2_FNO_std = np.std(Time_fno_std_array, axis=0)
std_times_L2_geo_fno = np.std(times_geo_fno_array, axis=0)

min_times_L2_phi_fem = np.min(Total_time_phi_fem_array, axis=0)
min_times_L2_std_fem = np.min(Total_time_std_fem_array, axis=0)
min_times_L2_FNO = np.min(Time_fno_array, axis=0)
min_times_L2_FNO_std = np.min(Time_fno_std_array, axis=0)
min_times_L2_geo_fno = np.min(times_geo_fno_array, axis=0)

max_times_L2_phi_fem = np.max(Total_time_phi_fem_array, axis=0)
max_times_L2_std_fem = np.max(Total_time_std_fem_array, axis=0)
max_times_L2_FNO = np.max(Time_fno_array, axis=0)
max_times_L2_FNO_std = np.max(Time_fno_std_array, axis=0)
max_times_L2_geo_fno = np.max(times_geo_fno_array, axis=0)

fig, ax = plt.subplots(1, 1, figsize=(6, 4))
ci_error_std_FEM = std_errors_L2_std_fem
ci_time_std_FEM = std_times_L2_std_fem

ax.loglog(
    mean_times_L2_std_fem,
    mean_errors_L2_std_fem,
    "-+",
    markersize=8,
    label="Std-FEM",
)
ax.fill_between(
    mean_times_L2_std_fem,
    mean_errors_L2_std_fem - ci_error_std_FEM,
    mean_errors_L2_std_fem + ci_error_std_FEM,
    alpha=0.2,
)
ax.errorbar(
    mean_times_L2_std_fem,
    mean_errors_L2_std_fem,
    xerr=ci_time_std_FEM,
    ecolor=sns.color_palette()[0],
    capsize=7,
    elinewidth=1.5,
    capthick=1.5,
    fmt="none",
    alpha=0.7,
)

ax.loglog(
    mean_times_L2_phi_fem,
    mean_errors_L2_phi_fem,
    "-+",
    markersize=8,
    label=r"$\phi$-FEM",
)

ci_error_phi_FEM = std_errors_L2_phi_fem
ci_time_phi_FEM = std_times_L2_phi_fem

ax.fill_between(
    mean_times_L2_phi_fem,
    mean_errors_L2_phi_fem - ci_error_phi_FEM,
    mean_errors_L2_phi_fem + ci_error_phi_FEM,
    alpha=0.2,
)
ax.errorbar(
    mean_times_L2_phi_fem,
    mean_errors_L2_phi_fem,
    xerr=ci_time_phi_FEM,
    ecolor=sns.color_palette()[1],
    capsize=7,
    elinewidth=1.5,
    capthick=1.5,
    fmt="none",
    alpha=0.7,
)

ax.plot(
    mean_times_L2_FNO[-1],
    mean_errors_L2_FNO[-1],
    "x",
    markersize=8,
    label=r"$\phi$-FEM-FNO",
    color=palette[-2],
)
confidence_ellipse(
    Time_fno_array[:, -1],
    L2_error_fno_array[:, -1],
    ax,
    alpha=0.8,
    facecolor=sns.color_palette("pastel")[3],
    edgecolor=sns.color_palette("pastel")[3],
)


ax.plot(
    mean_times_L2_FNO_std[-1],
    mean_errors_L2_FNO_std[-1],
    "x",
    markersize=8,
    label=r"Std-FEM-FNO",
    color=sns.color_palette("tab10")[4],
)
confidence_ellipse(
    Time_fno_std_array[:, -1],
    L2_error_fno_std_array[:, -1],
    ax,
    alpha=0.8,
    facecolor=sns.color_palette("pastel")[4],
    edgecolor=sns.color_palette("pastel")[4],
)

ax.plot(
    mean_times_L2_geo_fno[-1],
    mean_errors_L2_geo_fno[-1],
    "x",
    markersize=8,
    label=r"Geo-FNO",
    color=sns.color_palette("tab10")[5],
)
confidence_ellipse(
    times_geo_fno_array[:, -1],
    L2_error_geo_fno_array[:, -1],
    ax,
    alpha=0.8,
    facecolor=sns.color_palette("pastel")[5],
    edgecolor=sns.color_palette("pastel")[5],
)

ax.legend(fontsize=12, loc="lower left", ncol=2)
ax.set_xlabel("Computation time (s)", fontsize=16)
ax.set_ylabel("Relative $L^2$ error", fontsize=16)
plt.grid(axis="y", visible=True, which="both")
# plt.xlim(1e-3, 4e1)
# plt.ylim(1e-5, 5e-1)
plt.ylim(1.3e-5, 7e-2)

plt.grid(axis="x", visible=True, which="major")

plt.tight_layout()
plt.savefig(f"./error_time_5_methods.pdf")
plt.show()