In [3]:
# ========================= EXAMPLE 1: P4 (Cahn–Hilliard) =========================
# u(x,y) = (1/(2π²)) sin(πx) sin(πy)
# Domain: Ω = (0,1)²
# PDE: Δ²u = f, with P4 boundary conditions:
#       ∂u/∂n = g1,  ∂(Δu)/∂n = g2 on ∂Ω
# Network: [2, 84, 84, 84, 84, 1]
# Training: 20000 epochs, 20000 interior pts, 6000 boundary pts
# Uses Adam + ReduceLROnPlateau LR scheduler.

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
import os
import time

# ------------------------------------------------------------------------------
# Setup
# ------------------------------------------------------------------------------

torch.manual_seed(42)
np.random.seed(42)

results_dir = "pinn_biharmonic_results_pytorch_corrected"
os.makedirs(results_dir, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ------------------------------------------------------------------------------
# Neural Network
# ------------------------------------------------------------------------------

class BiharmonicPINN(nn.Module):
    def __init__(self, layers, activation=nn.Tanh()):
        super().__init__()
        self.layers = layers
        self.activation = activation

        self.linears = nn.ModuleList()
        for i in range(len(layers) - 1):
            self.linears.append(nn.Linear(layers[i], layers[i + 1]))
        self.init_weights()

    def init_weights(self):
        for linear in self.linears:
            nn.init.xavier_normal_(linear.weight)
            nn.init.constant_(linear.bias, 0.0)

    def forward(self, x):
        for i, linear in enumerate(self.linears):
            x = linear(x)
            if i < len(self.linears) - 1:
                x = self.activation(x)
        return x

# ------------------------------------------------------------------------------
# Differential operators (up to 4th order)
# ------------------------------------------------------------------------------

def compute_derivatives(u, x, y):
    """
    Compute derivatives up to 4th order using autograd.

    Returns:
        u, u_x, u_y, u_xx, u_yy, u_xy,
        u_xxx, u_xxy, u_xyy, u_yyy,
        u_xxxx, u_xxyy, u_yyyy
    """
    # first derivatives
    u_x = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u),
                              create_graph=True, retain_graph=True)[0]
    u_y = torch.autograd.grad(u, y, grad_outputs=torch.ones_like(u),
                              create_graph=True, retain_graph=True)[0]

    # second derivatives
    u_xx = torch.autograd.grad(u_x, x, grad_outputs=torch.ones_like(u_x),
                               create_graph=True, retain_graph=True)[0]
    u_yy = torch.autograd.grad(u_y, y, grad_outputs=torch.ones_like(u_y),
                               create_graph=True, retain_graph=True)[0]
    u_xy = torch.autograd.grad(u_x, y, grad_outputs=torch.ones_like(u_x),
                               create_graph=True, retain_graph=True)[0]

    # third derivatives
    u_xxx = torch.autograd.grad(u_xx, x, grad_outputs=torch.ones_like(u_xx),
                                create_graph=True, retain_graph=True)[0]
    u_xxy = torch.autograd.grad(u_xx, y, grad_outputs=torch.ones_like(u_xx),
                                create_graph=True, retain_graph=True)[0]
    u_xyy = torch.autograd.grad(u_xy, y, grad_outputs=torch.ones_like(u_xy),
                                create_graph=True, retain_graph=True)[0]
    u_yyy = torch.autograd.grad(u_yy, y, grad_outputs=torch.ones_like(u_yy),
                                create_graph=True, retain_graph=True)[0]

    # fourth derivatives
    u_xxxx = torch.autograd.grad(u_xxx, x, grad_outputs=torch.ones_like(u_xxx),
                                 create_graph=True, retain_graph=True)[0]
    u_xxyy = torch.autograd.grad(u_xxy, y, grad_outputs=torch.ones_like(u_xxy),
                                 create_graph=True, retain_graph=True)[0]
    u_yyyy = torch.autograd.grad(u_yyy, y, grad_outputs=torch.ones_like(u_yyy),
                                 create_graph=True, retain_graph=True)[0]

    return (u, u_x, u_y, u_xx, u_yy, u_xy,
            u_xxx, u_xxy, u_xyy, u_yyy,
            u_xxxx, u_xxyy, u_yyyy)

def compute_biharmonic(u_xxxx, u_xxyy, u_yyyy):
    """Δ²u = u_xxxx + 2 u_xxyy + u_yyyy"""
    return u_xxxx + 2.0 * u_xxyy + u_yyyy

def compute_normals(x, y):
    """
    Outward unit normal on ∂Ω for Ω = (0,1)².

        left   (x=0): n = (-1, 0)
        right  (x=1): n = ( 1, 0)
        bottom (y=0): n = ( 0,-1)
        top    (y=1): n = ( 0, 1)
    """
    n_x = torch.zeros_like(x)
    n_y = torch.zeros_like(y)
    eps = 1e-6

    left   = x <= eps
    right  = x >= 1.0 - eps
    bottom = y <= eps
    top    = y >= 1.0 - eps

    n_x[left] = -1.0
    n_x[right] = 1.0
    n_y[bottom] = -1.0
    n_y[top] = 1.0

    return n_x, n_y

def compute_normal_derivatives(x, y,
                               u_x, u_y,
                               u_xx, u_yy, u_xy,
                               u_xxx, u_xxy, u_xyy, u_yyy):
    """
    P4 boundary operators:
        ∂u/∂n      = n · ∇u
        ∂(Δu)/∂n   = n · ∇(Δu)
    """
    n_x, n_y = compute_normals(x, y)

    # ∂u/∂n
    u_n = n_x * u_x + n_y * u_y

    # Δu = u_xx + u_yy
    # ∂(Δu)/∂x = u_xxx + u_xyy
    # ∂(Δu)/∂y = u_xxy + u_yyy
    lap_x = u_xxx + u_xyy
    lap_y = u_xxy + u_yyy
    lap_n = n_x * lap_x + n_y * lap_y

    return u_n, lap_n

def compute_errors(u_pred, u_exact, x, y):
    """
    L2 and "energy" (H1-like) errors:
        L2  = ||uθ - u||_L2
        H1  ≈ L2(uθ - u) + L2(∇uθ - ∇u)
    plus relative versions.
    """
    diff = u_pred - u_exact
    l2_error = torch.sqrt(torch.mean(diff ** 2))

    # gradients
    u_pred_x = torch.autograd.grad(u_pred, x, grad_outputs=torch.ones_like(u_pred),
                                   create_graph=True, retain_graph=True)[0]
    u_pred_y = torch.autograd.grad(u_pred, y, grad_outputs=torch.ones_like(u_pred),
                                   create_graph=True, retain_graph=True)[0]
    u_exact_x = torch.autograd.grad(u_exact, x, grad_outputs=torch.ones_like(u_exact),
                                    create_graph=True, retain_graph=True)[0]
    u_exact_y = torch.autograd.grad(u_exact, y, grad_outputs=torch.ones_like(u_exact),
                                    create_graph=True, retain_graph=True)[0]

    grad_diff_sq = (u_pred_x - u_exact_x) ** 2 + (u_pred_y - u_exact_y) ** 2
    grad_error = torch.sqrt(torch.mean(grad_diff_sq))

    energy_error = l2_error + grad_error

    # exact norms
    l2_norm_exact = torch.sqrt(torch.mean(u_exact ** 2))
    grad_exact_sq = u_exact_x ** 2 + u_exact_y ** 2
    grad_norm_exact = torch.sqrt(torch.mean(grad_exact_sq))
    energy_norm_exact = l2_norm_exact + grad_norm_exact

    l2_rel = l2_error / (l2_norm_exact + 1e-12)
    energy_rel = energy_error / (energy_norm_exact + 1e-12)

    return l2_error, energy_error, l2_rel, energy_rel

# ------------------------------------------------------------------------------
# Example 1: exact solution and source term
# ------------------------------------------------------------------------------

def exact_solution1(x, y):
    return (1.0 / (2.0 * np.pi ** 2)) * torch.sin(np.pi * x) * torch.sin(np.pi * y)

def source_term1(x, y):
    # Derived from u: Δ²u = 2π² sin(πx) sin(πy)
    return (2.0 * np.pi ** 2) * torch.sin(np.pi * x) * torch.sin(np.pi * y)

# ------------------------------------------------------------------------------
# Training data
# ------------------------------------------------------------------------------

N_int = 20000   # interior points
N_bc  = 6000    # boundary points

# interior in (0,1)²
x_int = torch.rand((N_int, 1), dtype=torch.float32, requires_grad=True).to(device)
y_int = torch.rand((N_int, 1), dtype=torch.float32, requires_grad=True).to(device)

# boundary points – 4 sides
N_bc_side = N_bc // 4

# bottom: y = 0
x_bc_bottom = torch.rand((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)
y_bc_bottom = torch.zeros_like(x_bc_bottom, requires_grad=True)

# top: y = 1
x_bc_top = torch.rand((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)
y_bc_top = torch.ones_like(x_bc_top, requires_grad=True)

# left: x = 0
x_bc_left = torch.zeros((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)
y_bc_left = torch.rand((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)

# right: x = 1
x_bc_right = torch.ones((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)
y_bc_right = torch.rand((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)

x_bc = torch.cat([x_bc_bottom, x_bc_top, x_bc_left, x_bc_right], dim=0)
y_bc = torch.cat([y_bc_bottom, y_bc_top, y_bc_left, y_bc_right], dim=0)

# ------------------------------------------------------------------------------
# Model, optimizer, scheduler, and test points
# ------------------------------------------------------------------------------

layers = [2, 84, 84, 84, 84, 1]
pinn = BiharmonicPINN(layers).to(device)

optimizer = torch.optim.Adam(pinn.parameters(), lr=1e-3)

# LR scheduler: reduce LR by factor 0.5 when loss plateaus
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.5,
    patience=500)

epochs = 20000
print_interval = 100

# histories (one entry every 100 epochs)
loss_history = []
int_loss_history = []
bc_loss_history = []
l2_error_history = []
energy_error_history = []

# fixed test points for monitoring errors
x_test = torch.rand((1000, 1), dtype=torch.float32, requires_grad=True).to(device)
y_test = torch.rand((1000, 1), dtype=torch.float32, requires_grad=True).to(device)

def train_step():
    optimizer.zero_grad()

    # ---- interior (PDE: Δ²u = f) ----
    X_int = torch.cat([x_int, y_int], dim=1)
    u_int = pinn(X_int)

    (u_int,
     u_x_int, u_y_int,
     u_xx_int, u_yy_int, u_xy_int,
     u_xxx_int, u_xxy_int, u_xyy_int, u_yyy_int,
     u_xxxx_int, u_xxyy_int, u_yyyy_int) = compute_derivatives(u_int, x_int, y_int)

    bih_int = compute_biharmonic(u_xxxx_int, u_xxyy_int, u_yyyy_int)
    f_int = source_term1(x_int, y_int)

    loss_int = torch.mean((bih_int - f_int) ** 2)

    # ---- boundary (P4: ∂u/∂n = g1, ∂Δu/∂n = g2) ----
    X_bc = torch.cat([x_bc, y_bc], dim=1)
    u_bc = pinn(X_bc)

    (u_bc,
     u_x_bc, u_y_bc,
     u_xx_bc, u_yy_bc, u_xy_bc,
     u_xxx_bc, u_xxy_bc, u_xyy_bc, u_yyy_bc,
     _, _, _) = compute_derivatives(u_bc, x_bc, y_bc)

    # predicted BC operators
    u_n_bc, lap_n_bc = compute_normal_derivatives(
        x_bc, y_bc,
        u_x_bc, u_y_bc,
        u_xx_bc, u_yy_bc, u_xy_bc,
        u_xxx_bc, u_xxy_bc, u_xyy_bc, u_yyy_bc
    )

    # exact BC from exact u
    u_exact_bc = exact_solution1(x_bc, y_bc)

    u_exact_x_bc = torch.autograd.grad(u_exact_bc, x_bc,
                                       grad_outputs=torch.ones_like(u_exact_bc),
                                       create_graph=True, retain_graph=True)[0]
    u_exact_y_bc = torch.autograd.grad(u_exact_bc, y_bc,
                                       grad_outputs=torch.ones_like(u_exact_bc),
                                       create_graph=True, retain_graph=True)[0]

    n_x_bc, n_y_bc = compute_normals(x_bc, y_bc)
    u_n_exact_bc = n_x_bc * u_exact_x_bc + n_y_bc * u_exact_y_bc

    u_exact_xx_bc = torch.autograd.grad(u_exact_x_bc, x_bc,
                                        grad_outputs=torch.ones_like(u_exact_x_bc),
                                        create_graph=True, retain_graph=True)[0]
    u_exact_yy_bc = torch.autograd.grad(u_exact_y_bc, y_bc,
                                        grad_outputs=torch.ones_like(u_exact_y_bc),
                                        create_graph=True, retain_graph=True)[0]
    lap_exact_bc = u_exact_xx_bc + u_exact_yy_bc

    lap_x_exact = torch.autograd.grad(lap_exact_bc, x_bc,
                                      grad_outputs=torch.ones_like(lap_exact_bc),
                                      create_graph=True, retain_graph=True)[0]
    lap_y_exact = torch.autograd.grad(lap_exact_bc, y_bc,
                                      grad_outputs=torch.ones_like(lap_exact_bc),
                                      create_graph=True, retain_graph=True)[0]

    lap_n_exact_bc = n_x_bc * lap_x_exact + n_y_bc * lap_y_exact

    loss_bc = torch.mean((u_n_bc - u_n_exact_bc) ** 2) + \
              torch.mean((lap_n_bc - lap_n_exact_bc) ** 2)

    # total loss
    total_loss = loss_int + loss_bc
    total_loss.backward()
    optimizer.step()

    return total_loss, loss_int, loss_bc

print("=" * 80)
print("Training Example 1: u = (1/(2π²)) sin(πx) sin(πy) with P4 BCs")
print("=" * 80)

start_time = time.time()

for epoch in range(epochs):
    total_loss, loss_int, loss_bc = train_step()

    # Learning rate scheduler (use scalar loss)
    scheduler.step(total_loss.item())

    if epoch % print_interval == 0:
        # compute monitoring errors (need gradients -> no torch.no_grad())
        x_eval = x_test
        y_eval = y_test
        X_eval = torch.cat([x_eval, y_eval], dim=1)

        u_pred = pinn(X_eval)
        u_exact = exact_solution1(x_eval, y_eval)

        l2_err, en_err, l2_rel, en_rel = compute_errors(
            u_pred, u_exact, x_eval, y_eval
        )

        loss_history.append(total_loss.item())
        int_loss_history.append(loss_int.item())
        bc_loss_history.append(loss_bc.item())
        l2_error_history.append(l2_err.item())
        energy_error_history.append(en_err.item())

        current_lr = optimizer.param_groups[0]["lr"]

        print(f"Epoch {epoch:5d} | "
              f"Total: {total_loss.item():.3e} | "
              f"Int: {loss_int.item():.3e} | "
              f"BC: {loss_bc.item():.3e} | "
              f"L2: {l2_err.item():.3e} | "
              f"Energy: {en_err.item():.3e} | "
              f"LR: {current_lr:.2e}")

training_time = time.time() - start_time
print(f"\nTraining time (Example 1): {training_time:.2f} s")

# ------------------------------------------------------------------------------
# Post-processing: plots and model save
# ------------------------------------------------------------------------------

# grid for visualization (no gradients needed here)
x_plot = np.linspace(0, 1, 100)
y_plot = np.linspace(0, 1, 100)
X_plot, Y_plot = np.meshgrid(x_plot, y_plot)
X_flat = X_plot.flatten().reshape(-1, 1)
Y_flat = Y_plot.flatten().reshape(-1, 1)
XY_plot = torch.tensor(np.hstack([X_flat, Y_flat]), dtype=torch.float32).to(device)

with torch.no_grad():
    u_pred_plot = pinn(XY_plot).cpu().numpy().reshape(100, 100)

u_exact_plot = exact_solution1(
    torch.tensor(X_flat, dtype=torch.float32).to(device),
    torch.tensor(Y_flat, dtype=torch.float32).to(device)
).detach().cpu().numpy().reshape(100, 100)

error_plot = np.abs(u_pred_plot - u_exact_plot)

# 1) loss history
plt.figure(figsize=(10, 6))
plt.semilogy(loss_history, label="Total Loss", linewidth=2)
plt.semilogy(int_loss_history, label="Interior Loss", linewidth=2)
plt.semilogy(bc_loss_history, label="Boundary Loss", linewidth=2)
plt.xlabel("Index (every 100 epochs)")
plt.ylabel("Loss")
plt.title("Example 1: Training Loss History (P4)")
plt.legend()
plt.grid(alpha=0.3)
plt.savefig(os.path.join(results_dir, "example1_loss_history.png"), dpi=300, bbox_inches="tight")
plt.close()

# 2) error history
plt.figure(figsize=(10, 6))
plt.semilogy(l2_error_history, label="L2 Error", linewidth=2)
plt.semilogy(energy_error_history, label="Energy Error", linewidth=2)
plt.xlabel("Index (every 100 epochs)")
plt.ylabel("Error")
plt.title("Example 1: Error History (P4)")
plt.legend()
plt.grid(alpha=0.3)
plt.savefig(os.path.join(results_dir, "example1_error_history.png"), dpi=300, bbox_inches="tight")
plt.close()

# 3) predicted solution
plt.figure(figsize=(8, 6))
contour = plt.contourf(X_plot, Y_plot, u_pred_plot, levels=50, cmap="viridis")
plt.colorbar(contour)
plt.xlabel("x")
plt.ylabel("y")
plt.title("Example 1: PINN Predicted Solution (P4)")
plt.savefig(os.path.join(results_dir, "example1_predicted_solution.png"), dpi=300, bbox_inches="tight")
plt.close()

# 4) exact solution
plt.figure(figsize=(8, 6))
contour = plt.contourf(X_plot, Y_plot, u_exact_plot, levels=50, cmap="viridis")
plt.colorbar(contour)
plt.xlabel("x")
plt.ylabel("y")
plt.title("Example 1: Exact Solution")
plt.savefig(os.path.join(results_dir, "example1_exact_solution.png"), dpi=300, bbox_inches="tight")
plt.close()

# 5) absolute error
plt.figure(figsize=(8, 6))
contour = plt.contourf(X_plot, Y_plot, error_plot, levels=50, cmap="hot")
plt.colorbar(contour)
plt.xlabel("x")
plt.ylabel("y")
plt.title("Example 1: Absolute Error |u - uθ|")
plt.savefig(os.path.join(results_dir, "example1_absolute_error.png"), dpi=300, bbox_inches="tight")
plt.close()

# save model
torch.save(
    {
        "model_state_dict": pinn.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "training_time": training_time,
    },
    os.path.join(results_dir, "example1_model.pth"),
)

# ------------------------------------------------------------------------------
# Final error evaluation (NEEDS gradients -> no torch.no_grad)
# ------------------------------------------------------------------------------

x_eval = torch.rand((2000, 1), dtype=torch.float32, requires_grad=True).to(device)
y_eval = torch.rand((2000, 1), dtype=torch.float32, requires_grad=True).to(device)
X_eval = torch.cat([x_eval, y_eval], dim=1)

u_pred_final = pinn(X_eval)
u_exact_final = exact_solution1(x_eval, y_eval)

l2_error_final, energy_error_final, l2_rel_final, energy_rel_final = compute_errors(
    u_pred_final, u_exact_final, x_eval, y_eval
)

print("\n" + "=" * 60)
print("FINAL RESULTS - EXAMPLE 1 (P4)")
print("=" * 60)
print(f"L2 Error           : {l2_error_final.item():.6e}")
print(f"Energy Error       : {energy_error_final.item():.6e}")
print(f"Relative L2 Error  : {l2_rel_final.item():.6e}")
print(f"Relative H1 Error  : {energy_rel_final.item():.6e}")
print(f"All plots & model saved in: {results_dir}")
print("=" * 60)


Using device: cuda
Training Example 1: u = (1/(2π²)) sin(πx) sin(πy) with P4 BCs
Epoch     0 | Total: 9.450e+01 | Int: 8.995e+01 | BC: 4.551e+00 | L2: 4.322e-01 | Energy: 9.489e-01 | LR: 1.00e-03
Epoch   100 | Total: 1.808e+00 | Int: 1.213e+00 | BC: 5.951e-01 | L2: 2.948e-01 | Energy: 9.516e-01 | LR: 1.00e-03
Epoch   200 | Total: 5.658e-01 | Int: 3.186e-01 | BC: 2.472e-01 | L2: 5.184e-01 | Energy: 6.908e-01 | LR: 1.00e-03
Epoch   300 | Total: 1.233e-01 | Int: 8.710e-02 | BC: 3.620e-02 | L2: 1.549e-01 | Energy: 2.211e-01 | LR: 1.00e-03
Epoch   400 | Total: 8.053e-02 | Int: 6.377e-02 | BC: 1.676e-02 | L2: 4.491e-01 | Energy: 4.959e-01 | LR: 1.00e-03
Epoch   500 | Total: 5.749e-02 | Int: 4.644e-02 | BC: 1.104e-02 | L2: 6.275e-01 | Energy: 6.584e-01 | LR: 1.00e-03
Epoch   600 | Total: 3.724e-02 | Int: 3.010e-02 | BC: 7.138e-03 | L2: 7.535e-01 | Energy: 7.820e-01 | LR: 1.00e-03
Epoch   700 | Total: 2.396e-02 | Int: 1.734e-02 | BC: 6.616e-03 | L2: 8.283e-01 | Energy: 8.619e-01 | LR: 1.00e-03

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
import os
import time

# ------------------------------------------------------------------------------
# Setup
# ------------------------------------------------------------------------------

torch.manual_seed(42)
np.random.seed(42)

results_dir = "pinn_biharmonic_results_pytorch_corrected"
os.makedirs(results_dir, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ------------------------------------------------------------------------------
# Neural Network
# ------------------------------------------------------------------------------

class BiharmonicPINN(nn.Module):
    # --- CORRECTED: Switched to SiLU for more stable high-order gradients ---
    def __init__(self, layers, activation=nn.SiLU()):
    # --- END CORRECTION ---
        super().__init__()
        self.layers = layers
        self.activation = activation

        self.linears = nn.ModuleList()
        for i in range(len(layers) - 1):
            self.linears.append(nn.Linear(layers[i], layers[i + 1]))
        self.init_weights()

    def init_weights(self):
        for linear in self.linears:
            nn.init.xavier_normal_(linear.weight)
            nn.init.constant_(linear.bias, 0.0)

    def forward(self, x):
        for i, linear in enumerate(self.linears):
            x = linear(x)
            if i < len(self.linears) - 1:
                x = self.activation(x)
        return x

# ------------------------------------------------------------------------------
# Differential operators (up to 4th order)
# ------------------------------------------------------------------------------

def compute_derivatives(u, x, y):
    """
    Compute derivatives up to 4th order using autograd.
    """
    # first derivatives
    u_x = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u),
                             create_graph=True, retain_graph=True)[0]
    u_y = torch.autograd.grad(u, y, grad_outputs=torch.ones_like(u),
                             create_graph=True, retain_graph=True)[0]

    # second derivatives
    u_xx = torch.autograd.grad(u_x, x, grad_outputs=torch.ones_like(u_x),
                              create_graph=True, retain_graph=True)[0]
    u_yy = torch.autograd.grad(u_y, y, grad_outputs=torch.ones_like(u_y),
                              create_graph=True, retain_graph=True)[0]
    u_xy = torch.autograd.grad(u_x, y, grad_outputs=torch.ones_like(u_x),
                              create_graph=True, retain_graph=True)[0]

    # third derivatives
    u_xxx = torch.autograd.grad(u_xx, x, grad_outputs=torch.ones_like(u_xx),
                               create_graph=True, retain_graph=True)[0]
    u_xxy = torch.autograd.grad(u_xx, y, grad_outputs=torch.ones_like(u_xx),
                               create_graph=True, retain_graph=True)[0]
    u_xyy = torch.autograd.grad(u_xy, y, grad_outputs=torch.ones_like(u_xy),
                               create_graph=True, retain_graph=True)[0]
    u_yyy = torch.autograd.grad(u_yy, y, grad_outputs=torch.ones_like(u_yy),
                               create_graph=True, retain_graph=True)[0]

    # fourth derivatives
    u_xxxx = torch.autograd.grad(u_xxx, x, grad_outputs=torch.ones_like(u_xxx),
                                  create_graph=True, retain_graph=True)[0]
    u_xxyy = torch.autograd.grad(u_xxy, y, grad_outputs=torch.ones_like(u_xxy),
                                  create_graph=True, retain_graph=True)[0]
    u_yyyy = torch.autograd.grad(u_yyy, y, grad_outputs=torch.ones_like(u_yyy),
                                  create_graph=True, retain_graph=True)[0]

    return (u, u_x, u_y, u_xx, u_yy, u_xy,
            u_xxx, u_xxy, u_xyy, u_yyy,
            u_xxxx, u_xxyy, u_yyyy)

def compute_biharmonic(u_xxxx, u_xxyy, u_yyyy):
    """Δ²u = u_xxxx + 2 u_xxyy + u_yyyy"""
    return u_xxxx + 2.0 * u_xxyy + u_yyyy

def compute_normals(x, y):
    """
    Outward unit normal on ∂Ω for Ω = (0,1)².
    """
    n_x = torch.zeros_like(x)
    n_y = torch.zeros_like(y)
    eps = 1e-6

    left   = x <= eps
    right  = x >= 1.0 - eps
    bottom = y <= eps
    top    = y >= 1.0 - eps

    n_x[left] = -1.0
    n_x[right] = 1.0
    n_y[bottom] = -1.0
    n_y[top] = 1.0

    return n_x, n_y

def compute_normal_derivatives(x, y,
                               u_x, u_y,
                               u_xx, u_yy, u_xy,
                               u_xxx, u_xxy, u_xyy, u_yyy):
    """
    P4 boundary operators:
        ∂u/∂n      = n · ∇u
        ∂(Δu)/∂n   = n · ∇(Δu)
    """
    n_x, n_y = compute_normals(x, y)

    # ∂u/∂n
    u_n = n_x * u_x + n_y * u_y

    # ∂(Δu)/∂n
    lap_x = u_xxx + u_xyy
    lap_y = u_xxy + u_yyy
    lap_n = n_x * lap_x + n_y * lap_y

    return u_n, lap_n

def compute_errors(u_pred, u_exact, x, y):
    """
    L2 and "energy" (H1-like) errors.
    """
    diff = u_pred - u_exact
    l2_error = torch.sqrt(torch.mean(diff ** 2))

    # gradients
    u_pred_x = torch.autograd.grad(u_pred, x, grad_outputs=torch.ones_like(u_pred),
                                  create_graph=True, retain_graph=True)[0]
    u_pred_y = torch.autograd.grad(u_pred, y, grad_outputs=torch.ones_like(u_pred),
                                  create_graph=True, retain_graph=True)[0]
    u_exact_x = torch.autograd.grad(u_exact, x, grad_outputs=torch.ones_like(u_exact),
                                   create_graph=True, retain_graph=True)[0]
    u_exact_y = torch.autograd.grad(u_exact, y, grad_outputs=torch.ones_like(u_exact),
                                   create_graph=True, retain_graph=True)[0]

    grad_diff_sq = (u_pred_x - u_exact_x) ** 2 + (u_pred_y - u_exact_y) ** 2
    grad_error = torch.sqrt(torch.mean(grad_diff_sq))

    energy_error = l2_error + grad_error

    # exact norms
    l2_norm_exact = torch.sqrt(torch.mean(u_exact ** 2))
    grad_exact_sq = u_exact_x ** 2 + u_exact_y ** 2
    grad_norm_exact = torch.sqrt(torch.mean(grad_exact_sq))
    energy_norm_exact = l2_norm_exact + grad_norm_exact

    l2_rel = l2_error / (l2_norm_exact + 1e-12)
    energy_rel = energy_error / (energy_norm_exact + 1e-12)

    return l2_error, energy_error, l2_rel, energy_rel

# ------------------------------------------------------------------------------
# Example 1: exact solution and source term
# ------------------------------------------------------------------------------

def exact_solution1(x, y):
    return (1.0 / (2.0 * np.pi ** 2)) * torch.sin(np.pi * x) * torch.sin(np.pi * y)

def source_term1(x, y):
    # Derived from u: Δ²u = 2π² sin(πx) sin(πy)
    return (2.0 * np.pi ** 2) * torch.sin(np.pi * x) * torch.sin(np.pi * y)

# ------------------------------------------------------------------------------
# Training data
# ------------------------------------------------------------------------------

N_int = 20000  # interior points
N_bc  = 6000   # boundary points

# interior in (0,1)²
x_int = torch.rand((N_int, 1), dtype=torch.float32, requires_grad=True).to(device)
y_int = torch.rand((N_int, 1), dtype=torch.float32, requires_grad=True).to(device)

# boundary points – 4 sides
N_bc_side = N_bc // 4

# bottom: y = 0
x_bc_bottom = torch.rand((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)
y_bc_bottom = torch.zeros_like(x_bc_bottom, requires_grad=True)

# top: y = 1
x_bc_top = torch.rand((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)
y_bc_top = torch.ones_like(x_bc_top, requires_grad=True)

# left: x = 0
x_bc_left = torch.zeros((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)
y_bc_left = torch.rand((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)

# right: x = 1
x_bc_right = torch.ones((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)
y_bc_right = torch.rand((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)

x_bc = torch.cat([x_bc_bottom, x_bc_top, x_bc_left, x_bc_right], dim=0)
y_bc = torch.cat([y_bc_bottom, y_bc_top, y_bc_left, y_bc_right], dim=0)

# ------------------------------------------------------------------------------
# Model, optimizer, scheduler, and test points
# ------------------------------------------------------------------------------

layers = [2, 84, 84, 84, 84, 1]
pinn = BiharmonicPINN(layers).to(device)

optimizer = torch.optim.Adam(pinn.parameters(), lr=1e-3)

# LR scheduler: reduce LR by factor 0.5 when loss plateaus
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.5,
    patience=500)

epochs = 20000
print_interval = 100

# histories (one entry every 100 epochs)
loss_history = []
int_loss_history = []
bc_loss_history = []
l2_error_history = []
energy_error_history = []

# fixed test points for monitoring errors
x_test = torch.rand((1000, 1), dtype=torch.float32, requires_grad=True).to(device)
y_test = torch.rand((1000, 1), dtype=torch.float32, requires_grad=True).to(device)

def train_step():
    optimizer.zero_grad()

    # ---- interior (PDE: Δ²u = f) ----
    X_int = torch.cat([x_int, y_int], dim=1)
    u_int = pinn(X_int)

    (u_int,
     u_x_int, u_y_int,
     u_xx_int, u_yy_int, u_xy_int,
     u_xxx_int, u_xxy_int, u_xyy_int, u_yyy_int,
     u_xxxx_int, u_xxyy_int, u_yyyy_int) = compute_derivatives(u_int, x_int, y_int)

    bih_int = compute_biharmonic(u_xxxx_int, u_xxyy_int, u_yyyy_int)
    f_int = source_term1(x_int, y_int)

    loss_int = torch.mean((bih_int - f_int) ** 2)

    # ---- boundary (P4: ∂u/∂n = g1, ∂Δu/∂n = g2) ----
    X_bc = torch.cat([x_bc, y_bc], dim=1)
    u_bc = pinn(X_bc)

    (u_bc,
     u_x_bc, u_y_bc,
     u_xx_bc, u_yy_bc, u_xy_bc,
     u_xxx_bc, u_xxy_bc, u_xyy_bc, u_yyy_bc,
     _, _, _) = compute_derivatives(u_bc, x_bc, y_bc)

    # predicted BC operators
    u_n_bc, lap_n_bc = compute_normal_derivatives(
        x_bc, y_bc,
        u_x_bc, u_y_bc,
        u_xx_bc, u_yy_bc, u_xy_bc,
        u_xxx_bc, u_xxy_bc, u_xyy_bc, u_yyy_bc
    )

    # exact BC from exact u
    u_exact_bc = exact_solution1(x_bc, y_bc)

    u_exact_x_bc = torch.autograd.grad(u_exact_bc, x_bc,
                                      grad_outputs=torch.ones_like(u_exact_bc),
                                      create_graph=True, retain_graph=True)[0]
    u_exact_y_bc = torch.autograd.grad(u_exact_bc, y_bc,
                                      grad_outputs=torch.ones_like(u_exact_bc),
                                      create_graph=True, retain_graph=True)[0]

    n_x_bc, n_y_bc = compute_normals(x_bc, y_bc)
    u_n_exact_bc = n_x_bc * u_exact_x_bc + n_y_bc * u_exact_y_bc

    u_exact_xx_bc = torch.autograd.grad(u_exact_x_bc, x_bc,
                                       grad_outputs=torch.ones_like(u_exact_x_bc),
                                       create_graph=True, retain_graph=True)[0]
    u_exact_yy_bc = torch.autograd.grad(u_exact_y_bc, y_bc,
                                       grad_outputs=torch.ones_like(u_exact_y_bc),
                                       create_graph=True, retain_graph=True)[0]
    lap_exact_bc = u_exact_xx_bc + u_exact_yy_bc

    lap_x_exact = torch.autograd.grad(lap_exact_bc, x_bc,
                                     grad_outputs=torch.ones_like(lap_exact_bc),
                                     create_graph=True, retain_graph=True)[0]
    lap_y_exact = torch.autograd.grad(lap_exact_bc, y_bc,
                                     grad_outputs=torch.ones_like(lap_exact_bc),
                                     create_graph=True, retain_graph=True)[0]

    lap_n_exact_bc = n_x_bc * lap_x_exact + n_y_bc * lap_y_exact

    loss_bc = torch.mean((u_n_bc - u_n_exact_bc) ** 2) + \
              torch.mean((lap_n_bc - lap_n_exact_bc) ** 2)

    # --- CORRECTED: Added lambda weighting to balance loss terms ---
    lambda_int = 1.0
    lambda_bc = 100.0  # Force optimizer to prioritize boundary
    total_loss = (lambda_int * loss_int) + (lambda_bc * loss_bc)
    # --- END CORRECTION ---
    
    total_loss.backward()
    optimizer.step()

    return total_loss, loss_int, loss_bc

print("=" * 80)
print("Training Example 1: u = (1/(2π²)) sin(πx) sin(πy) with P4 BCs")
print("=" * 80)

start_time = time.time()

for epoch in range(epochs):
    total_loss, loss_int, loss_bc = train_step()

    # Learning rate scheduler (use scalar loss)
    scheduler.step(total_loss.item())

    if epoch % print_interval == 0:
        # compute monitoring errors (need gradients -> no torch.no_grad())
        x_eval = x_test
        y_eval = y_test
        X_eval = torch.cat([x_eval, y_eval], dim=1)

        u_pred = pinn(X_eval)
        u_exact = exact_solution1(x_eval, y_eval)

        l2_err, en_err, l2_rel, en_rel = compute_errors(
            u_pred, u_exact, x_eval, y_eval
        )

        loss_history.append(total_loss.item())
        int_loss_history.append(loss_int.item())
        bc_loss_history.append(loss_bc.item())
        l2_error_history.append(l2_err.item())
        energy_error_history.append(en_err.item())

        current_lr = optimizer.param_groups[0]["lr"]

        print(f"Epoch {epoch:5d} | "
              f"Total: {total_loss.item():.3e} | "
              f"Int: {loss_int.item():.3e} | "
              f"BC: {loss_bc.item():.3e} | "
              f"L2: {l2_err.item():.3e} | "
              f"Energy: {en_err.item():.3e} | "
              f"LR: {current_lr:.2e}")

training_time = time.time() - start_time
print(f"\nTraining time (Example 1): {training_time:.2f} s")

# ------------------------------------------------------------------------------
# Post-processing: plots and model save
# ------------------------------------------------------------------------------

# grid for visualization (no gradients needed here)
x_plot = np.linspace(0, 1, 100)
y_plot = np.linspace(0, 1, 100)
X_plot, Y_plot = np.meshgrid(x_plot, y_plot)
X_flat = X_plot.flatten().reshape(-1, 1)
Y_flat = Y_plot.flatten().reshape(-1, 1)
XY_plot = torch.tensor(np.hstack([X_flat, Y_flat]), dtype=torch.float32).to(device)

with torch.no_grad():
    u_pred_plot = pinn(XY_plot).cpu().numpy().reshape(100, 100)

u_exact_plot = exact_solution1(
    torch.tensor(X_flat, dtype=torch.float32).to(device),
    torch.tensor(Y_flat, dtype=torch.float32).to(device)
).detach().cpu().numpy().reshape(100, 100)

error_plot = np.abs(u_pred_plot - u_exact_plot)

# 1) loss history
plt.figure(figsize=(10, 6))
plt.semilogy(loss_history, label="Total Loss", linewidth=2)
plt.semilogy(int_loss_history, label="Interior Loss", linewidth=2)
plt.semilogy(bc_loss_history, label="Boundary Loss", linewidth=2)
plt.xlabel("Index (every 100 epochs)")
plt.ylabel("Loss")
plt.title("Example 1: Training Loss History (P4)")
plt.legend()
plt.grid(alpha=0.3)
plt.savefig(os.path.join(results_dir, "example1_loss_history.png"), dpi=300, bbox_inches="tight")
plt.close()

# 2) error history
plt.figure(figsize=(10, 6))
plt.semilogy(l2_error_history, label="L2 Error", linewidth=2)
plt.semilogy(energy_error_history, label="Energy Error", linewidth=2)
plt.xlabel("Index (every 100 epochs)")
plt.ylabel("Error")
plt.title("Example 1: Error History (P4)")
plt.legend()
plt.grid(alpha=0.3)
plt.savefig(os.path.join(results_dir, "example1_error_history.png"), dpi=300, bbox_inches="tight")
plt.close()

# 3) predicted solution
plt.figure(figsize=(8, 6))
contour = plt.contourf(X_plot, Y_plot, u_pred_plot, levels=50, cmap="viridis")
plt.colorbar(contour)
plt.xlabel("x")
plt.ylabel("y")
plt.title("Example 1: PINN Predicted Solution (P4)")
plt.savefig(os.path.join(results_dir, "example1_predicted_solution.png"), dpi=300, bbox_inches="tight")
plt.close()

# 4) exact solution
plt.figure(figsize=(8, 6))
contour = plt.contourf(X_plot, Y_plot, u_exact_plot, levels=50, cmap="viridis")
plt.colorbar(contour)
plt.xlabel("x")
plt.ylabel("y")
plt.title("Example 1: Exact Solution")
plt.savefig(os.path.join(results_dir, "example1_exact_solution.png"), dpi=300, bbox_inches="tight")
plt.close()

# 5) absolute error
plt.figure(figsize=(8, 6))
contour = plt.contourf(X_plot, Y_plot, error_plot, levels=50, cmap="hot")
plt.colorbar(contour)
plt.xlabel("x")
plt.ylabel("y")
plt.title("Example 1: Absolute Error |u - uθ|")
plt.savefig(os.path.join(results_dir, "example1_absolute_error.png"), dpi=300, bbox_inches="tight")
plt.close()

# save model
torch.save(
    {
        "model_state_dict": pinn.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "training_time": training_time,
    },
    os.path.join(results_dir, "example1_model.pth"),
)

# ------------------------------------------------------------------------------
# Final error evaluation (NEEDS gradients -> no torch.no_grad)
# ------------------------------------------------------------------------------

x_eval = torch.rand((2000, 1), dtype=torch.float32, requires_grad=True).to(device)
y_eval = torch.rand((2000, 1), dtype=torch.float32, requires_grad=True).to(device)
X_eval = torch.cat([x_eval, y_eval], dim=1)

u_pred_final = pinn(X_eval)
u_exact_final = exact_solution1(x_eval, y_eval)

l2_error_final, energy_error_final, l2_rel_final, energy_rel_final = compute_errors(
    u_pred_final, u_exact_final, x_eval, y_eval
)

print("\n" + "=" * 60)
print("FINAL RESULTS - EXAMPLE 1 (P4)")
print("=" * 60)
print(f"L2 Error            : {l2_error_final.item():.6e}")
print(f"Energy Error        : {energy_error_final.item():.6e}")
print(f"Relative L2 Error   : {l2_rel_final.item():.6e}")
print(f"Relative H1 Error   : {energy_rel_final.item():.6e}")
print(f"All plots & model saved in: {results_dir}")
print("=" * 60)

Using device: cuda
Training Example 1: u = (1/(2π²)) sin(πx) sin(πy) with P4 BCs
Epoch     0 | Total: 5.899e+02 | Int: 9.749e+01 | BC: 4.924e+00 | L2: 1.701e-02 | Energy: 1.315e-01 | LR: 1.00e-03
Epoch   100 | Total: 5.437e+01 | Int: 2.273e+01 | BC: 3.163e-01 | L2: 3.864e-01 | Energy: 5.067e-01 | LR: 1.00e-03
Epoch   200 | Total: 2.055e+01 | Int: 8.719e+00 | BC: 1.184e-01 | L2: 2.574e-01 | Energy: 3.155e-01 | LR: 1.00e-03
Epoch   300 | Total: 1.476e+01 | Int: 7.224e+00 | BC: 7.531e-02 | L2: 1.940e-01 | Energy: 2.262e-01 | LR: 1.00e-03
Epoch   400 | Total: 8.733e+00 | Int: 5.153e+00 | BC: 3.579e-02 | L2: 1.105e-01 | Energy: 1.370e-01 | LR: 1.00e-03
Epoch   500 | Total: 4.212e-01 | Int: 2.438e-01 | BC: 1.774e-03 | L2: 2.557e-02 | Energy: 4.247e-02 | LR: 1.00e-03
