In [1]:
import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from sampler import *
from params import *
key = jax.random.PRNGKey(123)
from model import *
model = PINN(key)
from model import *
from pde import *
from derivative import *
from residual import *
from params import *

In [3]:
key = jax.random.PRNGKey(123)
model = PINN(key)
derivative = Derivative(x_coef,t_coef)
pde = PDE(
    alpha_phi=alpha_phi,
    omega_phi=omega_phi,
    M=M,
    A=A,
    L=L,
    c_se=c_se,
    c_le=c_le,
)
r = Residual(x_coef,t_coef,pde,derivative)

In [4]:
# construct the sampler
sample_size = {'ic': 8, 'bc': 6, 'colloc_x': 5, 'colloc_t': 5, 'adapt': 10}
subsample_size = {'ic': 4, 'bc': 3, 'colloc': 6, 'adapt': 5}
noise = 0.5
sampler = Sampler(x_span, t_span, sample_size, subsample_size, r, noise)

In [5]:
# Test _denormalize
print("Denormalize 0.5 in (0,10):", sampler._denormalize(0.5, (0, 10)))

Denormalize 0.5 in (0,10): 5.0


In [6]:
# Test _lhs
lhs_samples = sampler._lhs(key, dim=1, num=5)
print("LHS samples (dim=1):", lhs_samples)

LHS samples (dim=1): [0.7288015  0.05534138 0.33729004 0.59948698 0.89211653]


In [7]:
lhs_samples_2d = sampler._lhs(key, dim=2, num=5)
print("LHS samples (dim=2):", lhs_samples_2d)

LHS samples (dim=2): (Array([0.7288015 , 0.05534138, 0.33729004, 0.59948698, 0.89211653],      dtype=float64), Array([0.44926569, 0.13196439, 0.94082043, 0.75244268, 0.21987435],      dtype=float64))


In [8]:
# Test _make_uniform_grid
uniform_grid = sampler._make_uniform_grid(key, 5, (0, 1))
print("Uniform grid:", uniform_grid)

Uniform grid: [0.02821268 0.22821268 0.42821268 0.62821268 0.82821268]


In [9]:
# Test _get_ic
x_ic, t_ic = sampler._get_ic(key)
print("IC x:", x_ic)
print("IC t:", t_ic)

IC x: [-0.45485972 -0.25485972 -0.07742986  0.02257014  0.14514028  0.34514028]
IC t: [0. 0. 0. 0. 0. 0.]


In [10]:
# Test _get_bc
x_bc, t_bc = sampler._get_bc(key)
print("BC x:", x_bc)
print("BC t:", t_bc)

BC x: [-0.5 -0.5 -0.5  0.5  0.5  0.5]
BC t: [0.05642535 0.38975869 0.72309202 0.05642535 0.38975869 0.72309202]


In [11]:
# Test _get_colloc
x_colloc, t_colloc = sampler._get_colloc(key)
print("Colloc x:", x_colloc)
print("Colloc t:", t_colloc)

Colloc x: [-0.41949907 -0.41949907 -0.41949907 -0.41949907 -0.41949907 -0.21949907
 -0.21949907 -0.21949907 -0.21949907 -0.21949907 -0.01949907 -0.01949907
 -0.01949907 -0.01949907 -0.01949907  0.18050093  0.18050093  0.18050093
  0.18050093  0.18050093  0.38050093  0.38050093  0.38050093  0.38050093
  0.38050093]
Colloc t: [0.03079106 0.23079106 0.43079106 0.63079106 0.83079106 0.03079106
 0.23079106 0.43079106 0.63079106 0.83079106 0.03079106 0.23079106
 0.43079106 0.63079106 0.83079106 0.03079106 0.23079106 0.43079106
 0.63079106 0.83079106 0.03079106 0.23079106 0.43079106 0.63079106
 0.83079106]


In [12]:
x_adapt, t_adapt = sampler._get_adapt(key, model)
print(x_adapt)
print(t_adapt)

[0.97936736 0.84199366 0.94507517 0.833729   0.81488725 0.64815095
 0.62921165 0.73948772 0.833729   0.81488725]
[0.08104491 0.02684082 0.31918632 0.17408204 0.28214012 0.11166575
 0.12198744 0.27993646 0.17408204 0.28214012]


In [13]:
# Test get_sample
x_sample, t_sample = sampler.get_sample(key, model)
print("Sample keys:", x_sample.keys())
for k in x_sample:
    print(f"{k} x shape: {x_sample[k].shape}, t shape: {t_sample[k].shape}")

Sample keys: dict_keys(['ic', 'bc', 'colloc', 'adapt'])
ic x shape: (6,), t shape: (6,)
bc x shape: (6,), t shape: (6,)
colloc x shape: (25,), t shape: (25,)
adapt x shape: (10,), t shape: (10,)


In [14]:
# Test get_subsample
x_subsample, t_subsample = sampler.get_subsample(key, x_sample, t_sample)
for k in x_subsample:
    print(f"Subsample {k} x shape: {x_subsample[k].shape}, t shape: {t_subsample[k].shape}")

Subsample ic x shape: (4,), t shape: (4,)
Subsample bc x shape: (3,), t shape: (3,)
Subsample colloc x shape: (6,), t shape: (6,)
Subsample adapt x shape: (5,), t shape: (5,)
