In [16]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class PINN(nn.Module):
    def __init__(self, width=128, depth=8):
        super().__init__()
        layers = []

        # input: (x, y, z, t) -> R^4
        layers += [nn.Linear(4, width), nn.Tanh()]
        for _ in range(depth - 1):
            layers += [nn.Linear(width, width), nn.Tanh()]
        layers += [nn.Linear(width, 1)]

        self.net = nn.Sequential(*layers)

    def forward(self, x, y, z, t):
        # x, y, t: tensors of same shape (...,)
        X = torch.stack((x, y, z,t), dim=-1)  # (..., 3)
        u = self.net(X)                     # (..., 1)
        return u.squeeze(-1)                # (...,)

In [18]:
def grad(u, x):
    return torch.autograd.grad(
        outputs=u,
        inputs=x,
        grad_outputs=torch.ones_like(u),
        create_graph=True,
    )[0]

def laplacian(u, x, y, z):
    ux  = grad(u, x)
    uxx = grad(ux, x)
    uy  = grad(u, y)
    uyy = grad(uy, y)
    uz  = grad(u, z)
    uzz = grad(uz, z)
    return uxx + uyy + uzz

In [19]:
def residual(model, x, y, z, t, c, alpha, f_fn):
    x = x.detach().requires_grad_(True)
    y = y.detach().requires_grad_(True)
    z = z.detach().requires_grad_(True)
    t = t.detach().requires_grad_(True)

    u = model(x, y, z, t)    
    ut = grad(u, t)
    utt = grad(ut, t)
    lap = laplacian(u, x, y, z)
    f = f_fn(x, y, z, t)   

    R = utt - c * lap + alpha * ut - f
    return R, u, ut

In [20]:
def f_fn(x, y, z, t):
    return torch.zeros_like(x)

In [21]:
mse = nn.MSELoss()

def loss_pinn(model,
              x_p, y_p, z_p, t_p,
              x_0, y_0, z_0, u0, v0,
              x_b, y_b, z_b, t_b, g,
              c, alpha, f_fn,
              w_phys=1.0, w_ic=10.0, w_bc=10.0):

    # physics residual
    R, _, _ = residual(model, x_p, y_p, z_p, t_p, c, alpha, f_fn)
    L_phys = torch.mean(R**2)

    # initial displacement and velocity
    t0 = torch.zeros_like(x_0, requires_grad=True)
    u0_pred = model(x_0, y_0, z_0, t0)
    L_u0 = mse(u0_pred, u0)

    ut0_pred = grad(u0_pred, t0)
    L_v0 = mse(ut0_pred, v0)

    # boundary (Dirichlet)
    u_b = model(x_b, y_b, z_b, t_b)
    L_bc = mse(u_b, g)

    L = w_phys * L_phys + w_ic * (L_u0 + L_v0) + w_bc * L_bc
    return L, L_phys.detach(), (L_u0 + L_v0).detach(), L_bc.detach()

# Test :

In [23]:
# Grid and time setup
x_min, x_max = -1.0, 1.0
y_min, y_max = -1.0, 1.0
z_min, z_max = -1.0, 1.0 
t_min, t_max =  0.0, 1.0 

# Physical parameters
c = torch.tensor(1.0, device=device) 
alpha = torch.tensor(0.1, device=device)

# Collection points
N_phys = 5000
x_p = torch.empty(N_phys, device=device).uniform_(x_min, x_max)
y_p = torch.empty(N_phys, device=device).uniform_(y_min, y_max)
z_p = torch.empty(N_phys, device=device).uniform_(z_min, z_max)
t_p = torch.empty(N_phys, device=device).uniform_(t_min, t_max)

# Initial condition
N_ic = 200
x_0 = torch.empty(N_ic, device=device).uniform_(x_min, x_max)
y_0 = torch.empty(N_ic, device=device).uniform_(y_min, y_max)
z_0 = torch.empty(N_ic, device=device).uniform_(z_min, z_max)
u0 = torch.exp(-5.0 * (x_0**2 + y_0**2 + z_0**2))
v0 = torch.zeros_like(x_0) # zero belocity

# Boundary points (u = 0 on boundary)
# Width
nb_x = 200
t_bx = torch.empty(nb_x, device=device).uniform_(t_min, t_max)
y_bx = torch.empty(nb_x, device=device).uniform_(y_min, y_max)
z_bx = torch.empty(nb_x, device=device).uniform_(z_min, z_max)

x_left  = torch.full((nb_x // 2,), x_min, device=device)
x_right = torch.full((nb_x - nb_x // 2,), x_max, device=device)
x_bx = torch.cat([x_left, x_right], dim=0)

# Height
nb_y = 200
t_by = torch.empty(nb_y, device=device).uniform_(t_min, t_max)
x_by = torch.empty(nb_y, device=device).uniform_(x_min, x_max)
z_by = torch.empty(nb_y, device=device).uniform_(z_min, z_max)

y_bottom = torch.full((nb_y // 2,), y_min, device=device)
y_top = torch.full((nb_y - nb_y // 2,), y_max, device=device)
y_by = torch.cat([y_bottom, y_top], dim=0)

# Depth
nb_z = 200
t_bz = torch.empty(nb_z, device=device).uniform_(t_min, t_max)
x_bz = torch.empty(nb_z, device=device).uniform_(x_min, x_max)
y_bz = torch.empty(nb_z, device=device).uniform_(y_min, y_max)

z_bottom = torch.full((nb_z // 2,), z_min, device=device)
z_top = torch.full((nb_z - nb_z // 2,), z_max, device=device)
z_bz = torch.cat([z_bottom, z_top], dim=0)

# Concatenate boundary points
x_b = torch.cat([x_bx, x_by, x_bz], dim=0)
y_b = torch.cat([y_bx, y_by, y_bz], dim=0)
z_b = torch.cat([z_bx, z_by, z_bz], dim=0)
t_b = torch.cat([t_bx, t_by, t_bz], dim=0)
g = torch.zeros_like(x_b)   # u = 0 on boundary

In [36]:
def train():
    model = PINN(width=128, depth=8).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    runs = 2000

    for it in range(runs + 1):
        optimizer.zero_grad()
        L, Lp, Lic, Lbc = loss_pinn(
            model,
            x_p, y_p, z_p, t_p, 
            x_0, y_0, z_0, u0, v0,
            x_b, y_b, z_b, t_b, g,
            c, alpha, f_fn,
            w_phys=1.0, w_ic=10.0, w_bc=10.0
        )
        L.backward()
        optimizer.step()

        if it % 1000 == 0:
            print(f"{it:05d}  total={L.item():.3e}  phys={Lp.item():.3e}  ic={Lic.item():.3e}  bc={Lbc.item():.3e}")
    return model

In [37]:
model = train()
model.eval()

# ---- grid definition ----
Nx, Ny = 101, 101
x_lin = torch.linspace(x_min, x_max, Nx, device=device)
y_lin = torch.linspace(y_min, y_max, Ny, device=device)
X, Y = torch.meshgrid(x_lin, y_lin, indexing='ij')

# FIXED SLICE IN Z

for z_slice in [-1,0.5, 0, 0.5,1]:
# z_slice = 0.0
    Z = torch.full_like(X, z_slice)

    # times to visualize
    times_to_plot = [0.0, 0.33, 0.66, 1.0]

    plt.figure(figsize=(15, 3))

    with torch.no_grad():
        for i, t_val in enumerate(times_to_plot, start=1):

            T = torch.full_like(X, float(t_val))

            # flatten inputs for model
            x_in = X.reshape(-1)
            y_in = Y.reshape(-1)
            z_in = Z.reshape(-1)      # NEW
            t_in = T.reshape(-1)

            # evaluate 3D PINN
            u_flat = model(x_in, y_in, z_in, t_in)       # (Nx*Ny,)
            u_grid = u_flat.reshape(Nx, Ny).cpu().numpy()

            # Plot slice
            ax = plt.subplot(1, len(times_to_plot), i)
            im = ax.imshow(
                u_grid.T,
                origin='lower',
                extent=[x_min, x_max, y_min, y_max],
                aspect='equal'
            )
            ax.set_title(f"t = {t_val:.2f}, z = 0")
            ax.set_xlabel("x")
            if i == 1:
                ax.set_ylabel("y")

            plt.colorbar(im, ax=ax, shrink=0.8)

    plt.tight_layout()
    plt.show()

00000  total=2.107e-01  phys=1.587e-06  ic=1.056e-02  bc=1.051e-02


KeyboardInterrupt: 

In [None]:
# from pyawd import VectorAcousticWaveDataset3D  # adjust import if needed
# import torch
# import numpy
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # --- generate 1 synthetic experiment with PyAWD ---
# interrogators = [(10, 0, 0), (-10, 0, 0)]
# dataset = VectorAcousticWaveDataset3D(size=1, interrogators=interrogators)

# experiment = dataset[0]
# field = experiment[0]          # torch.Tensor
# print("field.shape:", field.shape)
# print("field.dtype:", field.dtype)
# print("field.device:", field.device)
# print("field.ndim:", field.ndim)

# comp = 0                           # choose which component to use
# field_comp = field[comp]          # (Nt, Nx, Ny, Nz)

# Nt, Nx, Ny, Nz = field_comp.shape

# # Reorder to (Nx, Ny, Nz, Nt) for consistency with our PINN
# u_torch = field_comp.permute(1, 2, 3, 0).contiguous()  # (Nx, Ny, Nz, Nt)
# u_torch = u_torch.to(device)
# Nx, Ny, Nz, Nt = u_torch.shape

len(experiment): 2
type(experiment[0]): <class 'torch.Tensor'>
type(experiment[1]): <class 'dict'>

Keys in meta:
meta is not a dict, type: <class 'torch.Tensor'>


In [None]:
# x_idx = torch.arange(Nx, dtype=torch.float32, device=device)
# y_idx = torch.arange(Ny, dtype=torch.float32, device=device)
# z_idx = torch.arange(Nz, dtype=torch.float32, device=device)

# x_coords = (x_idx - (Nx - 1) / 2) / ((Nx - 1) / 2)
# y_coords = (y_idx - (Ny - 1) / 2) / ((Ny - 1) / 2)
# z_coords = (z_idx - (Nz - 1) / 2) / ((Nz - 1) / 2)

# t_idx = torch.arange(Nt, dtype=torch.float32, device=device)
# t_coords = t_idx / max(Nt - 1, 1)   # in [0,1]

# u0_grid = u_torch[..., 0]     # (Nx, Ny, Nz)

# if Nt >= 2:
#     dt = 1.0 / max(Nt - 1, 1)    # or the true dt from PyAWD
#     v0_grid = (u_torch[..., 1] - u_torch[..., 0]) / dt
# else:
#     v0_grid = torch.zeros_like(u0_grid)

In [None]:
# N_phys = 5000

# x_p = torch.empty(N_phys, device=device).uniform_(x_coords[0], x_coords[-1])
# y_p = torch.empty(N_phys, device=device).uniform_(y_coords[0], y_coords[-1])
# z_p = torch.empty(N_phys, device=device).uniform_(z_coords[0], z_coords[-1])
# t_p = torch.empty(N_phys, device=device).uniform_(0.0, 1.0)

# N_ic = 2000

# ix = torch.randint(0, Nx, (N_ic,), device=device)
# iy = torch.randint(0, Ny, (N_ic,), device=device)
# iz = torch.randint(0, Nz, (N_ic,), device=device)

# x_0 = x_coords[ix]
# y_0 = y_coords[iy]
# z_0 = z_coords[iz]

# u0 = u0_grid[ix, iy, iz]
# v0 = v0_grid[ix, iy, iz]

# N_bc = 4000
# n_face = N_bc // 6

# x_min, x_max = x_coords[0].item(), x_coords[-1].item()
# y_min, y_max = y_coords[0].item(), y_coords[-1].item()
# z_min, z_max = z_coords[0].item(), z_coords[-1].item()

# # faces x = const
# t_x = torch.empty(2*n_face, device=device).uniform_(0.0, 1.0)
# y_x = torch.empty(2*n_face, device=device).uniform_(y_min, y_max)
# z_x = torch.empty(2*n_face, device=device).uniform_(z_min, z_max)
# x_x = torch.cat([
#     torch.full((n_face,), x_min, device=device),
#     torch.full((n_face,), x_max, device=device)
# ], dim=0)

# # faces y = const
# t_y = torch.empty(2*n_face, device=device).uniform_(0.0, 1.0)
# x_y = torch.empty(2*n_face, device=device).uniform_(x_min, x_max)
# z_y = torch.empty(2*n_face, device=device).uniform_(z_min, z_max)
# y_y = torch.cat([
#     torch.full((n_face,), y_min, device=device),
#     torch.full((n_face,), y_max, device=device)
# ], dim=0)

# # faces z = const
# t_z = torch.empty(2*n_face, device=device).uniform_(0.0, 1.0)
# x_z = torch.empty(2*n_face, device=device).uniform_(x_min, x_max)
# y_z = torch.empty(2*n_face, device=device).uniform_(y_min, y_max)
# z_z = torch.cat([
#     torch.full((n_face,), z_min, device=device),
#     torch.full((n_face,), z_max, device=device)
# ], dim=0)

# x_b = torch.cat([x_x, x_y, x_z], dim=0)
# y_b = torch.cat([y_x, y_y, y_z], dim=0)
# z_b = torch.cat([z_x, z_y, z_z], dim=0)
# t_b = torch.cat([t_x, t_y, t_z], dim=0)
# g   = torch.zeros_like(x_b)    # homogeneous Dirichlet