In [1]:
from src.zpinn.dataio import PressureDataset, DomainSampler, BoundarySampler


data_path = r"C:\Users\STNj\dtu\thesis\code\data\processed\inf_baffle.pkl"
dataset = PressureDataset(data_path)
dataloader = dataset.get_dataloader(batch_size=16, shuffle=True)
data_iterator = iter(dataloader)

In [2]:
dom_sampler = DomainSampler(
    batch_size=16,
    limits=dict(x=(0, 1), y=(0, 1), z=(0, 1), f=(0, 1)),
    distributions=dict(x="uniform", y="uniform", z="uniform", f="uniform"),
    transforms=dataset.transforms,
)

dom_iterator = iter(dom_sampler)
next(dom_iterator)

{'x': Array([0.89965916, 0.54451245, 0.7551072 , 0.30612645, 1.0907483 ,
        0.527942  , 0.7002718 , 0.7091689 , 0.5983184 , 1.2476665 ,
        0.7626662 , 0.17276604, 1.2757558 , 0.53877735, 0.14619748,
        0.11825243], dtype=float32),
 'y': Array([1.2415363 , 1.2588104 , 0.89180833, 0.20118618, 1.3057855 ,
        0.03785117, 1.3305675 , 0.7668254 , 0.07583634, 0.86100835,
        0.76941425, 1.3311567 , 0.90818405, 0.5990197 , 0.21758191,
        0.7012175 ], dtype=float32),
 'z': Array([ 0.26074818,  1.0214976 ,  1.5902928 ,  0.03995725,  9.103167  ,
         9.785904  ,  0.3121819 , 11.8482485 ,  6.5587196 ,  0.3870913 ,
        13.150018  , 13.229212  , 12.3033085 ,  6.931617  , 10.591384  ,
         8.704803  ], dtype=float32),
 'f': Array([-1.0200034, -1.0198369, -1.0201769, -1.0200241, -1.0201287,
        -1.0200919, -1.0199314, -1.0199411, -1.0201936, -1.0201573,
        -1.019948 , -1.0200167, -1.0199752, -1.0201061, -1.0200771,
        -1.0200541], dtype=float32)}

In [3]:
bnd_sampler = BoundarySampler(
    batch_size=16,
    limits=dict(x=(0, 1), y=(0, 1), z=(0, 0), f=(0, 1)),
    distributions=dict(x="grid", y="grid", z="uniform", f="uniform"),
    transforms=dataset.transforms,
)

bnd_iterator = iter(bnd_sampler)

In [4]:
dataset.transforms


transforms = dict(
    x0=dataset.transforms["x"][0],
    xc=dataset.transforms["x"][1],
    y0=dataset.transforms["y"][0],
    yc=dataset.transforms["y"][1],
    z0=dataset.transforms["z"][0],
    zc=dataset.transforms["z"][1],
    f0=dataset.transforms["f"][0],
    fc=dataset.transforms["f"][1],
    a0=dataset.transforms["real_pressure"][0],
    ac=dataset.transforms["real_pressure"][1],
    b0=dataset.transforms["imag_pressure"][0],
    bc=dataset.transforms["imag_pressure"][1],
)

In [5]:
from src.zpinn.models import SIREN, BVPModel


model = SIREN.SIREN(
    in_features=4,
    out_features=2,
    hidden_features=256,
    hidden_layers=3,
    outermost_linear=True,
)

bvp_model = BVPModel.BVPModel(
    model=model,
    transforms=transforms,
    impedance_model="single_freq",
)

params = bvp_model.parameters()
coeffs = bvp_model.coefficients

In [6]:
from jax import vmap

r_net = bvp_model.r_net
x, y, z, f = next(dom_iterator).values()
vmap(bvp_model.r_net, in_axes=(None, *[0] * 4))(params, *(x, y, z, f))

(Array([ 13.896487 ,  15.151459 ,  10.504659 ,  -7.150683 , -19.713139 ,
         -3.3349352,  -5.4443455,   1.6721005,  -7.265933 , -20.758698 ,
        -21.95194  , -12.516868 ,  -6.036775 , -11.564328 , -11.95933  ,
          3.715053 ], dtype=float32),
 Array([-10.273615  ,  -8.151482  ,   1.2610406 ,  -3.8613884 ,
          6.1925154 ,  -0.72648686,  -4.5691957 ,  -0.1657106 ,
         -1.697756  ,   3.3997982 ,  15.014544  ,  -6.9602947 ,
        -10.531594  , -19.61528   , -23.171238  , -12.085783  ],      dtype=float32))

In [7]:
# test individual models

bvp_model.p_loss(params, next(data_iterator))
bvp_model.r_loss(params, next(dom_iterator))
bvp_model.z_loss(params, coeffs, next(bnd_iterator))

(Array(4.005037e-09, dtype=float32), Array(2.2584195e-09, dtype=float32))

In [8]:
from jax import vmap
coords, gt = next(data_iterator)  # unpack the data batch
f, x, y, z = coords.values()

vmap(bvp_model.p_net, in_axes=(None, *[0] * 4))(params, *(x, y, z, f))

(Array([ 0.04716443,  0.03689361,  0.02221077,  0.03747791,  0.04307295,
        -0.00032674,  0.01515444,  0.06518535,  0.04870435,  0.05693201,
         0.07002046,  0.04120278, -0.0098152 ,  0.02148196,  0.00976468,
         0.05594826], dtype=float32),
 Array([-0.01987108, -0.00024862, -0.02182606, -0.04618606, -0.03799538,
        -0.01071216, -0.04871527, -0.01985353, -0.01966653, -0.00612449,
        -0.01326989, -0.02304624, -0.03284209, -0.026276  , -0.01182558,
        -0.03974533], dtype=float32))

In [9]:
# bvp_model.p_loss(params, (coords, gt))
bvp_model.r_loss(params, next(dom_iterator))

(Array(199.24384, dtype=float32), Array(3857.212, dtype=float32))

In [10]:
bvp_model.losses(
    params,
    coeffs,
    next(data_iterator),
    next(dom_iterator),
    next(bnd_iterator),
)

{'data_re': Array(0.8149791, dtype=float32),
 'data_im': Array(0.7857977, dtype=float32),
 'pde_re': Array(393.76633, dtype=float32),
 'pde_im': Array(3131.1226, dtype=float32),
 'bc_re': Array(4.000836e-09, dtype=float32),
 'bc_im': Array(2.253595e-09, dtype=float32)}

In [11]:
bvp_model.coefficients

{'alpha': 0.0, 'beta': 0.0}

In [13]:
batches = dict(
    dat_batch=next(data_iterator),
    dom_batch=next(dom_iterator),
    bnd_batch=next(bnd_iterator),
)
w = bvp_model.compute_weights(**batches)

TypeError: Argument '<bound method BVPModel.losses of BVPModel(criterion=<function <lambda> at 0x0000016827EFF4C0>, coefficients={'alpha': 0.0, 'beta': 0.0}, model=SIREN(
  layers=[
    SineLayer(
      omega_0=30,
      is_first=True,
      in_features=4,
      out_features=256,
      linear=Linear(
        weight=f32[256,4],
        bias=f32[256],
        in_features=4,
        out_features=256,
        use_bias=True
      ),
      weight=f32[256,4]
    ),
    SineLayer(
      omega_0=30.0,
      is_first=False,
      in_features=256,
      out_features=256,
      linear=Linear(
        weight=f32[256,256],
        bias=f32[256],
        in_features=256,
        out_features=256,
        use_bias=True
      ),
      weight=f32[256,256]
    ),
    SineLayer(
      omega_0=30.0,
      is_first=False,
      in_features=256,
      out_features=256,
      linear=Linear(
        weight=f32[256,256],
        bias=f32[256],
        in_features=256,
        out_features=256,
        use_bias=True
      ),
      weight=f32[256,256]
    ),
    SineLayer(
      omega_0=30.0,
      is_first=False,
      in_features=256,
      out_features=256,
      linear=Linear(
        weight=f32[256,256],
        bias=f32[256],
        in_features=256,
        out_features=256,
        use_bias=True
      ),
      weight=f32[256,256]
    ),
    Linear(
      weight=f32[2,256],
      bias=f32[2],
      in_features=256,
      out_features=2,
      use_bias=True
    )
  ]
), x0=0.0, xc=0.75, y0=0.0, yc=0.75, z0=0.0, zc=0.07, a0=-0.058731083193746154, ac=0.8051005708866225, b0=-0.18135189435224625, bc=0.86690303980624, impedance_model=<function constant_impedance at 0x0000016827EFF6A0>)>' of type <class 'method'> is not a valid JAX type.