In [1]:
import math
import time
from dataclasses import dataclass
from typing import Tuple, List

import torch
import torch.nn as nn
import torch.nn.functional as F

# ------------------------------
# Config
# ------------------------------
N_PDE = 100          # collocation (interior) points
N_BC_EDGE = 100      # boundary samples per edge
HIDDEN = [20, 20, 20]
DTYPE = torch.float32
SEED = 0
LR = 1e-3

# ------------------------------
# Utils
# ------------------------------
def set_seed(seed: int = 0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def bytes_to_mb(nbytes: int) -> float:
    return nbytes / (1024**2)

def on_cuda() -> bool:
    return torch.cuda.is_available()

def device():
    return torch.device("cuda") if on_cuda() else torch.device("cpu")

# ------------------------------
# Model
# ------------------------------
class MLP(nn.Module):
    def __init__(self, in_dim=2, hidden: List[int] = [20,20,20], out_dim=1):
        super().__init__()
        layers = []
        last = in_dim
        for h in hidden:
            layers.append(nn.Linear(last, h))
            last = h
        self.hidden = nn.ModuleList(layers)
        self.out = nn.Linear(last, out_dim)
        self.act = nn.Tanh()

        # (optional) nice init
        for lin in self.hidden:
            nn.init.xavier_normal_(lin.weight, gain=1.0)
            nn.init.zeros_(lin.bias)
        nn.init.xavier_normal_(self.out.weight, gain=1.0)
        nn.init.zeros_(self.out.bias)

    def forward(self, z0):  # z0: (N,2) with columns [x,y]
        z = z0
        Ls, Zs = [], []
        for lin in self.hidden:
            # Pre-activation: L = Z @ W^T + b    (shapes: (N,last)*(h,last)^T => (N,h))
            L = z @ lin.weight.t() + lin.bias
            z = self.act(L)
            Ls.append(L)
            Zs.append(z)
        # Output (linear): u = Z @ W_out^T + b_out  => (N,1)
        u = z @ self.out.weight.t() + self.out.bias
        return u, Ls, Zs

# ------------------------------
# Manual forward-mode derivatives
# ------------------------------
def tanh_prime_from_Z(Z):
    # σ'(a) = 1 - tanh(a)^2 ; since Z=tanh(a), σ'(a)=1 - Z^2
    return 1.0 - Z**2

def tanh_doubleprime_from_L_and_Z(L, Z):
    # σ''(a) = -2 * tanh(a) * (1 - tanh(a)^2) = -2 * Z * (1 - Z^2)
    return -2.0 * Z * (1.0 - Z**2)

@torch.enable_grad()   # ensure params get grads through manual ops
def manual_seconds(model: MLP, z0: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Compute u, u_xx, u_yy with forward-mode streams (no autograd.grad on inputs).
    z0: (N,2) with columns [x,y]. requires_grad not needed here.
    """
    # Forward pass (cache Ls, Zs)
    u, Ls, Zs = model(z0)
    # Weight matrices to use: for derivative recurrences we need W_i^T (same as forward)
    WTs = [lin.weight for lin in model.hidden]  # (out,in); we'll use .t() when multiplying
    W_out_T = model.out.weight  # (1, last)

    N = z0.shape[0]
    widths = [lin.out_features for lin in model.hidden]

    # Base derivative streams for x and y at input:
    # G0 has shape (N, in_dim=2); for ∂/∂x it's [1,0], for ∂/∂y it's [0,1].
    G0x = torch.zeros_like(z0); G0x[:, 0] = 1.0
    G0y = torch.zeros_like(z0); G0y[:, 1] = 1.0

    # First-order streams G (post-activation) and H (post-linear)
    # Second-order streams F (post-linear) and E (post-activation)
    # We'll carry "current" G and F as we traverse layers; H,E are per-layer temporaries.
    def streams_for_coord(G0):
        G_prev = G0                     # (N, in_dim or width of previous)
        F_prev = None                   # None means zeros at first layer
        for i, (L, Z, W) in enumerate(zip(Ls, Zs, WTs)):
            # Through linear: H_i = G_{i-1} @ W_i^T  where W_i is (out,in)
            H_i = G_prev @ W.t()        # (N, width_i)
            # First derivative through activation:
            sig1 = tanh_prime_from_Z(Z)     # (N, width_i)
            G_i = sig1 * H_i
            # Second derivative:
            sig2 = tanh_doubleprime_from_L_and_Z(L, Z)  # (N, width_i)
            if F_prev is None:
                F_prev = torch.zeros_like(G_i)          # F_1 = 0 (N, width_1)
            C_i = sig2 * (H_i * H_i)                    # (N, width_i)
            E_i = C_i + sig1 * F_prev                   # (N, width_i)
            # Prepare for next layer: F_{i+1} = E_i @ W_{i+1}^T (if there is a next hidden layer)
            F_next = None
            if i < len(Ls) - 1:
                W_next = WTs[i+1]
                F_next = E_i @ W_next.t()               # (N, width_{i+1})
            # roll
            G_prev = G_i
            F_prev = F_next if F_next is not None else F_prev
            # Cache last layer's E, G for output derivatives
            if i == len(Ls) - 1:
                E_last = E_i
                G_last = G_i
        # Output derivatives (linear head):
        # u_c  = G_last @ W_out^T   ; u_cc = E_last @ W_out^T
        u_c  = G_last @ W_out_T.t()   # (N,1)
        u_cc = E_last @ W_out_T.t()   # (N,1)
        return u_c, u_cc

    ux, uxx = streams_for_coord(G0x)   # (N,1), (N,1)
    uy, uyy = streams_for_coord(G0y)   # (N,1), (N,1)
    return u, uxx, uyy  # (N,1) each

# ------------------------------
# autograd.grad-based seconds
# ------------------------------
def autograd_seconds(model: MLP, x: torch.Tensor, y: torch.Tensor):
    """
    Compute u, u_xx, u_yy using autograd.grad with create_graph=True.
    x,y must have requires_grad=True.
    """
    z0 = torch.cat([x, y], dim=1)            # (N,2)
    u, _, _ = model(z0)                      # (N,1)

    ones = torch.ones_like(u)

    # First derivatives: retain_graph because we'll also get grads w.r.t. the *other* input
    ux = torch.autograd.grad(u, x, grad_outputs=ones, create_graph=True, retain_graph=True)[0]   # (N,1)
    uy = torch.autograd.grad(u, y, grad_outputs=ones, create_graph=True, retain_graph=True)[0]   # (N,1)

    # Second derivatives:
    uxx = torch.autograd.grad(ux, x, grad_outputs=torch.ones_like(ux), create_graph=True)[0]     # (N,1)
    uyy = torch.autograd.grad(uy, y, grad_outputs=torch.ones_like(uy), create_graph=True)[0]     # (N,1)

    return u, uxx, uyy

# ------------------------------
# Data (PDE + boundary samples)
# ------------------------------
def sample_interior(N: int, dev):
    # Uniform interior in (0,1)^2
    x = torch.rand(N, 1, device=dev, dtype=DTYPE)
    y = torch.rand(N, 1, device=dev, dtype=DTYPE)
    return x, y

def sample_boundaries(N_edge: int, dev):
    y0 = torch.rand(N_edge, 1, device=dev, dtype=DTYPE)
    y1 = torch.rand(N_edge, 1, device=dev, dtype=DTYPE)
    x0 = torch.rand(N_edge, 1, device=dev, dtype=DTYPE)
    x1 = torch.rand(N_edge, 1, device=dev, dtype=DTYPE)

    # Edges:
    xb0 = torch.zeros_like(y0); yb0 = y0                   # x=0, u=0
    xb1 = torch.ones_like(y1);  yb1 = y1                   # x=1, u=sin(pi*y)
    yb2 = torch.zeros_like(x0); xb2 = x0                   # y=0, u=0
    yb3 = torch.ones_like(x1);  xb3 = x1                   # y=1, u=0

    # Targets on boundaries:
    u_b0 = torch.zeros_like(y0)                            # u(0,y)=0
    u_b1 = torch.sin(math.pi * y1)                         # u(1,y)=sin(pi*y)
    u_b2 = torch.zeros_like(x0)                            # u(x,0)=0
    u_b3 = torch.zeros_like(x1)                            # u(x,1)=0

    # Stack each edge into batches
    Bx = [xb0, xb1, xb2, xb3]
    By = [yb0, yb1, yb2, yb3]
    Bu = [u_b0, u_b1, u_b2, u_b3]
    return Bx, By, Bu

def bc_loss(model: MLP, Bx: List[torch.Tensor], By: List[torch.Tensor], Bu: List[torch.Tensor]) -> torch.Tensor:
    mse = nn.MSELoss()
    total = 0.0
    for xb, yb, ub in zip(Bx, By, Bu):
        z = torch.cat([xb, yb], dim=1)
        u_pred, _, _ = model(z)
        total = total + mse(u_pred, ub)
    return total

# ------------------------------
# Memory measurement helpers
# ------------------------------
class PeakMem:
    def __init__(self, label: str):
        self.label = label
        self.cuda = on_cuda()
        self.peak_bytes = 0

    def __enter__(self):
        if self.cuda:
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.synchronize()
        return self

    def __exit__(self, exc_type, exc, tb):
        if self.cuda:
            torch.cuda.synchronize()
            self.peak_bytes = torch.cuda.max_memory_allocated()
            print(f"[{self.label}] Peak CUDA allocated: {bytes_to_mb(self.peak_bytes):.3f} MB")
        else:
            # CPU mode: report rough param + tensor allocations (not perfect)
            print(f"[{self.label}] CPU mode: use theoretical estimates above; "
                  f"PyTorch doesn't expose 'graph-only' CPU memory precisely.")

# ------------------------------
# One training step for each method
# ------------------------------
def train_step_manual(model: MLP, opt, x_pde, y_pde, Bx, By, Bu):
    model.train()
    opt.zero_grad(set_to_none=True)

    z0 = torch.cat([x_pde, y_pde], dim=1)  # (N,2)
    u, uxx, uyy = manual_seconds(model, z0)  # (N,1) each

    # PDE residual: u_xx + u_yy = 0
    r = uxx + uyy
    loss_pde = (r**2).mean()

    # Boundary conditions
    loss_bc = bc_loss(model, Bx, By, Bu)

    loss = loss_pde + loss_bc
    loss.backward()
    opt.step()
    return float(loss.item())

def train_step_autograd(model: MLP, opt, x_pde, y_pde, Bx, By, Bu):
    model.train()
    opt.zero_grad(set_to_none=True)

    # Need grads on inputs for autograd.grad
    x = x_pde.clone().detach().requires_grad_(True)
    y = y_pde.clone().detach().requires_grad_(True)

    u, uxx, uyy = autograd_seconds(model, x, y)

    # PDE residual: u_xx + u_yy = 0
    r = uxx + uyy
    loss_pde = (r**2).mean()

    # Boundary conditions
    loss_bc = bc_loss(model, Bx, By, Bu)

    loss = loss_pde + loss_bc
    loss.backward()
    opt.step()
    return float(loss.item())

# ------------------------------
# Theoretical estimator (activations only)
# ------------------------------
@dataclass
class EstimationConfig:
    N: int
    n0: int
    hidden: List[int]
    k_coords: int = 2
    dtype_bytes: int = 4  # fp32

def estimate_activation_bytes_forward_mode(cfg: EstimationConfig, baseline: str = "upper",
                                           need_first: bool = True, need_pure_second: bool = True,
                                           include_mixed: bool = False) -> int:
    S = sum(cfg.hidden)
    n0 = cfg.n0
    N = cfg.N
    k = cfg.k_coords
    b = cfg.dtype_bytes

    if baseline == "lower":
        base_elts = N * (n0 + S)                 # Z0 + one of (L/Z) per layer
    else:
        base_elts = N * (n0 + 2*S)               # Z0 + both (L and Z) per layer

    extra = 0
    if need_first:
        extra += N * (2 * k * S)                 # H,G per layer per coord
    if need_pure_second:
        extra += N * (2 * k * S)                 # F,E per layer per coord
    if include_mixed and k >= 2:
        extra += N * (2 * (k*(k-1)//2) * S)      # K,J per layer per pair

    return (base_elts + extra) * b

def estimate_activation_bytes_autograd(cfg: EstimationConfig, baseline: str = "upper",
                                       need_first: bool = True, need_pure_second: bool = True) -> Tuple[int,int]:
    """
    Return (low_bytes, up_bytes) for autograd.grad.
    Low ~ 1 activation per layer per coord per grad graph; Up ~ 2 (like H/G pair).
    """
    S = sum(cfg.hidden)
    n0 = cfg.n0
    N = cfg.N
    k = cfg.k_coords
    b = cfg.dtype_bytes

    if baseline == "lower":
        base_elts = N * (n0 + S)
    else:
        base_elts = N * (n0 + 2*S)

    low_extra = 0
    up_extra  = 0
    if need_first:
        low_extra += N * (1 * k * S)
        up_extra  += N * (2 * k * S)
    if need_pure_second:
        low_extra += N * (1 * k * S)
        up_extra  += N * (2 * k * S)
    # (Mixed terms would add more; Laplace uses only pure seconds.)

    return (base_elts + low_extra) * b, (base_elts + up_extra) * b

# ------------------------------
# Main demo
# ------------------------------
def main():
    set_seed(SEED)
    dev = device()
    print(f"Device: {dev.type.upper()}")

    # Data
    x_pde, y_pde = sample_interior(N_PDE, dev)
    Bx, By, Bu = sample_boundaries(N_BC_EDGE, dev)

    # Theoretical estimates (activations/graph stash only)
    cfg = EstimationConfig(N=N_PDE, n0=2, hidden=HIDDEN, k_coords=2, dtype_bytes=4)
    fwd_bytes = estimate_activation_bytes_forward_mode(cfg, baseline="upper",
                                                       need_first=True, need_pure_second=True, include_mixed=False)
    ag_low, ag_up = estimate_activation_bytes_autograd(cfg, baseline="upper",
                                                       need_first=True, need_pure_second=True)
    print("\n[Theoretical graph/activation stash (fp32)]")
    print(f"Forward-mode (manual, with u_xx & u_yy): ~{bytes_to_mb(fwd_bytes):.3f} MB")
    print(f"autograd.grad  (low..up bounds):        ~{bytes_to_mb(ag_low):.3f} .. {bytes_to_mb(ag_up):.3f} MB")

    # Empirical peak memory per method (one training step)
    # Fresh model/opt each to avoid cross-graph interference
    model_A = MLP(2, HIDDEN, 1).to(dev).to(DTYPE)
    opt_A = torch.optim.Adam(model_A.parameters(), lr=LR)

    with PeakMem("Manual forward-mode (u_xx, u_yy)"):
        lossA = train_step_manual(model_A, opt_A, x_pde, y_pde, Bx, By, Bu)

    model_B = MLP(2, HIDDEN, 1).to(dev).to(DTYPE)
    opt_B = torch.optim.Adam(model_B.parameters(), lr=LR)

    with PeakMem("autograd.grad (u_xx, u_yy)"):
        lossB = train_step_autograd(model_B, opt_B, x_pde, y_pde, Bx, By, Bu)

    print(f"\nLoss (manual):   {lossA:.6f}")
    print(f"Loss (autograd): {lossB:.6f}")
    print("\nNote: CUDA numbers are peak *allocated* VRAM in the step,")
    print("      dominated by 'saved for backward' activations + grad graphs.\n")

if __name__ == "__main__":
    main()

Device: CUDA

[Theoretical graph/activation stash (fp32)]
Forward-mode (manual, with u_xx & u_yy): ~0.230 MB
autograd.grad  (low..up bounds):        ~0.138 .. 0.230 MB
[Manual forward-mode (u_xx, u_yy)] Peak CUDA allocated: 16.816 MB
[autograd.grad (u_xx, u_yy)] Peak CUDA allocated: 16.790 MB

Loss (manual):   0.866808
Loss (autograd): 0.619902

Note: CUDA numbers are peak *allocated* VRAM in the step,
      dominated by 'saved for backward' activations + grad graphs.

