In [None]:
import numpy as np
import matplotlib.pyplot as plt
import random
import os

seed = 2023
random.seed(seed)
np.random.seed(seed)
import dolfin as df
import time
from utils import *
from utils_training import *
from utils_compare_methods import *
import prepare_data
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pandas as pd
import torch
import gc

torch.manual_seed(seed)

sns.set_theme()
sns.set_context("paper")
sns.set(rc={"xtick.bottom": True, "ytick.left": True})
colors = sns.color_palette("mako").as_hex()
my_cmap = sns.color_palette("viridis", as_cmap=True)

%load_ext autoreload
%autoreload 2

In [None]:
save_figs = True
small_data = False
data = DataLoader(small_data)
agent = Agent(
    data,
    level=2,
    relative=True,
    squared=False,
    initial_lr=5e-3,
    n_modes=10,
    width=20,
    batch_size=32,
    pad_prop=0.05,
    pad_mode="reflect",
    l2_lambda=1e-3,
)

model = agent.model
device = agent.device

models_repo = "./models_H2"
images_repo = "./images_H2"
best_model = torch.load(f"{models_repo}/best_model.pkl")
model.load_state_dict(best_model["model_state_dict"])
model.eval()

print(f"Best epoch = {best_model['epoch']}")
if not (os.path.exists(f"{images_repo}/")) and save_figs:
    os.makedirs(f"{images_repo}/")

In [None]:
def change_size_fno(param, sizes, Plot=False, model=model):
    standard_solver = StandardFEMSolver(params=param)
    u_ref, V_ref, dx_ref = standard_solver.solve_one(0, 0.0015, reference_fem=True)

    errors, sols = [], []
    mu0, mu1, sigma_x, sigma_y, amplitude, x_0, y_0, lx, ly, theta, alpha, beta = param[
        0
    ]
    for size in sizes:
        print(f"{size=}")
        F = generate_F_numpy(mu0, mu1, sigma_x, sigma_y, amplitude, size)
        G = generate_G_numpy(alpha, beta, size)
        phi = generate_phi_numpy(x_0, y_0, lx, ly, theta, size)
        X = generate_manual_new_data_numpy(F, phi, G).to(device)
        x_normed = data.x_normalizer.encode(X)
        Y_normed = model(x_normed)
        Y = data.y_normalizer.decode(Y_normed)
        predicted_solution = (
            (X[:, 1, :, :] * Y[:, 0, :, :] + X[:, 2, :, :]).cpu().detach().numpy()
        )
        predicted_solution = np.reshape(predicted_solution, (size, size))
        predicted_solution_fenics = convert_numpy_matrix_to_fenics(
            predicted_solution, size, 1
        )
        predicted_solution_fenics_proj_V_ex = df.project(
            predicted_solution_fenics,
            V_ref,
            solver_type="gmres",
            preconditioner_type="hypre_amg",
        )
        sols.append(predicted_solution_fenics_proj_V_ex)
        l2_error_fno = (
            df.assemble((((u_ref - predicted_solution_fenics_proj_V_ex)) ** 2) * dx_ref)
            ** (0.5)
        ) / (df.assemble((((u_ref)) ** 2) * dx_ref) ** (0.5))
        errors.append(l2_error_fno)

    if Plot:
        plt.figure(figsize=(21, 4))
        plt.subplot(1, 7, 1)
        p = df.plot(u_ref, mode="color", cmap=my_cmap)
        plt.colorbar(p)
        plt.title("Reference solution", fontsize="14")
        for i in range(1, 7):
            plt.subplot(1, 7, i + 1)
            p = df.plot(sols[i - 1], mode="color", cmap=my_cmap)
            plt.colorbar(p)
            plt.title(
                "FNO prediction for nb_vert = "
                + str(sizes[i - 1])
                + "\n$L^2$ relative error : "
                + f"{errors[i-1]:.3e}",
                fontsize="14",
            )

        plt.tight_layout()
        plt.show()
    return errors

In [None]:
nb_params_to_test = 20
F, phi, G, params = create_FG_numpy(nb_params_to_test, 64)
sizes = [48, 64, 68, 84, 104, 124]
print(f"{len(sizes)=}")
print(f"{nb_params_to_test=}")

In [None]:
errors = []
for i in range(np.shape(params)[0]):
    print(f"Param : {i+1}/{np.shape(params)[0]}")
    error = change_size_fno(np.array([params[i]]), sizes, Plot=False, model=model)
    errors.append(error)
errors = np.array(errors)

In [None]:
abs_str = sizes
dataframe = pd.DataFrame(errors, columns=abs_str)
palette = sns.cubehelix_palette(
    n_colors=len(sizes) - 1, start=0.25, rot=-0.25, gamma=0.5
)
palette = palette.as_hex()
palette.insert(sizes.index(64), "#b22222")
palette = sns.color_palette(palette)

plt.figure(figsize=(8, 5))

sns.boxplot(data=dataframe, palette=palette)
plt.yscale("log")
plt.grid(True, which="both", axis="y")
plt.xlabel("Number of vertices", fontsize=18)
plt.ylabel("$L^2$ relative error", fontsize=18)
plt.minorticks_on()
plt.tight_layout()

if save_figs:
    plt.savefig(f"./{images_repo}/boxplots_new_data_change_size.png")
plt.show()

In [None]:
domains_list = []
solutions = []
for size in sizes:
    i = 0
    mu0, mu1, sigma_x, sigma_y, amplitude, x_0, y_0, lx, ly, theta, alpha, beta = (
        0.30122539,
        0.29640635,
        0.25481616,
        0.46846484,
        25,
        0.46192393,
        0.48159562,
        0.37990447,
        0.32648597,
        0.41180092,
        0.37976006,
        -0.31339751,
    )
    F = generate_F_numpy(mu0, mu1, sigma_x, sigma_y, amplitude, size)
    G = generate_G_numpy(alpha, beta, size)
    phi = generate_phi_numpy(x_0, y_0, lx, ly, theta, size)
    X = generate_manual_new_data_numpy(F, phi, G).to(device)
    x_normed = data.x_normalizer.encode(X)
    Y_normed = model(x_normed)
    Y = data.y_normalizer.decode(Y_normed)
    predicted_solution = (
        (X[0, 1, :, :] * Y[0, 0, :, :] + X[0, 2, :, :]).cpu().detach().numpy()
    )
    solutions.append(predicted_solution)

    domains = phi <= 3e-16
    domains_tmp = domains.flatten()
    domains_nan = domains.copy().flatten().astype(float)
    domains_nan[np.where(domains_tmp == False)] = np.nan
    domains_nan = np.reshape(domains_nan, domains.shape)
    domains_list.append(domains_nan)

In [None]:
fig, axes = plt.subplots(figsize=(8, 5.5), nrows=2, ncols=3)
i = 0
for ax in axes.flat:
    im = ax.imshow(
        (domains_list[i] * solutions[i])[0, :, :], cmap="viridis", origin="lower"
    )
    if i >= 3:
        ax.set_title(f"{sizes[i]} vertices", fontsize=18)
    else:
        ax.set_title(f"{sizes[i]} vertices", fontsize=18, y=-0.2, pad=-14)
    ax.grid(False)

    i += 1
fig.subplots_adjust(right=0.84, hspace=0.5)
cbar_ax = fig.add_axes([0.86, 0.14, 0.030, 0.7])
fig.colorbar(im, cax=cbar_ax)

plt.savefig(f"./{images_repo}/outputs_various_sizes.png")
plt.show()