# Surrogate Model :

## Problem setup

We consider a 2D acoustic wave equation.
The forward problem consists in computing seismograms at fixed sensors given an epicenter location.

Goal :
- Use a PINN as a surrogate forward model or any other model that coul outperform it
- Use this surrogate to solve an inverse problem (epicenter estimation)

In [63]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from dataclasses import dataclass


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype=torch.float32

## Forward models (physics -> seimograms) : 

### Finite difference solver (DP) : 

This solver numerically solves the wave equation using finite differences.
It is treated as the reference forward operator.

In [77]:
class DPForwardSolver:
    """
    Differentiable-physics forward solver for the 2D acoustic wave equation
    using an explicit second-order finite-difference time-stepping scheme.

    forward(e_x, e_y) -> seismograms (Nt, K)
    """

    def __init__(
        self,
        sensors: torch.Tensor,         # (K,2) in physical coords
        c=1,
        x_min=-1, x_max=1,
        y_min=-1, y_max=1,
        Nx=101, Ny=101, Nt=201, T=1,
        # source params
        A=5, beta=200, t0=0.2, gamma=200,
    ):

        self.device = device
        self.dtype = dtype

        # domain / grid
        self.x_min, self.x_max = x_min, x_max
        self.y_min, self.y_max = y_min, y_max
        self.Nx, self.Ny, self.Nt, self.T = Nx, Ny, Nt, T

        # physics
        self.c = torch.as_tensor(c, device=self.device, dtype=self.dtype)

        # source
        self.A = A
        self.beta = beta
        self.t0 = t0
        self.gamma = gamma

        # sensors
        self.sensors = sensors.to(self.device, dtype=self.dtype)
        self.K = self.sensors.shape[0]

        # build grid
        self.x = torch.linspace(self.x_min, self.x_max, self.Nx, device=self.device, dtype=self.dtype)
        self.y = torch.linspace(self.y_min, self.y_max, self.Ny, device=self.device, dtype=self.dtype)

        self.dx = self.x[1] - self.x[0]
        self.dy = self.y[1] - self.y[0]
        if abs(self.dx - self.dy) > 1e-7:
            raise ValueError("DPForwardSolver2D requires dx == dy (uniform square grid) for this stencil.")
        self.h = self.dx

        self.dt = self.T / (self.Nt - 1)

        # CFL check
        if self.c * self.dt / self.h > 1 / math.sqrt(2):
            raise ValueError("CFL unstable: decrease dt (increase Nt), increase h (decrease Nx/Ny), or reduce c.")

        # meshgrid for source evaluation (Ny, Nx)
        X, Y = torch.meshgrid(self.x, self.y, indexing="xy")  # (Nx,Ny)
        self.X = X.T.contiguous()  # (Ny,Nx)
        self.Y = Y.T.contiguous()        

    def get_bounds(self):
        return (self.x_min, self.x_max), (self.y_min, self.y_max)

    def _laplacian(self, u: torch.Tensor) -> torch.Tensor:
        # u: (1,1,Ny,Nx)
        k = torch.tensor([[0, 1, 0],
                          [1,-4, 1],
                          [0, 1, 0]], 
                          device=self.device, dtype=self.dtype) / (self.h * self.h)
        k = k.view(1, 1, 3, 3)
        return F.conv2d(u, k, padding=1)

    def _apply_dirichlet(self, u: torch.Tensor) -> torch.Tensor:
        # avoid in-place ops that can break autograd
        v = u.clone()
        v[..., 0, :] = 0
        v[..., -1, :] = 0
        v[..., :, 0] = 0
        v[..., :, -1] = 0
        return v

    def _source(self, t: torch.Tensor, e_x: torch.Tensor, e_y: torch.Tensor) -> torch.Tensor:
        # returns (1,1,Ny,Nx)
        time_env = self.A * torch.exp(-self.beta * (t - self.t0) ** 2)
        space_env = torch.exp(-self.gamma * ((self.X - e_x) ** 2 + (self.Y - e_y) ** 2))
        f = time_env * space_env
        return f.unsqueeze(0).unsqueeze(0)

    def _sample_sensors(self, u: torch.Tensor) -> torch.Tensor:
        # u: (1,1,Ny,Nx) -> (K,)
        x = self.sensors[:, 0]
        y = self.sensors[:, 1]
        x_n = 2 * (x - self.x_min) / (self.x_max - self.x_min) - 1
        y_n = 2 * (y - self.y_min) / (self.y_max - self.y_min) - 1
        grid = torch.stack([x_n, y_n], dim=-1).view(1, -1, 1, 2)  # (1,K,1,2)
        vals = F.grid_sample(u, grid, mode="bilinear", align_corners=True)  # (1,1,K,1)
        return vals.view(-1)

    def forward(self, e_x, e_y):
        """
        e_x, e_y: epicenter coordinates (float, tensor, or nn.Parameter)

        seismograms: (Nt, K)
        optionally u_curr: (1,1,Ny,Nx) final field if return_wavefield=True
        """
        # if not isinstance(e_x, torch.Tensor):
        #     e_x = torch.tensor(e_x, device=self.device, dtype=self.dtype)
        # else:
        #     e_x = e_x.to(device=self.device, dtype=self.dtype)

        # if not isinstance(e_y, torch.Tensor):
        #     e_y = torch.tensor(e_y, device=self.device, dtype=self.dtype)
        # else:
        #     e_y = e_y.to(device=self.device, dtype=self.dtype)
        if not isinstance(e_x, torch.Tensor):
            e_x = torch.tensor(float(e_x), device=self.device, dtype=self.dtype)
        else:
            if e_x.device != self.device:
                e_x = e_x.to(self.device)
            if e_x.dtype != self.dtype:
                e_x = e_x.to(self.dtype)

        if not isinstance(e_y, torch.Tensor):
            e_y = torch.tensor(float(e_y), device=self.device, dtype=self.dtype)
        else:
            if e_y.device != self.device:
                e_y = e_y.to(self.device)
            if e_y.dtype != self.dtype:
                e_y = e_y.to(self.dtype)

        u_prev = torch.zeros((1, 1, self.Ny, self.Nx), device=self.device, dtype=self.dtype)
        u_curr = torch.zeros((1, 1, self.Ny, self.Nx), device=self.device, dtype=self.dtype)

        seismograms = torch.zeros((self.Nt, self.K), device=self.device, dtype=self.dtype)

        cdt2 = (self.c * self.dt) ** 2
        dt2 = (self.dt ** 2)

        for n in range(self.Nt):
            t = torch.as_tensor(n * self.dt, device=self.device, dtype=self.dtype)

            seismograms[n] = self._sample_sensors(u_curr)

            if n < self.Nt - 1:
                lap = self._laplacian(u_curr)
                f = self._source(t, e_x, e_y)
                u_next = 2 * u_curr - u_prev + cdt2 * lap + dt2 * f
                u_next = self._apply_dirichlet(u_next)
                u_prev, u_curr = u_curr, u_next

        return seismograms

### PINN surrogate model : 

The PINN approximates the continuous wavefield $u(x,y,t; e_x,e_y)$.

Epicenter coordinates are given as inputs such that a single network can represent multiple source locations.

$f_\theta :
(e_x,e_y)
\;\longmapsto\;
\mathcal S_\theta(e_x,e_y)
\;=\;
\left\{
\hat u_\theta(x_k,y_k,t_n;e_x,e_y)
\right\}_{n=1..N_t,\;k=1..K}$


In [159]:
class PINN(nn.Module):
    def __init__(self, width=128, depth=8, act=nn.Tanh):
        super().__init__()
        layers = [nn.Linear(5, width), act()]
        for _ in range(depth - 1):
            layers += [nn.Linear(width, width), act()]
        layers += [nn.Linear(width, 1)]
        self.net = nn.Sequential(*layers)

    def forward(self, x, y, t, e_x, e_y):
        X = torch.stack([x, y, t, e_x, e_y], dim=-1)  
        return self.net(X).squeeze(-1)



class PINNForwardSolver2D:
    def __init__(
        self,
        model: nn.Module,
        sensors: torch.Tensor,   # (K,2)
        Nt: int,
        T: float,
        x_min=-1.0, x_max=1.0,
        y_min=-1.0, y_max=1.0,
        device=None,
        dtype=torch.float32,
        batch_size=4096,
    ):
        self.model = model
        self.device = device if device is not None else next(model.parameters()).device
        self.dtype = dtype

        self.sensors = sensors.to(self.device, dtype=self.dtype)
        self.K = int(self.sensors.shape[0])

        self.Nt = int(Nt)
        self.T = float(T)
        self.dt = self.T / (self.Nt - 1)
        self.t_grid = torch.linspace(0.0, self.T, self.Nt, device=self.device, dtype=self.dtype)

        self.x_min, self.x_max = float(x_min), float(x_max)
        self.y_min, self.y_max = float(y_min), float(y_max)

        self.batch_size = int(batch_size)

    def get_bounds(self):
        return (self.x_min, self.x_max), (self.y_min, self.y_max)

    def forward(self, e_x, e_y):
        # keep autograd on e_x,e_y for inversion (do not detach/cast)
        x_s = self.sensors[:, 0]
        y_s = self.sensors[:, 1]

        traces = torch.empty((self.Nt, self.K), device=self.device, dtype=self.dtype)

        for n in range(self.Nt):
            t = self.t_grid[n]
            start = 0
            while start < self.K:
                end = min(start + self.batch_size, self.K)
                xs = x_s[start:end]
                ys = y_s[start:end]
                ts = t.expand_as(xs)
                ex = e_x.expand_as(xs)
                ey = e_y.expand_as(xs)
                traces[n, start:end] = self.model(xs, ys, ts, ex, ey)
                start = end
        return traces


#### Physic-informed loss :

We enforce the acoustic wave equation residual at random points.
This term acts as a regularizer and is weighted weakly compared to data loss.

In [154]:
def _grad(u, x):
    return torch.autograd.grad(
        outputs=u, 
        inputs=x,
        grad_outputs=torch.ones_like(u),
        create_graph=True, 
        retain_graph=True
    )[0]

def _laplacian(u, x, y):
    ux = _grad(u, x) 
    uxx = _grad(ux, x)
    uy = _grad(u, y)
    uyy = _grad(uy, y)
    return uxx + uyy


def source_gauss(x, y, t, e_x, e_y, A=5, beta=200, t0=0.2, gamma=200):
    time_env = A * torch.exp(-beta * (t - t0)**2)
    space_env = torch.exp(-gamma * ((x - e_x)**2 + (y - e_y)**2))
    return time_env * space_env

def pde_residual(model, x, y, t, e_x, e_y, c, A, beta, t0, gamma):
    x = x.detach().requires_grad_(True)
    y = y.detach().requires_grad_(True)
    t = t.detach().requires_grad_(True)

    u = model(x, y, t, e_x, e_y)
    ut = _grad(u, t)
    utt = _grad(ut, t)
    lap = _laplacian(u, x, y)
    f = source_gauss(x, y, t, e_x, e_y, A=A, beta=beta, t0=t0, gamma=gamma)
    return utt - (c**2) * lap - f

#### Training the PINN surrogate : 

At each epoch :
- Epicenters are resampled
- DP seismograms are generated
- The PINN is trained to match DP traces at sensor locations

The data loss dominates, physics acts as regularization.

In [158]:
# Training configuration for the PINN surrogate model
@dataclass
class TrainCfg:
    x_min: float = -1.0
    x_max: float =  1.0
    y_min: float = -1.0
    y_max: float =  1.0
    T: float = 1.0

    c: float = 1.0
    A: float = 5.0
    beta: float = 200.0
    t0: float = 0.2
    gamma: float = 200.0

    # optimization
    epochs: int = 8000
    lr: float = 1e-3

    # sampling
    B: int = 16          # DP simulations per epoch
    N_phys: int = 1024   # PDE points per epoch
    N_ic: int = 1024
    N_bc: int = 1024

    # loss weights
    w_data: float = 1
    w_phys: float = 1e-3
    w_ic: float = 1e-3
    w_bc: float = 1e-3

    # normalization
    normalize_per_sensor: bool = True
    warmup_sims: int = 64

    # logging
    print_every: int = 200

@torch.no_grad()
def _sample_epicenters(B, x_min, x_max, y_min, y_max, device, dtype):
    ex = torch.empty(B, device=device, dtype=dtype).uniform_(x_min, x_max)
    ey = torch.empty(B, device=device, dtype=dtype).uniform_(y_min, y_max)
    return ex, ey


def train_pinn(
    model: nn.Module,
    dp_forward_solver,      
    sensors: torch.Tensor,
    Nt: int,
    cfg: TrainCfg,
    device,
    dtype=torch.float32,
):
    model = model.to(device=device, dtype=dtype).train()
    sensors = sensors.to(device=device, dtype=dtype)
    K = sensors.shape[0]

    # time grid
    t_grid = torch.linspace(0, cfg.T, Nt, device=device, dtype=dtype) # must match dp_forward_solver dt/T

    # scale per sensor 
    scale = None
    if cfg.normalize_per_sensor:
        with torch.no_grad():
            W = int(cfg.warmup_sims)
            ex_w, ey_w = _sample_epicenters(W, cfg.x_min, cfg.x_max, cfg.y_min, cfg.y_max, device, dtype)
            Yw = torch.stack([dp_forward_solver.forward(ex_w[i], ey_w[i]) for i in range(W)], dim=0)  # (W,Nt,K)
            scale = Yw.reshape(-1, K).std(dim=0, unbiased=False).clamp_min(1e-6)  # (K,)

    c_t = torch.as_tensor(cfg.c, device=device, dtype=dtype)

    opt = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    mse = nn.MSELoss()

    Xs = sensors[:, 0].view(1, K).expand(Nt, K)
    Ys = sensors[:, 1].view(1, K).expand(Nt, K)
    Ts = t_grid.view(Nt, 1).expand(Nt, K)

    for ep in range(1, cfg.epochs + 1):
        opt.zero_grad(set_to_none=True)

        # Sample epicenters (new epicenters each epoch)
        ex_b, ey_b = _sample_epicenters(cfg.B, cfg.x_min, cfg.x_max, cfg.y_min, cfg.y_max, device, dtype)

        # Targets (B,Nt,K) 
        with torch.no_grad():
            Yb = torch.stack([dp_forward_solver.forward(ex_b[i], ey_b[i]) for i in range(cfg.B)], dim=0)

        # Predictions (B,Nt,K)
        preds_list = []
        for i in range(cfg.B):
            ex = ex_b[i].expand_as(Xs)
            ey = ey_b[i].expand_as(Xs)
            ui = model(
                Xs.reshape(-1), Ys.reshape(-1), Ts.reshape(-1),
                ex.reshape(-1), ey.reshape(-1)
            ).view(Nt, K)
            preds_list.append(ui)

        preds = torch.stack(preds_list, dim=0)  # (B,Nt,K)

        # normalize per sensor
        if scale is not None:
            preds_n = preds / scale.view(1, 1, K)
            Yb_n = Yb / scale.view(1, 1, K)
        else:
            preds_n = preds
            Yb_n = Yb

        L_data = mse(preds_n, Yb_n)

        # physics loss 
        x_p = torch.empty(cfg.N_phys, device=device, dtype=dtype).uniform_(cfg.x_min, cfg.x_max)
        y_p = torch.empty(cfg.N_phys, device=device, dtype=dtype).uniform_(cfg.y_min, cfg.y_max)
        t_p = torch.empty(cfg.N_phys, device=device, dtype=dtype).uniform_(0.0, cfg.T)

        b2 = torch.randint(0, cfg.B, (cfg.N_phys,), device=device)
        ex_p = ex_b[b2]
        ey_p = ey_b[b2]

        R = pde_residual(model, x_p, y_p, t_p, ex_p, ey_p, c_t, cfg.A, cfg.beta, cfg.t0, cfg.gamma)
        L_phys = torch.mean(R**2)
        # L_ic = _ic_loss(model, ex_b, ey_b, cfg, device, dtype)
        # L_bc = _bc_loss(model, ex_b, ey_b, cfg, device, dtype)

        loss = cfg.w_data * L_data + cfg.w_phys * L_phys #+ cfg.w_ic * L_ic + cfg.w_bc * L_bc
        loss.backward()
        opt.step()

        if (ep % cfg.print_every == 0) or (ep == 1):
            print(f"ep {ep:04d}  total={loss.item():.3e}  data={L_data.item():.3e}  phys={L_phys.item():.3e}") # ic={L_ic.item():.3e}  bc={L_bc.item():.3e}")
    model.eval()
    return model, scale

#### Helper functions : 
Save and load the trained PINN.

In [None]:
# Save and load the trainned model (ChatGPT was used to generate this code)
def save_pinn(path, model, scale, cfg: TrainCfg):
    torch.save(
        {
            "state": model.state_dict(),
            "scale": None if scale is None else scale.detach().cpu(),
            "cfg": cfg.__dict__,
            "arch": {"width": model.net[0].out_features, "depth": (len(model.net) - 2)//2 + 1},
        },
        path,
    )

def load_pinn(path, device, dtype=torch.float32):
    ckpt = torch.load(path, map_location="cpu")
    cfg = TrainCfg(**ckpt["cfg"])
    arch = ckpt["arch"]
    model = PINN(width=arch["width"], depth=arch["depth"]).to(device=device, dtype=dtype)
    model.load_state_dict(ckpt["state"])
    model.eval()
    scale = ckpt["scale"]
    if scale is not None:
        scale = scale.to(device=device, dtype=dtype)
    return model, scale, cfg


### ML pipeline :

In [153]:
#TODO:

## Inverse problem solver :

Given the observed seismograms $\mathcal S_{\text{obs}}$, it estimates the epicenter $(e_x,e_y)$ by minimizing the difference between observed and the predicted seismograms using gradient-based optimization.

$(e_x^*, e_y^*) \;=\; \arg\min_{(e_x,e_y)} \; \left\| \mathcal S_\theta(e_x,e_y) - \mathcal S_{\text{obs}} \right\|^2$

In [None]:
def inverse_solver(
    forward,               # object with: forward(e_x, e_y) -> traces_pred (Nt,K)
    traces_obs,            # (Nt,K) torch tensor
    dt,                    # float
    t_star=0,              # ignore data before t_star
    init=(0, 0),           # initial guess (e_x0, e_y0)
    steps=25,             # outer steps
    lr=1,
    lam=1e-6
):
    """
    Inverse solver.
    Forward(e_x,e_y) returns predicted seismograms

    Returns:
      e_hat: (2,) tensor [e_x_hat, e_y_hat]
      traces_pred_final: (Nt,K) tensor
      history: list of dicts (optional diagnostics)
      n_star: int
    """
    traces_obs = traces_obs.to(device=device, dtype=dtype)

    Nt = traces_obs.shape[0]
    n_star = int(round(t_star / dt))
    n_star = max(0, min(n_star, Nt - 1))

    # trainable epicenter
    e_x = torch.nn.Parameter(torch.tensor(float(init[0]), device=device, dtype=dtype))
    e_y = torch.nn.Parameter(torch.tensor(float(init[1]), device=device, dtype=dtype))

    def clamp():
        (x_min, x_max), (y_min, y_max) = forward.get_bounds()
        with torch.no_grad():
            e_x.clamp(x_min + 1e-3, x_max - 1e-3)
            e_y.clamp(y_min + 1e-3, y_max - 1e-3)
    clamp()

    # fixed scaling based on observed data (stabilizes 2-sensors case)
    scale = traces_obs[n_star:].std(dim=0, unbiased=False).clamp_min(1e-6)

    opt = torch.optim.LBFGS([e_x, e_y], lr=lr, max_iter=20, line_search_fn="strong_wolfe")
    history = []

    def loss_and_pred():
        traces_pred = forward.forward(e_x, e_y)  # (Nt,K)
        pred = traces_pred[n_star:]
        obs = traces_obs[n_star:]

        pred = pred / scale
        obs = obs / scale

        loss_data = torch.sum((pred - obs) ** 2)
        loss_reg = lam * (e_x**2 + e_y**2)
        return loss_data + loss_reg, loss_data, traces_pred

    for it in range(int(steps)):
        def closure():
            opt.zero_grad()
            loss, _, _ = loss_and_pred()
            loss.backward()
            return loss

        opt.step(closure)
        clamp()

        if it % 5 == 0 or it == steps - 1:
            print(f"{it=}")
            with torch.no_grad():
                loss, loss_data, _ = loss_and_pred()
                history.append({
                    "it": it,
                    "loss": float(loss.item()),
                    "loss_data": float(loss_data.item()),
                    "e_x": float(e_x.item()),
                    "e_y": float(e_y.item()),
                })

    with torch.no_grad():
        _, _, traces_pred_final = loss_and_pred()

    e_hat = torch.stack([e_x.detach(), e_y.detach()])
    return e_hat, traces_pred_final.detach(), history, n_star

## Comparisons : 

Here, the different models and configurations are tested and compared to each other.

### DP :

Test using DP for the forward operation.

In [40]:
sensors = torch.tensor([[-0.9, -0.2],
                        [ 0.6,  0.8]], 
                        dtype=torch.float32, device=device)

e_true = (0.25, -0.10)

In [None]:
forward_model = DPForwardSolver(sensors=sensors)

with torch.no_grad():
    traces_obs = forward_model.forward(*e_true)  # (Nt,K)

e_hat, traces_pred, history, n_star = inverse_solver(
    forward=forward_model,
    traces_obs=traces_obs,
    dt=forward_model.dt,
    t_star=0.35,
    init=(0.0, 0.0),
    steps=50,
    lr=1.0,
    lam=1e-6,
)

print("true:", e_true)
print("hat :", (float(e_hat[0]), float(e_hat[1])))

it=0
it=50
it=100
it=150
it=200
it=250
it=300
it=350
it=400
it=450
it=499
true: (0.25, -0.1)
hat : (0.25000008940696716, -0.1000000610947609)


### PINN :

Test using the PINN for the forward operation.

In [111]:
# sensors
sensors = torch.tensor([[-0.9, -0.2],
                        [ 0.6,  0.8]], dtype=torch.float32, device=device)

# teacher DP
dp_forward_solver = DPForwardSolver(sensors=sensors)
Nt = dp_forward_solver.Nt
h = dp_forward_solver.h


In [None]:
# train

cfg = TrainCfg(
    T=1,
    c=1, A=5, beta=200, t0=0.2, gamma=200,
    epochs=4000,
    B=16,
    N_phys=512,
    N_ic=256,
    N_bc=256,
    w_data=1,
    w_phys=1e-4,
    w_ic=1e-6,
    w_bc=1e-6,
    lr=1e-3,
    print_every=250,
)

model = PINN(width=128, depth=8).to(device)
model, scale = train_pinn(
    model=model,
    dp_forward_solver=dp_forward_solver,
    sensors=sensors,
    Nt=Nt,
    cfg=cfg,
    device=device,
)

save_pinn("pinn_surrogate_v4.pt", model, scale, cfg)

ep 00001  total=2.031e+01  data=2.031e+01  phys=3.143e-03
ep 00250  total=1.047e+00  data=1.047e+00  phys=1.956e-03
ep 00500  total=5.427e-01  data=5.427e-01  phys=1.183e-06
ep 00750  total=9.477e-01  data=9.477e-01  phys=4.544e-05
ep 01000  total=1.269e+00  data=1.269e+00  phys=1.115e-04
ep 01250  total=9.833e-01  data=9.833e-01  phys=4.368e-03
ep 01500  total=7.922e-01  data=7.922e-01  phys=5.323e-03
ep 01750  total=4.401e-01  data=4.401e-01  phys=2.668e-06
ep 02000  total=6.927e-01  data=6.927e-01  phys=9.494e-05
ep 02250  total=1.042e+00  data=1.042e+00  phys=4.998e-06
ep 02500  total=5.605e-01  data=5.605e-01  phys=4.769e-06
ep 02750  total=1.241e+00  data=1.241e+00  phys=3.949e-05
ep 03000  total=4.764e-01  data=4.764e-01  phys=8.243e-03
ep 03250  total=1.277e+00  data=1.277e+00  phys=1.816e-03
ep 03500  total=5.503e-01  data=5.503e-01  phys=2.507e-04
ep 03750  total=1.480e+00  data=1.480e+00  phys=6.953e-03
ep 04000  total=9.132e-01  data=9.132e-01  phys=1.017e-05


In [150]:
# wrap forward
forward_pinn = PINNForwardSolver2D(
    model=model,
    sensors=sensors,
    Nt=Nt,
    T=cfg.T,
    x_min=cfg.x_min, x_max=cfg.x_max,
    y_min=cfg.y_min, y_max=cfg.y_max,
    device=device,
)

# observed traces from DP
e_true = (0.25, -0.10)
with torch.no_grad():
    traces_obs = dp_forward_solver.forward(*e_true)

# invert with your existing inverse_solver
e_hat, traces_pred, history, n_star = inverse_solver(
    forward=forward_pinn,
    traces_obs=traces_obs,
    dt=forward_pinn.dt,
    t_star=0,
    init=(0.0, 0.0),
    steps=50,
    lr=0.2,
    lam=1e-6,
)

print("true:", e_true)
print("hat :", (float(e_hat[0]), float(e_hat[1])))

it=0
it=5
it=10
it=15
it=20
it=25
it=30
it=35
it=40
it=45
it=49
true: (0.25, -0.1)
hat : (0.2694408893585205, -0.26058971881866455)


### NN :

Test using a neural network for the forward operation.

In [None]:
# TODO