In [None]:
import copy
from pathlib import Path

from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import torch
from tqdm import tqdm

from _utils.models import NetworkWithPDE, ScaledFNN, PINNLoss, NetPredWrapper
from _utils.plotting import plot_results, plot_losses, scatter_with_colorbar
from train import (L, W, bc_inflow_u, bc_inflow_v, bc_noslip_u, bc_noslip_v, bc_outflow_u, bc_outflow_v,
                   cylinder_center, geom, navier_stokes, navier_stokes_broken, radius)

plt.rcParams['axes.titlesize'] = 24
plt.rcParams['axes.labelsize'] = 18
plt.rcParams['xtick.labelsize'] = 18
plt.rcParams['ytick.labelsize'] = 18

In [None]:
model_zoo = Path("./model_zoo")

fig_base_path = Path("./figures")
fig_base_path.mkdir(exist_ok=True)
(fig_base_path / "model_predictions").mkdir(exist_ok=True)
(fig_base_path / "model_losses").mkdir(exist_ok=True)

In [None]:
device = torch.get_default_device()

layers = [2, 64, 64, 64, 64, 3]
activation = 'gelu'

In [None]:
data = np.load("./dataset/ns_steady.npy", allow_pickle=True).item()
u_target = np.array(data["u"])
v_target = np.array(data["v"])
p_target = np.array(data["p"])
coords = np.array(data["coords"]).astype(np.float32)
coords_tensor = torch.as_tensor(coords, dtype=torch.float32)

In [None]:
net = ScaledFNN(layers, activation, "Glorot uniform", L, W)

In [None]:
nets = {
    "good": copy.deepcopy(net),
    "bad": copy.deepcopy(net),
    "broken": copy.deepcopy(net)
}

for model_name, net in nets.items():
    model_chkpt = list((model_zoo / model_name).glob("lbfgs*"))[0]
    net.load_state_dict(
        torch.load(
            model_chkpt,
            map_location = device,
        )["model_state_dict"]
    )

In [None]:
continuity = lambda xy, out: navier_stokes(xy, out)[0]
momentum_x = lambda xy, out: navier_stokes(xy, out)[1]
momentum_y = lambda xy, out: navier_stokes(xy, out)[2]
inflow_boundary = lambda x: torch.isclose(x[:, 0], torch.tensor(0, dtype=torch.float32))
outflow_boundary = lambda x: torch.isclose(x[:, 0], torch.tensor(L, dtype=torch.float32))
noslip_boundary = lambda x: torch.logical_or(
    torch.logical_or(
        torch.isclose(x[:, 1], torch.tensor(0, dtype=torch.float32)),
        torch.isclose(x[:, 1], torch.tensor(W, dtype=torch.float32))
    ),
    (torch.sqrt((x[:, 0] - cylinder_center[0]) ** 2 + (x[:, 1] - cylinder_center[1]) ** 2) <= radius)
)

pde_nets = {
    k: NetworkWithPDE(
        net,
        pdes=[
            continuity,
            momentum_x if not k == "broken" else lambda xy, out: navier_stokes_broken(xy, out)[1],
            momentum_y,
        ],
        bcs=[
            (bc_inflow_u, inflow_boundary),
            (bc_inflow_v, inflow_boundary),
            (bc_outflow_u, outflow_boundary),
            (bc_outflow_v, outflow_boundary),
            (bc_noslip_u, noslip_boundary),
            (bc_noslip_v, noslip_boundary),
        ],
        geom=geom
    ) for (k, net) in nets.items()
}

In [None]:
influences_arrs = {
    "good": np.load("./model_zoo/good_influences/influence_arrs.npz"),
    "bad": np.load("./model_zoo/bad_influences/influence_arrs.npz"),
    "broken": np.load("./model_zoo/broken_influences/influence_arrs.npz")
}

In [None]:
for model_name, net in tqdm(nets.items(), desc='Creating plots for model predictions: '):
    fig_name = Path(f"{fig_base_path / 'model_predictions' / model_name}.png")
    if (not fig_name.exists()):
        fig, ax = plot_results(net, coords, u_target, v_target, p_target)
        fig.savefig(fig_name)
        plt.close(fig)

In [None]:
preds_dims = {}
for model_name, net in nets.items():
    preds_dims[model_name] = {}
    # i.e., u, v, p, and sqrt(u**2 + v**2)
    for dim in range(4):
        cur_net = NetPredWrapper(net, dim)
        preds_dims[model_name][dim] = cur_net(coords_tensor).detach().numpy()

dim_lookup = {
    0: "u_1",
    1: "u_2",
    2: "p",
    3: "||\\vec{u}||",
    "influences_pde": "pde_losses",
    "bcs": "bc_losses",
    "all_loss_terms": "all_losses",
}

In [None]:
for model_name, preds in preds_dims.items():
    fig, ax = plt.subplots(figsize = (22, 4 * 3), nrows = 3)
    for dim in range(3):
        scatter_with_colorbar(
            fig,
            ax[dim],
            coords,
            preds_dims[model_name][dim],
            limit_axes = False
        )
        ax[dim].axis("off")
        ax[dim].set_title(f"Predicted ${dim_lookup[dim]}$")
    fig.savefig(fig_base_path / "model_predictions" / f"{model_name}_without_target.png", bbox_inches='tight')

In [None]:
for model_name, preds in preds_dims.items():
    fig, ax = plt.subplots(figsize = (22, 4.1 * 4), nrows = 4)
    for dim in range(4):
        scatter_with_colorbar(
            fig,
            ax[dim],
            coords,
            preds_dims[model_name][dim],
            limit_axes = False
        )
        ax[dim].axis("off")
        ax[dim].set_title(f"Predicted ${dim_lookup[dim]}$")
    fig.savefig(fig_base_path / "model_predictions" / f"{model_name}_without_target_with_u.png", bbox_inches='tight')

In [None]:
targets = [
    u_target,
    v_target,
    p_target,
    np.sqrt(u_target**2 + v_target**2)
]
for model_name, preds in preds_dims.items():
    fig, ax = plt.subplots(figsize = (22, 4.1 * 4), nrows = 4)
    for dim in range(4):
        scatter_with_colorbar(
            fig,
            ax[dim],
            coords,
            np.abs(preds_dims[model_name][dim].flatten() - targets[dim].flatten()),
            limit_axes = False
        )
        ax[dim].axis("off")
        ax[dim].set_title(f"Error in ${dim_lookup[dim]}$")
    fig.savefig(fig_base_path / "model_predictions" / f"{model_name}_error.png", bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(figsize = (22, 4.1 * 4), nrows = 4)
for e, target in enumerate(targets):
    scatter_with_colorbar(
        fig,
        ax[e],
        coords,
        target,
        limit_axes = False
    )
    ax[e].axis("off")
    ax[e].set_title(f"Target ${dim_lookup[e]}$")
fig.savefig(fig_base_path / "model_predictions" / f"target.png", bbox_inches='tight')

In [None]:
loss_fn = PINNLoss()

for model_name, pde_net in tqdm(pde_nets.items(), desc='Creating Loss-Plots'):
    fig_name = Path(f"{fig_base_path / 'model_losses' / model_name}.png")
    if (not fig_name.exists()):
        fig, ax = plot_losses(pde_net, loss_fn, coords_tensor)
        fig.savefig(fig_name)

In [None]:
get_min_distance = lambda x,y: cdist(x, np.array([y])).argmin()

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(22, 4.1))
model_key = "good"
pred_key = "output_dim_0"
train_x = influences_arrs[model_key]["train_x"]
idx = get_min_distance(train_x, [cylinder_center[0], cylinder_center[1] - 0.05])
scatter_with_colorbar(
    fig,
    ax,
    coords,
    influences_arrs[model_key][pred_key][:, idx] * train_x.shape[0],  # multiply by number of training points to retrieve actual influence values
    norm=True,
    cmap="bwr"
)
ax.scatter(*influences_arrs[model_key]["train_x"][idx].T, marker="x", s=400, c='black')
ax.axis("off")

arrowprops = dict(facecolor='black', width=4, headwidth=30)
ax.annotate('', xy=(L - L/6, 5 * W / 6), xytext=(L/6, 5 * W / 6), arrowprops=arrowprops)

# Add centered text
mid_x = L / 2
mid_y = 5 * W / 6
ax.text(mid_x, mid_y, 'Flow Direction', ha='center', va='bottom', fontsize=32, color='black')

fig.savefig(fig_base_path / 'single_point_right_under_cylinder_output_dim_0.png', bbox_inches='tight')

In [None]:
import matplotlib.pyplot as plt
for model_key in ["bad", "broken"]:
    fig, ax = plt.subplots(figsize=(22, 4.1))
    pred_key = "output_dim_0"
    train_x = influences_arrs[model_key]["train_x"]
    idx = get_min_distance(train_x, [cylinder_center[0], cylinder_center[1] - 0.05])
    scatter_with_colorbar(
        fig,
        ax,
        coords,
        influences_arrs[model_key][pred_key][:, idx] * train_x.shape[0],  # multiply by number of training points to retrieve actual influence values
        norm=True,
        cmap="bwr"
    )
    ax.scatter(*influences_arrs[model_key]["train_x"][idx].T, marker="x", s=400, c='black')
    ax.axis("off")
    fig.savefig(fig_base_path / f'single_point_right_under_cylinder_output_dim_0_{model_key}.png', bbox_inches='tight')

In [None]:
for model_name in influences_arrs.keys():
    train_x = influences_arrs[model_name]['train_x']

    fig, ax = plt.subplots(figsize=(22, 4.1))
    scatter_with_colorbar(fig, ax, train_x, np.log(10000 * np.abs(influences_arrs[model_name]['output_dim_3']).mean(axis=0)), cmap='jet', markersize=12 * 4 if model_name == "bad" else 12)
    circle = patches.Circle(cylinder_center, radius * 1.5, edgecolor='black', facecolor='none', linewidth=5, alpha=0.9)
    ax.add_patch(circle)
    ax.axis('off')

    fig.savefig(fig_base_path / f"log_mean_of_absolute_train_influences_{model_name}_output_dim_3.png", bbox_inches='tight')

In [None]:
# Direction Indicator
for infl in [
        'output_dim_0',
        'output_dim_1',
        'output_dim_2',
        'output_dim_3',
        'influences_pde',
        'bcs',
        'all_loss_terms'
]:
    print(infl)
    for model_name in ['good', 'broken', 'bad']:
        print(model_name)
        train_x = influences_arrs[model_name]['train_x']
        influences_cur = influences_arrs[model_name][infl]
        influences_cur_abssum = influences_cur.sum()
        direction_indicator_scores = []

        # Iterate over each point in points1
        for i, (x1, y1) in enumerate(coords):
            if infl != "bcs" or np.abs(influences_cur[i]).sum() != 0:
                # Find indices of points in points2 with smaller x coordinates
                indices = np.where(train_x[:, 0] < x1)[0]
                direction_indicator_scores.append(np.abs(influences_cur[i, indices]).sum() / np.abs(influences_cur[i]).sum())

        print(np.mean(direction_indicator_scores).round(2))
    print()

In [None]:
# Object Indicator scores for C_1.5r
around_circle = lambda x: (np.sqrt((x[:,0] - 0.2) ** 2 + (x[:,1] - 0.2) ** 2) <= (radius * 1.5))

for infl in [
        'output_dim_0',
        'output_dim_1',
        'output_dim_2',
        'output_dim_3',
        'influences_pde',
        'bcs',
        'all_loss_terms'
]:
    print(infl)
    for model_name in ['good', 'broken', 'bad']:
        print(model_name)
        influences_cur = influences_arrs[model_name][infl]
        train_x = influences_arrs[model_name]['train_x']
        indices = around_circle(train_x)
        object_indicator_scores = []
        for i, (x1, y1) in enumerate(coords):
            if infl != "bcs" or np.abs(influences_cur[i]).sum() != 0:
                object_indicator_scores.append(np.abs(influences_cur[i, indices]).sum() / np.abs(influences_cur[i]).sum())
        print(np.mean(object_indicator_scores).round(2))
    print()