In [1]:
import numpy as np
import matplotlib.pyplot as plt
import random
import os

seed = 251024
random.seed(seed)
np.random.seed(seed)
import time
from utils import *
from utils_training import *
from prepare_data import *
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pandas as pd
import torch
import gc

torch.manual_seed(seed)

sns.set_context("paper", rc={"xtick.bottom": True, "ytick.left": True})
colors = sns.color_palette("mako").as_hex()
my_cmap = sns.color_palette("viridis", as_cmap=True)

%load_ext autoreload
%autoreload 2

# Initialization 

Creation of the data loader and of the agent. Construction of the list containing training steps. 

In [None]:
save_figs = True
small_data = False
data = DataLoader(small_data)
agent = Agent(data, level=1)
model = agent.model
device = agent.device

models_repo = "./models"
images_repo = "../images"
best_model = torch.load(f"{models_repo}/best_model.pkl", weights_only=True)

print(f"Best epoch = {best_model['epoch']}")
if not (os.path.exists(f"{images_repo}/")) and save_figs:
    os.makedirs(f"{images_repo}/")

epochs = [50, 100, 500, 1000, 1250, 1500, 1750, 2000, best_model["epoch"]]
print(len(epochs))
print(epochs)

indices = list(range(0, len(epochs)))

# Minimal error on the validation sample

# Replace minimal by median 

We compute the $L^2$ relative error of each prediction on the validation sample and determine the minimal error. We represent the data, the prediction and the ground truth. 

In [3]:
def compute_L2_norm_squared(U, domain):
    nb_vertices = torch.sum(domain, (1, 2), False)
    norm = (1.0 / nb_vertices) * torch.sum(U**2 * domain, (1, 2), False)
    return norm


def compute_Linf_norm(U, domain):
    norm = torch.amax(torch.abs(U * domain), (1, 2))
    return norm


def compute_domain(X):
    loss = Loss()
    domain_tmp = (X[:, 1, :, :] <= 3e-16).to(device)
    neighborhood = loss.neighborhood_6(domain_tmp).to(device)
    domain = ((neighborhood.int() + domain_tmp.int()) != 0).to(device)
    return domain

In [None]:
model = None  # clear memory
X_val, Y_val, x_normed, Y_pred, Y_pred_normed, X_val_normed = (
    None,
    None,
    None,
    None,
    None,
    None,
)
Phi, G, domain, U_true, U_pred = None, None, None, None, None
error, magnitude, error_inf, magnitude_inf = None, None, None, None
gc.collect()
torch.cuda.empty_cache()  # PyTorch thing


model = agent.model
best_model = torch.load(f"{models_repo}/best_model.pkl", weights_only=True)
model.load_state_dict(best_model["model_state_dict"])
model.eval()

X_val, Y_val = data.X_val.to(device), data.Y_val.to(device)
X_val_normed = data.X_val_normed.to(device)
Y_pred_normed = model(X_val_normed)
Y_pred = data.y_normalizer.decode(Y_pred_normed)
Phi, G = X_val[:, 1, :, :], X_val[:, 2, :, :]
domain = compute_domain(X_val)

U_true = Y_val[:, 0, :, :] * Phi + G
U_pred = Y_pred[:, 0, :, :] * Phi + G

error = compute_L2_norm_squared((U_pred - U_true), domain)
magnitude = compute_L2_norm_squared((U_true), domain)
L2_error = torch.sqrt(error / magnitude)
L2_error = L2_error.cpu().detach().numpy()

sorted_errors_indices = np.argsort(L2_error)
index = sorted_errors_indices[len(sorted_errors_indices) // 2 + 1]

model = agent.model
best_model = torch.load(f"{models_repo}/best_model.pkl", weights_only=True)
model.load_state_dict(best_model["model_state_dict"])
model.eval()

X = data.X_val[index]
X = X[None, :, :, :].to(device)
X_normed = data.x_normalizer.encode(X)
W_ref = data.Y_val[index].cpu().detach().numpy()
Y_pred_normed = model(X_normed)
Y_pred = data.y_normalizer.decode(Y_pred_normed).cpu().detach().numpy()
F, Phi, G = (
    X[0, 0, :, :].cpu().detach().numpy(),
    X[0, 1, :, :].cpu().detach().numpy(),
    X[0, 2, :, :].cpu().detach().numpy(),
)
domain = compute_domain(X).cpu().detach().numpy()
domains_tmp = domain.flatten()
domains_nan = domain.copy().flatten().astype(float)
domains_nan[np.where(domains_tmp == False)] = np.nan
domains_nan = np.reshape(domains_nan, domain.shape)
domains_nan = domains_nan[0, :, :]
fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1, 5, figsize=(25, 5))
xxx = np.linspace(0, 1, 64)

img = ax1.contourf(xxx, xxx, F * domains_nan, cmap=my_cmap, levels=50)
divider = make_axes_locatable(ax1)
cax = divider.append_axes("bottom", size="5%", pad=0.3)
ax1.grid(False)
ax1.set_title(r"$f$", fontsize=30)
cbar = plt.colorbar(img, cax=cax, orientation="horizontal", format="%.0f")
cbar.set_ticks(np.linspace(np.nanmin(F * domains_nan), np.nanmax(F * domains_nan), 5))
cbar.ax.tick_params(labelsize=17)

img = ax2.contourf(
    xxx,
    xxx,
    G * domains_nan,
    cmap=my_cmap,
    levels=50,
)
divider = make_axes_locatable(ax2)
cax = divider.append_axes("bottom", size="5%", pad=0.3)
ax2.grid(False)
ax2.set_title(r"$g$", fontsize=30)
cbar = plt.colorbar(img, cax=cax, orientation="horizontal", format="%.2f")
cbar.set_ticks(np.linspace(np.nanmin(G * domains_nan), np.nanmax(G * domains_nan), 5))
cbar.ax.tick_params(labelsize=17)

u_pred = (Y_pred[0, 0, :, :] * Phi + G) * domains_nan
img = ax3.contourf(
    xxx,
    xxx,
    u_pred,
    cmap=my_cmap,
    levels=50,
)
divider = make_axes_locatable(ax3)
cax = divider.append_axes("bottom", size="5%", pad=0.3)
ax3.grid(False)
ax3.set_title(r"$u_{\theta} = \varphi w_{\theta} + g$", fontsize=30)
cbar = plt.colorbar(img, cax=cax, orientation="horizontal", format="%.2f")
cbar.set_ticks(np.linspace(np.nanmin(u_pred), np.nanmax(u_pred), 5))
cbar.ax.tick_params(labelsize=17)

u_true = (W_ref[0, :, :] * Phi + G) * domains_nan
img = ax4.contourf(
    xxx,
    xxx,
    u_true,
    cmap=my_cmap,
    levels=50,
)
divider = make_axes_locatable(ax4)
cax = divider.append_axes("bottom", size="5%", pad=0.3)
ax4.grid(False)
ax4.set_title(r"$u_{true} = \varphi w_{true} + g$", fontsize=30)
cbar = plt.colorbar(img, cax=cax, orientation="horizontal", format="%.2f")
cbar.set_ticks(np.linspace(np.nanmin(u_true), np.nanmax(u_true), 5))
cbar.ax.tick_params(labelsize=17)

difference = np.absolute(
    ((W_ref[0, :, :] * Phi + G) - (Y_pred[0, 0, :, :] * Phi + G)) * domains_nan
)
img = ax5.contourf(xxx, xxx, difference, cmap=my_cmap, levels=50)
divider = make_axes_locatable(ax5)
cax = divider.append_axes("bottom", size="5%", pad=0.3)
ax5.grid(False)
ax5.set_title(r"$|u_{true} - u_{\theta}|$", fontsize=30)
cbar = plt.colorbar(img, cax=cax, orientation="horizontal", format="%.2e")
cbar.set_ticks(np.linspace(np.nanmin(difference), np.nanmax(difference), 4))
cbar.ax.tick_params(labelsize=17)
plt.tight_layout()
if save_figs:
    fig.savefig(f"{images_repo}/median_prediction_{L2_error[index]:6f}.png", dpi=64)
plt.show()

print(f"{L2_error[index]:.3e}")

## On a new test dataset

We now move to the case of new data. We generate a new dataset and compute the error between a $\varphi$-FEM solution of the dataset and a prediction of the trained model. 

In [10]:
from generate_data import PhiFemSolver
from prepare_data import set_seed

set_seed(161224)

In [11]:
X_test, Y_test, x_normed, Y_pred, X_denormed = None, None, None, None, None
X_val, Y_val = None, None
Phi, G, domain, U_true, U_pred = None, None, None, None, None
error, magnitude, error_inf, magnitude_inf = None, None, None, None
model = None  # clear memory
gc.collect()
torch.cuda.empty_cache()  # PyTorch thing

nb_test_data = 2500
if not os.path.exists(f"../../data_test_phi_fem_{nb_test_data}"):
    os.makedirs(f"../../data_test_phi_fem_{nb_test_data}")
    F, Phi, G, params = create_parameters(nb_test_data, 64)
    print("Parameters generated")
    solver = PhiFemSolver(nb_cell=64 - 1, params=params)
    W_phi_fem = solver.solve_several()
    X_test = generate_manual_new_data_numpy(F, Phi, G).to(device)

    np.save(f"../../data_test_phi_fem_{nb_test_data}/F.npy", F)
    np.save(f"../../data_test_phi_fem_{nb_test_data}/params.npy", params)
    np.save(f"../../data_test_phi_fem_{nb_test_data}/Phi.npy", Phi)
    np.save(f"../../data_test_phi_fem_{nb_test_data}/G.npy", G)
    np.save(f"../../data_test_phi_fem_{nb_test_data}/W.npy", W_phi_fem)

In [12]:
X_test, Y_test, x_normed, Y_pred, X_denormed = None, None, None, None, None
Phi, G, domain, U_true, U_pred = None, None, None, None, None
model = None  # clear memory
gc.collect()
torch.cuda.empty_cache()  # PyTorch thing

F, Phi, G = (
    np.load(f"../../data_test_phi_fem_{nb_test_data}/F.npy"),
    np.load(f"../../data_test_phi_fem_{nb_test_data}/Phi.npy"),
    np.load(f"../../data_test_phi_fem_{nb_test_data}/G.npy"),
)
W_phi_fem = np.load(f"../../data_test_phi_fem_{nb_test_data}/W.npy")

In [None]:
L2_errors_fno = []
test_batch_size = 500
nb_test_batch = nb_test_data // test_batch_size
print(f"{test_batch_size=}")
print(f"{nb_test_batch=}")

for i in epochs:
    L2_error_i = []
    for j in range(nb_test_batch):
        print(f"Epoch = {i} Batch : {j} / {nb_test_batch}")
        X_test, Y_test, x_normed, Y_pred, X_denormed = None, None, None, None, None
        domain, U_true, U_pred = None, None, None
        X_test_normed_j, Y_pred_normed_j, Y_pred, Phi_batch, Y_test_batch, G_batch = (
            None,
            None,
            None,
            None,
            None,
            None,
        )
        error, magnitude, error_inf, magnitude_inf = None, None, None, None
        model = None  # clear memory
        gc.collect()
        torch.cuda.empty_cache()  # PyTorch thing
        model = agent.model
        if i != epochs[-1]:
            model_i = torch.load(f"{models_repo}/model_{i}.pkl", weights_only=True)
            model.load_state_dict(model_i["model_state_dict"])
        else:
            best_model = torch.load(f"{models_repo}/best_model.pkl", weights_only=True)
            model.load_state_dict(best_model["model_state_dict"])
        model.eval()
        sli = slice(j * test_batch_size, (j + 1) * test_batch_size)
        X_test = generate_manual_new_data_numpy(F[sli], Phi[sli], G[sli]).to(device)
        X_test_normed_j = data.x_normalizer.encode(X_test)
        Y_pred_normed_j = model(X_test_normed_j)
        Y_pred = data.y_normalizer.decode(Y_pred_normed_j)

        Phi_batch, G_batch = X_test[:, 1, :, :], X_test[:, 2, :, :]
        domain = compute_domain(X_test)
        Y_test_batch = torch.tensor(W_phi_fem[sli, None, :, :]).to(device)
        U_true = Y_test_batch[:, 0, :, :] * Phi_batch + G_batch
        U_pred = Y_pred[:, 0, :, :] * Phi_batch + G_batch

        error = compute_L2_norm_squared((U_pred - U_true), domain)
        magnitude = compute_L2_norm_squared((U_true), domain)
        L2_error_batch = torch.sqrt(error / magnitude).cpu().detach().numpy()
        L2_error_i.append(L2_error_batch)
    L2_error_i = np.array(L2_error_i)
    L2_errors_fno.append(L2_error_i.flatten())

L2_errors_fno = np.array(L2_errors_fno)
print(L2_errors_fno.shape)

In [None]:
palette = sns.cubehelix_palette(
    n_colors=len(epochs) - 1, start=0.25, rot=-0.25, gamma=0.5
)
palette = palette.as_hex()
palette.insert(-1, "#b22222")
palette = sns.color_palette(palette)

plt.figure(figsize=(12, 4))
errors_fno = np.array(L2_errors_fno[:])
print(np.shape(errors_fno))
dataframe = pd.DataFrame(errors_fno.transpose(), columns=epochs)

sns.boxplot(
    data=dataframe,
    palette=palette,
    order=epochs,
    flierprops={"marker": "x", "markerfacecolor": "black"},
)
plt.xlabel("Training epochs", fontsize=16)
plt.ylabel("$L^2$ relative error", fontsize=16)
plt.yscale("log")
plt.grid(True, "both", "y")
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

plt.tight_layout()
if save_figs:
    plt.savefig(f"{images_repo}/boxplots_new_data_L2.pdf")
plt.show()