In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from functools import partial
import os
import pickle as pkl
from collections.abc import MutableMapping
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import tqdm

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["DDE_BACKEND"] = "jax"

# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".XX"
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"

from jax import config
config.update("jax_enable_x64", True)
# config.update("jax_debug_nans", True)

import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
import optax

try:
    print(f'Jax: CPUs={jax.local_device_count("cpu")} - GPUs={jax.local_device_count("gpu")}')
except:
    pass
    
import deepxde_al_patch.deepxde as dde

from deepxde_al_patch.model_loader import construct_model
from deepxde_al_patch.modified_train_loop import ModifiedTrainLoop
from deepxde_al_patch.plotters import plot_residue_loss, plot_error, plot_prediction
from deepxde_al_patch.train_set_loader import load_data

from deepxde_al_patch.ntk import NTKHelper
from deepxde_al_patch.utils import get_pde_residue, print_dict_structure

In [None]:
plt.rcParams['figure.figsize'] = (8,6)
plt.rcParams['figure.dpi'] = 200

plt.rcParams.update({
    'font.size': 12,
    'text.usetex': False,
})

## Setup

In [None]:
inverse_problem = False

model, model_aux = construct_model(
    
    #     # load data - without pdebench
    pde_name='conv-1d', 
    data_seed=40,
    pde_const=(1.,), 
    use_pdebench=True,
    num_domain=2000, 
    num_boundary=500, 
    num_initial=500,
    include_ic=(not inverse_problem),
    data_root='~/pdebench',
    test_max_pts=50000,
    
#     #     # load data - without pdebench
#     pde_name='burgers-1d', 
#     data_seed=20,
#     pde_const=(0.02,), 
#     use_pdebench=True,
# #     inverse_problem=inverse_problem, 
# #     inverse_problem_guess=(0.8,),
#     num_domain=2000, 
#     num_boundary=500, 
#     num_initial=500,
#     include_ic=True,
#     data_root='~/pdebench',
#     test_max_pts=50000,
    
    # model params
    hidden_layers=4, 
    hidden_dim=64, 
    activation='tanh', 
    initializer='Glorot uniform', 
#     arch='pfnn', 

)

### Experiments area


In [None]:
method = 'random'

al_args = dict(
    method='pseudo',
    res_proportion=0.8,
)
    
optim_args = dict(
    train_steps=100000,
    al_every=5000,
    select_anchors_every=100000,
    snapshot_every=1000,
    optim_method='adam', 
    optim_lr=1e-3, 
    optim_args=dict(),
)
    

train_loop = ModifiedTrainLoop(
    model=model, 
    inverse_problem=inverse_problem,
    point_selector_method=method,
    point_selector_args=al_args,
    mem_pts_total_budget=10000,
    anchor_budget=0,
    autoscale_loss_w_bcs=False,
    ntk_ratio_threshold=None,
    tensorboard_plots=False,
    **optim_args,
)

In [None]:
plot_prediction(train_loop, res=200, out_idx=0);

In [None]:
train_loop.train()

In [None]:
fig, _ = train_loop.plot_training_data(20000)
fig

### Visualisation

In [None]:
# Visualisation

# train_loop.plot_training_data(step_idx=0)
train_loop.plot_losses()

In [None]:
steps = [100000]

In [None]:
plot_prediction(train_loop=train_loop, step_idxs=steps, out_idx=0, plot_training_data=False, t_plot=0.);

In [None]:
ntk_fn = NTKHelper(model)

In [None]:
geom = model.data.geom.geometry
timedomain = geom = model.data.geom.timedomain

In [None]:
res = 30
grid = jnp.meshgrid(jnp.linspace(geom.l, geom.r, res), jnp.linspace(timedomain.t0, timedomain.t1, res))
pool_pts = jnp.array(grid).reshape(2, -1).T

ntk = ntk_fn.get_ntk(xs1=pool_pts, code1=0)

In [None]:
i = 20
T = ntk[i].reshape(res, res)

plt.pcolormesh(*grid, T, cmap='RdBu_r')
plt.plot([pool_pts[i, 0]], [pool_pts[i, 1]], 'x', color='black', ms=10)
plt.colorbar()