Nonlinear MLP Stencil

Goal: Use MLP to train nonlinear stencil

In [None]:
# Essential packages
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.func import jacrev, vmap, functional_call

Define MLP to output stencil perturbation and build nonlinear residual

In [5]:
class StencilMLP(nn.Module):
    """Map local stencil to perturbation coeff using MLP
    Input: local solution
    Output: perturbation weights
    """
    def __init__(self, stencil_width=3, hidden=16):
        super().__init__() # initialize parent
        self.net = nn.Sequential(
            nn.Linear(stencil_width, hidden), # layer1: input
            nn.Tanh(), # layer2: activation func
            nn.Linear(hidden, hidden), # layer3: hidden to hidden
            nn.Tanh(), # layer4: activation func
            nn.Linear(hidden, stencil_width), # layer 5: map hidden to output
        )
        with torch.no_grad():
            self.net[-1].weight.zero_() # set all weights in the final layer to zero
            self.net[-1].bias.zero_() # set all biases to zero

    def forward(self, u_local):
        return self.net(u_local)
    
def extract_stencil_window(u, stencil_radius=1):
    # Extract all local stencil windows from u with periodic BCs
    N = u.shape[0]
    width = 2 * stencil_radius + 1
    u_padded = torch.cat([u[-stencil_radius:], u, u[:stencil_radius]])
    windows = u_padded.unfold(0, width, 1)
    return windows

def apply_nonlinear_stencil(u, mlp_params, mlp_buffers, mlp_forward, base_coeffs):
    # Apply nonlinear stencil to the full solution vector u
    windows = extract_stencil_window(u)
    delta_a = mlp_forward(mlp_params, mlp_buffers, windows)
    a_eff = base_coeffs.unsqueeze(0) + delta_a
    Lu = (a_eff * windows).sum(dim=1)
    return Lu

Implicit Euler Residual and Jacobian

In [None]:
# Compute residual u_new - u_old - dt * Lu for the full system
def implicit_euler_residual(u_new, u_old, dt, mlp_params, mlp_buffers,
                            mlp_forward, base_coeffs):
    Lu = apply_nonlinear_stencil(u_new, mlp_params, mlp_buffers, mlp_forward, base_coeffs)
    return u_new - u_old - dt * Lu

# Compute the residual for a single spatial point i using its local stencil (3 points)
def local_residual_i(u_local_3, i, u_full ,u_old_i, dt, mlp_buffers, mlp_forward, base_coeefs):
    delta_a = mlp_forward(mlp_params, mlp_buffers, u_local_3.unsqueeze(0)).squeeze(0) 
    a_eff = base_coeffs + delta_a  
    Lu_i = (a_eff * u_local_3).sum()
    return u_local_3[1] - u_old_i - dt * Lu_i 

# Build Jacobian Matrix
def assemble_jacobian_banded(u_new, u_old, dt, mlp_params, mlp_buffers, mlp_forward, base_coeffs):
    N = u_new.shape[0]
    windows = extract_stencil_window(u_new)

    def local_res_fn(u_local_3, u_old_i):
        delta_a = mlp_forward(mlp_params, mlp_buffers, u_local_3.unsqueeze(0)).squeeze(0)
        a_eff = base_coeffs + delta_a
        Li_i = (a_eff * u_local_3).sum()
        return u_local_3[1] - u_old_i - dt * Lu_i
    
    local_jac_fn = jacrev(local_res_fn, argums=0)
    all_loacl_jacs = vmap(local_jac_fn)(windowns, u_old)

    J = torch.zeros(N, N, dtype=u_new.dtype, device=u_new.device)
    idx = torch.arrange(N, device=u_new.device)

    for k in range(3):
        col = (idx + (k - 1)) % N
        J[idx, col] = all_loacl_jacs[:, k]

    return J

Newton Solver with Armijo backtracking

In [8]:
def newton_solve(u_init, u_old, dt, mlp_params, mlp_buffers, mlp_forward,
                 base_coeffs, tol=1e-10, max_iter=20, verbose=False):
    
    u = u_init.detach().clone()

    for k in range(max_iter):
        with torch.no_grad():
            F_val = implicit_euler_residual(u, u_old, dt, mlp_params, mlp_buffers, mlp_forward, base_coeffs)

            res_norm = F_val.norm().item()
            if verbose:
                print(f"Newton iter {k}: ||F|| = {res_norm:.3e}")
                if res_norm < tol:
                    break
        with torch.enable_grad():
            J = assemble_jacobian_banded(u, u_old, dt, mlp_params, mlp_buffers, mlp_forward, base_coeffs)
        with torch.no_grad():
            delta_u = torch.linalg.solve(J.detach(), -F_val)

            alpha = 1.0
            c1 = 1e-4
            tau = 0.5
            phi_current = 0.5 * F_val.dot(F_val)
            directional_deriv = -F_val.dot(F_val)

            for _ in range(20):
                u_trial = u + alpha * delta_u
                F_trial = implicit_euler_residual(u_trial, u_old, dt, mlp_params, mlp_buffers, mlp_forward, base_coeffs)
                phi_trial = 0.5 * F_trial.dot(F_trial)
                if phi_trial <= phi_current + c1 * alpha * directional_deriv:
                    break
                alpha *= tau

            u = u + alpha * delta_u

    u_detached = u.detach()
    F_val = implicit_euler_residual(u_detached, u_old, dt, mlp_params, mlp_buffers, mlp_forward, base_coeffs)
    J = assemble_jacobian_banded(u_detached, u_old, dt, mlp_params, mlp_buffers, mlp_forward, base_coeffs)
    delta_u = torch.linalg.solve(J, -F_val)
    u_out = u_detached + delta_u

    return u_out