In [None]:
import torch
import torch.nn as nn
import torch.autograd as autograd

In [None]:
# -----------------
# PINN definition
# -----------------
# ## Class: PINN
# Defines a fully connected neural network used as the physics-informed neural network (PINN).
# - Input: (x, t)
# - Output: u(x,t), the predicted RF field amplitude
# - Architecture: feedforward MLP with tanh activations
class PINN(nn.Module):
    def __init__(self, layers):
        super(PINN, self).__init__()
        self.layers = nn.ModuleList()
        for i in range(len(layers)-1):
            self.layers.append(nn.Linear(layers[i], layers[i+1]))
        self.activation = nn.Tanh()
    
    def forward(self, x):
        for i, layer in enumerate(self.layers[:-1]):
            x = self.activation(layer(x))
        return self.layers[-1](x)

In [None]:
# -----------------
# PDE residual
# -----------------
# ## Function: wave_residual
# Computes the residual of the 1D wave equation:
#   u_tt - c^2 u_xx = 0
# where:
#   - u = model(x, t)
#   - u_tt = second derivative wrt time
#   - u_xx = second derivative wrt space
#
# Inputs:
#   model: PINN instance
#   x: spatial coordinate tensor
#   t: temporal coordinate tensor
#   c: wave propagation speed
# Output:
#   residual tensor
def wave_residual(model, x, t, c=1.0):
    # Concatenate inputs
    xt = torch.cat([x, t], dim=1).requires_grad_(True)
    u = model(xt)
    
    # First derivatives
    u_t = autograd.grad(u, xt, torch.ones_like(u), retain_graph=True, create_graph=True)[0][:,1:2]
    u_x = autograd.grad(u, xt, torch.ones_like(u), retain_graph=True, create_graph=True)[0][:,0:1]

    # Second derivatives
    u_tt = autograd.grad(u_t, xt, torch.ones_like(u_t), retain_graph=True, create_graph=True)[0][:,1:2]
    u_xx = autograd.grad(u_x, xt, torch.ones_like(u_x), retain_graph=True, create_graph=True)[0][:,0:1]

    # PDE residual: u_tt - c^2 u_xx = 0
    return u_tt - (c**2) * u_xx


In [None]:
# -----------------
# Training example
# -----------------
# ## Main block
# Example training script:
# - Defines a PINN with architecture [2, 50, 50, 50, 1]
# - Creates synthetic RF signal data (sine wave)
# - Samples collocation points for enforcing PDE residual
# - Trains with Adam optimiser using combined data + physics loss

if __name__ == "__main__":
    # Network: input (x,t) → output u(x,t)
    model = PINN([2, 50, 50, 50, 1])
    
    # Optimiser
    optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Example training data (synthetic RF samples)
    N_data = 100
    x_data = torch.rand((N_data,1)) * 2 - 1  # domain [-1,1]
    t_data = torch.rand((N_data,1)) * 1      # domain [0,1]
    u_data = torch.sin(2*torch.pi*(x_data - t_data))  # synthetic RF waveform
    
    xt_data = torch.cat([x_data, t_data], dim=1)

    # Collocation points for PDE residual
    N_f = 1000
    x_f = torch.rand((N_f,1)) * 2 - 1
    t_f = torch.rand((N_f,1)) * 1

    # Training loop
    for epoch in range(2000):
        optimiser.zero_grad()
        
        # Data loss
        u_pred = model(xt_data)
        loss_data = torch.mean((u_pred - u_data)**2)
        
        # Physics loss
        f_pred = wave_residual(model, x_f, t_f, c=1.0)
        loss_phys = torch.mean(f_pred**2)
        
        # Total loss
        loss = loss_data + loss_phys
        loss.backward()
        optimiser.step()
        
        if epoch % 200 == 0:
            print(f"Epoch {epoch}: Loss={loss.item():.6f} Data={loss_data.item():.6f} Phys={loss_phys.item():.6f}")