In [14]:
from src.zpinn.dataio import PressureDataset, DomainSampler, BoundarySampler


data_path = r"C:\Users\STNj\dtu\thesis\code\data\processed\inf_baffle.pkl"
dataset = PressureDataset(data_path)
dataloader = dataset.get_dataloader(batch_size=16, shuffle=True)
data_iterator = iter(dataloader)

In [15]:
dom_sampler = DomainSampler(
    batch_size=16,
    limits=dict(x=(0, 1), y=(0, 1), z=(0, 1), f=(0, 1)),
    distributions=dict(x="uniform", y="uniform", z="uniform", f="uniform"),
    transforms=dataset.transforms,
)

dom_iterator = iter(dom_sampler)
next(dom_iterator)

{'x': Array([0.89965916, 0.54451245, 0.7551072 , 0.30612645, 1.0907483 ,
        0.527942  , 0.7002718 , 0.7091689 , 0.5983184 , 1.2476665 ,
        0.7626662 , 0.17276604, 1.2757558 , 0.53877735, 0.14619748,
        0.11825243], dtype=float32),
 'y': Array([1.2415363 , 1.2588104 , 0.89180833, 0.20118618, 1.3057855 ,
        0.03785117, 1.3305675 , 0.7668254 , 0.07583634, 0.86100835,
        0.76941425, 1.3311567 , 0.90818405, 0.5990197 , 0.21758191,
        0.7012175 ], dtype=float32),
 'z': Array([ 0.26074818,  1.0214976 ,  1.5902928 ,  0.03995725,  9.103167  ,
         9.785904  ,  0.3121819 , 11.8482485 ,  6.5587196 ,  0.3870913 ,
        13.150018  , 13.229212  , 12.3033085 ,  6.931617  , 10.591384  ,
         8.704803  ], dtype=float32),
 'f': Array([-1.0200034, -1.0198369, -1.0201769, -1.0200241, -1.0201287,
        -1.0200919, -1.0199314, -1.0199411, -1.0201936, -1.0201573,
        -1.019948 , -1.0200167, -1.0199752, -1.0201061, -1.0200771,
        -1.0200541], dtype=float32)}

In [16]:
bnd_sampler = BoundarySampler(
    batch_size=16,
    limits=dict(x=(0, 1), y=(0, 1), z=(0, 0), f=(0, 1)),
    distributions=dict(x="grid", y="grid", z="uniform", f="uniform"),
    transforms=dataset.transforms,
)

bnd_iterator = iter(bnd_sampler)

In [17]:
dataset.transforms


transforms = dict(
    x0=dataset.transforms["x"][0],
    xc=dataset.transforms["x"][1],
    y0=dataset.transforms["y"][0],
    yc=dataset.transforms["y"][1],
    z0=dataset.transforms["z"][0],
    zc=dataset.transforms["z"][1],
    f0=dataset.transforms["f"][0],
    fc=dataset.transforms["f"][1],
    a0=dataset.transforms["real_pressure"][0],
    ac=dataset.transforms["real_pressure"][1],
    b0=dataset.transforms["imag_pressure"][0],
    bc=dataset.transforms["imag_pressure"][1],
)

In [18]:
from src.zpinn.models import SIREN, BVPModel


model = SIREN.SIREN(
    in_features=4,
    out_features=2,
    hidden_features=256,
    hidden_layers=3,
    outermost_linear=True,
)

bvp_model = BVPModel.BVPModel(
    model=model,
    transforms=transforms,
    impedance_model="single_freq",
)

params = bvp_model.parameters()
coeffs = bvp_model.coefficients

In [29]:
from jax import vmap

r_net = bvp_model.r_net
x, y, z, f = next(dom_iterator).values()
vmap(bvp_model.r_net, in_axes=(None, *[0] * 4))(params, *(x, y, z, f))

Array([-579.3683   ,   79.734436 , -173.08574  ,  309.32535  ,
       -436.06314  ,  379.29288  , -118.4902   ,  -13.519647 ,
        -28.325058 ,  -50.780792 ,  209.21278  ,  205.2106   ,
         -4.0723944,  385.00708  , -113.67912  , -123.55937  ],      dtype=float32)

ValueError: vmap in_axes must be an int, None, or a tuple of entries corresponding to the positional arguments passed to the function, but got len(in_axes)=6, len(args)=5

In [None]:
(None, *[0] * 4, None)

(None, 0, 0, 0, 0, None)

In [None]:
# test individual models

bvp_model.p_loss(params, next(data_iterator))
# bvp_model.r_loss(params, next(dom_iterator))
# bvp_model.b_loss(params, next(bnd_iterator))

(Array(1.0073488, dtype=float32), Array(1.1958337, dtype=float32))

In [None]:
from jax import vmap
coords, gt = next(data_iterator)  # unpack the data batch
f, x, y, z = coords.values()

vmap(bvp_model.p_net, in_axes=(None, *[0] * 4))(params, *(x, y, z, f))

(Array([ 0.04716443,  0.03689361,  0.02221077,  0.03747791,  0.04307295,
        -0.00032674,  0.01515444,  0.06518535,  0.04870435,  0.05693201,
         0.07002046,  0.04120278, -0.0098152 ,  0.02148196,  0.00976468,
         0.05594826], dtype=float32),
 Array([-0.01987108, -0.00024862, -0.02182606, -0.04618606, -0.03799538,
        -0.01071216, -0.04871527, -0.01985353, -0.01966653, -0.00612449,
        -0.01326989, -0.02304624, -0.03284209, -0.026276  , -0.01182558,
        -0.03974533], dtype=float32))

In [None]:
# bvp_model.p_loss(params, (coords, gt))
bvp_model.r_loss(params, next(dom_iterator))

ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

In [None]:
bvp_model.losses(
    params,
    coeffs,
    next(data_iterator),
    next(dom_iterator),
    next(bnd_iterator),
)

ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())