In [1]:
from model import *
from pde import *
from derivative import *
from residual import *
from params import *

In [2]:
key = jax.random.PRNGKey(123)
model = PINN(key)
derivative = Derivative(x_coef,t_coef)
pde = PDE(
    alpha_phi=alpha_phi,
    omega_phi=omega_phi,
    M=M,
    A=A,
    L=L,
    c_se=c_se,
    c_le=c_le,
)
r = Residual(x_coef,t_coef,pde,derivative)

In [3]:
x = jnp.arange(10)/10.0
t = jnp.arange(10)/10.0

In [4]:
r.res_ic(model,x,t)

{'ic': Array([-0.23823779,  0.2622055 ,  0.26260259,  0.26295337,  0.26326489,
         0.26355026,  0.26382633,  0.26411083,  0.26441964,  0.26476475,
        -0.53540194, -0.03861464, -0.04206976, -0.04564288, -0.04922042,
        -0.05270728, -0.05603095, -0.05914191, -0.06201156, -0.0646285 ],      dtype=float64)}

In [5]:
r.res_bc(model,x,t)

{'bc': Array([-2.38237790e-01,  9.99762206e+02,  1.99976260e+03,  2.99976295e+03,
         3.99976326e+03,  4.99976355e+03,  5.99976383e+03,  6.99976411e+03,
         7.99976442e+03,  8.99976476e+03, -5.35401936e-01,  9.99461385e+02,
         1.99945793e+03,  2.99945436e+03,  3.99945078e+03,  4.99944729e+03,
         5.99944397e+03,  6.99944086e+03,  7.99943799e+03,  8.99943537e+03],      dtype=float64)}

In [6]:
r.res_phys(model,x,t)

{'ac': Array([0.00031112, 0.00031585, 0.00032079, 0.00032578, 0.00033071,
        0.00033548, 0.00034005, 0.00034438, 0.00034848, 0.00035236],      dtype=float64),
 'ch': Array([-0.00037037, -0.00130973, -0.00218534, -0.00293175, -0.00350621,
        -0.00389223, -0.00409706, -0.00414515, -0.00406997, -0.00390668],      dtype=float64)}

In [8]:
x_dict = {'ic':x,'bc':x,'colloc':x,'adapt':x}
t_dict = {'ic':t,'bc':t,'colloc':t,'adapt':t}

In [9]:
r.compute_loss(model,x_dict,t_dict)

{'ic': Array(0.04968087, dtype=float64),
 'bc': Array(28496433.08224288, dtype=float64),
 'ac': Array(1.10734733e-07, dtype=float64),
 'ch': Array(1.0846143e-05, dtype=float64)}

In [11]:
r.get_noisy_points(model,x,t,6)

(Array([0.9, 0.8, 0.7, 0.7, 0.6, 0.8], dtype=float64, weak_type=True),
 Array([0.9, 0.8, 0.7, 0.7, 0.6, 0.8], dtype=float64, weak_type=True))

In [13]:
r.compute_ntk_weights(model,x_dict,t_dict)

{'ic': Array(2.00048761, dtype=float64),
 'bc': Array(2.00048761, dtype=float64),
 'ac': Array(698339.09989741, dtype=float64),
 'ch': Array(4126.89538644, dtype=float64)}