1d spatial pde

u_t = alpha *u_xx

Recover alpha + predict u_t

In [30]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [31]:
L = 1.0
alpha_true = 0.3
Q = 2.0
a = 0
b = 0
N_modes = 100

xs = np.linspace(0, L, 200)
ts = np.linspace(0, 0.5, 100)

def u_s(x, alpha):
    C2 = a
    C1 = (b-a + (Q/(2*alpha))*L**2)/L
    return -Q/(2*alpha)*x**2 + C1*x + C2

u0 = u_s(xs, alpha_true) + 0.5*np.exp(-((xs-0.4)**2)/(2*(0.05**2)))

b_n = np.zeros(N_modes)
for n in range(1, N_modes+1):
    phi = np.sin(n*np.pi*xs/L)
    integrand = (u0 - u_s(xs, alpha_true)) * phi
    b_n[n-1] = 2.0/L * np.trapezoid(integrand, xs)

def u_xt(x, t, alpha):
    s = u_s(x, alpha)
    for n in range(1, N_modes+1):
        s += b_n[n-1]*np.sin(n*np.pi*x/L)*np.exp(-alpha*(n*np.pi/L)**2 * t)
    return s

U = np.zeros((len(xs), len(ts)))
for i, x in enumerate(xs):
    U[i, :] = u_xt(x, ts, alpha_true)

noise_level = 0.20
U_noisy = U + noise_level*np.random.randn(*U.shape)*np.abs(U)

X, T = np.meshgrid(xs, ts, indexing='ij')
x_data_tensor = torch.tensor(X.flatten(), dtype=torch.float32).view(-1,1)
t_data_tensor = torch.tensor(T.flatten(), dtype=torch.float32).view(-1,1)
U_data_tensor = torch.tensor(U_noisy.flatten(), dtype=torch.float32).view(-1,1)
U_exact_tensor = torch.tensor(U.flatten(), dtype=torch.float32).view(-1, 1)

x_ic_tensor = torch.tensor(xs, dtype=torch.float32).view(-1,1)
t_ic_tensor = torch.zeros_like(x_ic_tensor)
u_ic_tensor = torch.tensor(u0, dtype=torch.float32).view(-1,1)

t_bc_tensor = torch.tensor(ts, dtype=torch.float32).view(-1,1)

In [32]:
class PINN(nn.Module):
    def __init__(self, n_hidden=20, n_layers=2):
        super().__init__()
        layers = [nn.Linear(2, n_hidden), nn.Tanh()]
        for _ in range(n_layers - 1):
            layers += [nn.Linear(n_hidden, n_hidden), nn.Tanh()]
        layers += [nn.Linear(n_hidden, 1)]
        self.net = nn.Sequential(*layers)

        self.alpha = nn.Parameter(torch.rand(1))

    def forward(self, x, t):
        XT = torch.cat((x, t), dim=1)
        return self.net(XT)

model = PINN()

In [33]:
def derivative(y, x):
    return torch.autograd.grad(
        y, x,
        grad_outputs=torch.ones_like(y),
        create_graph=True
    )[0]

In [36]:
def pde_loss(model, x, t):
    t.requires_grad_(True)
    x.requires_grad_(True)
    u_pred = model(x, t)
    du_dt_pred = derivative(u_pred, t)
    du_dx_pred = derivative(u_pred, x)
    du_dxx_pred = derivative(du_dx_pred, x)
    return torch.mean((du_dt_pred - model.alpha * du_dxx_pred)**2)

def data_loss(model, x, t, u_data):
    u_pred = model(x, t)
    return torch.mean((u_pred - u_data)**2)

def ic_loss(model, x_ic, t_ic, u_ic):
    u_pred = model(x_ic, t_ic)
    return torch.mean((u_pred - u_ic) ** 2)

def bc_loss(model, t_bc, L):
    x_left = torch.zeros_like(t_bc)
    u_pred_left = model(x_left, t_bc)
    loss_left = torch.mean((u_pred_left - a) ** 2)

    x_right = torch.full_like(t_bc, L)
    u_pred_right = model(x_right, t_bc)
    loss_right = torch.mean((u_pred_right - b) ** 2)

    return loss_left + loss_right

In [37]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

lambda_data = 1.0
lambda_pde  = 1.0
lambda_ic   = 1.0
lambda_bc   = 1.0

num_epochs = 2000
print_every = num_epochs // 10

model.train()
for epoch in range(num_epochs):
    optimizer.zero_grad()

    l_data = data_loss(model, x_data_tensor, t_data_tensor, U_data_tensor)
    l_pde  = pde_loss(model, x_data_tensor, t_data_tensor)
    l_ic   = ic_loss(model, x_ic_tensor, t_ic_tensor, u_ic_tensor)
    l_bc   = bc_loss(model, t_bc_tensor, L)

    loss = lambda_data * l_data + lambda_pde * l_pde + lambda_ic * l_ic + lambda_bc * l_bc
    loss.backward()
    optimizer.step()

    if (epoch+1) % print_every == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, "
              f"Total Loss = {loss.item():.6f}, "
              f"Data Loss = {l_data.item():.6f}, "
              f"PDE Loss = {l_pde.item():.6f}, "
              f"IC Loss = {l_ic.item():.6f}, "
              f"Predicted Alpha = {model.alpha.item():.6f}")

Epoch 200/2000, Total Loss = 0.030925, Data Loss = 0.019341, PDE Loss = 0.002674, IC Loss = 0.008651, Predicted Alpha = 0.014166
Epoch 400/2000, Total Loss = 0.028332, Data Loss = 0.019085, PDE Loss = 0.001516, IC Loss = 0.007596, Predicted Alpha = 0.012791
Epoch 600/2000, Total Loss = 0.038254, Data Loss = 0.022193, PDE Loss = 0.001145, IC Loss = 0.011004, Predicted Alpha = 0.011999
Epoch 800/2000, Total Loss = 0.027131, Data Loss = 0.018993, PDE Loss = 0.000892, IC Loss = 0.007179, Predicted Alpha = 0.011131
Epoch 1000/2000, Total Loss = 0.026636, Data Loss = 0.019025, PDE Loss = 0.000787, IC Loss = 0.006767, Predicted Alpha = 0.009903
Epoch 1200/2000, Total Loss = 0.026049, Data Loss = 0.019205, PDE Loss = 0.000642, IC Loss = 0.006149, Predicted Alpha = 0.008262
Epoch 1400/2000, Total Loss = 0.025571, Data Loss = 0.019505, PDE Loss = 0.000562, IC Loss = 0.005464, Predicted Alpha = 0.006827
Epoch 1600/2000, Total Loss = 0.025311, Data Loss = 0.019726, PDE Loss = 0.000549, IC Loss = 0

In [38]:
print("\nTrue alpha:", alpha_true)
print("Learned alpha:", model.alpha.item())


True alpha: 0.3
Learned alpha: 0.005424311384558678
