In [None]:
import numpy as np
import matplotlib.pyplot as plt
import random
import os
import torch 

seed = 2023
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
import dolfin as df
import time
from utils import *
from utils_training import *
import prepare_data
from utils_compare_methods import *
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
from prepare_data import rotate, outside_ball
import pandas as pd
import gc

sns.set_theme()
sns.set_context("paper")
colors = sns.color_palette("mako").as_hex()
my_cmap = sns.color_palette("viridis", as_cmap=True)

%load_ext autoreload
%autoreload 2


In [None]:
save_figs = True
small_data = False
data = DataLoader(small_data)

agent_H2 = Agent(
    data,
    level=2,
    relative=True,
    squared=False,
    initial_lr=5e-3,
    n_modes=10,
    width=20,
    batch_size=32,
    pad_prop=0.05,
    pad_mode="reflect",
    l2_lambda=1e-3,
)
model_H2 = agent_H2.model
device = agent_H2.device

models_repo_H2 = "./models_H2"
images_repo = "./images"
best_model_H2 = torch.load(f"{models_repo_H2}/best_model.pkl")
model_H2.load_state_dict(best_model_H2["model_state_dict"])
model_H2.eval()
print(f"Best epoch H2 = {best_model_H2['epoch']}")

agent_H1 = Agent(
    data,
    level=1,
    relative=True,
    squared=False,
    initial_lr=5e-3,
    n_modes=10,
    width=20,
    batch_size=32,
    pad_prop=0.05,
    pad_mode="reflect",
    l2_lambda=1e-3,
)
model_H1 = agent_H1.model
device = agent_H1.device

models_repo_H1 = "./models_H1"
images_repo = "./images"
best_model_H1 = torch.load(f"{models_repo_H1}/best_model.pkl")
model_H1.load_state_dict(best_model_H1["model_state_dict"])
model_H1.eval()

print(f"Best epoch H1 = {best_model_H1['epoch']}")

agent_L2 = Agent(
    data,
    level=0,
    relative=True,
    squared=False,
    initial_lr=5e-3,
    n_modes=10,
    width=20,
    batch_size=32,
    pad_prop=0.05,
    pad_mode="reflect",
    l2_lambda=1e-3,
)
model_L2 = agent_L2.model
device = agent_L2.device

models_repo_L2 = "./models_l2"
images_repo = "./images"
best_model_L2 = torch.load(f"{models_repo_L2}/best_model.pkl")
model_L2.load_state_dict(best_model_L2["model_state_dict"])
model_L2.eval()

print(f"Best epoch l2 = {best_model_L2['epoch']}")


if not (os.path.exists(f"{images_repo}/")) and save_figs:
    os.makedirs(f"{images_repo}/")

# With respect to a $\phi$-FEM solution

In [None]:
from generate_data import PhiFemSolver
from prepare_data import set_seed

set_seed(16012023)

In [None]:
X_test, Y_test, x_normed, Y_pred, X_denormed = None, None, None, None, None
X_val, Y_val = None, None
Phi, G, domain, U_true, U_pred = None, None, None, None, None
error, magnitude, error_inf, magnitude_inf = None, None, None, None
gc.collect()
torch.cuda.empty_cache()  # PyTorch thing

nb_test_data = 10000
if not os.path.exists(f"./data_test_phi_fem_{nb_test_data}"):
    os.makedirs(f"./data_test_phi_fem_{nb_test_data}")
    F, Phi, G, params = create_FG_numpy(nb_test_data, 64)
    print("Parameters generated")
    solver = PhiFemSolver(nb_cell=64 - 1, params=params)
    W_phi_fem = solver.solve_several()
    X_test = generate_manual_new_data_numpy(F, Phi, G).to(device)

    np.save(f"./data_test_phi_fem_{nb_test_data}/F.npy", F)
    np.save(f"./data_test_phi_fem_{nb_test_data}/agentParams.npy", params)
    np.save(f"./data_test_phi_fem_{nb_test_data}/Phi.npy", Phi)
    np.save(f"./data_test_phi_fem_{nb_test_data}/G.npy", G)
    np.save(f"./data_test_phi_fem_{nb_test_data}/W.npy", W_phi_fem)

In [None]:
X_test, Y_test, x_normed, Y_pred, X_denormed = None, None, None, None, None
Phi, G, domain, U_true, U_pred = None, None, None, None, None
gc.collect()
torch.cuda.empty_cache()  # PyTorch thing

F, Phi, G = (
    np.load(f"./data_test_phi_fem_{nb_test_data}/F.npy"),
    np.load(f"./data_test_phi_fem_{nb_test_data}/Phi.npy"),
    np.load(f"./data_test_phi_fem_{nb_test_data}/G.npy"),
)
W_phi_fem = np.load(f"./data_test_phi_fem_{nb_test_data}/W.npy")

In [None]:
def compute_L2_norm_squared(U, domain):
    nb_vertices = torch.sum(domain, (1, 2), False)
    norm = (1.0 / nb_vertices) * torch.sum(U**2 * domain, (1, 2), False)
    return norm


def compute_Linf_norm(U, domain):
    norm = torch.amax(torch.abs(U * domain), (1, 2))
    return norm

In [None]:
L2_error_H2, L2_error_H1, L2_error_L2 = [], [], []
test_batch_size = 100
nb_test_batch = nb_test_data // test_batch_size
print(f"{test_batch_size=}")
print(f"{nb_test_batch=}")

for j in range(nb_test_batch):
    print(f"Batch : {j} / {nb_test_batch}")
    X_test, Y_test, x_normed, Y_pred, X_denormed = None, None, None, None, None
    domain, U_true, U_pred_H2, U_pred_H1, U_pred_L2 = None, None, None, None, None
    (
        X_test_normed_j,
        Y_pred_normed_j_H2,
        Y_pred_normed_j_H1,
        Y_pred_normed_j_L2,
        Y_pred,
        Phi_batch,
        Y_test_batch,
        G_batch,
    ) = (None, None, None, None, None, None, None, None)
    error, magnitude, error_inf, magnitude_inf = None, None, None, None
    error_H1, error_H2, error_L2 = None, None, None
    model = None  # clear memory
    gc.collect()
    torch.cuda.empty_cache()  # PyTorch thing

    model_H2.eval()
    model_H1.eval()
    model_L2.eval()

    sli = slice(j * test_batch_size, (j + 1) * test_batch_size)
    X_test = generate_manual_new_data_numpy(F[sli], Phi[sli], G[sli]).to(device)
    X_test_normed_j = data.x_normalizer.encode(X_test)
    Phi_batch, G_batch = X_test[:, 1, :, :], X_test[:, 2, :, :]
    domain = (Phi_batch <= 3e-16).to(device)
    Y_test_batch = torch.tensor(W_phi_fem[sli, None, :, :]).to(device)
    U_true = Y_test_batch[:, 0, :, :] * Phi_batch + G_batch
    magnitude = compute_L2_norm_squared((U_true), domain)

    # eval
    Y_pred_normed_j_H2 = model_H2(X_test_normed_j)
    Y_pred_H2 = data.y_normalizer.decode(Y_pred_normed_j_H2)
    U_pred_H2 = Y_pred_H2[:, 0, :, :] * Phi_batch + G_batch

    error_H2 = compute_L2_norm_squared((U_pred_H2 - U_true), domain)
    L2_error_batch_H2 = torch.sqrt(error_H2 / magnitude).cpu().detach().numpy()
    L2_error_H2.append(L2_error_batch_H2)

    U_pred_H2 = None
    Y_pred_normed_j_H2, Y_pred_H2 = None, None
    error_H2, L2_error_batch_H2 = None, None
    gc.collect()
    torch.cuda.empty_cache()  # PyTorch thing

    Y_pred_normed_j_H1 = model_H1(X_test_normed_j)
    Y_pred_H1 = data.y_normalizer.decode(Y_pred_normed_j_H1)
    U_pred_H1 = Y_pred_H1[:, 0, :, :] * Phi_batch + G_batch

    error_H1 = compute_L2_norm_squared((U_pred_H1 - U_true), domain)
    L2_error_batch_H1 = torch.sqrt(error_H1 / magnitude).cpu().detach().numpy()
    L2_error_H1.append(L2_error_batch_H1)

    U_pred_H1 = None
    Y_pred_normed_j_H1, Y_pred_H1 = None, None
    error_H1, L2_error_batch_H1 = None, None

    Y_pred_normed_j_L2 = model_L2(X_test_normed_j)
    Y_pred_L2 = data.y_normalizer.decode(Y_pred_normed_j_L2)
    U_pred_L2 = Y_pred_L2[:, 0, :, :] * Phi_batch + G_batch

    error_L2 = compute_L2_norm_squared((U_pred_L2 - U_true), domain)
    L2_error_batch_L2 = torch.sqrt(error_L2 / magnitude).cpu().detach().numpy()
    L2_error_L2.append(L2_error_batch_L2)

    U_pred_L2 = None
    Y_pred_normed_j_L2, Y_pred_L2 = None, None
    error_L2, L2_error_batch_L2 = None, None
    gc.collect()
    torch.cuda.empty_cache()  # PyTorch thing

L2_error_H2 = np.array(L2_error_H2)
L2_error_H2 = L2_error_H2.flatten()

L2_error_H1 = np.array(L2_error_H1)
L2_error_H1 = L2_error_H1.flatten()

L2_error_L2 = np.array(L2_error_L2)
L2_error_L2 = L2_error_L2.flatten()

In [None]:
error_tab = []
error_tab.append(L2_error_L2)
error_tab.append(L2_error_H1)
error_tab.append(L2_error_H2)


abs_str = [
    r"FNO $\mathcal{L}_{L_2}$ loss",
    r"FNO $\mathcal{L}_{H_1}$ loss",
    r"FNO $\mathcal{L}$ loss",
]
errors = np.array(error_tab[:])
errors = np.reshape(errors, (3, np.shape(errors)[-1]))
print(np.shape(errors))
dataframe = pd.DataFrame(errors.transpose(), columns=abs_str)

sns.set(font_scale=1.1)

plt.figure(figsize=(6, 4))
sns.boxplot(
    data=dataframe,
    palette="ch:s=.0,rot=0.0,dark=0.5",
    flierprops={"marker": "x", "markerfacecolor": "black"},
)
plt.yscale("log")
plt.xlabel("Method", fontsize=16)
plt.ylabel("Relative $L^2$ error", fontsize=16)
plt.grid(axis="y", visible=True, which="both")

plt.tight_layout()
if save_figs:
    plt.savefig(f"{images_repo}/boxplots_compare_losses_L2_phi_FEM.pdf")
plt.show()

In [None]:
dataframe.describe()

# With respect to a reference solution

In [None]:
def compare_loss_levels(param):
    standard_solver = StandardFEMSolver(params=param)
    mu0, mu1, sigma_x, sigma_y, amplitude, x_0, y_0, lx, ly, theta, alpha, beta = param[
        0
    ]
    u_ref, V_ref, dx_ref = standard_solver.solve_one(0, 0.002, reference_fem=True)

    F = generate_F_numpy(mu0, mu1, sigma_x, sigma_y, amplitude, 64)
    G = generate_G_numpy(alpha, beta, 64)
    phi = generate_phi_numpy(x_0, y_0, lx, ly, theta, 64)
    X = generate_manual_new_data_numpy(F, phi, G).to(device)
    x_normed = data.x_normalizer.encode(X)

    Y_normed_H2 = model_H2(x_normed)
    Y_H2 = data.y_normalizer.decode(Y_normed_H2)
    predicted_solution_H2 = (
        (X[:, 1, :, :] * Y_H2[:, 0, :, :] + X[:, 2, :, :]).cpu().detach().numpy()
    )
    predicted_solution_H2 = np.reshape(predicted_solution_H2, (64, 64))
    predicted_sol_fenics_H2 = convert_numpy_matrix_to_fenics(
        predicted_solution_H2, 64, 1
    )

    predicted_sol_fenics_proj_V_ref_H2 = df.project(
        predicted_sol_fenics_H2,
        V_ref,
        solver_type="gmres",
        preconditioner_type="hypre_amg",
    )

    l2_error_fno_H2 = (
        df.assemble((((u_ref - predicted_sol_fenics_proj_V_ref_H2)) ** 2) * dx_ref)
        ** (0.5)
    ) / (df.assemble((((u_ref)) ** 2) * dx_ref) ** (0.5))

    Y_normed_H1 = model_H1(x_normed)
    Y_H1 = data.y_normalizer.decode(Y_normed_H1)
    predicted_solution_H1 = (
        (X[:, 1, :, :] * Y_H1[:, 0, :, :] + X[:, 2, :, :]).cpu().detach().numpy()
    )
    predicted_solution_H1 = np.reshape(predicted_solution_H1, (64, 64))
    predicted_sol_fenics_H1 = convert_numpy_matrix_to_fenics(
        predicted_solution_H1, 64, 1
    )

    predicted_sol_fenics_proj_V_ref_H1 = df.project(
        predicted_sol_fenics_H1,
        V_ref,
        solver_type="gmres",
        preconditioner_type="hypre_amg",
    )

    l2_error_fno_H1 = (
        df.assemble((((u_ref - predicted_sol_fenics_proj_V_ref_H1)) ** 2) * dx_ref)
        ** (0.5)
    ) / (df.assemble((((u_ref)) ** 2) * dx_ref) ** (0.5))

    Y_normed_L2 = model_L2(x_normed)
    Y_L2 = data.y_normalizer.decode(Y_normed_L2)
    predicted_solution_L2 = (
        (X[:, 1, :, :] * Y_L2[:, 0, :, :] + X[:, 2, :, :]).cpu().detach().numpy()
    )
    predicted_solution_L2 = np.reshape(predicted_solution_L2, (64, 64))
    predicted_sol_fenics_L2 = convert_numpy_matrix_to_fenics(
        predicted_solution_L2, 64, 1
    )

    predicted_sol_fenics_proj_V_ref_L2 = df.project(
        predicted_sol_fenics_L2,
        V_ref,
        solver_type="gmres",
        preconditioner_type="hypre_amg",
    )

    l2_error_fno_L2 = (
        df.assemble((((u_ref - predicted_sol_fenics_proj_V_ref_L2)) ** 2) * dx_ref)
        ** (0.5)
    ) / (df.assemble((((u_ref)) ** 2) * dx_ref) ** (0.5))

    print(f"{l2_error_fno_H2=:.3e}")
    print(f"{l2_error_fno_H1=:.3e}")
    print(f"{l2_error_fno_L2=:.3e}")

    return l2_error_fno_H2, l2_error_fno_H1, l2_error_fno_L2

In [None]:
if not os.path.exists("./data_test_compare_methods/"):
    os.makedirs("./data_test_compare_methods")
    F, phi, G, params = create_FG_numpy(300, 64)
    np.save("./data_test_compare_methods/F.npy", F)
    np.save("./data_test_compare_methods/Phi.npy", phi)
    np.save("./data_test_compare_methods/G.npy", G)
    np.save("./data_test_compare_methods/params.npy", params)
else:
    F = np.load("./data_test_compare_methods/F.npy")
    G = np.load("./data_test_compare_methods/G.npy")
    Phi = np.load("./data_test_compare_methods/Phi.npy")
    params = np.load("./data_test_compare_methods/params.npy")

if not os.path.exists("./compare_losses/"):
    os.makedirs("./compare_losses")

In [None]:
L2_error_fno_H2, L2_error_fno_H1, L2_error_fno_L2 = [], [], []
indices = list(range(0, len(params)))
for index in indices:
    print(f"Iter : {index+1}/{len(params)}")
    l2_error_fno_H2, l2_error_fno_H1, l2_error_fno_L2 = compare_loss_levels(
        np.array([params[index]])
    )
    L2_error_fno_H2.append(l2_error_fno_H2)
    L2_error_fno_H1.append(l2_error_fno_H1)
    L2_error_fno_L2.append(l2_error_fno_L2)
    np.save(
        f"./compare_losses/L2_error_fno_H2.npy",
        np.array([L2_error_fno_H2]),
    )
    np.save(
        f"./compare_losses/L2_error_fno_H1.npy",
        np.array([L2_error_fno_H1]),
    )
    np.save(
        f"./compare_losses/L2_error_fno_L2.npy",
        np.array([L2_error_fno_L2]),
    )

In [None]:
L2_error_phi_fem = np.load(f"{models_repo_H2}/L2_error_phi_fem.npy")
L2_error_std_fem = np.load(f"{models_repo_H2}/L2_error_std_fem.npy")
L2_error_fno_H2 = np.load(f"compare_losses/L2_error_fno_H2.npy")
L2_error_fno_H1 = np.load(f"compare_losses/L2_error_fno_H1.npy")
L2_error_fno_L2 = np.load(f"compare_losses/L2_error_fno_L2.npy")
print(L2_error_fno_L2.shape)
print(L2_error_fno_H1.shape)
print(L2_error_fno_H2.shape)
print(L2_error_std_fem.shape)
print(L2_error_phi_fem.shape)

In [None]:
error_tab = []
error_tab.append(L2_error_phi_fem)
error_tab.append(L2_error_std_fem)
error_tab.append(L2_error_fno_L2)
error_tab.append(L2_error_fno_H1)
error_tab.append(L2_error_fno_H2)


abs_str = [
    r"$\phi$-FEM",
    "Std FEM",
    r"FNO $\mathcal{L}_0$",
    r"FNO $\mathcal{L}_{H_1}$",
    r"FNO $\mathcal{L}$",
]
errors = np.array(error_tab[:])
errors = np.reshape(errors, (5, np.shape(errors)[-1]))
print(np.shape(errors))
dataframe = pd.DataFrame(errors.transpose(), columns=abs_str)

sns.set(font_scale=1.1)

plt.figure(figsize=(6, 4))
sns.boxplot(
    data=dataframe,
    palette="ch:s=.0,rot=0.0,dark=0.5",
    flierprops={"marker": "x", "markerfacecolor": "black"},
)
plt.yscale("log")
plt.xlabel("Method", fontsize=16)
plt.ylabel("Relative $L^2$ error", fontsize=16)
plt.grid(axis="y", visible=True, which="both")

plt.tight_layout()
if save_figs:
    plt.savefig(f"{images_repo}/boxplots_compare_losses.pdf")
plt.show()

In [None]:
dataframe.describe()