# Physics-Informed Neural Network (Wave Equation)

This notebook adapts `examples/integration/pinn/train.py`. We train a light physics-informed
neural network (PINN) to satisfy the ODE $u_{tt} = c^2 u$ with initial conditions $u(0)=0$, $u_t(0)=1$.
The workflow is:

1. Generate a reference trajectory with `DiffsolModule`.
2. Optimise a neural network so that both the PDE residual and data loss stay small.
3. Compare the learned trajectory against diffsol's solution.


In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from diffsol_pytorch import DiffsolModule

plt.rcParams["figure.figsize"] = (6, 3)
torch.manual_seed(0)


In [None]:
WAVE_CODE = '''
in = [c]
c { 1.0 }
u {
    u = 0.0,
    v = 1.0,
}
F {
    v,
    c * c * u,
}
'''

module = DiffsolModule(WAVE_CODE)
times = torch.linspace(0.0, 1.0, 81, dtype=torch.float64)
times_list = times.tolist()
params = [1.0]
_, _, flat = module.solve_dense(params, times_list)
reference_u = torch.tensor(flat[0::2], dtype=torch.float32)

In [None]:
class PINN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 32),
            nn.Tanh(),
            nn.Linear(32, 32),
            nn.Tanh(),
            nn.Linear(32, 1),
        )

    def forward(self, x, t):
        inp = torch.stack([x, t], dim=-1)
        return self.net(inp).squeeze(-1)


In [None]:
def residual_loss(model: PINN, collocation: torch.Tensor, c: float = 1.0) -> torch.Tensor:
    collocation = collocation.clone().detach().to(dtype=torch.float32).requires_grad_(True)
    x = torch.zeros_like(collocation, requires_grad=True)
    u = model(x, collocation)
    ones = torch.ones_like(u)
    u_t = torch.autograd.grad(u, collocation, ones, create_graph=True)[0]
    u_tt = torch.autograd.grad(u_t, collocation, ones, create_graph=True)[0]
    residual = u_tt - (c ** 2) * u
    return residual.pow(2).mean()


def data_loss(model: PINN, times: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
    x = torch.zeros_like(times, dtype=torch.float32)
    pred = model(x, times.to(dtype=torch.float32))
    return torch.nn.functional.mse_loss(pred, reference.to(dtype=pred.dtype))


In [None]:
model = PINN()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_history = []
for step in range(200):
    optimizer.zero_grad()
    res = residual_loss(model, torch.rand_like(times), c=1.0)
    data = data_loss(model, times, reference_u)
    loss = res + data
    loss.backward()
    optimizer.step()
    loss_history.append(loss.item())
loss_history[:5], len(loss_history)


In [None]:
plt.plot(loss_history)
plt.xlabel("Iteration")
plt.ylabel("Total loss")
plt.title("PINN training (200 steps)")
plt.show()


In [None]:
with torch.no_grad():
    x = torch.zeros_like(times, dtype=torch.float32)
    pinn_pred = model(x, times.to(dtype=torch.float32)).cpu()

plt.plot(times.numpy(), reference_u.numpy(), label="diffsol reference")
plt.plot(times.numpy(), pinn_pred.numpy(), label="PINN", linestyle="--")
plt.xlabel("t")
plt.ylabel("u(t)")
plt.legend()
plt.show()
