In [1]:
%env XLA_PYTHON_CLIENT_ALLOCATOR=platform

env: XLA_PYTHON_CLIENT_ALLOCATOR=platform


In [2]:
import jax
import optax
import equinox as eqx
import matplotlib.pyplot as plt

from jax import numpy as jnp, random, config
from model import PINN
from pdekernel import PDEKernel
from derivative import Derivative
from residual import Residual
from sampler import Sampler

key = random.PRNGKey(1312)
key, subkey = random.split(key)

In [3]:
s = Sampler()
x = s.get_sample(key,[5,5,5],[(0,0),(-0.5,0.5),(0,1)])
t = s.get_sample(subkey, [5,5,5],[(0,0),(-0.5,0.5),(0,1)])
y = s.get_sample(subkey, [5,5,5],[(0,0),(-0.5,0.5),(0,1)])
x_sub,t_sub,y_sub = s.subsample(key,5,x,t,y)
print(x_sub)
print(t_sub)
print(y_sub)
inp_ic = {'x':x, 't':t}
inp_bc = {'x':x, 't':t}
inp_colloc = {'x':x, 't':t}
inp_data = {'x':x, 't':t, 'phi':y, 'c':y}
inp = {'ic': inp_ic, 'bc':inp_bc, 'colloc':inp_colloc, 'data':inp_data}

[0.16686788 0.06698012 0.36686788 0.         0.86698012]
[0.20649349 0.08812114 0.40649349 0.         0.88812114]
[0.20649349 0.08812114 0.40649349 0.         0.88812114]


In [4]:
# set up the dimensionless PDE
params = {
    "alpha_phi": 9.62e-5,
    "omega_phi": 1.663e7,
    "A": 5.35e7,
    "c_se": 1.0,
    "c_le": 5100/1.43e5,
    "x_range_phys": (-50e-6, 50e-6),
    "t_range_phys": (0, 1e5),
    "L": 1e-11,
    "M": jnp.float64(8.5e-10) / (2 * jnp.float64(5.35e7)),
    "dc": 1.0-5100/1.43e5
}
pdekernel = PDEKernel(params)

In [5]:
# set up the model
inp_idx = {'x':0, 't':1}
out_idx = {'phi':0, 'c':0}
width = 16
depth = 4
model = PINN(inp_idx, out_idx, width, depth)
x_arr = jnp.arange(5)/10
t_arr = jnp.arange(5)/10
""" 
Note that model accepts three inputs x,t,M
- x,t must have matching dimensions (either scalar or 1D array)
- M is always a scalar
"""
print("test scalar")
print(model(1,2))
print("------------------------")
print("test batched")
print(model(x_arr,t_arr))

test scalar
(Array(0.09886375, dtype=float64), Array(-0.2572521, dtype=float64))
------------------------
test batched
(Array([0.17734853, 0.17461186, 0.17081278, 0.16607631, 0.16056868],      dtype=float64), Array([-0.36224114, -0.35612096, -0.34941203, -0.34216648, -0.33446341],      dtype=float64))


In [6]:
# set up the derivataive
phys_span = {'x':pdekernel.x_range_phys, 't':pdekernel.t_range_phys} # the physical span is now a dimensionless span
norm_span = {'x':(-0.5,0.5), 't':(0,1)}
d = Derivative(inp_idx, out_idx, phys_span, norm_span)
d.create_deriv_fn('phi_t')
d.create_deriv_fn('phi_x')
d.create_deriv_fn('phi_2x')
d.create_deriv_fn('c_t')
d.create_deriv_fn('c_x')
d.create_deriv_fn('c_2x')
print("test evaluate")
print(d.evaluate(model,0.1,0.1,function_names = ['phi_t','c_2x']))
print(d.evaluate(model,x_arr,t_arr,function_names = ['phi_t','c_2x']))

test evaluate
{'phi_t': Array(-1.57404653e-07, dtype=float64), 'c_2x': Array(-3074163.98140331, dtype=float64)}
{'phi_t': Array([-1.00226367e-07, -1.57404653e-07, -2.10276336e-07, -2.56133905e-07,
       -2.92842793e-07], dtype=float64), 'c_2x': Array([-3361116.06241818, -3074163.98140332, -2672613.92849501,
       -2191974.2743584 , -1671860.94259358], dtype=float64)}


In [7]:
# set up the residual
r = Residual(phys_span, norm_span, pdekernel, d)
print("test res_ic")
print(r.res_ic(model,inp['ic']))
print("--------------------------------")
print("test res_bc")
print(r.res_bc(model,inp['bc']))
print("--------------------------------")
print("test res_phys")
print(r.res_phys(model,inp['colloc']))
print("--------------------------------")
print("test res_data")
print(r.res_data(model,inp['data']))

test res_ic
{'ic': Array([-0.32265147, -0.32265147, -0.32265147, -0.32265147, -0.32265147,
       -0.82323762, -0.82059647, -0.6975871 ,  0.17133518,  0.16133215,
        0.15622898,  0.16721631,  0.15588142,  0.14287816,  0.12976466,
       -0.32265147, -0.32265147, -0.32265147, -0.32265147, -0.32265147,
       -0.82323762, -0.82059759, -0.779534  ,  0.17138997,  0.16133215,
        0.17425489,  0.16721647,  0.15588142,  0.14287816,  0.12976466],      dtype=float64)}
--------------------------------
test res_bc
{'bc': Array([ 0.17734853,  0.17734853,  0.17734853,  0.17734853,  0.17734853,
       -0.82323762, -0.82059759, -0.82234213,  0.17138998,  0.16133215,
        0.17533619,  0.16721647,  0.15588142,  0.14287816,  0.12976466,
        0.17734853,  0.17734853,  0.17734853,  0.17734853,  0.17734853,
       -0.82323762, -0.82059759, -0.82234213,  0.17138998,  0.16133215,
        0.17533619,  0.16721647,  0.15588142,  0.14287816,  0.12976466],      dtype=float64)}
---------------------

In [8]:
r.compute_ntk_weights(model, inp, batch_size=2)

{'ic': Array(3.00004629, dtype=float64),
 'bc': Array(3.00004629, dtype=float64),
 'ac': Array(48450776.86323194, dtype=float64),
 'ch': Array(64901.55268937, dtype=float64),
 'data': Array(3.00004629, dtype=float64)}