In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from functools import partial
import os
import pickle as pkl
from collections.abc import MutableMapping

import matplotlib.pyplot as plt
import matplotlib.tri as tri
import numpy as np
import tqdm

os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["DDE_BACKEND"] = "jax"

# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".XX"
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"

from jax import config
config.update("jax_enable_x64", True)
# config.update("jax_debug_nans", True)

import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
import optax

try:
    print(f'Jax: CPUs={jax.local_device_count("cpu")} - GPUs={jax.local_device_count("gpu")}')
except:
    pass
    
import deepxde_al_patch.deepxde as dde

from deepxde_al_patch.model_loader import construct_model, construct_net
from deepxde_al_patch.modified_train_loop import ModifiedTrainLoop
from deepxde_al_patch.plotters import plot_residue_loss, plot_error, plot_prediction
from deepxde_al_patch.train_set_loader import load_data

from deepxde_al_patch.ntk import NTKHelper
from deepxde_al_patch.utils import get_pde_residue, print_dict_structure

In [None]:
plt.rcParams['figure.dpi'] = 300

plt.rcParams.update({
    'font.size': 12,
    'text.usetex': False,
})

In [None]:
algs = {
    'random': ('Random', 'x:k'),
    'sampling_residue_scale-none_mem_autoal': ('Residue', '^--r'),
    'greedy_nystrom_wo_N_scale-none_mem_autoal': ('PINNAcLe-Gr', 'p-m'),
    'sampling_nystrom_wo_N_scale-none_mem_autoal': ('PINNAcLe-Sa', 'h-c'),
    'kmeans_nystrom_wo_N_scale-none_mem_autoal': ('PINNAcLe-KM', 'o-b'),
}

In [None]:
data_folder = '../../'

In [None]:
case_list = [

    (
        'al_pinn_results/conv-1d{1.0}_pb-40_anc/nn-None-8-128_adam_bcsloss-1.0_budget-1000-200-2',
        [0, 10000, 20000, 50000, 100000, 150000],
    ),

    (
        'al_pinn_results/conv-1d{1.0}_pb-40_ic/nn-None-8-128_adam_bcsloss-1.0_budget-1000-200-0'
        [0, 10000, 20000, 50000, 100000, 200000],
    ),

    (
        'al_pinn_results/conv-1d{1.0}_pb-40_inv_anc/nn-None-8-128_adam_bcsloss-1.0_budget-1000-200-5',
        [0, 10000, 20000, 50000, 100000, 200000],
    ),

    (
        'al_pinn_results_ic_change/conv-1d{1.0}_ftic-40-898_anc/nn-None-8-128_adam_bcsloss-1.0_budget-200-50-0',
        [0, 20000, 40000, 60000, 80000, 100000],
    ),

    (
        'al_pinn_results/conv-1d{1.0}_pb-80_anc/nn-None-8-128_adam_bcsloss-1.0_budget-1000-200-2',
        [0, 20000, 40000, 60000, 80000, 100000],
    ),

    (
        'al_pinn_results/conv-1d{1.0}_pb-80_ic/nn-None-8-128_adam_bcsloss-1.0_budget-1000-200-0',
        [0, 10000, 20000, 50000, 100000, 200000],
    ),

    (
        'al_pinn_results/conv-1d{1.0}_pb-80_inv_anc/nn-None-8-128_adam_bcsloss-1.0_budget-1000-200-5',
        [0, 10000, 20000, 50000, 100000, 200000],
    ),

    (
        'al_pinn_results_ic_change/conv-1d{1.0}_ftic-80-272_ic/nn-None-8-128_adam_bcsloss-1.0_budget-200-50-0',
        [0, 20000, 40000, 60000, 80000, 100000],
    ),

    (
        'al_pinn_results_ic_change/conv-1d{1.0}_ftic-80-272_anc/nn-None-8-128_adam_bcsloss-1.0_budget-200-50-2',
        [0, 20000, 40000, 60000, 80000, 100000],
    ),
    
]


case_folder, steps_plot = case_list[8]
max_steps = steps_plot[-1]

In [None]:
root_folder = os.path.join(data_folder, case_folder)

In [None]:
_, arch, depth, width = root_folder.split('/')[-1].split('_')[0].split('-')

net, _ = construct_net(
    input_dim=2, 
    output_dim=1, 
    hidden_layers=int(depth), 
    hidden_dim=int(width), 
    arch=(None if arch == 'None' else arch)
)

In [None]:
cases = {x: os.listdir(f'{root_folder}/{x}') for x in algs.keys() if os.path.exists(f'{root_folder}/{x}')}
cases.keys()

In [None]:
data = dict()
steps_min = dict()

for c in cases.keys():
    
    s_min = float('inf')
    
    runs = []
    for r in cases[c]:
        
        d = dict()
        
        for file in os.listdir(f'{root_folder}/{c}/{r}'):
            
            if file.startswith('snapshot_data'):
        
                fname = f'{root_folder}/{c}/{r}/{file}'

                with open(fname, 'rb') as f:
                    d_update = pkl.load(f)
                
                d.update(d_update)
        
        steps_range = sorted([x for x in d.keys() if (x is not None) and (max_steps >= x)])
        if (len(steps_range) > 0) and (max_steps == steps_range[-1]):
            
            print(c, r, sorted([x for x in d.keys() if (x is not None)])[-1])
            
            s_min = min(s_min, steps_range[-1])
            
            x_test = d[None]['x_test']

            d_modified = {
                'x_test': x_test,
                'y_test': d[None]['y_test'],
                'steps': steps_range,
                'res_mean': [d[k]['residue_test_mean'] for k in steps_range],
                'err_mean': [d[k]['error_test_mean'] for k in steps_range],
                'res': [d[k]['residue_test'] for k in steps_range],
                'err': [d[k]['error_test'] for k in steps_range],
                'pred': [d[k]['pred_test'] if 'pred_test' in d[k].keys() 
                         else net.apply(d[k]['params'][0], x_test)
                         for k in steps_range],
                'chosen_pts': [d[k]['al_intermediate']['chosen_pts'] for k in steps_range],
                'inv': [d[k]['params'][1] for k in steps_range],
            }
            
            arr_shape = [d_modified['y_test'].shape[1]] + [np.unique(x).shape[0] for x in d_modified['x_test'].T]
            d_modified['y_test_fft'] = np.fft.fftn(
                d_modified['y_test'].reshape(*arr_shape), 
                axes=[1, 2]
            )
            d_modified['pred_fft'] = [np.fft.fftn(
                y.reshape(*arr_shape), axes=[1, 2]) 
                for y in d_modified['pred']]
            
            d_modified['fft_err'] = [np.abs(yf - d_modified['y_test_fft'])
                for yf in d_modified['pred_fft']]
            
#             idxs = np.meshgrid(np.arange(arr_shape[1]), np.arange(arr_shape[2]))[0].T
#             idxs = np.array([idxs, idxs])
            
            idxs = np.array(np.meshgrid(np.arange(arr_shape[1]), np.arange(arr_shape[2]))).swapaxes(1, 2)
            
            klow = (idxs <= 4).all(axis=0).astype(float)
            kmid = (idxs <= 12).all(axis=0).astype(float) - klow
            khigh = (idxs <= np.inf).all(axis=0).astype(float) - kmid - klow
            
            for s, k in [('low', klow), ('mid', kmid), ('high', khigh)]:
                d_modified[f'fft_mean_{s}'] = [np.sum(yf * k[None, :]) / (np.sum(k) * yf.shape[0])
                    for yf in d_modified['fft_err']]

            runs.append(d_modified)
        
    data[c] = runs
    steps_min[c] = s_min

In [None]:
graph_folder = os.path.join(data_folder, 'al_pinn_graphs', case_folder)
os.makedirs(graph_folder, exist_ok=True)

In [None]:
for _, k in [('low', klow), ('mid', kmid), ('high', khigh)]:
    plt.imshow(k.T)
    plt.show()

In [None]:
def contour_on_ax(ax, xs, zs, levels, res=200):
    xi, yi = [np.linspace(np.min(xs[:,i]), np.max(xs[:,i]), res) for i in range(2)]
    grid = np.meshgrid(xi, yi)
    triang = tri.Triangulation(xs[:,0], xs[:,1])
    interpolator = tri.LinearTriInterpolator(triang, zs)
    Xi, Yi = np.meshgrid(xi, yi)
    zi = interpolator(Xi, Yi)
    cb = ax.contourf(xi, yi, zi, levels=levels, cmap="RdBu_r")
    return cb


def plot_contours(xs, ys_list, titles, res=200):
    fig, axs = plt.subplots(
        nrows=1, 
        ncols=len(ys_list), 
        sharex=True, 
        sharey=True, 
        figsize=(4 * (len(ys_list) + 1), 4)
    )
    levels = np.linspace(np.min(ys_list), np.max(ys_list), num=res)
    for ax, zs, title in zip(axs, ys_list, titles):
        cb = contour_on_ax(ax, xs, zs, levels, res)
        ax.set_title(title)
    axs = axs.ravel().tolist()
    fig.colorbar(cb, ax=axs)
    return fig, axs

In [None]:
def plot_training_data(ax, samples):
    ms = 2.
    ax.plot(samples['res'][:, 0], samples['res'][:, 1], 'o', color='black', ms=ms)
    if 'anc' in samples.keys():
        ax.plot(samples['anc'][:, 0], samples['anc'][:, 1], '^', color='blue', ms=ms)
    for i, bc_pts in enumerate(samples['bcs']):
        ax.plot(bc_pts[:, 0], bc_pts[:, 1], 's', color=f'C{i+1}', ms=ms)

In [None]:
fig, ax = plt.subplots(figsize=(5, 4))
for c in cases.keys():
    label, marker = algs[c]
    min_idx = np.argmin([x['err_mean'][-1] for x in data[c]])
    ax.semilogy(data[c][min_idx]['steps'], 
                data[c][min_idx]['err_mean'], 
                marker, label=label)
ax.legend()
ax.set_xlabel('Steps')
ax.set_ylabel('Mean error')
fig.tight_layout()
for ext in ['pdf', 'png']:
    fig.savefig(os.path.join(graph_folder, f'err_mean.{ext}'))

In [None]:
fig, ax = plt.subplots(figsize=(5, 4))
for c in cases.keys():
    
    ys = [y['err_mean'] for y in data[c]]
    mean = np.mean(ys, axis=0)
    err = np.std(ys, axis=0)
    
    label, marker = algs[c]
    
    ax.set_yscale("log")
    ax.errorbar(data[c][0]['steps'], mean, fmt=marker, capsize=2, label=label, alpha=0.7)
    
ax.legend()
ax.set_xlabel('Steps')
ax.set_ylabel('Mean error')
fig.tight_layout()
for ext in ['pdf', 'png']:
    fig.savefig(os.path.join(graph_folder, f'err_mean_avg.{ext}'))

In [None]:
fig, ax = plt.subplots(figsize=(5, 4))
for c in cases.keys():
    
    ys = [y['err_mean'] for y in data[c]]
    mean = np.mean(ys, axis=0)
    err = np.std(ys, axis=0)
    
    label, marker = algs[c]
    
    ax.set_yscale("log")
    ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], 
                fmt=marker, capsize=2, label=label, alpha=0.7)
    
ax.legend()
ax.set_xlabel('Steps')
ax.set_ylabel('Mean error')
fig.tight_layout()
for ext in ['pdf', 'png']:
    fig.savefig(os.path.join(graph_folder, f'err_mean_bar.{ext}'))

In [None]:
fig, ax = plt.subplots(figsize=(5, 4))
for c in cases.keys():
    label, marker = algs[c]
    ax.semilogy(
        data[c][0]['steps'], 
        [jnp.sqrt(jnp.mean(e**2)) for e in data[c][np.argmin([x['err_mean'][-1] for x in data[c]])]['err']], 
        marker, label=label
    )
ax.legend()
ax.set_xlabel('Steps')
ax.set_ylabel('RMSE')
fig.tight_layout()
for ext in ['pdf', 'png']:
    fig.savefig(os.path.join(graph_folder, f'rmse.{ext}'))

In [None]:
fig, ax = plt.subplots(figsize=(5, 4))
for c in cases.keys():
    label, marker = algs[c]
    min_idx = np.argmin([x['res_mean'][-1] for x in data[c]])
    ax.semilogy(data[c][min_idx]['steps'], 
                data[c][min_idx]['res_mean'], 
                marker, label=label)
ax.legend()
ax.set_xlabel('Steps')
ax.set_ylabel('Mean residue')
fig.tight_layout()
for ext in ['pdf', 'png']:
    fig.savefig(os.path.join(graph_folder, f'res_mean.{ext}'))

In [None]:
fig, ax = plt.subplots(figsize=(5, 4))
for c in cases.keys():
    
    ys = [y['res_mean'] for y in data[c]]
    mean = np.mean(ys, axis=0)
    err = np.std(ys, axis=0)
    
    label, marker = algs[c]
    
    ax.set_yscale("log")
    ax.errorbar(data[c][0]['steps'], mean, fmt=marker, capsize=2, label=label, alpha=0.7)
    
ax.legend()
ax.set_xlabel('Steps')
ax.set_ylabel('Mean residue')
fig.tight_layout()
for ext in ['pdf', 'png']:
    fig.savefig(os.path.join(graph_folder, f'res_mean_avg.{ext}'))

In [None]:
fig, ax = plt.subplots(figsize=(5, 4))
for c in cases.keys():
    
    ys = [y['res_mean'] for y in data[c]]
    mean = np.mean(ys, axis=0)
    err = np.std(ys, axis=0)
    
    label, marker = algs[c]
    
    ax.set_yscale("log")
    ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err],
                fmt=marker, capsize=2, label=label, alpha=0.7)
    
ax.legend()
ax.set_xlabel('Steps')
ax.set_ylabel('Mean residue')
fig.tight_layout()
for ext in ['pdf', 'png']:
    fig.savefig(os.path.join(graph_folder, f'res_mean_bar.{ext}'))

In [None]:
for s in ['low', 'mid', 'high']:

    fig, ax = plt.subplots(figsize=(5, 4))
    for c in cases.keys():

        label, marker = algs[c]
        min_idx = np.argmin([x['err_mean'][-1] for x in data[c]])
        ax.semilogy(data[c][min_idx]['steps'], 
                    data[c][min_idx][f'fft_mean_{s}'], 
                    marker, label=label)
    ax.legend()
    ax.set_xlabel('Steps')
    ax.set_ylabel(f'Mean FFT ({s}) diff.')
    fig.tight_layout()
    for ext in ['pdf', 'png']:
        fig.savefig(os.path.join(graph_folder, f'fft-{s}_mean.{ext}'))

In [None]:
for s in ['low', 'mid', 'high']:
        
    fig, ax = plt.subplots(figsize=(5, 4))
    for c in cases.keys():

        ys = [y[f'fft_mean_{s}'] for y in data[c]]
        mean = np.mean(ys, axis=0)
        err = np.std(ys, axis=0)

        label, marker = algs[c]

        ax.set_yscale("log")
        ax.errorbar(data[c][0]['steps'], mean, fmt=marker, capsize=2, label=label, alpha=0.7)

    ax.legend()
    ax.set_xlabel('Steps')
    ax.set_ylabel(f'Mean FFT ({s}) diff.')
    fig.tight_layout()
    for ext in ['pdf', 'png']:
        fig.savefig(os.path.join(graph_folder, f'fft-{s}_mean_avg.{ext}'))

In [None]:
for s in ['low', 'mid', 'high']:
        
    fig, ax = plt.subplots(figsize=(5, 4))
    for c in cases.keys():

        ys = [y[f'fft_mean_{s}'] for y in data[c]]
        mean = np.mean(ys, axis=0)
        err = np.std(ys, axis=0)

        label, marker = algs[c]

        ax.set_yscale("log")
        ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], 
                    fmt=marker, capsize=2, label=label, alpha=0.7)

    ax.legend()
    ax.set_xlabel('Steps')
    ax.set_ylabel(f'Mean FFT ({s}) diff.')
    fig.tight_layout()
    for ext in ['pdf', 'png']:
        fig.savefig(os.path.join(graph_folder, f'fft-{s}_mean_bar.{ext}'))

In [None]:
if '_change/' in case_folder:
    start_y = [data['random'][0]['pred'][0][:,0], data['random'][0]['y_test'][:,0]]
    start_title = ['Initial model', 'True solution']
else:
    start_y = [data['random'][0]['y_test'][:,0]]
    start_title = ['True solution']
    

fig, axs = plot_contours(
    xs=data['random'][0]['x_test'], 
    ys_list=start_y + [data[c][np.argmin([x['err_mean'][-1] for x in data[c]])]['pred'][-1][:,0] for c in cases.keys()], 
    titles=start_title + [algs[c][0] for c in cases.keys()], 
)

fig.savefig(os.path.join(graph_folder, f'pred_s{data[c][0]["steps"][-1]}.png'))

In [None]:
fig, axs = plot_contours(
    xs=data['random'][0]['x_test'], 
    ys_list=[data[c][np.argmin([x['err_mean'][-1] for x in data[c]])]['err'][-1][:,0] for c in cases.keys()],
    titles=[algs[c][0] for c in cases.keys()], 
)

fig.savefig(os.path.join(graph_folder, f'err_s{data[c][0]["steps"][-1]}.png'))

In [None]:
fig, axs = plot_contours(
    xs=data['random'][0]['x_test'], 
    ys_list=[data[c][np.argmin([x['err_mean'][-1] for x in data[c]])]['res'][-1][:,0] for c in cases.keys()],
    titles=[algs[c][0] for c in cases.keys()], 
)

fig.savefig(os.path.join(graph_folder, f'res_s{data[c][0]["steps"][-1]}.png'))

In [None]:
for c in cases.keys():

    print(algs[c][0])
    
    min_idx = np.argmin([x['err_mean'][-1] for x in data[c]])
    steps = [data[c][min_idx]['steps'].index(s) for s in steps_plot]
    
    fig, axs = plot_contours(
        xs=data['random'][0]['x_test'], 
        ys_list=(
            [data[c][min_idx]['y_test'][:,0]] +
            [data[c][min_idx]['pred'][s][:,0] for s in steps]
        ), 
        titles=(
            ['True solution'] +
            [f'Step {s}' for s in steps_plot]
        ), 
    )

    for ax, s in zip(axs[1:], steps):
        plot_training_data(ax, data[c][0]['chosen_pts'][s])
        
    fig.savefig(os.path.join(graph_folder, f'data_pred_{algs[c][0]}.png'))
    plt.show()

In [None]:
for c in cases.keys():
    
    print(algs[c][0])
    
    min_idx = np.argmin([x['err_mean'][-1] for x in data[c]])
    steps = [data[c][min_idx]['steps'].index(s) for s in steps_plot]

    fig, axs = plot_contours(
        xs=data[c][min_idx]['x_test'], 
        ys_list=[data[c][min_idx]['res'][s][:,0] for s in steps], 
        titles=[f'Step {data[c][0]["steps"][s]}' for s in steps], 
    )

    for ax, s in zip(axs, steps):
        plot_training_data(ax, data[c][0]['chosen_pts'][s])

    fig.savefig(os.path.join(graph_folder, f'data_res_{algs[c][0]}.png'))
    plt.show()

In [None]:
import torch.nn.functional as F
import torch

# FftLoss Function  
class FftMseLoss(object):
    """
    loss function in Fourier space

    June 2022, F.Alesiani
    """
    def __init__(self, reduction='mean'):
        super(FftMseLoss, self).__init__()
        #Dimension and Lp-norm type are postive
        self.reduction = reduction
    def __call__(self, x, y, flow=None, fhigh=None, eps=1e-20):
        x = torch.tensor(x)
        y = torch.tensor(y)
        num_examples = x.size()[0]
        others_dims = x.shape[1:]
        for d in others_dims:
            assert (d>1), "we expect the dimension to be the same and greater the 1"
        print(others_dims)
        dims = list(range(1,len(x.shape)-1))
        xf = torch.fft.fftn(x,dim=dims)
        yf = torch.fft.fftn(y,dim=dims)
        if flow is None: flow = 0
        if fhigh is None: fhigh = np.max(xf.shape[1:])
            
#         return xf, yf
#         print(xf.shape, yf.shape)

        if len(others_dims) ==1:
            xf = xf[:,flow:fhigh]
            yf = yf[:,flow:fhigh]        
        if len(others_dims) ==2:
            xf = xf[:,flow:fhigh,flow:fhigh]
            yf = yf[:,flow:fhigh,flow:fhigh]
        if len(others_dims) ==3:
            xf = xf[:,flow:fhigh,flow:fhigh,flow:fhigh]
            yf = yf[:,flow:fhigh,flow:fhigh,flow:fhigh]
        if len(others_dims) ==4:
            xf = xf[:,flow:fhigh,flow:fhigh,flow:fhigh,flow:fhigh]
            yf = yf[:,flow:fhigh,flow:fhigh,flow:fhigh,flow:fhigh]
        _diff = xf - yf
        _diff = _diff.reshape(num_examples,-1).abs()**2
        print(_diff)
        if self.reduction in ['mean']:
            return torch.mean(_diff).abs()
        if self.reduction in ['sum']:
            return torch.sum(_diff).abs()
        return _diff.abs()

In [None]:
fftmseloss_fn = FftMseLoss(reduction="mean")

In [None]:
u0 = np.array(data['random'][-1]['y_test'].reshape(1, 1024, -1))
x = np.array(data['greedy_nystrom_wo_N_scale-none_mem_autoal'][0]['pred'][-1].reshape(1, 1024, -1))

fmid = u0.shape[2]//4
fftmseloss_low_u0 = fftmseloss_fn(u0, x, 0, fmid).item()
fftmseloss_mid_u0 = fftmseloss_fn(u0, x, fmid, 2*fmid).item()
fftmseloss_hi_u0 = fftmseloss_fn(u0, x, 2*fmid).item()

In [None]:
x = np.array(data['greedy_nystrom_wo_N_scale-none_mem_autoal'][0]['pred'][-1].reshape(1, 1024, -1))
# x = torch.tensor(x)

In [None]:
y = torch.fft.fftn(x, dim=[1, 2])
z = torch.fft.fftn(y, dim=[1, 2])

In [None]:
plt.imshow(x[0].T)
plt.colorbar()

In [None]:
plt.imshow(z[0].real.T)
plt.colorbar()

In [None]:
plt.imshow(y[0].real.T)

In [None]:
z[0].real / x[0]

In [None]:
x = np.array(data['greedy_nystrom_wo_N_scale-none_mem_autoal'][0]['pred'][-1].reshape(1, 1024, -1))
y = np.fft.fftn(x, axes=[1, 2])

y[:, 100:] = 0.
y[:, :, 186:] = 0.

# y[:-50] = 0.
# y[:, :-20] = 0.

z = np.fft.ifftn(y, axes=[1, 2])

In [None]:
z.shape

In [None]:
plt.imshow(x[0].T)
# plt.colorbar()

In [None]:
plt.imshow(z[0].real.T)
# plt.colorbar()

In [None]:
plt.imshow(y[0, :30].imag.T)