In [1]:
import numpy as np
import matplotlib.pyplot as plt
import random
import os
import torch 
import vtk
import pyvista
pyvista.global_theme.notebook = True
pyvista.set_jupyter_backend('static')
# pyvista.set_jupyter_backend('trame')

seed = 160318
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
import time
from utils import *
from utils_training import *
import prepare_data
from utils_compare_methods import *
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
from prepare_data import rotate, outside_ball, create_parameters
import pandas as pd
import gc
from utils_plot import * 
sns.set_theme("paper", rc={"xtick.bottom": True, "ytick.left": True}, font_scale=1.1)
colors = sns.color_palette("mako").as_hex()
my_cmap = sns.color_palette("viridis", as_cmap=True)

%load_ext autoreload
%autoreload 2


In [None]:
save_figs = True
small_data = False
data = DataLoader(small_data)
agent = Agent(data, level=1)

model = agent.model
device = agent.device

models_repo = "./models"
images_repo = "../images"

best_model = torch.load(f"{models_repo}/best_model.pkl")
model.load_state_dict(best_model["model_state_dict"])
model.eval()

print(f"Best epoch = {best_model['epoch']}")
if not (os.path.exists(f"{images_repo}/")) and save_figs:
    os.makedirs(f"{images_repo}/")

In [3]:
def error_L2(uh, u_ex, degree_raise=3):
    degree = uh.function_space.ufl_element().degree
    family = uh.function_space.ufl_element().family_name
    mesh = uh.function_space.mesh
    W = dolfinx.fem.functionspace(mesh, (family, degree + degree_raise))

    u_W = dolfinx.fem.Function(W)
    u_W.interpolate(uh)
    u_ex_W = dolfinx.fem.Function(W)
    u_ex_W.interpolate(u_ex)

    e_W = dolfinx.fem.Function(W)
    e_W.x.array[:] = u_W.x.array - u_ex_W.x.array

    # Integrate the error
    error = dolfinx.fem.form(ufl.inner(e_W, e_W) * ufl.dx)
    error_local = dolfinx.fem.assemble_scalar(error)
    error_global = mesh.comm.allreduce(error_local, op=MPI.SUM)
    norm = dolfinx.fem.form(ufl.inner(u_ex_W, u_ex_W) * ufl.dx)
    norm_local = dolfinx.fem.assemble_scalar(norm)
    norm_global = mesh.comm.allreduce(norm_local, op=MPI.SUM)
    return np.sqrt(error_global / norm_global)


def non_matching_interpolation(uh, V_to, padding=1e-14):
    u_to = dolfinx.fem.Function(V_to)

    u1_2_u2_nmm_data = dolfinx.fem.create_nonmatching_meshes_interpolation_data(
        u_to.function_space.mesh,
        u_to.function_space.element,
        uh.function_space.mesh,
        padding=padding,
    )
    u_to.interpolate(uh, nmm_interpolation_data=u1_2_u2_nmm_data)
    u_to.x.scatter_forward()

    return u_to


def compute_domain(X):
    loss = Loss()
    domain_tmp = (X[:, 1, :, :] <= 3e-16).to(device)
    neighborhood = loss.neighborhood_6(domain_tmp).to(device)
    domain = ((neighborhood.int() + domain_tmp.int()) != 0).to(device)
    return domain

# Plot PhiFEM StdFEM FNO error time (convergence curves)

In [4]:
def convergence_fems(param, size_std, size_phi_fem, u_ref, V_ref, dx_ref):
    solver = PhiFemSolver_error(nb_cell=size_phi_fem - 1, params=param)
    (
        u_phi_fem_deg_2,
        _h_phi_deg_2,
        _times,
    ) = solver.solve_one(0, force_deg2=True)

    u_phi_fem_deg_2_proj_V_ref = non_matching_interpolation(u_phi_fem_deg_2, V_ref)
    uex = u_ref
    uh = u_phi_fem_deg_2_proj_V_ref
    l2_error_phi_fem_deg_2 = error_L2(uh, uex, degree_raise=0)

    standard_solver = StandardFEMSolver(params=param)

    (
        u_std_deg_2,
        _h_std_deg_2,
        _times,
    ) = standard_solver.solve_one(0, size_std, size_phi_fem, False, force_deg2=True)

    u_std_fem_deg_2_proj_V_ref = non_matching_interpolation(u_std_deg_2, V_ref, 1e-3)
    uex = u_ref
    uh = u_std_fem_deg_2_proj_V_ref
    l2_error_std_fem_deg_2 = error_L2(uh, uex, degree_raise=0)

    return (
        l2_error_phi_fem_deg_2,
        l2_error_std_fem_deg_2,
        _h_std_deg_2,
    )

In [5]:
if not os.path.exists("../../data_convergence_fems/F_convergence.npy"):
    F, phi, G, params = create_parameters(5, 128)
    if not os.path.exists("../../data_convergence_fems/"):
        os.makedirs("../../data_convergence_fems")
    np.save("../../data_convergence_fems/F_convergence.npy", F)
    np.save("../../data_convergence_fems/Phi_convergence.npy", phi)
    np.save("../../data_convergence_fems/G_convergence.npy", G)
    np.save("../../data_convergence_fems/params_convergence.npy", params)

F = np.load("../../data_convergence_fems/F_convergence.npy")
phi = np.load("../../data_convergence_fems/Phi_convergence.npy")
G = np.load("../../data_convergence_fems/G_convergence.npy")
params = np.load("../../data_convergence_fems/params_convergence.npy")

In [None]:
sizes_phi_fem = [16, 32, 64, 128]
sizes_std = []
for size in sizes_phi_fem:
    mesh_macro = dolfinx.mesh.create_rectangle(
        MPI.COMM_WORLD,
        np.array([[0, 0], [1, 1]]),
        np.array([size - 1, size - 1]),
    )
    num_cells = (
        mesh_macro.topology.index_map(mesh_macro.topology.dim).size_local
        + mesh_macro.topology.index_map(mesh_macro.topology.dim).num_ghosts
    )
    sizes_std.append(max(mesh_macro.h(2, np.array(list(range(num_cells))))))
print(
    sizes_std
)  # [0.09428090415820647, 0.045619792334616084, 0.02244783432338254, 0.01113553986120561, 0.005545935538718157]

In [None]:
(
    L2_error_phi_fem_deg_2_array,
    L2_error_std_fem_deg_2_array,
    Sizes_std_meshes_deg_2_array,
) = (
    [],
    [],
    [],
)

index = 0
u_refs, V_refs, dx_refs = [], [], []

for index in range(len(params)):
    standard_solver = StandardFEMSolver(params=np.array([params[index]]))
    u_ref, V_ref, dx_ref, _h_ref = standard_solver.solve_one(
        0, 0.0009, reference_fem=True
    )
    u_refs.append(u_ref)
    V_refs.append(V_ref)
    dx_refs.append(dx_ref)
    print(f"Desired res = 0.0009. {_h_ref=}")
    sizes_phi_fem = [16, 32, 64, 128]
    sizes_std = [
        0.09428090415820647,
        0.045619792334616084,
        0.02244783432338254,
        0.01113553986120561,
    ]

    (
        L2_error_phi_fem_deg_2,
        L2_error_std_fem_deg_2,
        Size_std_mesh_deg_2,
    ) = (
        [],
        [],
        [],
    )

    for i in range(len(sizes_phi_fem)):
        print(f"Param : {index}/{len(params)}    {sizes_phi_fem[i]=}")
        (
            l2_error_phi_fem_deg_2,
            l2_error_std_fem_deg_2,
            size_std_mesh_deg_2,
        ) = convergence_fems(
            np.array([params[index]]),
            sizes_std[i],
            sizes_phi_fem[i],
            u_ref,
            V_ref,
            dx_ref,
        )

        L2_error_phi_fem_deg_2.append(l2_error_phi_fem_deg_2)
        L2_error_std_fem_deg_2.append(l2_error_std_fem_deg_2)
        Size_std_mesh_deg_2.append(size_std_mesh_deg_2)

    L2_error_phi_fem_deg_2_array.append(L2_error_phi_fem_deg_2)
    L2_error_std_fem_deg_2_array.append(L2_error_std_fem_deg_2)
    Sizes_std_meshes_deg_2_array.append(Size_std_mesh_deg_2)

In [12]:
L2_error_phi_fem_deg_2_array = np.array(L2_error_phi_fem_deg_2_array)
L2_error_std_fem_deg_2_array = np.array(L2_error_std_fem_deg_2_array)
Sizes_std_meshes_deg_2_array = np.array(Sizes_std_meshes_deg_2_array)

if not os.path.exists("../convergence_fems/"):
    os.makedirs("../convergence_fems")

np.save(
    "../convergence_fems/L2_error_phi_fem_deg_2_array.npy", L2_error_phi_fem_deg_2_array
)
np.save(
    "../convergence_fems/L2_error_std_fem_deg_2_array.npy", L2_error_std_fem_deg_2_array
)
np.save(
    "../convergence_fems/Sizes_std_meshes_deg_2_array.npy", Sizes_std_meshes_deg_2_array
)

In [4]:
L2_error_phi_fem_deg_2_array = np.load(
    "../convergence_fems/L2_error_phi_fem_deg_2_array.npy"
)
L2_error_std_fem_deg_2_array = np.load(
    "../convergence_fems/L2_error_std_fem_deg_2_array.npy"
)
Sizes_std_meshes_deg_2_array = np.load(
    "../convergence_fems/Sizes_std_meshes_deg_2_array.npy"
)

In [5]:
sizes_phi_fem_h = []
sizes_phi_fem = [16, 32, 64, 128]
for size in sizes_phi_fem:
    mesh_macro = dolfinx.mesh.create_rectangle(
        MPI.COMM_WORLD,
        np.array([[0, 0], [1, 1]]),
        np.array([size - 1, size - 1]),
    )
    num_cells = (
        mesh_macro.topology.index_map(mesh_macro.topology.dim).size_local
        + mesh_macro.topology.index_map(mesh_macro.topology.dim).num_ghosts
    )
    sizes_phi_fem_h.append(max(mesh_macro.h(2, np.array(list(range(num_cells))))))

In [None]:
mean_errors_L2_phi_fem_deg_2 = np.mean(L2_error_phi_fem_deg_2_array, axis=0)
mean_errors_L2_std_fem_deg_2 = np.mean(L2_error_std_fem_deg_2_array, axis=0)

std_errors_L2_phi_fem_deg_2 = np.std(L2_error_phi_fem_deg_2_array, axis=0)
std_errors_L2_std_fem_deg_2 = np.std(L2_error_std_fem_deg_2_array, axis=0)

min_errors_L2_phi_fem_deg_2 = np.min(L2_error_phi_fem_deg_2_array, axis=0)
min_errors_L2_std_fem_deg_2 = np.min(L2_error_std_fem_deg_2_array, axis=0)

max_errors_L2_phi_fem_deg_2 = np.max(L2_error_phi_fem_deg_2_array, axis=0)
max_errors_L2_std_fem_deg_2 = np.max(L2_error_std_fem_deg_2_array, axis=0)

mean_sizes_std_fem_deg_2 = np.mean(Sizes_std_meshes_deg_2_array, axis=0)

palette = sns.color_palette("Paired")

plt.figure(figsize=(6, 6))
plt.loglog(
    mean_sizes_std_fem_deg_2,
    mean_errors_L2_std_fem_deg_2,
    "-2",
    markersize=12,
    color=palette[7], 
    label="Std-FEM",
)
ci_error_std_FEM_deg_2 = std_errors_L2_std_fem_deg_2
plt.fill_between(
    mean_sizes_std_fem_deg_2,
    mean_errors_L2_std_fem_deg_2 - ci_error_std_FEM_deg_2,
    mean_errors_L2_std_fem_deg_2 + ci_error_std_FEM_deg_2,
    alpha=0.2,
    color=palette[7],
)

plt.loglog(
    sizes_phi_fem_h,
    mean_errors_L2_phi_fem_deg_2,
    "-1",
    markersize=12,
    color=palette[1], 
    label=r"$\varphi$-FEM",
)
ci_error_phi_fem_deg_2 = std_errors_L2_phi_fem_deg_2

plt.fill_between(
    sizes_phi_fem_h,
    mean_errors_L2_phi_fem_deg_2 - ci_error_phi_fem_deg_2,
    mean_errors_L2_phi_fem_deg_2 + ci_error_phi_fem_deg_2,
    alpha=0.2,
    color=palette[1],
)

plt.loglog(
    sizes_phi_fem_h,
    [s**2 for s in sizes_phi_fem_h],
    "--",
    label=r"$\mathcal{O}(h^2)$",
)

plt.legend(fontsize=18, loc="lower right", ncol=1)
plt.xlabel(r"$h$", fontsize=20)
plt.ylabel(r"Relative $L^2$ error", fontsize=20)
plt.grid(axis="y", visible=True, which="both")

plt.tight_layout()
plt.savefig(f"./{images_repo}/convergence_fems.pdf")
plt.show()