In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from functools import partial
import os
import pickle as pkl
from collections.abc import MutableMapping

import matplotlib.pyplot as plt
import matplotlib.tri as tri
import matplotlib.ticker as ticker
import numpy as np
import tqdm

os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["DDE_BACKEND"] = "jax"

# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".XX"
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"

from jax import config
config.update("jax_enable_x64", True)
# config.update("jax_debug_nans", True)

import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
import optax

try:
    print(f'Jax: CPUs={jax.local_device_count("cpu")} - GPUs={jax.local_device_count("gpu")}')
except:
    pass
    
import deepxde_al_patch.deepxde as dde

from deepxde_al_patch.model_loader import construct_model, construct_net
from deepxde_al_patch.modified_train_loop import ModifiedTrainLoop
from deepxde_al_patch.plotters import plot_residue_loss, plot_error, plot_prediction
from deepxde_al_patch.train_set_loader import load_data

from deepxde_al_patch.ntk import NTKHelper
from deepxde_al_patch.utils import get_pde_residue, print_dict_structure

In [None]:
plt.rcParams['figure.dpi'] = 150
plt.rcParams['font.size'] = 18
plt.rcParams["figure.titlesize"] = 24
plt.rcParams['text.usetex'] = False

In [None]:
main_graph = 'al_pinn_graphs_final/main'

In [None]:
graph_root = 'al_pinn_graphs_final/main'
max_runs = 10
algs = {
    'random_pseudo_prop-0.8': ('Uniform Rand', dict(c='black', ls=':', marker='p')),
    'random_Hammersley_prop-0.8': ('Hammersley', dict(c='grey', ls=':', marker='h')),
    'residue_prop-0.8': ('RAD', dict(c='red', ls='--', marker='v')),
    'residue_prop-0.8_alltype': ('RAD-All', dict(c='orange', ls='--', marker='^')),
    'sampling_alignment_scale-none_mem_autoal': ('PINNACLE-S (ours)', dict(c='green', ls='-', marker='s')),
    'kmeans_alignment_scale-none_mem_autoal': ('PINNACLE-K (ours)', dict(c='blue', ls='-', marker='o')),
}

In [None]:
data_folder = '../../'

In [None]:
def contour_on_ax(ax, xs, zs, levels, res=200, rm_axis=False):
    xi, yi = [np.linspace(np.min(xs[:,i]), np.max(xs[:,i]), res) for i in range(2)]
    grid = np.meshgrid(xi, yi)
    triang = tri.Triangulation(xs[:,0], xs[:,1])
    interpolator = tri.LinearTriInterpolator(triang, zs)
    Xi, Yi = np.meshgrid(xi, yi)
    zi = interpolator(Xi, Yi)
    cb = ax.contourf(xi, yi, zi, levels=levels, cmap="RdBu_r")
    if rm_axis:
        ax.set_xticklabels([])
        ax.set_yticklabels([])
    else:
        ax.xaxis.set_major_locator(ticker.MultipleLocator(1 if (np.max(xs[:,0]) - np.min(xs[:,0]) > 1.0) else 0.5))
        ax.yaxis.set_major_locator(ticker.MultipleLocator(1 if (np.max(xs[:,1]) - np.min(xs[:,1]) > 1.0) else 0.5))
    return cb


def plot_contours(xs, ys_list, titles, res=200, sym_colour=False, ptile=False, cbar=True):
    
    nrows = ys_list[0].shape[1]
    fig, axs = plt.subplots(
        nrows=nrows, 
        ncols=len(ys_list), 
        sharex=True, 
        sharey=True, 
        figsize=(3 * (len(ys_list) + (1 if cbar else -1)), 3 * nrows + 2),
        constrained_layout=True
    )
    
    p_d = 1
    if nrows == 1:
        if ptile:
            min_ = np.percentile(ys_list, p_d)
            max_ = np.percentile(ys_list, 100-p_d)
        else:
            min_ = np.min(ys_list)
            max_ = np.max(ys_list)
        if sym_colour and (min_ < 0 < max_):
            m = max(-min_, max_)
            min_ = -m
            max_ = m
        if ptile:
            ys_list = [np.clip(y, min_, max_) for y in ys_list]
        levels = np.linspace(min_, max_, num=res)
        if not hasattr(axs, '__iter__'):
            axs = np.array([axs])
        for ax, zs, title in zip(axs, ys_list, titles):
            cb = contour_on_ax(ax, xs, zs[:,0], levels, res, rm_axis=not cbar)
            ax.set_title(title)
        axs = axs.ravel().tolist()
        if cbar:
            fig.colorbar(cb, ax=axs)
    
    else:
        for i in range(nrows):
            ys_list_reduced = [y[:,i] for y in ys_list]
            if ptile:
                min_ = np.percentile(ys_list_reduced, p_d)
                max_ = np.percentile(ys_list_reduced, 100-p_d)
            else:
                min_ = np.min(ys_list_reduced)
                max_ = np.max(ys_list_reduced)
            if sym_colour and (min_ < 0 < max_):
                m = max(-min_, max_)
                min_ = -m
                max_ = m
            if ptile:
                ys_list_reduced = [np.clip(y, min_+1e-9, max_-1e-9) for y in ys_list_reduced]
            levels = np.linspace(min_, max_, num=res)
            for ax, zs in zip(axs[i], ys_list_reduced):
                cb = contour_on_ax(ax, xs, zs, levels, res, rm_axis=not cbar)
            if cbar:
                fig.colorbar(cb, ax=axs[i])
        for ax, title in zip(axs[0], titles):
            ax.set_title(title)
    
    return fig, axs

In [None]:
def plot_training_data(ax, samples):
    ms = 4.
    ax.plot(samples['res'][:, 0], samples['res'][:, 1], 'o', color='black', ms=ms, alpha=0.95, zorder=10, clip_on=False)
    if 'anc' in samples.keys():
        ax.plot(samples['anc'][:, 0], samples['anc'][:, 1], '^', color='blue', ms=ms, alpha=0.95, zorder=10, clip_on=False)
    for i, bc_pts in enumerate(samples['bcs']):
        ax.plot(bc_pts[:, 0], bc_pts[:, 1], 's', color=f'C{i+1}', ms=ms, alpha=0.95, zorder=10, clip_on=False)

In [None]:
example_folder = '../../al_pinn_results/conv-1d{1.0}_pb-80_ic/nn-None-8-128_adam_bcsloss-1.0_budget-1000-200-0/kmeans_alignment_scale-none_mem_autoal/20230914101511'
eigplot_folder = '../../al_pinn_graphs_final/eigplots/conv-80'

os.makedirs(eigplot_folder, exist_ok=True)

In [None]:
for s in [10000, 100000]:

    steps_range = [s]
    step_idx = 0

    print('plots number', step_idx)

    [x.delete() for x in jax.devices()[0].client.live_buffers()];

    model, model_aux = construct_model(

        pde_name='conv-1d', 
        data_seed=40,
        pde_const=(1.0,), 
        use_pdebench=True,
        test_max_pts=50000,
        include_ic=True,
        data_root='~/pdebench',

        # model params
        hidden_layers=8, 
        hidden_dim=128, 
        activation='tanh', 
        initializer='Glorot uniform', 
        arch=None, 

    )

    d = dict()


    for file in os.listdir(example_folder):

        if file.startswith('snapshot_data'):

            fname = f'{example_folder}/{file}'

            with open(fname, 'rb') as f:
                d_update = pkl.load(f)

            d.update(d_update)

    x_test = d[None]['x_test']
    y_test = d[None]['y_test']

    d_modified = {
        'x_test': x_test,
        'y_test': d[None]['y_test'],
        'steps': steps_range,
        'res_mean': [d[k]['residue_test_mean'] for k in steps_range],
        'err_mean': [d[k]['error_test_mean'] for k in steps_range],
        'err_q50': [np.percentile(d[k]['error_test'], 50) for k in steps_range],
        'err_q90': [np.percentile(d[k]['error_test'], 90) for k in steps_range],
        'err_q95': [np.percentile(d[k]['error_test'], 95) for k in steps_range],
        'err_q100': [np.percentile(d[k]['error_test'], 100) for k in steps_range],
        'res': [d[k]['residue_test'] for k in steps_range],
        'err': [d[k]['error_test'] for k in steps_range],
        'pred': [d[k]['pred_test'] for k in steps_range],
        'chosen_pts': [d[k]['al_intermediate']['chosen_pts'] for k in steps_range],
        'inv': [d[k]['params'][1] for k in steps_range],
        'params': [d[k]['params'][0] for k in steps_range],
    }

    ntk = NTKHelper(model)

    res = 80
    from scipy.spatial.distance import cdist
    xi, yi = [jnp.linspace(jnp.min(x_test[:,i]), jnp.max(x_test[:,i]), res) for i in range(2)]
    grid = jnp.array([y.flatten() for y in jnp.meshgrid(xi, yi)]).T
    grid_idxs = np.argmin(cdist(grid, x_test), axis=1)
    grid = x_test[grid_idxs]
    grid_ans = y_test[grid_idxs]

    jac_I = ntk.get_jac(grid, code=-2, params=d_modified['params'][step_idx])
    jac_N = ntk.get_jac(grid, code=-1, params=d_modified['params'][step_idx])

    T_ii = ntk.get_ntk(jac1=jac_I, jac2=jac_I)
    T_in = ntk.get_ntk(jac1=jac_I, jac2=jac_N)
    T_nn = ntk.get_ntk(jac1=jac_N, jac2=jac_N)

    T = np.block([[T_ii, T_in], [T_in.T, T_nn]])
    T = T + 1e-9 * np.eye(T.shape[0])

    eigvals, eigvects = np.linalg.eigh(T)
    eigvals = eigvals[::-1] / (res**2)
    eigvects = eigvects.T[::-1]    

    ans_flat = grid_ans.reshape(-1)
    ys_true = np.concatenate([ans_flat, jnp.zeros_like(ans_flat)])

    ys_ = model.net.apply(d_modified['params'][step_idx], grid)
    ys_res = model.data.pde(grid, (ys_, lambda x: model.net.apply(d_modified['params'][step_idx], x)))[0]
    ys_pred = np.concatenate([ys_.reshape(-1), ys_res.reshape(-1)])

    ys_diff = ys_pred - ys_true

    
    for k, eigvals_rank in enumerate([[10, 20, 50, 100, 200, 500, 1000], [10, 50, 500]]):

        for ys, name in [
            (ys_true, 'ys_true'),
            (ys_pred, 'ys_pred'),
            (ys_diff, 'ys_res'),
        ]:

            coeffs = np.sum(ys * eigvects, axis=1)
            scaled_vects = coeffs[:,None] * eigvects

            fig, axs = plot_contours(
                xs=grid, 
                ys_list=[ys.reshape(2, -1).T] + [
                    np.sum(scaled_vects[:i], axis=0).reshape(2, -1).T
                    for i in eigvals_rank
                ], 
                titles=['True solution'] + [
                    f'Top {i} eig.fn.'
                    for i in eigvals_rank
                ], 
                res=200, sym_colour=False, ptile=False, cbar=True)

            axs[0,0].set_ylabel('Experimental pts.')
            axs[1,0].set_ylabel('PDE Collocation pts.')
            fig.savefig(os.path.join(eigplot_folder, f'eigdecomp-s{steps_range[step_idx]}-{name}-{k}.png'), bbox_inches='tight', pad_inches=0.1)
            plt.close('all')

In [None]:
def plot_contours_eigval(xs, ys_list, titles, res=200, cbar=False):
    
    nrows = ys_list[0].shape[1]
    fig, axs = plt.subplots(
        nrows=nrows, 
        ncols=len(ys_list), 
        sharex=True, 
        sharey=True, 
        figsize=(4 * len(ys_list) - 3, 4 * nrows),
        constrained_layout=True
    )
    
    p_d = 1
    if nrows == 1:
        axs = [axs]
    for i in range(nrows):
        ys_list_reduced = [y[:,i] for y in ys_list]
        cb = contour_on_ax(
            axs[i][0], xs, ys_list_reduced[0], 
            np.linspace(np.min(ys_list_reduced[0]), np.max(ys_list_reduced[0]), num=res), 
            res, rm_axis=False)
        min_ = np.percentile(ys_list_reduced[1:], p_d)
        max_ = np.percentile(ys_list_reduced[1:], 100-p_d)
        m = max(-min_, max_)
        min_ = -m
        max_ = m
        ys_list_reduced = [np.clip(y, min_+1e-9, max_-1e-9) for y in ys_list_reduced]
        levels = np.linspace(min_, max_, num=res)
        for ax, zs in zip(axs[i][1:], ys_list_reduced[1:]):
            cb = contour_on_ax(ax, xs, zs, levels, res, rm_axis=True)
        if cbar:
            fig.colorbar(cb, ax=axs[i])
    for ax, title in zip(axs[0], titles):
        ax.set_title(title)
    
    return fig, axs

In [None]:
eigvals_rank = [1, 2, 3, 10, 20, 100, 1000]
steps_range = [10000, 100000]

for step_idx in range(len(steps_range)):
    
    print('plots number', step_idx)
    
    [x.delete() for x in jax.devices()[0].client.live_buffers()];
    
    model, model_aux = construct_model(

        pde_name='conv-1d', 
        data_seed=40,
        pde_const=(1.0,), 
        use_pdebench=True,
        test_max_pts=50000,
        include_ic=True,
        data_root='~/pdebench',

        # model params
        hidden_layers=8, 
        hidden_dim=128, 
        activation='tanh', 
        initializer='Glorot uniform', 
        arch=None, 

    )

    d = dict()


    for file in os.listdir(example_folder):

        if file.startswith('snapshot_data'):

            fname = f'{example_folder}/{file}'

            with open(fname, 'rb') as f:
                d_update = pkl.load(f)

            d.update(d_update)

    x_test = d[None]['x_test']

    d_modified = {
        'x_test': x_test,
        'y_test': d[None]['y_test'],
        'steps': steps_range,
        'res_mean': [d[k]['residue_test_mean'] for k in steps_range],
        'err_mean': [d[k]['error_test_mean'] for k in steps_range],
        'err_q50': [np.percentile(d[k]['error_test'], 50) for k in steps_range],
        'err_q90': [np.percentile(d[k]['error_test'], 90) for k in steps_range],
        'err_q95': [np.percentile(d[k]['error_test'], 95) for k in steps_range],
        'err_q100': [np.percentile(d[k]['error_test'], 100) for k in steps_range],
        'res': [d[k]['residue_test'] for k in steps_range],
        'err': [d[k]['error_test'] for k in steps_range],
        'pred': [d[k]['pred_test'] for k in steps_range],
        'chosen_pts': [d[k]['al_intermediate']['chosen_pts'] for k in steps_range],
        'inv': [d[k]['params'][1] for k in steps_range],
        'params': [d[k]['params'][0] for k in steps_range],
    }

    ntk = NTKHelper(model)

    res = 90
    xi, yi = [jnp.linspace(jnp.min(x_test[:,i]), jnp.max(x_test[:,i]), res) for i in range(2)]
    grid = jnp.array([y.flatten() for y in jnp.meshgrid(xi, yi)]).T
    
    ys = model.net.apply(d_modified['params'][step_idx], grid)
    ys_res = model.data.pde(grid, (ys, lambda x: model.net.apply(d_modified['params'][step_idx], x)))[0]
    ys_pred_grid = jnp.concatenate([ys, ys_res], axis=1)

    jac_I = ntk.get_jac(grid, code=-2, params=d_modified['params'][step_idx])
    jac_N = ntk.get_jac(grid, code=-1, params=d_modified['params'][step_idx])

    T_ii = ntk.get_ntk(jac1=jac_I, jac2=jac_I)
    T_in = ntk.get_ntk(jac1=jac_I, jac2=jac_N)
    T_nn = ntk.get_ntk(jac1=jac_N, jac2=jac_N)

    T = jnp.block([[T_ii, T_in], [T_in.T, T_nn]])
    T = T + 1e-9 * jnp.eye(T.shape[0])

    eigvals, eigvects = jnp.linalg.eigh(T)
    eigvals = eigvals[::-1] / (res**2)
    eigvects = eigvects.T[::-1]    
    
    fig, ax = plt.subplots()
    ax.semilogy(eigvals[:1000])
    ax.set_xlabel('Eigenvalue rank')
    ax.set_ylabel('Eigenvalue')
    fig.tight_layout()
    fig.savefig(os.path.join(eigplot_folder, f's{steps_range[step_idx]}-eigval.pdf'))
    plt.close('all')
    
    eigvects_modifies = [jnp.array([eigvects[idx-1, :res**2], eigvects[idx-1, res**2:]]).T for idx in eigvals_rank]
    fig, axs = plot_contours_eigval(
        grid, 
        [ys_pred_grid] + eigvects_modifies, 
        ['NN output'] + [f'$\lambda_{{{i}}}$ = {eigvals[i]:.1E}' for i in eigvals_rank], 
    )
    axs[0,0].set_ylabel('Experimental pts.')
    axs[1,0].set_ylabel('PDE Collocation pts.')
    fig.suptitle(f'Step {steps_range[step_idx]}', fontsize=1.8*plt.rcParams['font.size'])
    fig.savefig(os.path.join(eigplot_folder, f's{steps_range[step_idx]}-eigvect.png'), bbox_inches='tight', pad_inches=0.1)
    plt.close('all')

In [None]:
eigvals_rank = [0, 1, 10, 100]

# example_folder = '../../al_pinn_results/conv-1d{1.0}_pb-40_ic/nn-None-8-128_adam_bcsloss-1.0_budget-1000-200-0/kmeans_alignment-none_mem_autoal/'
# eigplot_folder = '../../al_pinn_graphs_eigplots/conv-40'
# alg_name = 'PINNACLE-K'
# step_num = 50000

example_folder = '../../al_pinn_results/conv-1d{1.0}_pb-40_ic/nn-None-8-128_adam_bcsloss-1.0_budget-1000-200-0/sampling_alignment_scale-none_mem_autoal/20230901072031'
eigplot_folder = '../../al_pinn_graphs_eigplots/conv-40'
alg_name = 'PINNACLE-S'
step_num = 50000

In [None]:
[x.delete() for x in jax.devices()[0].client.live_buffers()];

model, model_aux = construct_model(

    pde_name='conv-1d', 
    data_seed=40,
    pde_const=(1.0,), 
    use_pdebench=True,
    test_max_pts=50000,
    include_ic=True,
    data_root='~/pdebench',

    # model params
    hidden_layers=8, 
    hidden_dim=128, 
    activation='tanh', 
    initializer='Glorot uniform', 
    arch=None, 

)

d = dict()


for file in os.listdir(example_folder):

    if file.startswith('snapshot_data'):

        fname = f'{example_folder}/{file}'

        with open(fname, 'rb') as f:
            d_update = pkl.load(f)

        d.update(d_update)

x_test = d[None]['x_test']

params = d[step_num]['params'][0]
train_pts_series = {s: d[s]['al_intermediate']['chosen_pts'] for s in d.keys() if s is not None}
chosen_pts = d[step_num]['al_intermediate']['new_points']

print(d[step_num]['error_test_mean'])

In [None]:
pts_prop = []
for s in sorted(train_pts_series.keys()):
    tr = train_pts_series[s]
    n_res = tr['res'].shape[0]
    n_ic = tr['bcs'][0].shape[0]
    n_bc = tr['bcs'][1].shape[0]
    pts_prop.append((n_res, n_ic, n_bc))
    
pts_prop = np.array(pts_prop)
plt.stackplot(sorted(train_pts_series.keys()), *(pts_prop / np.sum(pts_prop, axis=1)[:,None]).T, alpha=0.8,
             labels=['Residual', 'IC', 'BC'])
plt.legend(loc='lower right')
plt.xlabel('Steps')
plt.ylabel('Proportion of training set')
plt.xticks(range(0, 150001, 50000))
plt.savefig(os.path.join(eigplot_folder, f'all-pointsel_{alg_name}.pdf'), bbox_inches='tight', pad_inches=0.1)
plt.close()

In [None]:
ntk = NTKHelper(model)

res = 90
xi, yi = [jnp.linspace(jnp.min(x_test[:,i]), jnp.max(x_test[:,i]), res) for i in range(2)]
grid = jnp.array([y.flatten() for y in jnp.meshgrid(xi, yi)]).T

ys = model.net.apply(params, grid)
ys_res = model.data.pde(grid, (ys, lambda x: model.net.apply(params, x)))[0]
ys_pred_grid = jnp.concatenate([ys, ys_res], axis=1)

jac_I = ntk.get_jac(grid, code=-2, params=params)
jac_N = ntk.get_jac(grid, code=-1, params=params)

T_ii = ntk.get_ntk(jac1=jac_I, jac2=jac_I)
T_in = ntk.get_ntk(jac1=jac_I, jac2=jac_N)
T_nn = ntk.get_ntk(jac1=jac_N, jac2=jac_N)

T = jnp.block([[T_ii, T_in], [T_in.T, T_nn]])
T = T + 1e-9 * jnp.eye(T.shape[0])

eigvals, eigvects = jnp.linalg.eigh(T)
eigvals = eigvals[::-1] / (res**2)
eigvects = eigvects.T[::-1]

In [None]:
res = 200
xi, yi = [jnp.linspace(jnp.min(x_test[:,i]), jnp.max(x_test[:,i]), res) for i in range(2)]
grid = jnp.array([y.flatten() for y in jnp.meshgrid(xi, yi)]).T

ys = model.net.apply(params, grid)
ys_res = model.data.pde(grid, (ys, lambda x: model.net.apply(params, x)))[0]
ys_pred_grid = jnp.concatenate([ys, ys_res], axis=1)

jac_Ip = ntk.get_jac(grid, code=-2, params=params)
jac_Np = ntk.get_jac(grid, code=-1, params=params)

In [None]:
T = jnp.block([
    [ntk.get_ntk(jac1=jac_I, jac2=jac_Ip), ntk.get_ntk(jac1=jac_I, jac2=jac_Np)], 
    [ntk.get_ntk(jac1=jac_N, jac2=jac_Ip), ntk.get_ntk(jac1=jac_N, jac2=jac_Np)]
])

In [None]:
eigvects_modifies = [jnp.array([eigvects[idx, :res**2], eigvects[idx, res**2:]]).T for idx in eigvals_rank]
fig, axs = plot_contours_eigval(
    grid, 
    [ys_pred_grid] + eigvects_modifies, 
    ['NN output'] + [f'$\lambda_{{{i}}}$ = {eigvals[i]:.1E}' for i in eigvals_rank], 
    cbar=True,
)
for ax_row in axs:
    for ax in ax_row[1:]:
        plot_training_data(ax, chosen_pts)
axs[0,0].set_ylabel('Prediction')
axs[1,0].set_ylabel('PDE Residual')
fig.suptitle(alg_name)
fig.savefig(os.path.join(eigplot_folder, f's{step_num}-pointsel_{alg_name}.png'), bbox_inches='tight', pad_inches=0.1)
plt.close('all')

In [None]:
eigvects_modifies = [jnp.array([eigvects[idx, res**2:]]).T for idx in eigvals_rank]
fig, axs = plot_contours_eigval(
    grid, 
    [ys_pred_grid[:,1:2]] + eigvects_modifies, 
    ['NN output'] + [f'$\lambda_{{{i}}}$ = {eigvals[i]:.1E}' for i in eigvals_rank], 
    cbar=True,
)
for ax_row in axs:
    plot_training_data(ax, chosen_pts)
axs[0][0].set_ylabel('PDE Residual')
fig.suptitle(alg_name)
fig.savefig(os.path.join(eigplot_folder, f's{step_num}-pointsel_{alg_name}-resonly.png'), bbox_inches='tight', pad_inches=0.1)
plt.close('all')

In [None]:
P = d[0]['al_intermediate']['P']
P.shape

In [None]:
plt.imshow(np.log(np.abs(P)))
plt.colorbar()

In [None]:
# idx = np.argsort(np.linalg.norm(P, axis=1))[-50:]
plt.plot(P[:,-1], P[:,-2], '.')
# plt.plot(P[idx,-1], P[idx,-2], '.')
plt.xlabel('1st eigenvector coefficient')
plt.ylabel('2nd eigenvector coefficient')
# plt.xscale('symlog')
# plt.yscale('symlog')

In [None]:
# idx = np.argsort(np.linalg.norm(P, axis=1))[-50:]
plt.plot(P[:,-1], P[:,-2], '.')
# plt.plot(P[idx,-1], P[idx,-1000], '.')
plt.xlabel('1st eigenvector coefficient')
plt.ylabel('100th eigenvector coefficient')
# plt.xscale('symlog')
# plt.yscale('symlog')
plt.xlim(-0.008, 0.008)
plt.ylim(-0.008, 0.008)

In [None]:
[x.delete() for x in jax.devices()[0].client.live_buffers()];

case_folder, steps_plot, suptit = [
    'al_pinn_results/fd-2d{1.0-0.01}_inv_anc[0,1]/nn-None-4-64_adam_bcsloss-1.0_budget-1000-200-30',
    [0, 5000, 10000, 50000],
    '2D Fluid Dynamics (Inv)',
]
    
throwout = []

print('PROCESSING:', case_folder)

max_steps = steps_plot[-1]

root_folder = os.path.join(data_folder, case_folder)

_, arch, depth, width = root_folder.split('/')[-1].split('_')[0].split('-')


model, _ = construct_model(

    pde_name='fd-2d', 
    data_seed=40,
    pde_const=(1.0, 0.01), 
    use_pdebench=False,
    test_max_pts=50000,
    include_ic=True,
    inverse_problem=True,

    # model params
    hidden_layers=int(depth), 
    hidden_dim=int(width), 
    arch=(None if arch == 'None' else arch),
    activation='tanh', 
    initializer='Glorot uniform', 

)

ntk = NTKHelper(model)

In [None]:
cases = {x: os.listdir(f'{root_folder}/{x}') for x in algs.keys() if os.path.exists(f'{root_folder}/{x}')}
print('Exist:', list(cases.keys()))

data = dict()
steps_min = dict()
plotted_cases = dict()

for c in cases.keys():

    s_min = float('inf')

    runs = []
    runs_cases = []

    for r in cases[c]:

        try:

            d = dict()

            if len(runs) < 10:

                for file in os.listdir(f'{root_folder}/{c}/{r}'):

                    if file.startswith('snapshot_data'):

                        fname = f'{root_folder}/{c}/{r}/{file}'

                        with open(fname, 'rb') as f:
                            d_update = pkl.load(f)

                        d.update(d_update)

            steps_range = sorted([x for x in d.keys() if (x is not None) and (max_steps >= x)])
            if (len(steps_range) > 0) and (max_steps == steps_range[-1]) and (None in d.keys()):

                print(c, r, sorted([x for x in d.keys() if (x is not None)])[-1])

                s_min = min(s_min, steps_range[-1])

                x_test = d[None]['x_test']

                d_modified = {
                    'x_test': x_test,
                    'y_test': d[None]['y_test'],
                    'steps': steps_range,
                    'res_mean': [d[k]['residue_test_mean'] for k in steps_range],
                    'err_mean': [d[k]['error_test_mean'] for k in steps_range],
                    'err_q50': [np.percentile(d[k]['error_test'], 50) for k in steps_range],
                    'err_q90': [np.percentile(d[k]['error_test'], 90) for k in steps_range],
                    'err_q95': [np.percentile(d[k]['error_test'], 95) for k in steps_range],
                    'err_q100': [np.percentile(d[k]['error_test'], 100) for k in steps_range],
                    'res': [d[k]['residue_test'] for k in steps_range],
                    'err': [d[k]['error_test'] for k in steps_range],
                    'pred': [d[k]['pred_test'] for k in steps_range],
                    'chosen_pts': [d[k]['al_intermediate']['chosen_pts'] for k in steps_range],
                    'inv': [d[k]['params'][1] for k in steps_range],
                    'params': [d[k]['params'][0] for k in steps_range],
                }

                if x_test.shape[1] == 2:

                    arr_shape = [d_modified['y_test'].shape[1]] + [np.unique(x).shape[0] for x in d_modified['x_test'].T]
                    d_modified['y_test_fft'] = np.fft.fftn(
                        d_modified['y_test'].reshape(*arr_shape), 
                        axes=[1, 2]
                    )
                    d_modified['pred_fft'] = [np.fft.fftn(
                        y.reshape(*arr_shape), axes=[1, 2]) 
                        for y in d_modified['pred']]

                    d_modified['fft_err'] = [np.abs(yf - d_modified['y_test_fft'])
                        for yf in d_modified['pred_fft']]

        #             idxs = np.meshgrid(np.arange(arr_shape[1]), np.arange(arr_shape[2]))[0].T
        #             idxs = np.array([idxs, idxs])

                    idxs = np.array(np.meshgrid(np.arange(arr_shape[1]), np.arange(arr_shape[2]))).swapaxes(1, 2)

                    klow = (idxs <= 4).all(axis=0).astype(float)
                    kmid = (idxs <= 12).all(axis=0).astype(float) - klow
                    khigh = (idxs <= np.inf).all(axis=0).astype(float) - kmid - klow

                    for s, k in [('low', klow), ('mid', kmid), ('high', khigh)]:
                        d_modified[f'fft_mean_{s}'] = [np.sum(yf * k[None, :]) / (np.sum(k) * yf.shape[0])
                            for yf in d_modified['fft_err']]

                if 'darcy' in case_folder:
                    a_pred = [y[:,0] for y in d_modified['pred']]
                    a_true = d_modified['y_test'][:,0]
                    d_modified['a_err_mean'] = [np.mean((a_true - y)**2) for y in a_pred]
                    d_modified['a_err_q90'] = [np.percentile((a_true - y)**2, 90) for y in a_pred]
                    f_true = np.array(a_true > 0.5, dtype=float)
                    f_pred = [np.array(y > 0.5, dtype=float) for y in a_pred]
                    d_modified['bool_err_mean'] = [np.mean(np.abs(f_true - y)) for y in f_pred]
                    multidim = True
                elif 'reacdiff' in case_folder:
                    multidim = True
                elif 'fd-2d' in case_folder:
                    multidim = True
                else:
                    multidim = False

                if 'inv' in case_folder:
                    d_modified['inv_param_true'] = [float(x) for x in case_folder.split('{')[1].split('}')[0].split('-')]
                    d_modified['inv_param_pred'] = tuple([float(d[k]['params'][1][i]) for k in steps_range] for i in range(len(d_modified['inv_param_true'])))
                    inv_params = d_modified['inv_param_true']
                    plot_inv_param = True
                else:
                    plot_inv_param = False

                runs.append(d_modified)
                runs_cases.append(r)

            else:
                throwout.append(f'{root_folder}/{c}/{r}')

        except Exception as e:
#                 raise e
            pass

    if len(runs) > 0:
        data[c] = runs
        steps_min[c] = s_min
        plotted_cases[c] = runs_cases

print('To plot algorithms =', {k: len(data[k]) for k in data.keys()})
print()

In [None]:
all_vals = dict()

for c in data.keys():
    val = []
    for d in data[c]:
        trials = []
        for p in zip(d['params'], d['inv']):
            jac_cross = ntk.get_pde_jac_crossterm(x_test[::100], p)
            cross_size = [
                jnp.mean(jnp.linalg.norm(jnp.concatenate(list(jac_cross[i].values()), axis=1), axis=1))
                for i in range(2)
            ]
            trials.append(cross_size)
        val.append(trials)
    all_vals[c] = np.array(val)

In [None]:
fig, ax = plt.subplots()
for c in data.keys():
    v = all_vals[c][...,0]
    mean = np.mean(v, axis=0)
    err = np.std(v, axis=0)
    label, marker = algs[c]
    ax.set_yscale("log")
    ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, **marker)
# ax.xaxis.set_major_locator(ticker.MultipleLocator([0, 25000, 50000]))

In [None]:
with open('../../al_pinn_results/conv-1d{1.0}_pb-40_ic/nn-None-8-128_adam_bcsloss-1.0_budget-1000-200-0/kmeans_alignment_scale-none_mem_autoal/20230912192601/snapshot_data_s100000.pkl', 'rb') as f:
    d = pkl.load(f)

In [None]:
params = d[100000]['params']
del d

In [None]:
# n = 100

# algs = ['eig_kmeans', 'eig_sampling', 'residue']
# active_set = {a: dict() for a in algs}
    
# for factor in [n, 2*n, 4*n]:

#     for res_scale in [2, 4]:
        
#         for a in algs:

a = 'eig_kmeans'
factor = 1000
res_scale = 4

model, model_aux = construct_model(
    pde_name='conv-1d', 
    data_seed=40,
    pde_const=(1.,), 
    use_pdebench=True,
    num_domain=2000, 
    num_boundary=500, 
    num_initial=500,
    include_ic=True,
    data_root='~/pdebench',
    test_max_pts=50000,
    hidden_layers=8, 
    hidden_dim=128, 
    activation='tanh', 
    initializer='Glorot uniform', 
)

if a == 'residue':
    al_args = dict(
        res_proportion=0.8,
        select_icbc_with_residue=True,
        select_anc_with_residue=True,
    )

elif a.startswith('eig'):
    al_args = dict(
        num_points_round=n,
        weight_method= "alignment", 
        num_candidates_res=res_scale*factor,
        num_candidates_bcs=factor,
        num_candidates_init=factor,
        memory=True, # True to remember old points and add on new ones
        sampling='pseudo', # uniform, pseudo
        min_num_points_bcs=1,
        min_num_points_res=1,
    )

optim_args = dict(
    train_steps=1,
    al_every=1,
    select_anchors_every=1,
    snapshot_every=100,
    optim_method='adam', 
    optim_lr=1e-3, 
    optim_args=dict(),
)


train_loop = ModifiedTrainLoop(
    model=model, 
    point_selector_method=a,
    point_selector_args=al_args,
    mem_pts_total_budget=n,
    anchor_budget=10,
    autoscale_loss_w_bcs=False,
    ntk_ratio_threshold=0.5,
    tensorboard_plots=False,
    **optim_args,
)

model.params = params
model.net.params = params[0]

train_loop.train()

active_set[a][factor, res_scale] = train_loop.al_data_round[0]

In [None]:
demo_data = {
    k: train_loop.snapshot_data[0]['al_intermediate'][k]
    for k in ['P', 'candidate_pts', 'residual_candidates', 'K_train_test', 'NTK']
}

In [None]:
with open(f'int_data_demo.pkl', 'wb+') as f:
    pkl.dump(demo_data, f)

In [None]:
with open(f'int_data_demo.pkl', 'rb') as f:
    demo_data = pkl.load(f)

In [None]:
from sklearn.cluster._kmeans import kmeans_plusplus
from scipy.spatial import ConvexHull

In [None]:
data = [
    demo_data['candidate_pts']['res'],
    demo_data['candidate_pts']['bcs'][0],
    demo_data['candidate_pts']['bcs'][1],
    demo_data['candidate_pts']['anc']
]

i = 0
Ps = []
for d in data:
    j = d.shape[0]
    Ps.append(np.array(demo_data['P'].T)[i:i+j])
    i += j

In [None]:
pools = []
for j in [1, 2, 4]:
    for i in [100, 200, 400, 800]:
        pools.append((i*j, i, i, i*j))
        
pools

In [None]:
n = 100
pts = {k: [] for k in ['kmeans', 'sampling', 'greedy']}


for _ in range(10):

    for prop in pools:
        
        for k in pts.keys():
        
            prop = list(prop)
            prop[-1] = 0

            data_sub = [d[:i] for (i, d) in zip(prop, data)]
            P_sub = np.concatenate([d[:i] for (i, d) in zip(prop, Ps)])
    
            if k == 'kmeans':
                _, idxs = kmeans_plusplus(P_sub, n)
                idxs = np.sort(idxs)
            elif k == 'greedy':
                v = np.linalg.norm(P_sub, axis=1)
                idxs = np.argsort(v)[-n:]
            elif k == 'sampling':
                v = np.linalg.norm(P_sub, axis=1)
                prob = v**2 / np.sum(v**2)
                idxs = np.random.choice(P_sub.shape[0], size=n, replace=False, p=prob)
            else:
                continue

            data_chosen = []
            i = 0
            for j, d in enumerate(data_sub):
                id_chosen = idxs[(i <= idxs) & (idxs < i + d.shape[0])] - i
                data_chosen.append(d[id_chosen])
                i += d.shape[0]

            pts[k].append(data_chosen)

In [None]:
for i, k in enumerate(pts.keys()):
    pt = pts[k]
    szs = np.array([[d.shape[0] for d in e] for e in pt])
    plt.plot(szs[:,1], szs[:,2], '.', color=f'C{i}', label=k, alpha=0.6)
    hull = ConvexHull(szs[:,1:3])
    for simplex in hull.simplices:
        plt.plot(szs[simplex, 1], szs[simplex, 2], color=f'C{i}')
    
plt.axis('square')
plt.legend()

In [None]:
fig, axs = plt.subplots(
    nrows=3, 
    ncols=4, 
    sharex=True, 
    sharey=True, 
    figsize=(15, 11),
    constrained_layout=True
)

for ax_row, a in zip(axs, pts.keys()):
    for ax, samples in zip(ax_row, pts[a][8:12]):
        
        ms = 3.
        ax.plot(samples[0][:, 0], samples[0][:, 1], 'o', color='black', ms=ms, label=samples[0].shape[0])
#         ax.plot(samples[3][:, 0], samples[3][:, 1], '^', color='blue', ms=ms, label=samples['anc'].shape[0])
        for i in range(2):
            ax.plot(samples[i+1][:, 0], samples[i+1][:, 1], 's', color=f'C{i+1}', ms=ms, label=samples[i+1].shape[0])
        ax.legend(loc='upper left', borderpad=0.1)