# ADS Thesis Adaptive PINNs


## Import Libraries, PINN and Plotting functions

In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import time
from scipy.linalg import solve
from scipy.stats import norm
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from scipy.interpolate import interp1d

#### Define Test Cases

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------------
# Case 1: Single cosine wave
# -------------------------
def gen_testdata_1(n=1000):
    x = np.linspace(0, 1, n)[:, None]
    u = 10 * np.cos(np.pi * (x - 0.5) / 2)
    return x, u

def gen_bc_1(n=200):
    x = np.vstack([np.zeros((n // 2, 1)), np.ones((n - n // 2, 1))])
    u = 10 * np.cos(np.pi * (x - 0.5) / 2)
    return x, u

def pde_residual_1(model, x_tensor):
    x = x_tensor.clone().detach().requires_grad_(True)
    u = model(x)
    u_x = torch.autograd.grad(u, x, torch.ones_like(u), create_graph=True)[0]
    rhs = -5 * np.pi * torch.sin(np.pi * (x - 0.5) / 2)
    return u_x - rhs


# -------------------------
# Case 2: Steep V-shape (smooth)
# -------------------------
def gen_testdata_2(n=1000):
    k, x0 = 200.0, 0.5
    x = np.linspace(0, 1, n)[:, None]
    C0 = np.log(np.cosh(-k * x0)) / k
    u = (np.log(np.cosh(k * (x - x0))) / k) - C0
    return x, u

def gen_bc_2(n=200):
    k, x0 = 200.0, 0.5
    x = np.vstack([np.zeros((n // 2, 1)), np.ones((n - n // 2, 1))])
    C0 = np.log(np.cosh(-k * x0)) / k
    u = (np.log(np.cosh(k * (x - x0))) / k) - C0
    return x, u

def pde_residual_2(model, x_tensor):
    k, x0 = 200.0, 0.5
    x = x_tensor.clone().detach().requires_grad_(True)
    u = model(x)
    u_x = torch.autograd.grad(u, x, torch.ones_like(u), create_graph=True)[0]
    rhs = torch.tanh(k * (x - x0))
    return u_x - rhs


# -------------------------
# Case 3: Steep tanh ramp
# -------------------------
def gen_testdata_3(n=1000):
    k, x0 = 150.0, 0.6
    x = np.linspace(0, 1, n)[:, None]
    C0 = 0.5 * (1 + np.tanh(-k * x0))
    u = 0.5 * (1 + np.tanh(k * (x - x0))) - C0
    return x, u

def gen_bc_3(n=200):
    k, x0 = 150.0, 0.6
    x = np.vstack([np.zeros((n // 2, 1)), np.ones((n - n // 2, 1))])
    C0 = 0.5 * (1 + np.tanh(-k * x0))
    u = 0.5 * (1 + np.tanh(k * (x - x0))) - C0
    return x, u

def pde_residual_3(model, x_tensor):
    k, x0 = 150.0, 0.6
    x = x_tensor.clone().detach().requires_grad_(True)
    u = model(x)
    u_x = torch.autograd.grad(u, x, torch.ones_like(u), create_graph=True)[0]
    rhs = 0.5 * k * (1 / torch.cosh(k * (x - x0))**2)
    return u_x - rhs


# -------------------------
# Case 4: Exponential increase
# -------------------------
def gen_testdata_4(n=1000):
    u0, u1, k = 0.0, 5.0, 10.0
    x = np.linspace(0, 1, n)[:, None]
    u = u0 + (u1 - u0) * (1 - np.exp(k * x)) / (1 - np.exp(k))
    return x, u

def gen_bc_4(n=200):
    u0, u1, k = 0.0, 5.0, 10.0
    x = np.vstack([np.zeros((n // 2, 1)), np.ones((n - n // 2, 1))])
    u = u0 + (u1 - u0) * (1 - np.exp(k * x)) / (1 - np.exp(k))
    return x, u

def pde_residual_4(model, x_tensor):
    u0, u1, k = 0.0, 5.0, 10.0
    L = 1.0
    x = x_tensor.clone().detach().requires_grad_(True)
    u = model(x)
    u_x = torch.autograd.grad(u, x, torch.ones_like(u), create_graph=True)[0]
    rhs = (u1 - u0) * (-k * torch.exp(k * x)) / (1 - torch.exp(torch.tensor(k * L)))
    return u_x - rhs


# -------------------------
# Case 5: Multi-frequency decaying cosine
# -------------------------
def gen_testdata_5(n=1000):
    x = np.linspace(0, 1, n)[:, None]
    u = np.exp(-x) * np.cos(5 * np.pi * (x - 0.5))
    return x, u

def gen_bc_5(n=200):
    x = np.vstack([np.zeros((n // 2, 1)), np.ones((n - n // 2, 1))])
    u = np.exp(-x) * np.cos(5 * np.pi * (x - 0.5))
    return x, u

def pde_residual_5(model, x_tensor):
    x = x_tensor.clone().detach().requires_grad_(True)
    u = model(x)
    u_x = torch.autograd.grad(u, x, torch.ones_like(u), create_graph=True)[0]
    expm = torch.exp(-x)
    arg = 5 * np.pi * (x - 0.5)
    rhs = -expm * torch.cos(arg) + expm * 5 * np.pi * torch.sin(arg)
    return u_x - rhs


# -------------------------
# Case 6: Piecewise constant (discontinuous)
# -------------------------
def gen_testdata_6(n=1000):
    x = np.linspace(0, 1, n)[:, None]
    u = np.where(x < 0.5, 10.0, 1.0)
    return x, u

def gen_bc_6(n=200):
    x = np.vstack([np.zeros((n // 2, 1)), np.ones((n - n // 2, 1))])
    u = np.where(x < 0.5, 10.0, 1.0)
    return x, u

def pde_residual_6(model, x_tensor):
    x = x_tensor.clone().detach().requires_grad_(True)
    u = model(x)
    u_x = torch.autograd.grad(u, x, torch.ones_like(u), create_graph=True)[0]
    return u_x


# -------------------------
# Case Dictionary
# -------------------------
ode_cases = {
    1: (gen_testdata_1, gen_bc_1, pde_residual_1),
    2: (gen_testdata_2, gen_bc_2, pde_residual_2),
    3: (gen_testdata_3, gen_bc_3, pde_residual_3),
    4: (gen_testdata_4, gen_bc_4, pde_residual_4),
    5: (gen_testdata_5, gen_bc_5, pde_residual_5),
    6: (gen_testdata_6, gen_bc_6, pde_residual_6)
}


#### Define Plotting Functions

In [None]:

def plot_pde_bc_loss (pde_hist, bc_hist):
    plt.figure(figsize=(12,8))
    plt.plot(pde_hist, label='PDE Loss', color='blue')
    plt.plot(bc_hist, label='BC Loss', color='orange')
    plt.yscale('log')
    plt.yticks(fontsize=14)
    plt.xticks(fontsize=14)
    plt.xlabel('Epochs', fontsize=16)
    plt.ylabel('Loss', fontsize=16)
    plt.legend(fontsize=16)
    plt.grid(ls='--', alpha=0.4)
    plt.tight_layout()
    plt.show()

def plot_loss_comparison(results, metric="loss", case_name="Case 3"):
    """
    results: list of dicts with keys:
      - 'name': str
      - 'include_boundaries': bool
      - 'adaptive': bool
      - 'loss_hist': list
      - 'pde_loss_hist': list
      - 'bc_loss_hist': list
      - 'l2_hist': list
    metric: "loss", "pde", "bc", or "l2"
    """
    plt.figure(figsize=(10, 6))
    for result in results:
        if metric == "loss":
            y = result['loss_hist']
            label = f"{result['name']} | BC {'✓' if result['include_boundaries'] else '✗'}"
        elif metric == "pde":
            y = result['pde_loss_hist']
            label = f"{result['name']} | PDE loss | BC {'✓' if result['include_boundaries'] else '✗'}"
        elif metric == "bc":
            y = result['bc_loss_hist']
            label = f"{result['name']} | BC loss | BC {'✓' if result['include_boundaries'] else '✗'}"
        elif metric == "l2":
            y = result['l2_hist']
            label = f"{result['name']} | L2 error | BC {'✓' if result['include_boundaries'] else '✗'}"
        else:
            continue
        plt.plot(y, label=label)
    
    titles = {
        "loss": "Total Loss (PDE + BC)",
        "pde": "PDE Residual Loss",
        "bc": "Boundary Condition Loss",
        "l2": "L2 Relative Error on Test Set"
    }
    plt.title(f"{titles.get(metric, 'Loss')} | {case_name}")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.yscale("log")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

def plot_solution_snapshots_compare(
    x_test: np.ndarray,
    u_test: np.ndarray,
    pred_hist_adapt: list[np.ndarray],
    pred_hist_nonadapt: list[np.ndarray],
    ep_hist_pred: list[int],
    n_cols: int = 4
):
    """
    Grid of solution snapshots comparing adaptive vs non-adaptive PINN.

    Each subplot at epoch ep_hist_pred[i] shows:
      • True u(x)      (black solid)
      • Adaptive û(x) (blue dashed)
      • Non‐adaptive û(x) (red dash‐dot)

    Legend is placed in 3 columns above the grid.
    """
    import math
    x = x_test.flatten()
    y_true = u_test.flatten()

    # ensure we only loop over the common length
    n_snap = min(len(pred_hist_adapt),
                 len(pred_hist_nonadapt),
                 len(ep_hist_pred))
    if n_snap == 0:
        raise ValueError("No snapshots to plot!")

    # set up grid
    n_rows = math.ceil(n_snap / n_cols)
    fig, axes = plt.subplots(n_rows, n_cols,
                             figsize=(4*n_cols, 3*n_rows),
                             squeeze=False)

    # plot each panel
    for idx in range(n_snap):
        ax = axes[idx//n_cols][idx % n_cols]
        ep = ep_hist_pred[idx]
        y_ad = pred_hist_adapt[idx].flatten()
        y_na = pred_hist_nonadapt[idx].flatten()

        # only label on the first panel; we'll pull these handles for the global legend
        if idx == 0:
            ax.plot(x, y_true, 'k-',  lw=1.5, label='True')
            ax.plot(x, y_ad,   'b--', lw=1.5, label='Adaptive')
            ax.plot(x, y_na,  'r-.', lw=1.5, label='Non-adaptive')
        else:
            ax.plot(x, y_true, 'k-',  lw=1.5)
            ax.plot(x, y_ad,   'b--', lw=1.5)
            ax.plot(x, y_na,  'r-.', lw=1.5)

        ax.set_title(f"Epoch {ep}")
        ax.set_xlabel('x')
        ax.set_ylabel('u(x)')
        ax.grid(alpha=0.3)

    # delete any extra axes
    for j in range(n_snap, n_rows*n_cols):
        fig.delaxes(axes[j//n_cols][j % n_cols])

    # grab handles from the first axis
    handles, labels = axes[0][0].get_legend_handles_labels()

    # place a 3-column legend above all subplots
    fig.legend(
        handles, labels,
        loc='lower center',
        bbox_to_anchor=(0.5, 0.98),
        ncol=3,
        frameon=False,
        fontsize=14
    )
    # leave a bit less room at top for the legend
    fig.tight_layout(rect=[0, 0, 1, 0.99])
    plt.show()


def plot_continuous_collocation_error_evolution(
    pos_hist: list[np.ndarray],
    ep_hist: list[int],
    pred_hist: list[np.ndarray],
    res_hist: list[np.ndarray],
    u_test_fn: callable  # function that gives u_true(x)
):
    """
    Plots a continuous heatmap of the evolution of two error metrics over a common spatial grid.
    
    For each epoch, the collocation error and PDE residual (both taken at the collocation nodes)
    are interpolated onto a common x-grid and then plotted as two heatmaps:
      - Top: absolute PDE residual error.
      - Bottom: solution prediction error.
    
    Parameters
    ----------
    pos_hist : list of np.ndarray
        Node positions at each epoch.
    ep_hist : list of int
        Epoch numbers corresponding to pos_hist.
    pred_hist : list of np.ndarray
        Model predictions at the collocation nodes per epoch.
    res_hist : list of np.ndarray
        PDE residuals at the collocation nodes per epoch.
    u_test_fn : callable
        Function returning u_true(x) for a given array x.
    """
    import numpy as np
    import matplotlib.pyplot as plt

    # Combine all positions to get a global common grid
    all_x = np.concatenate([p.flatten() for p in pos_hist])
    x_min, x_max = all_x.min(), all_x.max()    
    x_grid = np.linspace(x_min, x_max, 500)

    # Prepare arrays for the two error metrics, one row per epoch.
    residual_field = []
    sol_error_field = []
    
    for p, pred, res in zip(pos_hist, pred_hist, res_hist):
        x_i = p.flatten()
        # Interpolate PDE residual (absolute value) onto common grid.
        r_i = np.abs(res.flatten())
        r_interp = np.interp(x_grid, x_i, r_i, left=np.nan, right=np.nan)
        residual_field.append(r_interp)
        
        # Compute solution error: compare prediction at collocation nodes with u_true.
        # Here we interpolate the prediction onto the common grid.
        u_pred_i = np.interp(x_grid, x_i, pred.flatten(), left=np.nan, right=np.nan)
        u_true_grid = u_test_fn(x_grid)
        sol_err = np.abs(u_pred_i - u_true_grid)
        sol_error_field.append(sol_err)
    
    residual_field = np.array(residual_field)
    sol_error_field = np.array(sol_error_field)
    epochs = np.array(ep_hist)
    
    # Plot the continuous heatmaps.
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(18, 10), sharex=True)
    
    # Plot residual field (top)
    im1 = ax1.imshow(
        residual_field,
        aspect='auto',
        extent=[x_min, x_max, epochs[-1], epochs[0]],
        cmap='magma'
    )
    ax1.set_title("|PDE residual| evolution")
    ax1.set_ylabel("Epoch")
    fig.colorbar(im1, ax=ax1, label="|r|")
    
    # Plot solution error field (bottom)
    im2 = ax2.imshow(
        sol_error_field,
        aspect='auto',
        extent=[x_min, x_max, epochs[-1], epochs[0]],
        cmap='magma'
    )
    ax2.set_title("Solution error evolution")
    ax2.set_xlabel("x")
    ax2.set_ylabel("Epoch")
    fig.colorbar(im2, ax=ax2, label="|u_pred - u_true|")
    
    plt.tight_layout()
    plt.show()

def plot_discrete_collocation_error_evolution(
    pos_hist: list[np.ndarray],
    ep_hist: list[int],
    pred_hist: list[np.ndarray],
    res_hist: list[np.ndarray],
    u_test_fn: callable  # function that gives u_true(x)
):
    """
    Discrete scatter-heatmap of collocation node positions over training epochs,
    colored by (1) |PDE residual| and (2) |solution error|.

    Parameters
    ----------
    pos_hist : list of np.ndarray
        Node positions at each epoch (shape varies).
    ep_hist : list of int
        Epochs corresponding to pos_hist.
    pred_hist : list of np.ndarray
        Predictions at pos_hist per epoch.
    res_hist : list of np.ndarray
        Residuals at pos_hist per epoch.
    u_test_fn : callable
        Function that returns u_true(x) for arbitrary x.
    """
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(18, 10))

    for x_i, ep, u_pred_i, r_i in zip(pos_hist, ep_hist, pred_hist, res_hist):
        x_i = x_i.flatten()
        u_true_i = u_test_fn(x_i)
        err_i = np.abs(u_pred_i.flatten() - u_true_i)
        r_abs = np.abs(r_i.flatten())

        # Scatter plots (epoch on y-axis, x on x-axis, color by value)
        ax1.scatter(x_i, np.full_like(x_i, ep), c=r_abs, s=20, cmap='magma', alpha=0.8)
        ax2.scatter(x_i, np.full_like(x_i, ep), c=err_i, s=20, cmap='magma', alpha=0.8)

    for ax, title, label in zip(
        (ax1, ax2),
        ("|PDE residual|", "Solution error"),
        ("|r|", "|uₚᵣₑd − u_true|")
    ):
        ax.set_xlabel('x')
        ax.set_ylabel('Epoch')
        ax.set_title(title)
        ax.grid(ls='--', alpha=0.3)

    fig.colorbar(ax1.collections[0], ax=ax1, label='|r|')
    fig.colorbar(ax2.collections[0], ax=ax2, label='|uₚᵣₑd − u_true|')

    plt.tight_layout()
    plt.show()

def plot_collocation_error_evolution(
    pos_hist, ep_hist,
    model,
    pde_residual_fn,
    x_test, u_test,
    device='cpu'
):
    """
    Scatter‐heatmap of collocation nodes over epochs,
    colored by (1) |PDE residual| and (2) |solution error|.
    
    Parameters
    ----------
    pos_hist : list of 1D np.ndarray
        Node‐positions at each recorded epoch.
    ep_hist : list of ints
        Corresponding epochs.
    model : nn.Module
        Your trained PINN (should be in eval() mode).
    pde_residual_fn : callable
        Residual function: res = pde_residual_fn(model, x_tensor).
    x_test, u_test : np.ndarray
        Dense test grid and true solution on it.
    device : str
        'cpu' or 'cuda'.
    """
    model.eval()
    # build interpolant for analytic u
    interp_u = interp1d(x_test.flatten(), u_test.flatten(), kind='cubic',
                        fill_value='extrapolate')
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(18,10))
    
    for pts, epoch in zip(pos_hist, ep_hist):
        # prepare tensor for residual
        x_dom = pts[:,None]
        x_t = torch.tensor(x_dom, dtype=torch.float32, device=device, requires_grad=True)
        # compute abs‐residual
        with torch.enable_grad():
            r = pde_residual_fn(model, x_t).abs().detach().cpu().numpy().flatten()
        # compute solution error
        with torch.no_grad():
            u_pred = model(torch.tensor(x_dom, dtype=torch.float32, device=device)).cpu().numpy().flatten()
        u_true = interp_u(pts)
        sol_err = np.abs(u_pred - u_true)
        
        # scatter for residual
        sc1 = ax1.scatter(pts, np.full_like(pts, epoch),
                          c=r, s=20, cmap='magma', alpha=0.8)
        # scatter for sol‐error
        sc2 = ax2.scatter(pts, np.full_like(pts, epoch),
                          c=sol_err, s=20, cmap='magma', alpha=0.8)
    
    # labels & colorbars
    for ax, title in zip((ax1, ax2), ("|PDE residual|", "Solution error")):
        ax.set_xlabel('x')
        ax.set_ylabel('Epoch')
        ax.set_title(title)
        ax.grid(ls='--', alpha=0.3)
    fig.colorbar(sc1, ax=ax1, label='|r|')
    fig.colorbar(sc2, ax=ax2, label='|uₚᵣₑd–u_true|')
    
    plt.tight_layout()
    plt.show()


# --- Plotting utils (unchanged) ---
def plot_training_and_solution(loss_hist, l2_hist, t_test, u_test, model):
    model.eval()
    with torch.no_grad():
        u_pred = model(torch.tensor(t_test, dtype=torch.float32, device=device)).cpu().numpy()
    plt.figure(figsize=(18,6))
    ax1 = plt.subplot(1,3,1)
    ax1.plot(loss_hist); ax1.set_yscale('log'); ax1.set_title('Training Loss', fontsize=14)
    ax2 = plt.subplot(1,3,2)
    ax2.plot(l2_hist); ax2.set_yscale('log'); ax2.set_title('Test Loss', fontsize=14)
    ax3 = plt.subplot(1,3,3)
    ax3.plot(t_test,u_test,'k-',label='True')
    ax3.plot(t_test,u_pred,'r--',label='Pred'); ax3.set_title('Solution')
    ax3.legend(); plt.tight_layout(); plt.show()

def plot_test_error(l2_hist_ad, l2_hist_na, epochs_ad, epochs_na):
    plt.figure(figsize=(10,5))
    plt.plot(epochs_ad, l2_hist_ad, 'b-', label='Adaptive PINN')
    plt.plot(epochs_na, l2_hist_na, 'r--', label='Non-adaptive PINN')
    plt.yscale('log')
    plt.xlabel('Epochs', fontsize=14)
    plt.ylabel('L2 Error', fontsize=14)
    plt.legend()
    plt.grid(ls='--', alpha=0.3)
    plt.tight_layout()
    plt.show()

def plot_collocation_evolution(pos_hist, ep_hist):
    plt.figure(figsize=(18,6))
    for pts,epoch in zip(pos_hist,ep_hist):
        plt.scatter(pts, np.full_like(pts,epoch),s=5,alpha=0.6)
    plt.xlabel('x', fontsize=14); plt.ylabel('Epoch', fontsize=14); plt.title('Collocation Evolution',  fontsize=16)
    plt.tight_layout(); plt.show()


def plot_adaptation_density(
    model, x_dom_before, x_dom_after,
    pde_residual_fn, device='cpu',
    smoothing_window=5, density_bins=30, epoch=None
):
    """
    Visualize PDE residual and collocation-point densities before/after adaptation.

    - Smoothed line of |PDE residual| over pre-adaptation points.
    - Filled density plots for x_dom_before (black) and x_dom_after (red).

    Parameters:
    - model: PINN (in eval mode).
    - x_dom_before: np.ndarray (n_dom,1)
    - x_dom_after:  np.ndarray (n_dom,1)
    - pde_residual_fn: function(model, x_tensor) -> residual tensor
    - device: 'cpu' or 'cuda'
    - smoothing_window: int for moving-average smoothing
    - density_bins: int number of histogram bins
    """
    # Prepare and sort
    x = x_dom_before.flatten()
    idx = np.argsort(x)
    x_sorted = x[idx]

    # Compute residual
    x_t = torch.tensor(x_sorted[:, None], dtype=torch.float32, device=device).requires_grad_(True)
    r_t = pde_residual_fn(model, x_t).abs()       # no torch.no_grad here
    r = r_t.detach().cpu().numpy().flatten()

    # Smooth
    if smoothing_window > 1:
        kern = np.ones(smoothing_window) / smoothing_window
        r = np.convolve(r, kern, mode='same')

    # Densities
    bins = np.linspace(0.0, 1.0, density_bins+1)
    d_before, _ = np.histogram(x_dom_before.flatten(), bins=bins, density=True)
    d_after,  _ = np.histogram(x_dom_after.flatten(),  bins=bins, density=True)
    centers = 0.5 * (bins[:-1] + bins[1:])

    # Plot
    fig, ax1 = plt.subplots(figsize=(8,4))
    ax1.plot(x_sorted, r, color='black', lw=2, label='Smoothed |PDE residual|')
    ax1.set_xlabel('x'); ax1.set_ylabel('Residual', color='black')
    ax1.tick_params(axis='y', labelcolor='black')

    ax2 = ax1.twinx()
    ax2.fill_between(centers, d_before, color='black', alpha=0.3, label='Density before')
    ax2.fill_between(centers, d_after,  color='red',   alpha=0.3, label='Density after')
    ax2.set_ylabel('Density')
    ax2.set_ylim(0.25, None)
    ax2.tick_params(axis='y', color='tab:gray')

    # Combined legend
    h1, l1 = ax1.get_legend_handles_labels()
    h2, l2 = ax2.get_legend_handles_labels()
    ax1.legend(h1+h2, l1+l2, loc='upper left')

    ax1.grid(ls='--', alpha=0.3)
    plt.title(f"Adaptive Node Movement at epoch: {epoch}")
    plt.tight_layout()
    # save plot
    plt.savefig(f"plots/node_movement_at-{epoch}.png", dpi=300)
    plt.show()




def plot_spatial_loss_evolution(pde_grid, l2_grid, x_eval_grid, ep_hist):
    pde_arr = np.array(pde_grid)
    l2_arr = np.array(l2_grid)
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(18, 10), sharex=True)

    im1 = ax1.imshow(
        np.flipud(pde_arr),
        aspect='auto',
        extent=[x_eval_grid.min(), x_eval_grid.max(), ep_hist[0], ep_hist[-1]],
        cmap='magma',
        # norm=plt.matplotlib.colors.LogNorm(vmin=1e-40, vmax=1e1)
        norm=plt.matplotlib.colors.LogNorm()

    )
    ax1.set_title("PDE Residual Loss |r(x)|²", fontsize=16)
    ax1.set_ylabel("Epoch", fontsize=14)
    fig.colorbar(im1, ax=ax1, label="PDE Loss")

    im2 = ax2.imshow(
        np.flipud(l2_arr),
        aspect='auto',
        extent=[x_eval_grid.min(), x_eval_grid.max(), ep_hist[0], ep_hist[-1]],
        cmap='magma',
        # norm=plt.matplotlib.colors.LogNorm(vmin=1e-40, vmax=1e1)
        norm=plt.matplotlib.colors.LogNorm()
    )
    ax2.set_title("Test L2 Loss |u_pred(x) - u_true(x)|²", fontsize=16)
    ax2.set_xlabel("x", fontsize=14)
    ax2.set_ylabel("Epoch", fontsize=14)
    fig.colorbar(im2, ax=ax2, label="L2 Error")

    plt.tight_layout()
    plt.show()

#### Define PINN Class

In [None]:
class PINN(nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.act = torch.tanh
        self.layers = nn.ModuleList([
            nn.Linear(layers[i], layers[i+1])
            for i in range(len(layers) - 1)
        ])
    def forward(self, x):
        for layer in self.layers[:-1]:
            x = self.act(layer(x))
        return self.layers[-1](x)

------------
## Adaptive Node Moving function

In [None]:
def node_moving_1d(
    coordinate: np.ndarray,
    estimated_error: np.ndarray,
    boundary_nodes_id: np.ndarray,
    ratio_nodal_distance: float = 5,
    relaxation: float = 1.0,
    stifness: float = 0.5
) -> np.ndarray:
    """
    Reallocates 1D nodes using a spring model based on error estimates.

    This function adjusts the positions of nodes along a 1D domain such that the spacing
    between nodes becomes proportional to a given error estimate, using a spring model.
    Parameters
    ----------
    coordinate : np.ndarray
        Array of node coordinates along the 1D domain.

    estimated_error : np.ndarray
        Estimated error values used to redistribute the nodes. Higher errors lead to denser spacing.

    boundary_nodes_id : np.ndarray
        Indices of boundary nodes whose positions remain fixed during the redistribution.

    ratio_nodal_distance : float, optional
        Desired maximum ratio of the largest to the smallest nodal spacing, by default 5.

    Returns
    -------
    recollocated_coordinate : np.ndarray
        The reallocated node coordinates after applying the node-moving algorithm.
    """
    # eps = 1e-4
    # estimated_error = np.clip(estimated_error, eps, None)

    b = np.log(ratio_nodal_distance) / np.log(
        np.max(estimated_error) / np.min(estimated_error)
    )
    # Adjust gradients
    adjusted_error_estimate = estimated_error**b
    # Calculate ke values
    ke = stifness * (adjusted_error_estimate[:-1] + adjusted_error_estimate[1:])

    num_nodes = len(coordinate)
    inside_domain_node_id = [i for i in range(num_nodes) if i not in boundary_nodes_id]

    # Create the K matrix using the calculated ke
    k1 = np.diag(np.hstack([0, ke])) + np.diag(np.hstack([ke, 0]))
    k2 = -np.diag(ke, -1)
    k3 = -np.diag(ke, 1)
    k = k1 + k2 + k3

    reg = 1e-6
    k += np.eye(k.shape[0]) * reg

    f = np.zeros(len(coordinate))

    # Modify K to impose boundary conditions
    for i in boundary_nodes_id:
        f = f - np.dot(
            k[:, i],
            coordinate[i],
        )
    k = np.delete(k, boundary_nodes_id, axis=0)
    k = np.delete(k, boundary_nodes_id, axis=1)
    f = np.delete(f, boundary_nodes_id)

    recollocate_inside_coordinate = solve(k, f)

    # Reconstruct the full coordinate array including fixed boundary nodes
    recollocated_coordinate = coordinate.copy()
    recollocated_coordinate[inside_domain_node_id] = recollocate_inside_coordinate

    if relaxation > 0.0:
        recollocated_coordinate = (
            coordinate + relaxation * (recollocated_coordinate - coordinate)
        )

    return recollocated_coordinate

### (Extra) Random Resampling

In [None]:
def random_resampling_1d(
    coordinate: np.ndarray,
    boundary_nodes_id: np.ndarray,
    **kwargs  # absorb other arguments like ratio_nodal_distance, stifness
) -> np.ndarray:
    """
    Randomly resamples all non-boundary nodes within [x_min, x_max].

    Parameters
    ----------
    coordinate : np.ndarray
        Current 1D node coordinates.

    boundary_nodes_id : np.ndarray
        Indices of fixed boundary nodes.

    Returns
    -------
    np.ndarray
        New coordinates array with random interior resampling.
    """
    num_nodes = len(coordinate)
    x_min, x_max = coordinate.min(), coordinate.max()

    interior_ids = [i for i in range(num_nodes) if i not in boundary_nodes_id]
    new_coords = coordinate.copy()

    # Random uniform resampling for interior nodes
    resampled = np.random.uniform(x_min, x_max, size=len(interior_ids))

    new_coords[interior_ids] = resampled
    return new_coords

### The Train Loop

In [None]:
def train(model, optimizer,
          gen_testdata, gen_bc, pde_residual,
          n_dom=200, n_bc=200, n_test=1000,
          epochs=1000, refine_every=None, relaxation=1.0, include_boundaries=False, adaptive_fn=None, stifness=0.5):

    # generate domain data, now includes boundaries 0 and 1
    x_dom = np.random.uniform(0, 1, (n_dom, 1))
    if include_boundaries:
        x_dom = np.vstack(([[0.0]], x_dom, [[1.0]]))

    x_bc, u_bc = gen_bc(n_bc)
    x_test, u_test = gen_testdata(n_test)

    x_bc_t  = torch.tensor(x_bc, dtype=torch.float32, device=device)
    u_bc_t  = torch.tensor(u_bc, dtype=torch.float32, device=device)

    loss_hist, l2_hist = [], []
    pred_hist, ep_hist_pred, res_hist = [], [], []
    pos_hist, ep_hist = [x_dom.flatten().copy()], []
    pde_loss_hist, bc_loss_hist = [], []

    pde_grid = []
    l2_grid = []
    u_test_fn = interp1d(x_test.flatten(), u_test.flatten(), kind='cubic', fill_value='extrapolate')
    x_eval_grid = None

    start = time.time()
    for ep in range(1, epochs+1):
        epochs_to_threshold = ep
        
        model.train()
        optimizer.zero_grad()

        x_dom_t = torch.tensor(x_dom, dtype=torch.float32, device=device)
        r = pde_residual(model, x_dom_t)
        loss_pde = (r**2).mean()

        u_bc_pred = model(x_bc_t)
        loss_bc   = ((u_bc_pred - u_bc_t)**2).mean()

        pde_loss_hist.append(loss_pde.item())
        bc_loss_hist.append(loss_bc.item())

        loss = loss_pde + loss_bc
        loss.backward()
        optimizer.step()
        loss_hist.append(loss.item())

        model.eval()
        with torch.no_grad():
            u_pred = model(torch.tensor(x_test, dtype=torch.float32, device=device))
            l2 = float(np.linalg.norm(u_pred.cpu().numpy() - u_test) / np.linalg.norm(u_test))
        
        if ep % 20 == 0:
            l2_hist.append(l2)

        # Evaluate losses over high-res grid
        if ep % 1 == 0:
            x_grid = np.linspace(0, 1, 500)[:, None]
            x_grid_t = torch.tensor(x_grid, dtype=torch.float32, device=device, requires_grad=True)

            # PDE residual at each x_grid point
            with torch.enable_grad():
                r_grid = pde_residual(model, x_grid_t).detach().cpu().numpy().flatten()
                pde_grid.append(r_grid**2)

            # Prediction and L2 error at each x_grid point
            with torch.no_grad():
                u_pred_grid = model(x_grid_t).cpu().numpy().flatten()
                l2_grid.append((u_pred_grid - u_test_fn(x_grid.flatten()))**2)

            x_eval_grid = x_grid.flatten()

        stop_threshold = 0.001
        if l2 < stop_threshold:
            epochs_to_threshold = ep
            break

        # if ep % 1 == 0:
        u_pred = model(torch.tensor(x_dom, dtype=torch.float32, device=device)).detach().numpy()
        pred_hist.append(u_pred)
        res_hist.append(r.detach().numpy())
        ep_hist_pred.append(ep)
        # print(f"Epoch: {ep}, n_dom size: {len(x_dom)}, Loss: {loss.item():.4e}, L2 Error: {l2:.4e}")

        pos_hist.append(x_dom.flatten().copy())
        ep_hist.append(ep)

        if refine_every and ep % refine_every == 0:
            global pde_residual_global
            pde_residual_global = pde_residual

            
            coords = x_dom.flatten()
            sort_idx = np.argsort(coords)
            coords_sorted = coords[sort_idx]

            x_full_t = torch.tensor(coords_sorted[:, None], dtype=torch.float32, device=device).requires_grad_(True)
            r_full = pde_residual_global(model, x_full_t).detach().abs().cpu().numpy().flatten()

            x_bc_small, u_bc_small = gen_bc(2)
            u_bc_small = u_bc_small.flatten()
            with torch.no_grad():
                u_full_pred = model(x_full_t).cpu().numpy().flatten()

            bc_err = np.zeros_like(r_full)
            bc_err[0] = abs(u_full_pred[0] - u_bc_small[0])
            bc_err[-1] = abs(u_full_pred[-1] - u_bc_small[1])

            total_err = r_full + bc_err
            boundary_ids = np.array([0, len(coords_sorted) - 1], dtype=int)

            if adaptive_fn == node_moving_1d:
                new_coords_full = node_moving_1d(
                    coords_sorted,
                    total_err,
                    boundary_ids,
                    ratio_nodal_distance=5,
                    relaxation=relaxation,
                    stifness=stifness 
                )
                x_dom = new_coords_full.reshape(-1, 1)
            elif adaptive_fn == random_resampling_1d:
                new_coords_full = random_resampling_1d(
                    coords_sorted,
                    boundary_ids,
                )
                x_dom = new_coords_full.reshape(-1, 1)

            # x_dom_before = x_dom.copy()
            # x_dom_after  = new_coords_full.reshape(-1, 1).copy()
            # plot_adaptation_density(model, x_dom_before, x_dom_after, pde_residual_global, device=device, epoch=ep)
    end = time.time() - start

    return loss_hist, l2_hist, pos_hist, ep_hist, ep_hist_pred, x_test, u_test, pred_hist, \
        epochs_to_threshold, end, pde_loss_hist, bc_loss_hist, res_hist, pde_grid, l2_grid, x_eval_grid


------------
## Calling the Model

In [None]:
if __name__=='__main__':

    # random sampling and node moving
    repeats = 1
    model_configs = []
    cases = [1]
    ndom_list = [200] 
    reallocate_every = [200]
    for case in cases:
        for n_dom in ndom_list:
                for method in [None, node_moving_1d, random_resampling_1d]:
                    method_name = method.__name__ if method is not None else "None"
                    for refine in reallocate_every:
                        model_configs.append({
                            'case':             case,
                            'name':             f"Case{case}_nDom{n_dom}_refine{refine}_method{method_name}",
                            'n_dom':            n_dom,
                            'refine':           refine,
                            'adaptive_method':  method,
                            'epochs':           10000,  
                            'n_repeats':        repeats,
                            'include_boundaries': False
                        })

    results_dir = 'results'
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    # Write results to dataframe
    results_df = pd.DataFrame(columns=[
        'case','name', 'method', 'NN', 'n_dom', 'n_bc', 'refine',
        'l2_error','l2_hist','loss_hist','x_test','u_test','pred_hist', 'u_pred', 'time', 
        'epochs', 'pos_hist', 'include_boundaries','pde_loss_hist', 'bc_loss_hist', 'delay_adapt', 'stifness'
    ])

    print(f"Running {len(model_configs)} models with {repeats} repeats each.")
    print("Results will be saved to:", results_dir)
    total = len(model_configs) * repeats
    i, ave_time = 0, 0

    for cfg in model_configs:
        # Calculate waiting time based on average
        waiting_time_mins = (ave_time / (i+1)) * (total - i) / 60
        waiting_time_str = f"{waiting_time_mins:.2f} mins"
        print(f"\n### {cfg['name']} \t|\t Wait ~ {waiting_time_str}###")
        gen_xd, gen_bc_f, pde_res = ode_cases[cfg['case']]

        l2_average = []

        for run in range(cfg['n_repeats']):
            i += 1
            # reseed for reproducibility
            seed = 43 + run
            torch.manual_seed(seed)
            np.random.seed(seed)

            # Define the PINN
            if ('n_layers' in cfg) and ('n_neurons' in cfg):
                model = PINN([1] + [cfg['n_neurons']]*(cfg['n_layers']) + [1]).to(device)
            else:
                model = PINN([1,20,20,1]).to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

            # Call the training loop
            loss_hist, l2_hist_ad, pos_hist, ep_hist_ad, ep_hist_pred,\
                x_test, u_test, pred_hist, epochs_to_threshold_ad, \
                    end_ad, pde_loss_hist, bc_loss_hist, res_hist, \
                        pde_grid, l2_grid, x_eval_grid = train(

                model, optimizer,
                gen_xd, gen_bc_f, pde_res,
                n_dom=cfg['n_dom'] if 'n_dom' in cfg else 50, 
                n_bc=cfg['n_bc'] if 'n_bc' in cfg else 2, 
                n_test=cfg.get('n_test', 10000),
                epochs=cfg.get('epochs', 10000),
                refine_every=cfg['refine'],
                adaptive_fn=cfg['adaptive_method'],
                relaxation=cfg.get('relaxation', 1.0),
                include_boundaries=cfg.get('boundary_in_xdom', False),
                stifness = cfg.get('spring_stifness', 0.5),
            )
            ave_time += end_ad

            pred_hist_ad = pred_hist
            with torch.no_grad():
                u_ad = model(torch.tensor(x_test, dtype=torch.float32, device=device)).cpu().numpy()

            print(f"Iteration: {i}/{total}, \t|\t Final L2 error: {l2_hist_ad[-1]:.3e}, \t|\t Time: {end_ad:.2f}s, \t|\t Epochs: {epochs_to_threshold_ad}")


            # # --- PLOTTING!! ---
            # if cfg['refine'] != 0:
            #     plot_collocation_evolution(pos_hist, ep_hist_ad)

            # plot_pde_bc_loss(pde_loss_hist, bc_loss_hist)
            # # # print(f"PDE loss: {pde_loss_hist[-1]:.3e}, \t|\t BC loss: {bc_loss_hist[-1]:.3e}")
            # # plot_training_and_solution(
            # #     loss_hist, l2_hist_ad, x_test, u_test, model
            # # )

            # plot_spatial_loss_evolution(pde_grid, l2_grid, x_eval_grid, ep_hist_pred)


            # Store the results in the DataFrame
            # ! Lists take up a lot of space so only store them when necessary
            new_row = {
                'case':    cfg['case'],
                'name':    cfg['name'],
                'method':  cfg.get('adaptive_method', 'PINN'),
                'NN':      f"{cfg['n_layers']}x{cfg['n_neurons']}" if 'n_layers' in cfg else '2x20',
                'n_dom':   cfg['n_dom'],
                'refine':  cfg['refine'],
                'l2_error': l2_hist_ad[-1],
                # 'relaxation': cfg.get('relaxation', 1.0),
                'l2_hist': l2_hist_ad, 
                # 'loss_hist': loss_hist,
                # 'x_test':  x_test,
                # 'u_test':  u_test,
                # 'pred_hist':  pred_hist,
                'time':    end_ad,
                'epochs':  epochs_to_threshold_ad,
                # 'pos_hist': pos_hist,
                # 'include_boundaries': cfg.get('boundary_in_xdom', False),
                'pde_loss_hist': pde_loss_hist[-1],
                'bc_loss_hist': bc_loss_hist[-1],
                # 'delay_adapt': cfg.get('delay_adapt', 0),
                # 'stifness': cfg.get('spring_stifness', 0.5),
            }
            results_df.loc[len(results_df)] = new_row  
            
            l2_average.append(l2_hist_ad[-1])
    
        print(f"\nAverage L2 error across all runs: {np.mean(l2_average):.3e}")
        print(f"Average time per run: {ave_time / i:.2f} seconds")

    # Save results to CSV or pickle
    # results_df.to_csv(os.path.join(results_dir, 'results.csv'), index=False)
    results_df.to_pickle(os.path.join(results_dir, 'results.pkl'))