In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from functools import partial
import os
import pickle as pkl
from collections.abc import MutableMapping

import matplotlib.pyplot as plt
import matplotlib.tri as tri
import matplotlib.ticker as ticker
import numpy as np
import tqdm

os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["DDE_BACKEND"] = "jax"

# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".XX"
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"

from jax import config
config.update("jax_enable_x64", True)
# config.update("jax_debug_nans", True)

import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
import optax

try:
    print(f'Jax: CPUs={jax.local_device_count("cpu")} - GPUs={jax.local_device_count("gpu")}')
except:
    pass
    
import deepxde_al_patch.deepxde as dde

from deepxde_al_patch.model_loader import construct_model, construct_net
from deepxde_al_patch.modified_train_loop import ModifiedTrainLoop
from deepxde_al_patch.plotters import plot_residue_loss, plot_error, plot_prediction
from deepxde_al_patch.train_set_loader import load_data

from deepxde_al_patch.ntk import NTKHelper
from deepxde_al_patch.utils import get_pde_residue, print_dict_structure

In [None]:
plt.rcParams['figure.dpi'] = 150
plt.rcParams['font.size'] = 12
plt.rcParams["figure.titlesize"] = 22
plt.rcParams['text.usetex'] = False

In [None]:
graph_root = 'al_pinn_graphs'

In [None]:
data_folder = '../../al_pinn_results_timing/conv-1d{5.0}_ic/nn-laaf-6-64_adam_bcsloss-1.0_budget-'
algs = {
    '10000-50-0-random_Hammersley_prop-0.8': ('Hamm (10k)', dict(c='darkgrey', ls='--', marker='v')),
    '100000-50-0-random_Hammersley_prop-0.8': ('Hamm (100k)', dict(c='black', ls='--', marker='^')),
    '10000-50-0-residue_prop-0.8_alltype': ('RAD-All (10k)', dict(c='orange', ls=':', marker='p')),
#     '1000-50-0-residue_prop-0.8_alltype_unlimcolloc': ('RAR-D-All (1k)', dict(c='red', ls=':', marker='h')),
    '1000-200-0-sampling_alignment_scale-none_mem_autoal': ('PINNACLE-S (1k)', dict(c='green', ls='-', marker='s')),
    '1000-200-0-kmeans_alignment_scale-none_mem_autoal': ('PINNACLE-K (1k)', dict(c='blue', ls='-', marker='o')),
}
cutoff = 0.01

print(data_folder)

cases = {x: os.listdir(f'{data_folder}{x}') for x in algs.keys() if os.path.exists(f'{data_folder}{x}')}

data = dict()

for c in cases.keys():

    data[c] = []

    for k in cases[c]:

        try:
            with open(f'{data_folder}{c}/{k}/timing.pkl', 'rb') as f:
                d = pkl.load(f)
        except FileNotFoundError:
            continue
        d = np.array(d)
        j = 0
        for i in range(d.shape[0]):
            if d[i,2] < cutoff:
                j += 1
                if j == 2:
                    break
            else:
                j = 0
        d[:,1] = d[:,1] / 60.
        d = d[:i+1]
#             if (d[-1,1] <= 180.) and (d[-1,2] >= cutoff):
#                 continue
        data[c].append(d)
        if len(data[c]) == 5:
            break

fig, ax = plt.subplots()

for a in data.keys():
    if len(data[a]) > 0:
        print(a, len(data[a]))
        best_idx = 0
        best_t = float('inf')
        vals = {i: d[-1,1] if d[-1,2] < cutoff else 200. + d[-1,2]
                for i, d in enumerate(data[a])}
        best_idx = sorted(vals.keys(), key=lambda k: vals[k])[0]
        second_best_idx = sorted(vals.keys(), key=lambda k: vals[k])[1]
        for i, d in enumerate(data[a]):
            if i == best_idx:
                ax.semilogy(d[:,1],d[:,2], algs[a][1]['marker'] + '-', 
                            color=algs[a][1]['c'], label=algs[a][0], 
                            alpha=0.9, markerfacecolor='none', lw=1, ms=5)
            elif i == second_best_idx:
                ax.semilogy(d[:,1],d[:,2], algs[a][1]['marker'] + '-', 
                            color=algs[a][1]['c'], 
                            alpha=0.2, markerfacecolor='none', lw=1, ms=5)
            else:
                ax.semilogy(d[:,1],d[:,2], algs[a][1]['marker'] + '-', 
                            color=algs[a][1]['c'], 
                            alpha=0.05, markerfacecolor='none', lw=1, ms=5)

ax.axhline(cutoff, linestyle='--', color='darkgrey', zorder=1)
ax.set_xlabel('Time (min)')
ax.set_ylabel('Mean error')
ax.set_xticks(range(0, 181, 30))
ax.set_xlim(0, None)
#     ax.set_ylim(0.1 * cutoff, None)
ax.grid(alpha=0.2)
ax.legend(ncols=1, loc='lower left')
fig.tight_layout()
fig.savefig('../../al_pinn_graphs_final/timing_forward_conv.pdf')

In [None]:
data_folder = '../../al_pinn_results_timing/kdv-1d{1.0-0.0}_ic/nn-None-4-32_adam_bcsloss-1.0_budget-'
algs = {
    '300-50-0-random_Hammersley_prop-0.8': ('Hamm (300)', dict(c='black', ls='--', marker='^')),
    '300-50-0-residue_prop-0.8_alltype': ('RAD-All (300)', dict(c='orange', ls=':', marker='p')),
#     '300-50-0-residue_prop-0.8_alltype_unlimcolloc': ('RAR-D-All (300)', dict(c='red', ls=':', marker='h')),
    '300-100-0-sampling_alignment_scale-none_mem_autoal': ('PINNACLE-S (300)', dict(c='green', ls='-', marker='s')),
    '300-100-0-kmeans_alignment_scale-none_mem_autoal': ('PINNACLE-K (300)', dict(c='blue', ls='-', marker='o')),
}
cutoff = 0.01

print(data_folder)

cases = {x: os.listdir(f'{data_folder}{x}') for x in algs.keys() if os.path.exists(f'{data_folder}{x}')}

data = dict()

for c in cases.keys():

    data[c] = []

    for k in cases[c]:

        try:
            with open(f'{data_folder}{c}/{k}/timing.pkl', 'rb') as f:
                d = pkl.load(f)
        except FileNotFoundError:
            continue
        d = np.array(d)
        j = 0
        for i in range(d.shape[0]):
            if d[i,2] < cutoff:
                j += 1
                if j == 2:
                    break
            else:
                j = 0
            if d[i,1] > 3600:
                break
        d[:,1] = d[:,1] / 60.
        d = d[:i+1]
#             if (d[-1,1] <= 180.) and (d[-1,2] >= cutoff):
#                 continue
        data[c].append(d)
        if len(data[c]) == 5:
            break

fig, ax = plt.subplots()

for a in data.keys():
    if len(data[a]) > 0:
        print(a, len(data[a]))
        best_idx = 0
        best_t = float('inf')
        vals = {i: d[-1,1] if d[-1,2] < cutoff else 200. + d[-1,2]
                for i, d in enumerate(data[a])}
        best_idx = sorted(vals.keys(), key=lambda k: vals[k])[0]
#         second_best_idx = sorted(vals.keys(), key=lambda k: vals[k])[1]
        for i, d in enumerate(data[a]):
            if i == best_idx:
                ax.semilogy(d[:,1],d[:,2], algs[a][1]['marker'] + '-', 
                            color=algs[a][1]['c'], label=algs[a][0], 
                            alpha=0.9, markerfacecolor='none', lw=1, ms=5)
#             elif i == second_best_idx:
#                 ax.semilogy(d[:,1],d[:,2], algs[a][1]['marker'] + '-', 
#                             color=algs[a][1]['c'], 
#                             alpha=0.2, markerfacecolor='none', lw=1, ms=5)
            else:
                ax.semilogy(d[:,1],d[:,2], algs[a][1]['marker'] + '-', 
                            color=algs[a][1]['c'], 
                            alpha=0.05, markerfacecolor='none', lw=1, ms=5)

ax.axhline(cutoff, linestyle='--', color='darkgrey', zorder=1)
ax.set_xlabel('Time (min)')
ax.set_ylabel('Mean error')
ax.set_xticks(range(0, 61, 30))
ax.set_xlim(0, None)
#     ax.set_ylim(0.1 * cutoff, None)
ax.grid(alpha=0.2)
ax.legend(ncols=1, loc='lower left')
fig.tight_layout()
fig.savefig('../../al_pinn_graphs_final/timing_forward_kdv.pdf')

In [None]:
data_folder = '../../al_pinn_results_timing/fd-2d{1.0-0.01}_inv_anc[0,1]/nn-laaf-6-64_adam_bcsloss-1.0_budget-'
algs = {
    '10000-50-30-random_Hammersley_prop-0.8': ('Hamm (10k)', dict(c='black', ls='--', marker='v')),
#     '1000-50-30-residue_prop-0.8_alltype_unlimcolloc': ('RAR-D-All (1k)', dict(c='red', ls=':', marker='h')),
    '1000-200-30-sampling_alignment_scale-none_mem_autoal': ('PINNACLE-S (1k)', dict(c='green', ls='-', marker='s')),
    '1000-200-30-kmeans_alignment_scale-none_mem_autoal': ('PINNACLE-K (1k)', dict(c='blue', ls='-', marker='o')),
}
c1 = 0.05
c2 = 0.0005

cases = {x: os.listdir(f'{data_folder}{x}') for x in algs.keys() if os.path.exists(f'{data_folder}{x}')}

data = dict()


for c in cases.keys():

    data[c] = []

    for k in cases[c]:

        try:
            with open(f'{data_folder}{c}/{k}/timing.pkl', 'rb') as f:
                d = pkl.load(f)
        except FileNotFoundError:
            continue
        except EOFError:
            continue
        d = np.array(d)
        if d.shape[1] < 7:
            continue
        j = 0
        for i in range(d.shape[0]):
            if d[i,3] < c1 and d[i,4] < c2:
                j += 1
                if j == 5:
                    break
            else:
                j = 0
        d[:,1] = d[:,1] / 60.
        d = d[:i+1]
#             if (d[-1,1] <= 180.) and (d[-1,2] >= cutoff):
#                 continue
        data[c].append(d)
        if len(data[c]) == 5:
            break
            
for k, (v, idx, cutoff) in enumerate([[1., 5, c1], [0.01, 6, c2]]):

    fig, ax = plt.subplots()

    for a in data.keys():
        if len(data[a]) > 0:
            print(a, len(data[a]))
            best_idx = 0
            best_t = float('inf')
            vals = {i: d[-1,1] if d[-1,4] < c2 else 90. + d[-1,2]
                    for i, d in enumerate(data[a])}
            best_idx = sorted(vals.keys(), key=lambda k: vals[k])[0]
            second_best_idx = sorted(vals.keys(), key=lambda k: vals[k])[1]
            for i, d in enumerate(data[a]):
                if i == best_idx:
                    ax.plot(d[:,1],d[:,idx], algs[a][1]['marker'] + '-', 
                                color=algs[a][1]['c'], label=algs[a][0], 
                                alpha=0.9, markerfacecolor='none', lw=1, ms=5)
                elif i == second_best_idx:
                    ax.plot(d[:,1],d[:,idx], algs[a][1]['marker'] + '-', 
                                color=algs[a][1]['c'], 
                                alpha=0.2, markerfacecolor='none', lw=1, ms=5)
                else:
                    ax.plot(d[:,1],d[:,idx], algs[a][1]['marker'] + '-', 
                                color=algs[a][1]['c'], 
                                alpha=0.05, markerfacecolor='none', lw=1, ms=5)

    ax.axhline(v + cutoff, linestyle='--', color='darkgrey', zorder=1)
    ax.axhline(v, linestyle='-', color='darkgrey', zorder=1)
    ax.axhline(v - cutoff, linestyle='--', color='darkgrey', zorder=1)
    ax.set_xlabel('Time (min)')
    ax.set_ylabel(f'Inv. param. {k+1}')
    ax.set_xticks(range(0, 91, 30))
    ax.set_xlim(0, None)
    ax.set_ylim(v - 5. * cutoff, v + 5. * cutoff)
    ax.grid(alpha=0.2)
    ax.legend(ncols=1)
    fig.tight_layout()
    fig.savefig(f'../../al_pinn_graphs_final/timing_inv_fd_{k}.pdf')

In [None]:
data_folder = '../../al_pinn_results_timing/eik1-3d{}_anc0/nn-laaf-8-32_adam_bcsloss-1.0_budget-'
algs = {
    '1000-50-5-random_Hammersley_prop-0.8': ('Hamm (1k)', dict(c='darkgrey', ls='--', marker='v')),
    '10000-50-5-random_Hammersley_prop-0.8': ('Hamm (10k)', dict(c='black', ls='--', marker='^')),
    '500-50-5-residue_prop-0.8_alltype': ('RAD-All (500)', dict(c='orange', ls='.-', marker='p')),
#     '500-50-5-residue_prop-0.8_alltype_unlimcolloc': ('RAR-D-All (500)', dict(c='pink', ls=':', marker='*')),
#     '1000-50-5-residue_prop-0.8_alltype_unlimcolloc': ('RAR-D-All (1k)', dict(c='red', ls=':', marker='h')),
    '500-100-5-sampling_alignment_scale-none_mem_autoal': ('PINNACLE-S (500)', dict(c='green', ls='-', marker='s')),
    '500-100-5-kmeans_alignment_scale-none_mem_autoal': ('PINNACLE-K (500)', dict(c='blue', ls='-', marker='o')),
}
# cutoff = 0.1

cases = {x: os.listdir(f'{data_folder}{x}') for x in algs.keys() if os.path.exists(f'{data_folder}{x}')}

data = dict()

for c in cases.keys():

    data[c] = []

    for k in cases[c]:

        try:
            with open(f'{data_folder}{c}/{k}/timing.pkl', 'rb') as f:
                d = pkl.load(f)
        except FileNotFoundError:
            continue
        d = np.array([[p[0], p[1], np.mean((p[4][:,1] - p[5][:,1])**2)**0.5] for p in d])
        j = 0
        for i in range(d.shape[0]):
            if d[i,2] < cutoff:
                j += 1
                if j == 2:
                    break
            else:
                j = 0
        d[:,1] = d[:,1] / 60.
        d = d[:i+1]
#             if (d[-1,1] <= 180.) and (d[-1,2] >= cutoff):
#                 continue
        data[c].append(d)
        if len(data[c]) == 5:
            break

fig, ax = plt.subplots()

for a in data.keys():
    if len(data[a]) > 0:
        print(a, len(data[a]))
        best_idx = 0
        best_t = float('inf')
        vals = {i: d[-1,1] if d[-1,2] < cutoff else 200. + d[-1,2]
                for i, d in enumerate(data[a])}
        best_idx = sorted(vals.keys(), key=lambda k: vals[k])[0]
        second_best_idx = sorted(vals.keys(), key=lambda k: vals[k])[1]
        for i, d in enumerate(data[a]):
            if i == best_idx:
                ax.semilogy(d[:,1],d[:,2], algs[a][1]['marker'] + '-', 
                            color=algs[a][1]['c'], label=algs[a][0], 
                            alpha=0.9, markerfacecolor='none', lw=1, ms=5)
            elif i == second_best_idx:
                ax.semilogy(d[:,1],d[:,2], algs[a][1]['marker'] + '-', 
                            color=algs[a][1]['c'], 
                            alpha=0.2, markerfacecolor='none', lw=1, ms=5)
            else:
                ax.semilogy(d[:,1],d[:,2], algs[a][1]['marker'] + '-', 
                            color=algs[a][1]['c'], 
                            alpha=0.05, markerfacecolor='none', lw=1, ms=5)

# ax.axhline(cutoff, linestyle='--', color='darkgrey', zorder=1)
ax.set_xlabel('Time (min)')
ax.set_ylabel('Mean error')
ax.set_xticks(range(0, 181, 30))
ax.set_xlim(0, None)
#     ax.set_ylim(0.1 * cutoff, None)
ax.grid(alpha=0.2)
ax.legend(ncols=1, loc='lower left')
fig.tight_layout()
fig.savefig('../../al_pinn_graphs_final/timing_inv_eik.pdf')

In [None]:
# data_folder = '../../al_pinn_results_timing/eik1-3d{}_anc0/nn-laaf-8-32_adam_bcsloss-1.0_budget-'
# algs = {
#     '1000-50-5-random_Hammersley_prop-0.8': ('Hamm (1k)', dict(c='darkgrey', ls='--', marker='v')),
#     '10000-50-5-random_Hammersley_prop-0.8': ('Hamm (10k)', dict(c='black', ls='--', marker='^')),
#     '500-50-5-residue_prop-0.8_alltype': ('RAD-All (500)', dict(c='orange', ls='.-', marker='p')),
# #     '500-50-5-residue_prop-0.8_alltype_unlimcolloc': ('RAR-D-All (500)', dict(c='pink', ls=':', marker='*')),
# #     '1000-50-5-residue_prop-0.8_alltype_unlimcolloc': ('RAR-D-All (1k)', dict(c='red', ls=':', marker='h')),
#     '500-100-5-sampling_alignment_scale-none_mem_autoal': ('PINNACLE-S (500)', dict(c='green', ls='-', marker='s')),
#     '500-100-5-kmeans_alignment_scale-none_mem_autoal': ('PINNACLE-K (500)', dict(c='blue', ls='-', marker='o')),
# }
# # cutoff = 0.1

# cases = {x: os.listdir(f'{data_folder}{x}') for x in algs.keys() if os.path.exists(f'{data_folder}{x}')}

# data = dict()

# for c in cases.keys():

#     data[c] = []

#     for k in cases[c]:

#         try:
#             with open(f'{data_folder}{c}/{k}/timing.pkl', 'rb') as f:
#                 d = pkl.load(f)
#         except FileNotFoundError:
#             continue
#         d = np.array([[p[0], p[1], np.mean((p[4][:,0] - p[5][:,0])**2)**0.5] for p in d])
#         j = 0
#         for i in range(d.shape[0]):
#             if d[i,2] < cutoff:
#                 j += 1
#                 if j == 2:
#                     break
#             else:
#                 j = 0
#         d[:,1] = d[:,1] / 60.
#         d = d[:i+1]
# #             if (d[-1,1] <= 180.) and (d[-1,2] >= cutoff):
# #                 continue
#         data[c].append(d)
#         if len(data[c]) == 5:
#             break

# fig, ax = plt.subplots()

# for a in data.keys():
#     if len(data[a]) > 0:
#         print(a, len(data[a]))
#         best_idx = 0
#         best_t = float('inf')
#         vals = {i: d[-1,1] if d[-1,2] < cutoff else 200. + d[-1,2]
#                 for i, d in enumerate(data[a])}
#         best_idx = sorted(vals.keys(), key=lambda k: vals[k])[0]
#         second_best_idx = sorted(vals.keys(), key=lambda k: vals[k])[1]
#         for i, d in enumerate(data[a]):
#             if i == best_idx:
#                 ax.semilogy(d[:,1],d[:,2], algs[a][1]['marker'] + '-', 
#                             color=algs[a][1]['c'], label=algs[a][0], 
#                             alpha=0.9, markerfacecolor='none', lw=1, ms=5)
#             elif i == second_best_idx:
#                 ax.semilogy(d[:,1],d[:,2], algs[a][1]['marker'] + '-', 
#                             color=algs[a][1]['c'], 
#                             alpha=0.2, markerfacecolor='none', lw=1, ms=5)
#             else:
#                 ax.semilogy(d[:,1],d[:,2], algs[a][1]['marker'] + '-', 
#                             color=algs[a][1]['c'], 
#                             alpha=0.05, markerfacecolor='none', lw=1, ms=5)

# # ax.axhline(cutoff, linestyle='--', color='darkgrey', zorder=1)
# ax.set_xlabel('Time (min)')
# ax.set_ylabel('Mean error')
# ax.set_xticks(range(0, 181, 30))
# ax.set_xlim(0, None)
# #     ax.set_ylim(0.1 * cutoff, None)
# ax.grid(alpha=0.2)
# ax.legend(ncols=1, loc='lower left')
# fig.tight_layout()
# # fig.savefig('../../al_pinn_graphs_final/timing_inv_eik.pdf')

In [None]:
x = list(cases.keys())[0]

In [None]:
f'{data_folder}{x}'

In [None]:
plt.imshow(d['pred_test'].reshape(200, 200))
plt.colorbar()

In [None]:
d