# Physics-Informed Neural Network with Kolmogorov-Arnold Networks (PINN-KAN) for Allenâ€“Cahn Equation



We use:
- **PINNs (Physics-Informed Neural Networks)** to incorporate physical laws into the training process,
- **KANs (Kolmogorovâ€“Arnold Networks)** as flexible and interpretable neural function approximators,
- **Radial Basis Functions (RBFs)** for effective spatial feature representation.

---

### Notebook Structure

This notebook is organized into ten major stages:

1. **RBF Components** â€“ Define radial basis kernels for spatial encoding  
2. **Model Architectures** â€“ Build KAN and PINN models  
3. **Physics Residuals** â€“ Formulate the Allenâ€“Cahn PDE residual  
4. **Loss Functions** â€“ Combine PDE, boundary, and initial condition losses  
5. **Training** â€“ Optimize the model using gradient-based methods  
6. **Evaluation** â€“ Assess model accuracy and residual consistency  
7. **Visualization** â€“ Compare predicted vs. true field solutions  
8. **Main Pipeline** â€“ Integrate all steps into a unified workflow  
9. **Run Experiment** â€“ Execute multiple configurations for comparison  
10. **Diagnostics & Summary** â€“ Analyze, compare, and interpret final results  




In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from typing import Dict, List, Tuple, Optional
import os
import pickle
from scipy.interpolate import griddata
import logging

# Configure a module-level logger and ensure we don't add multiple handlers
logger = logging.getLogger("pinn_logger")
if not logger.handlers:
    handler = logging.StreamHandler()
    formatter = logging.Formatter("%(message)s")
    handler.setFormatter(formatter)
    logger.addHandler(handler)
logger.setLevel(logging.INFO)

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## Step 1: Radial Basis Function (RBF) Components
We begin by defining the RBF kernels that help represent the spatial structure of the Allenâ€“Cahn field.

These functions will later be used to build flexible spatial embeddings for the model inputs.

**Key Concepts:**
- RBFs are used to map spatial inputs into higher-dimensional feature spaces.
- Useful for capturing local patterns and smooth variations.


In [2]:

class RBFEdge(nn.Module):
    """Gaussian Radial Basis Function layer."""
    def __init__(self, input_dim: int, num_rbfs: int):
        super().__init__()
        self.centers = nn.Parameter(torch.randn(num_rbfs, input_dim))
        self.sigmas = nn.Parameter(torch.ones(num_rbfs, input_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        expanded = (x.unsqueeze(1) - self.centers) ** 2
        scaled = expanded / (2 * (self.sigmas ** 2))
        return torch.exp(-scaled.sum(dim=-1))


class KANLayer(nn.Module):
    """
    KANLayer: maps from input_dim -> output_dim using an RBF expansion
    num_rbfs controls the size of the intermediate RBF feature map.
    """
    def __init__(self, input_dim: int, num_rbfs: int, output_dim: int):
        super().__init__()
        self.rbf_edge = RBFEdge(input_dim=input_dim, num_rbfs=num_rbfs)
        self.linear = nn.Linear(num_rbfs, output_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        phi = self.rbf_edge(x)
        return self.linear(phi)

## Step 2: Model Architectures
In this section, we define the model architectures for both:
- **Kolmogorovâ€“Arnold Network (KAN)** layers  
- **Physics-Informed Neural Network (PINN)** components  

These models approximate the solution to the PDE while embedding physics constraints directly into the training objective.


In [3]:
class PINN_KAN(nn.Module):
    """
    Research-grade Physics-Informed Kolmogorovâ€“Arnold Network:
        - RBF + linear KAN layers
        - LayerNorm for stable PDE training
        - Tanh activation for smooth second derivatives
        - Skip-connection (KAN-style functional bypass)
        - More expressive RBF dimensions
    """

    def __init__(self,
                 input_dim: int = 2,
                 num_rbfs_list: List[int] = [32, 48, 32],
                 out_dim: int = 1):

        super().__init__()

        layers = []
        in_dim = input_dim

        for num_rbfs in num_rbfs_list:
            layers.append(KANLayer(
                input_dim=in_dim,
                num_rbfs=num_rbfs,
                output_dim=num_rbfs
            ))
            layers.append(nn.LayerNorm(num_rbfs))
            layers.append(nn.Tanh())
            in_dim = num_rbfs

        layers.append(nn.Linear(in_dim, out_dim))
        self.model = nn.Sequential(*layers)
        self.skip = nn.Linear(input_dim, out_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass:
        u = F_KAN(x) + Î± * skip(x)
        """
        return self.model(x) + 0.1 * self.skip(x)


class VanillaMLP(nn.Module):
    """Pure data-driven MLP (no physics loss)."""
    def __init__(self, input_dim: int = 2, hidden_dims: List[int] = [64, 64, 32], out_dim: int = 1):
        super().__init__()
        layers = []

        layers.append(nn.Linear(input_dim, hidden_dims[0]))
        layers.append(nn.Tanh())

        for i in range(len(hidden_dims) - 1):
            layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))
            layers.append(nn.Tanh())

        layers.append(nn.Linear(hidden_dims[-1], out_dim))

        self.model = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)


class VanillaPINN(nn.Module):
    """Vanilla PINN (MLP + Physics loss)."""
    def __init__(self, input_dim: int = 2, hidden_dims: List[int] = [64, 64, 32], out_dim: int = 1):
        super().__init__()
        layers = []

        layers.append(nn.Linear(input_dim, hidden_dims[0]))
        layers.append(nn.Tanh())

        for i in range(len(hidden_dims) - 1):
            layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))
            layers.append(nn.Tanh())

        layers.append(nn.Linear(hidden_dims[-1], out_dim))

        self.model = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

print("âœ… Model architectures loaded")

âœ… Model architectures loaded


## Step 3: Physics Residuals

Here we define the Allenâ€“Cahn PDE. In display form:

$$
u_t = \varepsilon^{2}\, u_{xx} - f(u)
$$

Equivalently, the physics residual (which the PINN minimizes) can be written as

$$
\mathcal{R}(x,t) \;=\; u_t(x,t) \;-\; \varepsilon^{2}\,u_{xx}(x,t) \;+\; f\big(u(x,t)\big).
$$

A commonly used choice for the nonlinear reaction term is the double-well potential derivative:

$$
f(u) = u^{3} - u,
$$

so that the PDE becomes

$$
u_t = \varepsilon^{2}\,u_{xx} - (u^{3} - u).
$$

**Purpose:**  
The residual enforces physical consistency â€” the predicted solution must minimize \(\mathcal{R}(x,t)\) across the spatialâ€“temporal domain.


In [4]:
def allen_cahn_pde_residual(model: nn.Module, x: torch.Tensor, t: torch.Tensor, epsilon: float) -> torch.Tensor:
    """
    Stable Allenâ€“Cahn residual:
        u_t = ÎµÂ² u_xx - (uÂ³ - u)
    Includes:
        - shape-safe handling for x,t
        - clamped u for stability
        - stabilizer Î»u term
        - residual clipping
    """

    if x.dim() == 1:
        x = x.unsqueeze(-1)
    if t.dim() == 1:
        t = t.unsqueeze(-1)

    X = torch.cat([x, t], dim=1).clone().detach().requires_grad_(True)

    u = model(X)
    u = torch.clamp(u, -1.5, 1.5)

    grads = torch.autograd.grad(
        outputs=u, inputs=X, grad_outputs=torch.ones_like(u),
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]

    u_x = grads[:, 0:1]
    u_t = grads[:, 1:2]

    u_xx = torch.autograd.grad(
        outputs=u_x, inputs=X, grad_outputs=torch.ones_like(u_x),
        create_graph=True, retain_graph=True, only_inputs=True
    )[0][:, 0:1]

    stabilizer = 1e-4 * u
    residual = u_t - (epsilon**2) * u_xx + u**3 - u + stabilizer
    residual = torch.clamp(residual, -5.0, 5.0)

    return residual

print("âœ… Improved Physics residual function loaded")


# ==================== LOSS FUNCTIONS ====================

def initial_condition_loss(model: nn.Module, x_ic: torch.Tensor, u_ic_true: torch.Tensor) -> torch.Tensor:
    """
    MSE loss enforcing u(x,0) = u0(x)
    x_ic: shape (N_ic, 1)
    u_ic_true: shape (N_ic, 1)
    """
    t_ic = torch.zeros_like(x_ic).to(x_ic.device)
    X_ic = torch.cat([x_ic, t_ic], dim=1)
    u_pred = model(X_ic)
    return nn.MSELoss()(u_pred, u_ic_true)


def boundary_condition_loss(model: nn.Module, x_bc: torch.Tensor, t_bc: torch.Tensor, u_bc_true: torch.Tensor) -> torch.Tensor:
    """
    Enforce Dirichlet boundary values: u(x=Â±L, t) = known (in your data it's 0)
    x_bc: (N_bc,1), t_bc: (N_bc,1), u_bc_true: (N_bc,1)
    """
    X_bc = torch.cat([x_bc, t_bc], dim=1).to(x_bc.device)
    u_pred = model(X_bc)
    return nn.MSELoss()(u_pred, u_bc_true)


def data_loss(model: nn.Module, X: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Pure data loss (for Vanilla MLP)."""
    pred = model(X)
    return nn.MSELoss()(pred, y)

âœ… Improved Physics residual function loaded


## Step 4: Loss Function Definitions
We now define the composite loss that combines:
- **PDE residual loss**
- **Initial condition (IC) loss**
- **Boundary condition (BC) loss**

The total loss guides the model toward both data fidelity and physical correctness.


In [5]:
def pinn_loss(model: nn.Module,
              X_data: torch.Tensor, y_data: torch.Tensor,
              X_coll: torch.Tensor, epsilon: float,
              X_ic: Optional[torch.Tensor] = None,
              y_ic: Optional[torch.Tensor] = None,
              X_bc: Optional[torch.Tensor] = None,
              y_bc: Optional[torch.Tensor] = None,
              alpha: float = 0.1,
              beta: float = 1.0,
              gamma_ic: float = 200.0,
              gamma_bc: float = 50.0
              ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Improved PINN loss:
        - data loss (small)
        - physics residual loss (adaptive growing)
        - IC loss (strong)
        - BC loss (moderate)
        - PDE curriculum (adaptive Î²)
    """

    pred = model(X_data)
    loss_data = nn.MSELoss()(pred, y_data)

    x_col = X_coll[:, 0]
    t_col = X_coll[:, 1]
    residual = allen_cahn_pde_residual(model, x_col, t_col, epsilon)
    loss_physics = torch.mean(residual**2)

    loss_ic = torch.tensor(0.0, device=X_data.device)
    if X_ic is not None and y_ic is not None:
        loss_ic = nn.MSELoss()(model(X_ic), y_ic)

    loss_bc = torch.tensor(0.0, device=X_data.device)
    if X_bc is not None and y_bc is not None:
        loss_bc = nn.MSELoss()(model(X_bc), y_bc)

    global epoch_num
    if "epoch_num" not in globals():
        epoch_num = 0

    pde_weight = beta * (1.0 + 0.001 * epoch_num)

    total_loss = (
        alpha * loss_data +
        pde_weight * loss_physics +
        gamma_ic * loss_ic +
        gamma_bc * loss_bc
    )

    return total_loss, loss_data, loss_physics, loss_ic, loss_bc

print("âœ… Loss functions loaded")



âœ… Loss functions loaded


## Step 5: Training and Results Visualization

After defining all components, we train the model using gradient descent.

Below, we visualize:
- The **predicted vs true** Allenâ€“Cahn field,
- The **error map** across the domain,
- And optionally, the **training loss curve** over iterations.

In [6]:
def train_vanilla_mlp(model: nn.Module, X_data: torch.Tensor, y_data: torch.Tensor,
                      epochs: int = 2000, lr: float = 1e-3,
                      print_every: int = 200) -> Dict[str, List[float]]:
    """Train vanilla MLP (data-driven only)."""
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_history = []

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()

        loss = data_loss(model, X_data, y_data)
        loss.backward()
        optimizer.step()

        loss_history.append(loss.item())

        if (epoch + 1) % print_every == 0 or epoch == 0:
            logger.info(f"Epoch {epoch+1}/{epochs} | Loss: {loss.item():.6e}")

    return {'loss': loss_history}


# ==================== IMPROVED TRAINING LOOP (ADAM + LBFGS) ====================
def train_pinn(model: nn.Module, X_data: torch.Tensor, y_data: torch.Tensor,
               X_coll: torch.Tensor, epsilon: float,
               X_ic: Optional[torch.Tensor] = None, y_ic: Optional[torch.Tensor] = None,
               X_bc: Optional[torch.Tensor] = None, y_bc: Optional[torch.Tensor] = None,
               epochs: int = 2000, lr: float = 1e-3,
               print_every: int = 200) -> Dict[str, List[float]]:
    """Train PINN with improved Adam + LBFGS optimization."""

    global epoch_num
    epoch_num = 0

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    loss_history = []
    data_loss_history = []
    physics_loss_history = []
    ic_loss_history = []
    bc_loss_history = []

    print("\n=== Starting Adam optimizer phase ===")

    for epoch in range(epochs):
        epoch_num = epoch

        optimizer.zero_grad()

        total, d_loss, p_loss, ic_loss, bc_loss = pinn_loss(
            model,
            X_data, y_data,
            X_coll, epsilon,
            X_ic=X_ic, y_ic=y_ic,
            X_bc=X_bc, y_bc=y_bc,
            alpha=0.1,
            beta=1.0,
            gamma_ic=200.0,
            gamma_bc=50.0
        )

        total.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
        optimizer.step()

        loss_history.append(total.item())
        data_loss_history.append(d_loss.item())
        physics_loss_history.append(p_loss.item())
        ic_loss_history.append(ic_loss.item() if isinstance(ic_loss, torch.Tensor) else ic_loss)
        bc_loss_history.append(bc_loss.item() if isinstance(bc_loss, torch.Tensor) else bc_loss)

        if (epoch + 1) % print_every == 0 or epoch == 0:
            print(f"[Adam] Epoch {epoch+1}/{epochs} | "
                  f"Total: {total.item():.3e} | "
                  f"Data: {d_loss.item():.2e} | Physics: {p_loss.item():.2e} | "
                  f"IC: {ic_loss.item() if isinstance(ic_loss, torch.Tensor) else ic_loss:.2e} | "
                  f"BC: {bc_loss.item() if isinstance(bc_loss, torch.Tensor) else bc_loss:.2e}")

    print("\n=== Starting LBFGS refinement ===")

    lbfgs = torch.optim.LBFGS(model.parameters(),
                              max_iter=500,
                              tolerance_grad=1e-9,
                              tolerance_change=1e-9,
                              history_size=50)

    def closure():
        lbfgs.zero_grad()
        total, *_ = pinn_loss(
            model,
            X_data, y_data,
            X_coll, epsilon,
            X_ic=X_ic, y_ic=y_ic,
            X_bc=X_bc, y_bc=y_bc,
            alpha=0.1,
            beta=1.0,
            gamma_ic=200.0,
            gamma_bc=50.0
        )
        total.backward()
        return total

    lbfgs.step(closure)
    print("LBFGS complete.\n")

    return {
        "loss": loss_history,
        "data_loss": data_loss_history,
        "physics_loss": physics_loss_history,
        "ic_loss": ic_loss_history,
        "bc_loss": bc_loss_history
    }

print("âœ… Training functions loaded")


âœ… Training functions loaded


In [7]:
def compute_metrics(model: nn.Module, X: torch.Tensor, y: torch.Tensor,
                   X_coll: torch.Tensor, epsilon: float) -> Dict[str, float]:
    """Compute comprehensive evaluation metrics."""
    model.eval()
    with torch.no_grad():
        u_pred = model(X)

        mse = torch.mean((u_pred - y)**2)
        rmse = torch.sqrt(mse)
        mae = torch.mean(torch.abs(u_pred - y))
        rel_l2_error = torch.norm(u_pred - y) / torch.norm(y)

    X_coll_eval = X_coll.clone().detach().requires_grad_(True)
    x_col = X_coll_eval[:, 0]
    t_col = X_coll_eval[:, 1]

    residual = allen_cahn_pde_residual(model, x_col, t_col, epsilon)

    residual_l2_norm = torch.norm(residual, p=2)
    residual_l2_normalized = residual_l2_norm / np.sqrt(len(residual))
    max_residual = torch.max(torch.abs(residual))
    mean_residual = torch.mean(torch.abs(residual))

    metrics = {
        'RMSE': rmse.item(),
        'MAE': mae.item(),
        'Relative_L2_Error': rel_l2_error.item(),
        'Residual_L2_Norm': residual_l2_normalized.item(),
        'Max_Residual': max_residual.item(),
        'Mean_Residual': mean_residual.item()
    }

    return metrics

print("âœ… Evaluation functions loaded")


âœ… Evaluation functions loaded


## Step 7: Visualization of Results

Here, we visualize:
- The predicted field `u_pred(x,t)` over the domain,
- The true/reference field `u_true(x,t)`, and
- The absolute error distribution `|u_pred - u_true|`.

These visualizations help confirm if the learned dynamics replicate the Allenâ€“Cahn diffusion-reaction pattern formation.




In [8]:
def plot_prediction_vs_actual(model: nn.Module, X: torch.Tensor, y: torch.Tensor,
                              title: str = "Predictions vs Actual"):
    """Scatter plot of predictions vs actual values."""
    model.eval()
    with torch.no_grad():
        preds = model(X).cpu().numpy()
        actual = y.cpu().numpy()

    plt.figure(figsize=(8, 6))
    plt.scatter(actual, preds, alpha=0.5)
    plt.xlabel('Actual u', fontsize=12)
    plt.ylabel('Predicted u', fontsize=12)
    plt.title(title, fontsize=14, fontweight='bold')
    plt.plot([actual.min(), actual.max()], [actual.min(), actual.max()], 'k--')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f"{title.replace(' ', '_').replace(':', '')}_scatter.png")
    plt.close() # Close plot to prevent it from displaying in a loop

def plot_residual_surface(model, X_coll, epsilon):

    x = X_coll[:, 0].detach().cpu().numpy()
    t = X_coll[:, 1].detach().cpu().numpy()

    res_tensor = allen_cahn_pde_residual(
        model,
        X_coll[:, 0],
        X_coll[:, 1],
        epsilon
    )
    # *** FIX: .flatten() is required for 3D scatter's c and z args ***
    res = res_tensor.detach().cpu().numpy().flatten()

    fig = plt.figure(figsize=(10, 6))
    ax = fig.add_subplot(111, projection='3d')

    ax.scatter(x, t, np.abs(res), s=2, c=np.abs(res), cmap='hot')
    ax.set_xlabel("x")
    ax.set_ylabel("t")
    ax.set_zlabel("|Residual|")
    ax.set_title("PDE Residual Surface")

    plt.savefig(f"residual_surface_{model.__class__.__name__}.png")
    plt.close() # Close plot


def plot_solution_heatmaps(model: nn.Module, X: torch.Tensor, y: torch.Tensor,
                          title_prefix: str = ""):
    """Plot predicted, true, and error heatmaps."""
    model.eval()
    with torch.no_grad():
        u_pred = model(X).cpu().numpy()

    x_vals = X[:, 0].cpu().numpy()
    t_vals = X[:, 1].cpu().numpy()
    x_unique = np.unique(x_vals)
    t_unique = np.unique(t_vals)

    X_grid, T_grid = np.meshgrid(x_unique, t_unique)

    # Handle non-grid data
    if X_grid.size != u_pred.size:
        print(f"Warning: Interpolating for heatmap. Grid size {X_grid.size} != Data size {u_pred.size}")
        U_pred_grid = griddata((x_vals, t_vals), u_pred.flatten(), (X_grid, T_grid), method='cubic')
        U_true_grid = griddata((x_vals, t_vals), y.cpu().numpy().flatten(), (X_grid, T_grid), method='cubic')
    else:
        U_pred_grid = u_pred.reshape(len(t_unique), len(x_unique))
        U_true_grid = y.cpu().numpy().reshape(len(t_unique), len(x_unique))

    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    im1 = axes[0].contourf(X_grid, T_grid, U_pred_grid, levels=50, cmap='viridis')
    axes[0].set_title(f'{title_prefix} Predicted Solution u(x,t)',
                      fontsize=14, fontweight='bold')
    axes[0].set_xlabel('x', fontsize=12)
    axes[0].set_ylabel('t', fontsize=12)
    plt.colorbar(im1, ax=axes[0], label='u')

    im2 = axes[1].contourf(X_grid, T_grid, U_true_grid, levels=50, cmap='viridis')
    axes[1].set_title('True Solution u(x,t)', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('x', fontsize=12)
    axes[1].set_ylabel('t', fontsize=12)
    plt.colorbar(im2, ax=axes[1], label='u')

    error = np.abs(U_pred_grid - U_true_grid)
    im3 = axes[2].contourf(X_grid, T_grid, error, levels=50, cmap='hot')
    axes[2].set_title(f'{title_prefix} Absolute Error',
                      fontsize=14, fontweight='bold')
    axes[2].set_xlabel('x', fontsize=12)
    axes[2].set_ylabel('t', fontsize=12)
    plt.colorbar(im3, ax=axes[2], label='Error')

    plt.tight_layout()
    plt.savefig(f"{title_prefix.replace(' ', '_')}_solution_heatmaps.png")
    plt.close() # Close plot


def plot_residual_heatmap(model: nn.Module, X_coll: torch.Tensor, epsilon: float,
                         title_prefix: str = ""):
    """Plot physics residual heatmap."""
    X_coll_eval = X_coll.clone().detach().requires_grad_(True)

    model.eval()
    x_col = X_coll_eval[:, 0]
    t_col = X_coll_eval[:, 1]

    residual = allen_cahn_pde_residual(model, x_col, t_col, epsilon)

    x_coll_np = X_coll[:, 0].cpu().numpy()
    t_coll_np = X_coll[:, 1].cpu().numpy()
    residual_np = residual.detach().cpu().numpy()

    X_grid, T_grid = np.meshgrid(
        np.linspace(x_coll_np.min(), x_coll_np.max(), 100),
        np.linspace(t_coll_np.min(), t_coll_np.max(), 100)
    )
    residual_grid = griddata(
        (x_coll_np, t_coll_np),
        residual_np.flatten(),
        (X_grid, T_grid),
        method='cubic'
    )

    plt.figure(figsize=(10, 6))
    im = plt.contourf(X_grid, T_grid, np.abs(residual_grid), levels=50, cmap='hot')
    plt.colorbar(im, label='|Residual|')
    plt.title(f'{title_prefix} Physics Residual Heatmap',
              fontsize=14, fontweight='bold')
    plt.xlabel('x', fontsize=12)
    plt.ylabel('t', fontsize=12)
    plt.tight_layout()
    plt.savefig(f"{title_prefix.replace(' ', '_')}_residual_heatmap.png")
    plt.close() # Close plot

    print(f"Max absolute residual: {np.abs(residual_np).max():.6e}")
    print(f"Mean absolute residual: {np.abs(residual_np).mean():.6e}")


def plot_loss_curves(loss_dict: Dict[str, List[float]], title: str = "Training Loss"):
    """Plot training loss curves."""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    axes[0].plot(loss_dict['loss'], label='Total Loss', linewidth=2)
    if 'data_loss' in loss_dict:
        axes[0].plot(loss_dict['data_loss'], label='Data Loss', linewidth=2, alpha=0.7)
        axes[0].plot(loss_dict['physics_loss'], label='Physics Loss', linewidth=2, alpha=0.7)
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Loss', fontsize=12)
    axes[0].set_title(f'{title} (Log Scale)', fontsize=14, fontweight='bold')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    axes[0].set_yscale('log')

    n = min(500, len(loss_dict['loss']))
    axes[1].plot(loss_dict['loss'][-n:], label='Total Loss', linewidth=2)
    axes[1].set_xlabel('Epoch', fontsize=12)
    axes[1].set_ylabel('Loss', fontsize=12)
    axes[1].set_title(f'{title} (Last {n} Epochs)', fontsize=14, fontweight='bold')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f"{title.replace(' ', '_').replace(':', '')}_loss_curves.png")
    plt.close() # Close plot


def plot_comparison_bar_chart(results: Dict[str, Dict[str, float]]):
    """Bar chart comparing metrics across models."""
    models = list(results.keys())
    metrics = ['RMSE', 'MAE', 'Relative_L2_Error', 'Max_Residual']

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()

    for idx, metric in enumerate(metrics):
        values = [results[model][metric] for model in models]
        axes[idx].bar(models, values, color=['blue', 'orange', 'green'][:len(models)])
        axes[idx].set_ylabel(metric, fontsize=12)
        axes[idx].set_title(f'{metric} Comparison', fontsize=13, fontweight='bold')
        axes[idx].grid(True, alpha=0.3, axis='y')

        for i, v in enumerate(values):
            axes[idx].text(i, v, f'{v:.2e}', ha='center', va='bottom', fontsize=10)

    plt.tight_layout()
    plt.savefig("metrics_comparison_barchart.png")
    plt.close() # Close plot

print("âœ… Visualization functions loaded")



âœ… Visualization functions loaded


## Step 8: Main Pipeline Execution

This section orchestrates the entire workflow:
1. Initializes the dataset and domain,
2. Builds the model architecture,
3. Defines loss functions and PDE residuals,
4. Trains the model,
5. Evaluates and visualizes results.

It acts as a single entry point to reproduce all results from scratch, ensuring the process is **modular, repeatable, and scalable**.


In [9]:
def u0(x):
    return x**2 * np.cos(np.pi * x)

print("âœ… Initial condition function u0(x) loaded")


# ==================== MAIN PIPELINE (FIXED) ====================

def run_allen_cahn_experiment(
    data_path: str,
    collocation_path: str,
    epsilon: float = 0.01,
    epochs: int = 2000,
    lr: float = 1e-3,
    save_dir: str = 'allen_cahn_results'
) -> Dict[str, any]:
    """
    Complete experimental pipeline for Allen-Cahn equation.
    *** FIX: This version uses SCALED coordinates for all
    PINN-related training, evaluation, and visualization. ***
    """

    print("="*70)
    print("ALLEN-CAHN EQUATION PINN-KAN EXPERIMENT")
    print("="*70)

    print("\n1. Loading data...")
    try:
        collocation_df = pd.read_csv(collocation_path)
        full_df = pd.read_csv(data_path)
    except FileNotFoundError as e:
        print(f"Error: Could not load data file. {e}")
        print("Please ensure paths are correct.")
        return {}

    X_collocation = torch.tensor(collocation_df[['x', 't']].values, dtype=torch.float32).to(device)
    X_full = torch.tensor(full_df[['x', 't']].values, dtype=torch.float32).to(device)
    y_full = torch.tensor(full_df['u'].values, dtype=torch.float32).unsqueeze(-1).to(device)

    print(f"   X_full shape: {X_full.shape}")
    print(f"   y_full shape: {y_full.shape}")
    print(f"   X_collocation shape: {X_collocation.shape}")

    x_min, x_max = X_full[:, 0].min(), X_full[:, 0].max()
    t_min, t_max = X_full[:, 1].min(), X_full[:, 1].max()

    def scale_X(X):
        Xs = X.clone()
        Xs[:, 0] = 2.0 * (X[:, 0] - x_min) / (x_max - x_min) - 1.0
        Xs[:, 1] = 2.0 * (X[:, 1] - t_min) / (t_max - t_min) - 1.0
        return Xs

    x_ic_np = np.unique(full_df['x'].values)
    x_ic = torch.tensor(x_ic_np.reshape(-1,1), dtype=torch.float32).to(device)
    y_ic_np = u0(x_ic_np).reshape(-1,1).astype(np.float32)
    y_ic = torch.tensor(y_ic_np, dtype=torch.float32).to(device)
    X_ic = torch.cat([x_ic, torch.zeros_like(x_ic)], dim=1)

    t_bc_np = np.unique(full_df['t'].values)
    x_left = np.full_like(t_bc_np, fill_value=full_df['x'].min()).reshape(-1,1)
    x_right = np.full_like(t_bc_np, fill_value=full_df['x'].max()).reshape(-1,1)
    t_bc = t_bc_np.reshape(-1,1).astype(np.float32)

    X_bc_left = torch.tensor(np.hstack([x_left, t_bc]), dtype=torch.float32).to(device)
    X_bc_right = torch.tensor(np.hstack([x_right, t_bc]), dtype=torch.float32).to(device)
    X_bc = torch.cat([X_bc_left, X_bc_right], dim=0)
    y_bc = torch.zeros((len(X_bc), 1), dtype=torch.float32, device=device)

    X_full_s = scale_X(X_full)
    X_coll_s = scale_X(X_collocation)
    X_ic_s   = scale_X(X_ic)
    X_bc_s   = scale_X(X_bc)

    print("\n2. Initializing models...")
    models = {
        'PINN-KAN': PINN_KAN(input_dim=2, num_rbfs_list=[30, 40, 30], out_dim=1).to(device),
        'Vanilla-MLP': VanillaMLP(input_dim=2, hidden_dims=[64, 64, 32], out_dim=1).to(device),
        'Vanilla-PINN': VanillaPINN(input_dim=2, hidden_dims=[64, 64, 32], out_dim=1).to(device)
    }

    for name, model in models.items():
        n_params = sum(p.numel() for p in model.parameters())
        print(f"   {name}: {n_params:,} parameters")

    print("\n3. Training models...")
    loss_histories = {}

    for name, model in models.items():
        print(f"\n{'='*70}")
        print(f"Training {name}")
        print(f"{'='*70}")

        if name == 'Vanilla-MLP':
            # *** FIX: Train MLP on UNSCALED data ***
            loss_dict = train_vanilla_mlp(model, X_full, y_full, epochs=epochs, lr=lr)
        else:
            # *** FIX: Train PINNs on SCALED data ***
            loss_dict = train_pinn(
                model, X_full_s, y_full, X_coll_s, epsilon,
                X_ic=X_ic_s, y_ic=y_ic,
                X_bc=X_bc_s, y_bc=y_bc,
                epochs=epochs, lr=lr,
                print_every=200
            )

        loss_histories[name] = loss_dict

    print("\n4. Evaluating models...")
    all_metrics = {}

    for name, model in models.items():
        print(f"\nEvaluating {name}...")

        # *** FIX: Evaluate MLP on unscaled, PINNs on scaled ***
        if name == 'Vanilla-MLP':
            metrics = compute_metrics(model, X_full, y_full, X_collocation, epsilon)
        else:
            metrics = compute_metrics(model, X_full_s, y_full, X_coll_s, epsilon)

        all_metrics[name] = metrics

        print(f"   RMSE: {metrics['RMSE']:.6e}")
        print(f"   MAE: {metrics['MAE']:.6e}")
        print(f"   Max Residual: {metrics['Max_Residual']:.6e}")

    print(f"\n5. Generating visualizations... (saving to '{save_dir}')")
    os.makedirs(save_dir, exist_ok=True)

    for name, model in models.items():
        print(f"\nVisualizations for {name}:")

        # *** FIX: Plot MLP on unscaled, PINNs on scaled ***
        if name == 'Vanilla-MLP':
            plot_prediction_vs_actual(model, X_full, y_full, title=f"{name}: Predictions vs Actual")
            plot_solution_heatmaps(model, X_full, y_full, title_prefix=name)
        else:
            plot_prediction_vs_actual(model, X_full_s, y_full, title=f"{name}: Predictions vs Actual")
            plot_solution_heatmaps(model, X_full_s, y_full, title_prefix=name)
            plot_residual_heatmap(model, X_coll_s, epsilon, title_prefix=name)

        plot_loss_curves(loss_histories[name], title=f"{name} Training")

    print("\nGenerating comparison chart...")
    plot_comparison_bar_chart(all_metrics)

    print("\nPlotting Residual Surfaces...")
    # *** FIX: Pass scaled X_coll_s ***
    plot_residual_surface(models["PINN-KAN"], X_coll_s, epsilon)
    plot_residual_surface(models["Vanilla-PINN"], X_coll_s, epsilon)

    print(f"\n6. Saving results to {save_dir}...")

    for name, model in models.items():
        model_path = os.path.join(save_dir, f"{name.lower().replace('-', '_')}_model.pth")
        torch.save(model.state_dict(), model_path)

    metrics_df = pd.DataFrame(all_metrics).T
    metrics_df.to_csv(os.path.join(save_dir, 'metrics_comparison.csv'))

    with open(os.path.join(save_dir, 'all_results.pkl'), 'wb') as f:
        pickle.dump({
            'metrics': all_metrics,
            'loss_histories': loss_histories
        }, f)

    print("\n" + "="*70)
    print("EXPERIMENT COMPLETED!")
    print("="*70)

    print("\nFINAL METRICS COMPARISON:")
    print(metrics_df.to_string())

    return {
        'models': models,
        'metrics': all_metrics,
        'loss_histories': loss_histories
    }

print("âœ… Main pipeline loaded")



âœ… Initial condition function u0(x) loaded
âœ… Main pipeline loaded


## Step 9: Run Experiment

Here, we execute experiments with multiple hyperparameters or configurations (e.g., different network sizes, learning rates, or basis functions).

**Goal:**  
To identify the most accurate and efficient configuration for solving the Allenâ€“Cahn PDE.

Each run logs metrics like:
- Training loss
- Residual loss
- Boundary and initial condition satisfaction
- Total runtime


In [12]:
if __name__ == "__main__":
    # Run complete experiment
    # *** NOTE: You must have the data files at these paths ***
    results = run_allen_cahn_experiment(
        data_path='/content/allen_cahn_1d.csv',
        collocation_path='/content/allen_cahn_collocation.csv',
        epsilon=0.01,
        epochs=2000,
        lr=1e-3,
        save_dir='Results/allen_cahn_results'
    )

    print("\nâœ… All experiments completed successfully!")
    print(f"ðŸ“Š Results saved in 'allen_cahn_results/' directory")

ALLEN-CAHN EQUATION PINN-KAN EXPERIMENT

1. Loading data...
   X_full shape: torch.Size([51200, 2])
   y_full shape: torch.Size([51200, 1])
   X_collocation shape: torch.Size([20000, 2])

2. Initializing models...
   PINN-KAN: 8,653 parameters
   Vanilla-MLP: 6,465 parameters
   Vanilla-PINN: 6,465 parameters

3. Training models...

Training PINN-KAN

=== Starting Adam optimizer phase ===
[Adam] Epoch 1/2000 | Total: 9.906e+01 | Data: 5.80e-01 | Physics: 9.24e-02 | IC: 4.49e-01 | BC: 1.82e-01
[Adam] Epoch 200/2000 | Total: 2.177e+01 | Data: 1.54e-01 | Physics: 3.28e-02 | IC: 1.05e-01 | BC: 1.57e-02
[Adam] Epoch 400/2000 | Total: 2.163e+01 | Data: 1.60e-01 | Physics: 3.22e-02 | IC: 1.05e-01 | BC: 1.18e-02
[Adam] Epoch 600/2000 | Total: 2.153e+01 | Data: 1.65e-01 | Physics: 3.43e-02 | IC: 1.05e-01 | BC: 1.06e-02
[Adam] Epoch 800/2000 | Total: 2.148e+01 | Data: 1.67e-01 | Physics: 3.85e-02 | IC: 1.04e-01 | BC: 1.09e-02
[Adam] Epoch 1000/2000 | Total: 2.144e+01 | Data: 1.73e-01 | Physics: 

Epoch 1/2000 | Loss: 1.737415e-01
INFO:pinn_logger:Epoch 1/2000 | Loss: 1.737415e-01


LBFGS complete.


Training Vanilla-MLP


Epoch 200/2000 | Loss: 1.829303e-02
INFO:pinn_logger:Epoch 200/2000 | Loss: 1.829303e-02
Epoch 400/2000 | Loss: 9.481753e-03
INFO:pinn_logger:Epoch 400/2000 | Loss: 9.481753e-03
Epoch 600/2000 | Loss: 8.159027e-03
INFO:pinn_logger:Epoch 600/2000 | Loss: 8.159027e-03
Epoch 800/2000 | Loss: 7.531142e-03
INFO:pinn_logger:Epoch 800/2000 | Loss: 7.531142e-03
Epoch 1000/2000 | Loss: 7.356492e-03
INFO:pinn_logger:Epoch 1000/2000 | Loss: 7.356492e-03
Epoch 1200/2000 | Loss: 6.560012e-03
INFO:pinn_logger:Epoch 1200/2000 | Loss: 6.560012e-03
Epoch 1400/2000 | Loss: 6.258022e-03
INFO:pinn_logger:Epoch 1400/2000 | Loss: 6.258022e-03
Epoch 1600/2000 | Loss: 5.960371e-03
INFO:pinn_logger:Epoch 1600/2000 | Loss: 5.960371e-03
Epoch 1800/2000 | Loss: 5.648928e-03
INFO:pinn_logger:Epoch 1800/2000 | Loss: 5.648928e-03
Epoch 2000/2000 | Loss: 5.277555e-03
INFO:pinn_logger:Epoch 2000/2000 | Loss: 5.277555e-03



Training Vanilla-PINN

=== Starting Adam optimizer phase ===
[Adam] Epoch 1/2000 | Total: 2.745e+01 | Data: 1.45e-01 | Physics: 8.71e-02 | IC: 1.18e-01 | BC: 7.31e-02
[Adam] Epoch 200/2000 | Total: 6.020e+00 | Data: 2.37e-01 | Physics: 3.63e-01 | IC: 1.70e-02 | BC: 4.31e-02
[Adam] Epoch 400/2000 | Total: 4.551e+00 | Data: 2.07e-01 | Physics: 3.47e-01 | IC: 1.16e-02 | BC: 3.46e-02
[Adam] Epoch 600/2000 | Total: 3.203e+00 | Data: 1.91e-01 | Physics: 3.70e-01 | IC: 6.16e-03 | BC: 2.72e-02
[Adam] Epoch 800/2000 | Total: 2.001e+00 | Data: 1.84e-01 | Physics: 3.60e-01 | IC: 1.49e-03 | BC: 2.07e-02
[Adam] Epoch 1000/2000 | Total: 1.652e+00 | Data: 1.81e-01 | Physics: 3.29e-01 | IC: 4.12e-04 | BC: 1.79e-02
[Adam] Epoch 1200/2000 | Total: 1.410e+00 | Data: 1.79e-01 | Physics: 2.83e-01 | IC: 2.17e-04 | BC: 1.45e-02
[Adam] Epoch 1400/2000 | Total: 1.249e+00 | Data: 1.73e-01 | Physics: 2.44e-01 | IC: 1.37e-04 | BC: 1.24e-02
[Adam] Epoch 1600/2000 | Total: 1.151e+00 | Data: 1.70e-01 | Physics: 2.1