In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from functools import partial
import os
import pickle as pkl
from collections.abc import MutableMapping
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import tqdm

os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["DDE_BACKEND"] = "jax"

# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".XX"
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"

from jax import config
config.update("jax_enable_x64", True)
# config.update("jax_debug_nans", True)

import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
import optax

try:
    print(f'Jax: CPUs={jax.local_device_count("cpu")} - GPUs={jax.local_device_count("gpu")}')
except:
    pass
    
import deepxde_al_patch.deepxde as dde

from deepxde_al_patch.model_loader import construct_model
from deepxde_al_patch.modified_train_loop import ModifiedTrainLoop
from deepxde_al_patch.plotters import plot_residue_loss, plot_error, plot_prediction
from deepxde_al_patch.train_set_loader import load_data

from deepxde_al_patch.ntk import NTKHelper
from deepxde_al_patch.utils import get_pde_residue, print_dict_structure

In [None]:
import pickle as pkl
from sklearn.cluster._kmeans import kmeans_plusplus

In [None]:
plt.rcParams['figure.figsize'] = (4, 4)
plt.rcParams['figure.dpi'] = 300

plt.rcParams.update({
    'font.size': 14,
    'text.usetex': False,
})

In [None]:
# with open('../../al_pinn_results/burgers-1d{0.02}_pb-20_ic/nn-None-4-128_adam_bcsloss-1.0_budget-300-100-0/kmeans_alignment_scale-none_mem_autoal/20230823161701/snapshot_data_s50000.pkl', 'rb') as f:
with open('../../al_pinn_results/burgers-1d{0.02}_pb-20_ic/nn-None-4-128_adam_bcsloss-1.0_budget-300-100-0/kmeans_alignment_scale-none_mem_autoal/20230823161701/snapshot_data_s100000.pkl', 'rb') as f:
    dd = pkl.load(f)

In [None]:
idx = 100000

xs = dd[idx]['al_intermediate']['P'].T[:,::-1]
xs_norm = np.linalg.norm(xs, axis=1)
ranking = np.argsort(-xs_norm)

xs = np.array(xs[ranking])
xs_norm = np.array(xs_norm[ranking])

In [None]:
n = 30

label = 'kmeans'
idx_sel = np.sort(kmeans_plusplus(xs, n)[1])

# label = 'sampling'
# idx_sel = np.sort(np.random.choice(xs.shape[0], n, replace=False, p=xs_norm**2 / sum(xs_norm**2)))

# label = 'greedy'
# idx_sel = np.arange(n)

In [None]:
for i1, i2 in [(1, 2), (1, 100), (1, 200), (100, 500)]:

    plt.plot(xs[:,i1-1], xs[:,i2-1], '.', alpha=0.1)
    plt.plot(xs[idx_sel,i1-1], xs[idx_sel,i2-1], '.', alpha=0.9)

    # plt.xscale('symlog')
    # plt.yscale('symlog')

    plt.xlabel(f'Component {i1}')
    plt.ylabel(f'Component {i2}')
    plt.tight_layout()
    plt.savefig(f'../../al_pinn_graphs_final/emb_{label}_d{i1}-{i2}.pdf')
    plt.show()

In [None]:
plt.plot(xs_norm)
plt.plot(idx_sel, xs_norm[idx_sel], '.')
plt.yscale('log')
plt.xlabel(f'Rank of α(z)')
plt.ylabel(f'Value of α(z)')
plt.tight_layout()
plt.savefig(f'../../al_pinn_graphs_final/emb_{label}_norm.pdf')

In [None]:
k = 300

plt.figure(figsize=(6, 3))
plt.plot(
    [np.linalg.norm(np.sum(xs[np.random.choice(xs.shape[0], n)], axis=0)) for _ in range(k)], 
    4 * np.ones(k), 
    'o', alpha=0.5)
plt.plot(
    [np.linalg.norm(np.sum(xs[kmeans_plusplus(xs, n)[1]], axis=0)) for _ in range(k)], 
    3 * np.ones(k), 
    '^', label='K-Means++', alpha=0.5)
plt.plot(
    [np.linalg.norm(np.sum(xs[np.random.choice(xs.shape[0], n, replace=False, p=xs_norm**2 / sum(xs_norm**2))], axis=0)) for _ in range(k)], 
    2 * np.ones(k), 
    's', label='Sampling', alpha=0.5)
plt.plot(
    np.linalg.norm(np.sum(xs[np.arange(n)], axis=0)), 
    [1], 
    'p', label='Greedy', alpha=0.6)

plt.ylim(0.5, 4.5)
plt.yticks(
    ticks = [4, 3, 2, 1],
    labels=['Random', 'K-means++', 'Sampling', 'Greedy'],
)
plt.xlabel('α(Z)')
plt.ylabel('Method to sample Z')

plt.tight_layout()
plt.savefig(f'../../al_pinn_graphs_final/emb_sel_method.pdf')