# Surrogate Model :

In [18]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype=torch.float32

## Forward models (physics -> seimograms) : 

### Finite difference solver (DP) : 

In [None]:
class DPForwardSolver:
    """
    Differentiable-physics forward solver for the 2D acoustic wave equation
    using an explicit second-order finite-difference time-stepping scheme.

    forward(e_x, e_y) -> seismograms (Nt, K)
    """

    def __init__(
        self,
        sensors: torch.Tensor,         # (K,2) in physical coords
        c=1,
        x_min=-1, x_max=1,
        y_min=-1, y_max=1,
        Nx=101, Ny=101, Nt=201, T=1,
        # source params
        A=5, beta=200, t0=0.2, gamma=200,
    ):

        self.device = device
        self.dtype = dtype

        # domain / grid
        self.x_min, self.x_max = x_min, x_max
        self.y_min, self.y_max = y_min, y_max
        self.Nx, self.Ny, self.Nt, self.T = Nx, Ny, Nt, T

        # physics
        self.c = torch.as_tensor(c, device=self.device, dtype=self.dtype)

        # source
        self.A = A
        self.beta = beta
        self.t0 = t0
        self.gamma = gamma

        # sensors
        self.sensors = sensors.to(self.device, dtype=self.dtype)
        self.K = self.sensors.shape[0]

        # build grid
        self.x = torch.linspace(self.x_min, self.x_max, self.Nx, device=self.device, dtype=self.dtype)
        self.y = torch.linspace(self.y_min, self.y_max, self.Ny, device=self.device, dtype=self.dtype)

        self.dx = self.x[1] - self.x[0]
        self.dy = self.y[1] - self.y[0]
        if abs(self.dx - self.dy) > 1e-7:
            raise ValueError("DPForwardSolver2D requires dx == dy (uniform square grid) for this stencil.")
        self.h = self.dx

        self.dt = self.T / (self.Nt - 1)

        # CFL check
        if self.c * self.dt / self.h > 1 / math.sqrt(2):
            raise ValueError("CFL unstable: decrease dt (increase Nt), increase h (decrease Nx/Ny), or reduce c.")

        # meshgrid for source evaluation (Ny, Nx)
        X, Y = torch.meshgrid(self.x, self.y, indexing="xy")  # (Nx,Ny)
        self.X = X.T.contiguous()  # (Ny,Nx)
        self.Y = Y.T.contiguous()        

    def get_bounds(self):
        return (self.x_min, self.x_max), (self.y_min, self.y_max)

    def _laplacian(self, u: torch.Tensor) -> torch.Tensor:
        # u: (1,1,Ny,Nx)
        k = torch.tensor([[0, 1, 0],
                          [1,-4, 1],
                          [0, 1, 0]], 
                          device=self.device, dtype=self.dtype) / (self.h * self.h)
        k = k.view(1, 1, 3, 3)
        return F.conv2d(u, k, padding=1)

    def _apply_dirichlet(self, u: torch.Tensor) -> torch.Tensor:
        # avoid in-place ops that can break autograd
        v = u.clone()
        v[..., 0, :] = 0
        v[..., -1, :] = 0
        v[..., :, 0] = 0
        v[..., :, -1] = 0
        return v

    def _source(self, t: torch.Tensor, e_x: torch.Tensor, e_y: torch.Tensor) -> torch.Tensor:
        # returns (1,1,Ny,Nx)
        time_env = self.A * torch.exp(-self.beta * (t - self.t0) ** 2)
        space_env = torch.exp(-self.gamma * ((self.X - e_x) ** 2 + (self.Y - e_y) ** 2))
        f = time_env * space_env
        return f.unsqueeze(0).unsqueeze(0)

    def _sample_sensors(self, u: torch.Tensor) -> torch.Tensor:
        # u: (1,1,Ny,Nx) -> (K,)
        x = self.sensors[:, 0]
        y = self.sensors[:, 1]
        x_n = 2 * (x - self.x_min) / (self.x_max - self.x_min) - 1
        y_n = 2 * (y - self.y_min) / (self.y_max - self.y_min) - 1
        grid = torch.stack([x_n, y_n], dim=-1).view(1, -1, 1, 2)  # (1,K,1,2)
        vals = F.grid_sample(u, grid, mode="bilinear", align_corners=True)  # (1,1,K,1)
        return vals.view(-1)

    def forward(self, e_x, e_y):
        """
        e_x, e_y: epicenter coordinates (float, tensor, or nn.Parameter)

        seismograms: (Nt, K)
        optionally u_curr: (1,1,Ny,Nx) final field if return_wavefield=True
        """
        e_x = torch.as_tensor(e_x, device=self.device, dtype=self.dtype)
        e_y = torch.as_tensor(e_y, device=self.device, dtype=self.dtype)

        u_prev = torch.zeros((1, 1, self.Ny, self.Nx), device=self.device, dtype=self.dtype)
        u_curr = torch.zeros((1, 1, self.Ny, self.Nx), device=self.device, dtype=self.dtype)

        seismograms = torch.zeros((self.Nt, self.K), device=self.device, dtype=self.dtype)

        cdt2 = (self.c * self.dt) ** 2
        dt2 = (self.dt ** 2)

        for n in range(self.Nt):
            t = torch.as_tensor(n * self.dt, device=self.device, dtype=self.dtype)

            seismograms[n] = self._sample_sensors(u_curr)

            if n < self.Nt - 1:
                lap = self._laplacian(u_curr)
                f = self._source(t, e_x, e_y)
                u_next = 2 * u_curr - u_prev + cdt2 * lap + dt2 * f
                u_next = self._apply_dirichlet(u_next)
                u_prev, u_curr = u_curr, u_next

        return seismograms

### PINN : 

In [None]:
class PINN_2D(nn.Module):
    """
    Approximates the wavefield u(x,y,t; e_x,e_y) for a 2D problem.
    Inputs: x,y,t,e_x,e_y  
    Output: u 
    """
    def __init__(self, width=128, depth=8, act=nn.Tanh):
        super().__init__()
        layers = [nn.Linear(5, width), act()]
        for _ in range(depth - 1):
            layers += [nn.Linear(width, width), act()]
        layers += [nn.Linear(width, 1)]
        self.net = nn.Sequential(*layers)

    def forward(self, x, y, t, e_x, e_y):
        X = torch.stack((x, y, t, e_x, e_y), dim=-1)  # (..., 5)
        u = self.net(X)
        return u.squeeze(-1)


class PINNForwardSolver2D:
    """
    Wraps a trained PINN into a forward solver that outputs seismograms.
    """

    def __init__(
        self,
        model: nn.Module,
        sensors: torch.Tensor,         
        x_min=-1.0, x_max=1.0,
        y_min=-1.0, y_max=1.0,
        Nt=201, T=1.0, t_min=0.0,
        batch_size=4096
    ):
        self.model = model
        self.model.to(device=device)
        self.model.eval()

        self.device = device
        self.dtype = dtype

        # domain bounds (for inverse solver clamping)
        self.x_min, self.x_max = float(x_min), float(x_max)
        self.y_min, self.y_max = float(y_min), float(y_max)

        # time grid
        self.Nt = int(Nt)
        self.T = float(T)
        self.t_min = float(t_min)
        if self.Nt < 2:
            raise ValueError("Nt must be >= 2")
        self.dt = (self.T - self.t_min) / (self.Nt - 1)
        self.t_grid = torch.linspace(self.t_min, self.T, self.Nt, device=self.device, dtype=self.dtype)

        # sensors
        sensors = sensors.to(device=self.device, dtype=self.dtype)
        if sensors.ndim != 2 or sensors.shape[1] != 2:
            raise ValueError("sensors must be (K,2)")
        self.sensors = sensors
        self.K = int(sensors.shape[0])

        self.batch_size = int(batch_size)

    def get_bounds(self):
        return (self.x_min, self.x_max), (self.y_min, self.y_max)

    def forward(self, e_x, e_y):
        """
        Input : e_x, e_y : torch.Tensor (ideally nn.Parameter) with requires_grad=True for inversion

        Output: seismograms (Nt, K) 
        """
        # IMPORTANT: do not convert/cast e_x,e_y here (keep autograd graph intact)
        x_s = self.sensors[:, 0]  # (K,)
        y_s = self.sensors[:, 1]  # (K,)

        seismograms = torch.empty((self.Nt, self.K), device=self.device, dtype=self.dtype)

        # Evaluate u at sensor locations for each time t_n
        for n in range(self.Nt):
            t = self.t_grid[n]  # scalar tensor

            start = 0
            while start < self.K:
                end = min(start + self.batch_size, self.K)

                xs = x_s[start:end]
                ys = y_s[start:end]
                ts = t.expand_as(xs)

                ex = e_x.expand_as(xs)
                ey = e_y.expand_as(xs)

                # Do NOT wrap with no_grad(): inversion needs gradients wrt e_x,e_y
                seismograms[n, start:end] = self.model(xs, ys, ts, ex, ey)
                start = end

        return seismograms

### ML pipeline :

## Inverse problem solver :

In [None]:
def inverse_solver(
    forward,               # object with: forward(e_x, e_y) -> traces_pred (Nt,K)
    traces_obs,            # (Nt,K) torch tensor
    dt,                    # float
    t_star=0,              # ignore data before t_star
    init=(0, 0),           # initial guess (e_x0, e_y0)
    steps=500,             # outer steps
    lr=1,
    lam=1e-6
):
    """
    Minimal inverse solver.
    Forward(e_x,e_y) returns predicted seismograms with same shape as traces_obs.

    Returns:
      e_hat: (2,) tensor [e_x_hat, e_y_hat]
      traces_pred_final: (Nt,K) tensor
      history: list of dicts (optional diagnostics)
      n_star: int
    """
    traces_obs = traces_obs.to(device=device, dtype=dtype)

    Nt = traces_obs.shape[0]
    n_star = int(round(t_star / dt))
    n_star = max(0, min(n_star, Nt - 1))

    # trainable epicenter
    e_x = torch.nn.Parameter(torch.tensor(float(init[0]), device=device, dtype=dtype))
    e_y = torch.nn.Parameter(torch.tensor(float(init[1]), device=device, dtype=dtype))

    def clamp():
        (x_min, x_max), (y_min, y_max) = forward.get_bounds()
        with torch.no_grad():
            e_x.clamp(x_min + 1e-3, x_max - 1e-3)
            e_y.clamp(y_min + 1e-3, y_max - 1e-3)
    clamp()

    # fixed scaling based on observed data (stabilizes 2-sensors case)
    scale = traces_obs[n_star:].std(dim=0, unbiased=False).clamp_min(1e-6)

    opt = torch.optim.LBFGS([e_x, e_y], lr=lr, max_iter=1, line_search_fn="strong_wolfe")
    history = []

    def loss_and_pred():
        traces_pred = forward.forward(e_x, e_y)  # (Nt,K)
        pred = traces_pred[n_star:]
        obs = traces_obs[n_star:]

        pred = pred / scale
        obs = obs / scale

        loss_data = torch.sum((pred - obs) ** 2)
        loss_reg = lam * (e_x**2 + e_y**2)
        return loss_data + loss_reg, loss_data, traces_pred

    for it in range(int(steps)):
        def closure():
            opt.zero_grad()
            loss, _, _ = loss_and_pred()
            loss.backward()
            return loss

        opt.step(closure)
        clamp()

        if it % 50 == 0 or it == steps - 1:
            print(f"{it=}")
            with torch.no_grad():
                loss, loss_data, _ = loss_and_pred()
                history.append({
                    "it": it,
                    "loss": float(loss.item()),
                    "loss_data": float(loss_data.item()),
                    "e_x": float(e_x.item()),
                    "e_y": float(e_y.item()),
                })

    with torch.no_grad():
        _, _, traces_pred_final = loss_and_pred()

    e_hat = torch.stack([e_x.detach(), e_y.detach()])
    return e_hat, traces_pred_final.detach(), history, n_star

## Comparisons : 

In [None]:
sensors = torch.tensor([[-0.9, -0.2],
                        [ 0.6,  0.8]], 
                        dtype=torch.float32, device=device)

e_true = (0.25, -0.10)

In [None]:
forward_model = DPForwardSolver(sensors=sensors)

with torch.no_grad():
    traces_obs = forward_model.forward(*e_true)  # (Nt,K)

e_hat, traces_pred, history, n_star = inverse_solver(
    forward=forward_model,
    traces_obs=traces_obs,
    dt=forward_model.dt,
    t_star=0.35,
    init=(0.0, 0.0),
    steps=500,
    lr=1.0,
    lam=1e-6,
)

print("true:", e_true)
print("hat :", (float(e_hat[0]), float(e_hat[1])))

it=0
it=50
it=100
it=150
it=200
it=250
it=300
it=350
it=400
it=450
it=499
true: (0.25, -0.1)
hat : (0.25000008940696716, -0.1000000610947609)


In [None]:
forward_pinn = PINNForwardSolver2D(
    model=model,          
    sensors=sensors,
    x_min=-1.0, x_max=1.0,
    y_min=-1.0, y_max=1.0,
    Nt=201, T=1.0, t_min=0.0,
    device=device,
    dtype=torch.float32,
    batch_size=4096
)

with torch.no_grad():
    traces_obs = forward_pinn.forward(*e_true)  # (Nt,K)

e_hat, traces_pred, history, n_star = inverse_solver(
    forward=forward_pinn,
    traces_obs=traces_obs,
    dt=forward_pinn.dt,
    t_star=0.35,
    init=(0.0, 0.0),
    steps=500,
    lr=1.0,
    lam=1e-6,
)

print("true:", e_true)
print("hat :", (float(e_hat[0]), float(e_hat[1])))