Experiment one (2D poisson)

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.autograd as autograd
import math, time

torch.set_default_dtype(torch.float64)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------- Problem definition ----------
pi = math.pi

def u_exact_np(X):
    return np.sin(2*pi*X[:,0]) * np.sin(2*pi*X[:,1])

def f_forcing_np(X):
    return (2*pi)**2 * 2.0 * u_exact_np(X)

def u_exact_torch(x):
    return torch.sin(2*pi*x[:,0]) * torch.sin(2*pi*x[:,1])

def f_forcing_torch(x):
    return (2*pi)**2 * 2.0 * torch.sin(2*pi*x[:,0]) * torch.sin(2*pi*x[:,1])

# ---------- Sampling ----------
def sample_grid_in_box(n_per_dim=50):
    xs = np.linspace(-1.0, 1.0, n_per_dim, dtype=np.float64)
    X, Y = np.meshgrid(xs, xs)
    pts = np.stack([X.ravel(), Y.ravel()], axis=-1)
    return pts

def sample_boundary_points(n_side=50):
    s = np.linspace(-1.0, 1.0, n_side, dtype=np.float64)
    pts = []
    for y in s: pts.append([-1.0, y])
    for y in s: pts.append([1.0, y])
    for x in s: pts.append([x, -1.0])
    for x in s: pts.append([x,  1.0])
    return np.array(pts)

# ---------- TransNet helpers ----------
def sigma(s):
    return np.tanh(s)
def sigma_dd(s):
    t = np.tanh(s)
    sech2 = 1 - t*t
    return -2.0 * t * sech2

def sample_a_r(M, d=2, R=1.5, seed=1234):
    rng = np.random.RandomState(seed)
    A = rng.randn(M, d)
    norms = np.linalg.norm(A, axis=1, keepdims=True)
    A /= norms
    r = rng.rand(M) * R
    return A, r

def build_matrices(X_int, X_bd, A, rvec, gamma):
    Xc_int = X_int
    Xc_bd = X_bd
    s_int = Xc_int.dot(A.T) + rvec.reshape(1,-1)
    s_bd = Xc_bd.dot(A.T) + rvec.reshape(1,-1)
    S_int = gamma * s_int
    S_bd = gamma * s_bd
    Psi_int = sigma(S_int)
    Psi_bd = sigma(S_bd)
    Psi_dd_int = (gamma**2) * sigma_dd(S_int)
    F_int = -Psi_dd_int
    return F_int, Psi_bd, Psi_int, s_int, s_bd

def solve_alpha_ls(F_int, Psi_bd, f_int, g_bd, lambda_L=1.0, lambda_B=1.0, reg=1e-8):
    wL = math.sqrt(lambda_L)
    wB = math.sqrt(lambda_B)
    A_big = np.vstack([wL*F_int, wB*Psi_bd])
    rhs = np.concatenate([wL*f_int, wB*g_bd])
    AtA = A_big.T @ A_big + reg*np.eye(A_big.shape[1])
    Atb = A_big.T @ rhs
    alpha = np.linalg.solve(AtA, Atb)
    return alpha

def u_base_numpy(X, A, rvec, gamma, alpha):
    s = X.dot(A.T) + rvec.reshape(1,-1)
    S = gamma*s
    Psi = sigma(S)
    return Psi.dot(alpha)

# ========== Step 1: TransNet Baseline ==========
M = 300
grid_n = 50
bd_side = 50
X_int = sample_grid_in_box(grid_n)
X_bd = sample_boundary_points(bd_side)
f_int = f_forcing_np(X_int)
g_bd = u_exact_np(X_bd)

A, rvec = sample_a_r(M)

# Golden search for gamma
def eta_of_gamma(gamma):
    F_int, Psi_bd, _, _, _ = build_matrices(X_int, X_bd, A, rvec, gamma)
    alpha = solve_alpha_ls(F_int, Psi_bd, f_int, g_bd)
    resid_int = F_int @ alpha - f_int
    resid_bd = Psi_bd @ alpha - g_bd
    mse_int = np.mean(resid_int**2)
    mse_bd = np.mean(resid_bd**2)
    return mse_int + mse_bd, mse_int, mse_bd       #这里后面两个mse是在每给定一组alpha后，计算的pde残差和边界残差

def golden_search(func, a, b, tol=1e-3, max_iters=50):
    phi = (1 + 5**0.5)/2
    invphi = 1/phi
    c = b - invphi*(b-a)
    d = a + invphi*(b-a)
    fc = func(c)
    fd = func(d)
    for _ in range(max_iters):
        if (b - a) < tol: break
        if fc[0] < fd[0]:
            b = d
            d = c
            fd = fc
            c = b - invphi*(b - a)
            fc = func(c)
        else:
            a = c
            c = d
            fc = fd
            d = a + invphi*(b - a)
            fd = func(d)
    return (c, fc) if fc[0] < fd[0] else (d, fd)

gamma_opt, best = golden_search(eta_of_gamma, 1e-2, 10.0)
alpha_opt = solve_alpha_ls(*build_matrices(X_int, X_bd, A, rvec, gamma_opt)[:2], f_int, g_bd)

print(f"Baseline gamma_opt={gamma_opt:.6f}, MSE_int={best[1]:.3e}, MSE_bd={best[2]:.3e}")

# Evaluate baseline error
X_test = sample_grid_in_box(100)
u_pred_base = u_base_numpy(X_test, A, rvec, gamma_opt, alpha_opt)
mse_base = np.mean((u_pred_base - u_exact_np(X_test))**2)
print(f"[Baseline TransNet] Interior MSE vs exact: {mse_base:.6e}")

# ========== Step 2: Residual Hybrid Skip (PyTorch) ==========
# 构造 u_base(x) 的 Torch 版本，用于计算 Δu_base 和训练残差方程
X_int_torch = torch.tensor(X_int, device=device)
X_bd_torch = torch.tensor(X_bd, device=device)
A_torch = torch.tensor(A, device=device)
rvec_torch = torch.tensor(rvec, device=device)
alpha_torch = torch.tensor(alpha_opt, device=device)

def u_base_torch(x):
    s = x @ A_torch.T + rvec_torch
    S = gamma_opt * s
    Psi = torch.tanh(S)
    return Psi @ alpha_torch

def laplacian(u, x):
    grads = autograd.grad(u, x, grad_outputs=torch.ones_like(u),
                          create_graph=True)[0]
    lap = 0
    for i in range(x.shape[1]):
        grad2 = autograd.grad(grads[:, i], x,
                              grad_outputs=torch.ones_like(grads[:, i]),
                              create_graph=True)[0][:, i]
        lap += grad2
    return lap   #这里是在算拉普拉斯算子

# 残差网络 vθ
class MLP(nn.Module):
    def __init__(self, in_dim=2, hidden=64, depth=4):
        super().__init__()
        layers = [nn.Linear(in_dim, hidden), nn.Tanh()]
        for _ in range(depth-1):
            layers += [nn.Linear(hidden, hidden), nn.Tanh()]
        layers += [nn.Linear(hidden, 1)]
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        return self.net(x).squeeze(-1)

model_res = MLP().to(device)
opt = torch.optim.Adam(model_res.parameters(), lr=1e-3)
lambda_bd = 10.0
num_epochs = 3000

print("\nStart training residual v_theta ...")

for epoch in range(num_epochs):
    opt.zero_grad()
    x_int = X_int_torch
    x_int.requires_grad_(True)
    v_pred = model_res(x_int)
    lap_v = laplacian(v_pred, x_int)
    u_base_val = u_base_torch(x_int)
    lap_u_base = laplacian(u_base_val, x_int)
    r_pde = f_forcing_torch(x_int) + lap_u_base
    loss_pde = torch.mean((-lap_v - r_pde)**2)

    x_bd = X_bd_torch
    v_bd = model_res(x_bd)
    loss_bd = torch.mean(v_bd**2)

    loss = loss_pde + lambda_bd*loss_bd
    loss.backward()
    opt.step()

    if epoch % 500 == 0:
        print(f"Epoch {epoch:5d}: total={loss.item():.3e}, pde={loss_pde.item():.3e}, bd={loss_bd.item():.3e}")

# 评估最终误差
X_test_torch = torch.tensor(X_test, device=device)
u_pred_final = u_base_torch(X_test_torch) + model_res(X_test_torch)
mse_final = torch.mean((u_pred_final - u_exact_torch(X_test_torch))**2).item()

print(f"\n[Residual Hybrid Skip] Interior MSE vs exact: {mse_final:.6e}")
print(f"Baseline MSE: {mse_base:.6e}  →  Residual Hybrid MSE: {mse_final:.6e}")


Baseline gamma_opt=1.641822, MSE_int=3.061e-05, MSE_bd=7.161e-06
[Baseline TransNet] Interior MSE vs exact: 8.175459e-07

Start training residual v_theta ...
Epoch     0: total=9.467e-02, pde=2.201e-04, bd=9.445e-03
Epoch   500: total=3.181e-05, pde=3.152e-05, bd=2.890e-08
Epoch  1000: total=3.096e-05, pde=3.090e-05, bd=5.547e-09
Epoch  1500: total=3.075e-05, pde=3.074e-05, bd=1.713e-09
Epoch  2000: total=3.068e-05, pde=3.067e-05, bd=8.727e-10
Epoch  2500: total=3.065e-05, pde=3.065e-05, bd=7.181e-10

[Residual Hybrid Skip] Interior MSE vs exact: 8.163803e-07
Baseline MSE: 8.175459e-07  →  Residual Hybrid MSE: 8.163803e-07


Experiment Two (wave equation)

In [None]:
# ==== Cell 1: imports & dtype ====
import numpy as np
import math, time
import torch
import torch.nn as nn
import torch.autograd as autograd

# 双精度对二阶导更稳定
torch.set_default_dtype(torch.float64)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pi = math.pi


# ==== Cell 2: wave equation setup & sampling ====

# PDE: u_tt - c u_xx = 0
c = 1.0/(16.0*pi*pi)

def u_exact_np(X):  # X: (N,2) columns: x,t
    x, t = X[:,0], X[:,1]
    return 0.5*(np.sin(4*pi*x + t) + np.sin(4*pi*x - t))

def u_exact_torch(x):  # x: (N,2) torch
    X, T = x[:,0], x[:,1]
    return 0.5*(torch.sin(4*pi*X + T) + torch.sin(4*pi*X - T))

def IC_u_np(x):   # u(x,0)
    return np.sin(4*pi*x)

def IC_ut_np(x):  # u_t(x,0)=0
    return np.zeros_like(x)

# interior: 去掉 t=0 与 x=0/1（这些点用于 IC/BC）
def sample_interior(n_x=60, n_t=60):
    xs = np.linspace(0.0, 1.0, n_x, dtype=np.float64)
    ts = np.linspace(0.0, 2.0, n_t, dtype=np.float64)
    X, T = np.meshgrid(xs, ts, indexing='xy')
    P = np.stack([X.ravel(), T.ravel()], axis=-1)
    mask = (P[:,1] > 0.0) & (P[:,0] > 0.0) & (P[:,0] < 1.0)
    return P[mask]

def sample_IC(n_x=160):
    xs = np.linspace(0.0, 1.0, n_x, dtype=np.float64)
    return np.stack([xs, np.zeros_like(xs)], axis=-1)  # t=0

def sample_BC_periodic(n_t=160):
    ts = np.linspace(0.0, 2.0, n_t, dtype=np.float64)
    P0 = np.stack([np.zeros_like(ts), ts], axis=-1) # x=0
    P1 = np.stack([np.ones_like(ts),  ts], axis=-1) # x=1
    return P0, P1


# ==== Cell 3: TransNet basis & linear blocks for wave ====

def sigma(s):    return np.tanh(s)
def sigma_p(s):  return 1.0 - np.tanh(s)**2
def sigma_dd(s): 
    t = np.tanh(s)
    sech2 = 1.0 - t*t
    return -2.0*t*sech2

def sample_a_r(M, d=2, R=1.3, seed=1234):
    rng = np.random.RandomState(seed)
    A = rng.randn(M, d)
    A /= np.linalg.norm(A, axis=1, keepdims=True)
    r = rng.rand(M)*R
    return A, r

# 构造线性系统各子块：
# Lψ = ψ_tt - c ψ_xx = σ''(S)*γ^2*(a_t^2 - c a_x^2)
# IC-u: ψ(x,0) ;  IC-ut: ψ_t(x,0) = σ'(S)*γ*a_t
# BC-per: ψ(0,t)-ψ(1,t)
def build_blocks_wave(X_int, X_IC, X0_t, X1_t, A, r, gamma, xi_c=(0.5,1.0)):    #这个函数是用来构造线性系统的各个子块，即为后面的最小二乘求解做准备
    X_int = X_int - np.array(xi_c)[None,:]
    X_IC  = X_IC  - np.array(xi_c)[None,:]
    X0_t  = X0_t  - np.array(xi_c)[None,:]
    X1_t  = X1_t  - np.array(xi_c)[None,:]

    S_int = gamma*(X_int @ A.T + r[None,:])
    S_IC  = gamma*(X_IC  @ A.T + r[None,:])
    S0    = gamma*(X0_t  @ A.T + r[None,:])
    S1    = gamma*(X1_t  @ A.T + r[None,:])

    ax = A[:,0][None,:]
    at = A[:,1][None,:]
    L_int  = sigma_dd(S_int)*(gamma**2)*(at**2 - c*ax**2)
    Psi_IC = sigma(S_IC)
    Psi_tIC = sigma_p(S_IC)*(gamma*at)
    BC_per = sigma(S0) - sigma(S1)
    return L_int, Psi_IC, Psi_tIC, BC_per


def solve_alpha_ls(A_top, A_ic, A_ut, A_per, rhs_top, rhs_ic, rhs_ut, rhs_per,
                   w_top=1.0, w_ic=10.0, w_ut=10.0, w_per=10.0, reg=1e-8):
    A_big = np.vstack([w_top*A_top, w_ic*A_ic, w_ut*A_ut, w_per*A_per])
    b_big = np.concatenate([w_top*rhs_top, w_ic*rhs_ic, w_ut*rhs_ut, w_per*rhs_per])
    AtA = A_big.T @ A_big + reg*np.eye(A_big.shape[1])
    Atb = A_big.T @ b_big
    alpha = np.linalg.solve(AtA, Atb)
    return alpha

def u_base_numpy(X, A, r, gamma, alpha, xi_c=(0.5,1.0)):
    X_c = X - np.array(xi_c)[None,:]
    S = gamma*(X_c @ A.T + r[None,:])
    return sigma(S) @ alpha


# ==== Cell 4: golden search for gamma & baseline ====

def golden_search(func, a, b, tol=1e-3, max_iters=50):
    phi = (1 + 5**0.5)/2.0
    invphi = 1.0/phi
    c = b - invphi*(b-a)
    d = a + invphi*(b-a)
    fc = func(c)
    fd = func(d)
    it = 0
    while (b-a)>tol and it<max_iters:
        it += 1
        if fc[0] < fd[0]:
            b, d, fd = d, c, fc
            c = b - invphi*(b-a)
            fc = func(c)
        else:
            a, c, fc = c, d, fd
            d = a + invphi*(b-a)
            fd = func(d)
    return (c, fc) if fc[0] < fd[0] else (d, fd)

# --- 数据采样 ---
X_int = sample_interior(60, 60)       # interior
X_IC  = sample_IC(160)                # t=0
X0_t, X1_t = sample_BC_periodic(160)  # x=0, x=1

rhs_int = np.zeros(X_int.shape[0])        # L(u)=0
rhs_IC  = IC_u_np(X_IC[:,0])              # u(x,0)
rhs_ut  = IC_ut_np(X_IC[:,0])             # u_t(x,0)=0
rhs_per = np.zeros(X0_t.shape[0])         # periodic: u(0,t)-u(1,t)=0

# --- 随机基 ---
M = 300
seed = 2025
np.random.seed(seed)
A, r = sample_a_r(M, d=2, R=1.3, seed=seed+1)

# --- γ 搜索目标 ---
def eta_of_gamma(gamma):
    L_int, Psi_IC, Psi_tIC, BC_per = build_blocks_wave(X_int, X_IC, X0_t, X1_t, A, r, gamma)
    alpha = solve_alpha_ls(L_int, Psi_IC, Psi_tIC, BC_per, rhs_int, rhs_IC, rhs_ut, rhs_per,
                          w_top=1.0, w_ic=10.0, w_ut=10.0, w_per=10.0)
    e_int = L_int@alpha - rhs_int
    e_ic  = Psi_IC@alpha - rhs_IC
    e_ut  = Psi_tIC@alpha - rhs_ut
    e_per = BC_per@alpha - rhs_per
    mse = (np.mean(e_int**2) + 10.0*np.mean(e_ic**2) + 10.0*np.mean(e_ut**2) + 10.0*np.mean(e_per**2))
    return (mse,)

print("Searching gamma ...")
gamma_opt, best = golden_search(eta_of_gamma, 1e-2, 10.0, tol=1e-3, max_iters=50)
print(f"gamma_opt = {gamma_opt:.6f}, eta = {best[0]:.3e}")

# --- 最终 α & 基线误差 ---
L_int, Psi_IC, Psi_tIC, BC_per = build_blocks_wave(X_int, X_IC, X0_t, X1_t, A, r, gamma_opt)
alpha_opt = solve_alpha_ls(L_int, Psi_IC, Psi_tIC, BC_per, rhs_int, rhs_IC, rhs_ut, rhs_per,
                           w_top=1.0, w_ic=10.0, w_ut=10.0, w_per=10.0)

X_test = np.stack(np.meshgrid(np.linspace(0,1,120), np.linspace(0,2,120), indexing='xy'), axis=-1).reshape(-1,2)
u_base = u_base_numpy(X_test, A, r, gamma_opt, alpha_opt)
mse_base = np.mean((u_base - u_exact_np(X_test))**2)
print(f"[Baseline TransNet] MSE vs exact: {mse_base:.6e}")



# ==== Cell 5: Torch u_base / operator / r_PDE precompute ====

# 常量参数转 torch（不需要 requires_grad）
A_t  = torch.tensor(A, device=device)
r_t  = torch.tensor(r, device=device)
alpha_t = torch.tensor(alpha_opt, device=device)
gamma_t = torch.tensor(gamma_opt, device=device)

def u_base_torch(x, xi_c=(0.5,1.0)):
    xi_c_t = torch.tensor(xi_c, device=device)
    x_c = x - xi_c_t[None,:]
    S = gamma_t*(x_c @ A_t.T + r_t[None,:])
    return torch.tanh(S) @ alpha_t


# 波动方程算子 L(u) = u_tt - c u_xx
def wave_operator(u, x):     #这里因为是二阶导，所以要做两遍autograd
    grad = autograd.grad(u, x, grad_outputs=torch.ones_like(u), create_graph=True)[0]
    du_dx = grad[:,0]
    du_dt = grad[:,1]
    d2u_dxx = autograd.grad(du_dx, x, grad_outputs=torch.ones_like(du_dx), create_graph=True)[0][:,0]
    d2u_dtt = autograd.grad(du_dt, x, grad_outputs=torch.ones_like(du_dt), create_graph=True)[0][:,1]
    return d2u_dtt - c*d2u_dxx

# 预计算 r_PDE = -L(u_base) ；不要在 no_grad 中做二阶导
def precompute_r_pde(X, chunk=8192):
    outs = []
    N = X.shape[0]
    for i in range(0, N, chunk):
        xb = X[i:i+chunk].clone().detach().requires_grad_(True)
        ub = u_base_torch(xb)
        Lub = wave_operator(ub, xb)
        outs.append((-Lub).detach())   # 结果 detach 成常量监督
    return torch.cat(outs, dim=0)

# 训练集（torch）
Xin_t  = torch.tensor(X_int, device=device)
Xic_t  = torch.tensor(X_IC,  device=device)
X0_t_t = torch.tensor(X0_t,  device=device)
X1_t_t = torch.tensor(X1_t,  device=device)

# 预计算 r_PDE
r_int = precompute_r_pde(Xin_t, chunk=8192)


# ==== Cell 6: Residual network v_theta & training ====

class MLP(nn.Module):
    def __init__(self, in_dim=2, hidden=64, depth=4):
        super().__init__()
        layers = [nn.Linear(in_dim, hidden), nn.Tanh()]
        for _ in range(depth-1):
            layers += [nn.Linear(hidden, hidden), nn.Tanh()]
        layers += [nn.Linear(hidden, 1)]
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        return self.net(x).squeeze(-1)

model = MLP().to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
lambda_ic  = 10.0
lambda_per = 10.0
epochs = 3000

print("\nStart training residual v_theta ...")

for ep in range(epochs):    #这个循环是在做MLP的训练
    opt.zero_grad()

    # PDE 残差（域内）：L(v) ≈ r_PDE
    x = Xin_t.clone().detach().requires_grad_(True)
    v = model(x)
    L_v = wave_operator(v, x)
    loss_pde = torch.mean((L_v - r_int)**2)

    # 初值：v(x,0)=0, v_t(x,0)=0
    v_ic = model(Xic_t)                     # 值为零
    loss_ic_val = torch.mean(v_ic**2)

    Xic_req = Xic_t.clone().detach().requires_grad_(True)
    v_ic2 = model(Xic_req)
    grad_ic = autograd.grad(v_ic2, Xic_req, grad_outputs=torch.ones_like(v_ic2), create_graph=True)[0]
    vt_ic = grad_ic[:,1]                    # ∂v/∂t at t=0
    loss_ic_vel = torch.mean(vt_ic**2)
    loss_ic = loss_ic_val + loss_ic_vel

    # 周期边界：v(0,t)-v(1,t)=0
    v0 = model(X0_t_t)
    v1 = model(X1_t_t)
    loss_per = torch.mean((v0 - v1)**2)

    loss = loss_pde + lambda_ic*loss_ic + lambda_per*loss_per
    loss.backward()
    opt.step()

    if ep % 400 == 0:
        print(f"Epoch {ep:4d}: total={loss.item():.3e}, pde={loss_pde.item():.3e}, ic={loss_ic.item():.3e}, per={loss_per.item():.3e}")


# ==== Cell 7: evaluation ====
with torch.no_grad():
    Xt = torch.tensor(X_test, device=device)
    u_pred = u_base_torch(Xt) + model(Xt)
    mse_final = torch.mean((u_pred - u_exact_torch(Xt))**2).item()

print(f"\n[Summary] Baseline MSE = {mse_base:.6e} | Residual-Hybrid MSE = {mse_final:.6e}")


Searching gamma ...
gamma_opt = 2.370048, eta = 6.503e-06
[Baseline TransNet] MSE vs exact: 1.561160e-04

Start training residual v_theta ...
Epoch    0: total=4.464e-02, pde=7.177e-05, ic=3.826e-03, per=6.310e-04
Epoch  400: total=5.230e-06, pde=4.694e-06, ic=4.907e-08, per=4.509e-09
Epoch  800: total=4.395e-06, pde=4.277e-06, ic=1.126e-08, per=5.688e-10
Epoch 1200: total=4.207e-06, pde=4.176e-06, ic=2.702e-09, per=3.807e-10
Epoch 1600: total=4.173e-06, pde=4.151e-06, ic=1.844e-09, per=3.577e-10
Epoch 2000: total=4.549e-06, pde=4.205e-06, ic=3.215e-08, per=2.316e-09
Epoch 2400: total=4.165e-06, pde=4.147e-06, ic=1.395e-09, per=4.409e-10
Epoch 2800: total=5.717e-06, pde=4.193e-06, ic=1.484e-07, per=4.098e-09

[Summary] Baseline MSE = 1.561160e-04 | Residual-Hybrid MSE = 1.550566e-04
