<a href="https://colab.research.google.com/github/Nourallaah/FitzHugh-Nagumo-PINN-Simulation/blob/main/PINN_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>



---


**This notebook simulated the solutions of the reaction-diffusion system based on the Fitzhugh-Nagumo PDE model by Using Physics-informed neural networks (PINN)**


---

There will also be results for the predefined case 4.


---

**Detailed explanation of all the functions are in the Comparison notebook.**

In [29]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn



---


`numpy`: For numerical operations and creating grids.

`matplotlib.pyplot`: For plotting heatmaps of the solution.

`torch` and `torch.nn`: PyTorch library for building and training neural networks.


---



In [30]:
D, a = 1.0, 0.3  # Physical parameters of the FitzHugh-Nagumo (FHN) equation
xmin, xmax = -5, 5  # Spatial domain boundaries
N_x = 200  # Number of spatial points for evaluation and plotting
t_min, t_max = 0.0, 2  # Time domain boundaries
N_t = 100  # Number of time points for heatmap resolution
ncase = 4  # Boundary condition case selector



---
- These parameters define the PDE coefficients, spatial and temporal domain, and resolution.

- We are evaluating the model on a different spatial and temporal domain than in the comparison model to be able to see a difference between the cases in the heatmap.

- `ncase` allows switching between different boundary conditions.


---




In [31]:
def safe_exp(z):
    return np.exp(np.clip(z, -700, 700))
def analytical_solution(x, t, D=1.0, a=0.3):
    return 1 / (1 + safe_exp(x / np.sqrt(2 * D) + (a - 0.5) * t))
def analytical_solution_derivative(x, t, D=1.0, a=0.3):
    exp_term = safe_exp(x / np.sqrt(2 * D) + (a - 0.5) * t)
    denom = (1 + exp_term) ** 2
    return -exp_term / (np.sqrt(2 * D) * denom)

`safe_exp` : Prevents overflow in exponential by clipping input values

`analytical_solution` : Provides exact solution values at boundaries and initial time to enforce conditions in training.

`analytical_solution_derivative` : Computes the spatial derivative of the analytical solution at position `x` and time `t`.

---



In [32]:
class FHN_PINN(nn.Module):
    def __init__(self, layers):
        super(FHN_PINN, self).__init__()
        modules = []
        for i in range(len(layers) - 1):
            modules.append(nn.Linear(layers[i], layers[i + 1]))
            if i < len(layers) - 2:
                modules.append(nn.Tanh())  # Activation function
        self.net = nn.Sequential(*modules)

        # Xavier initialization for weights and zero bias initialization
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x, t):
        # Forward pass: concatenate spatial and temporal inputs
        return self.net(torch.cat([x, t], dim=1))

- The network takes spatial coordinate `x` and time `t` as inputs.

- Uses `Tanh` activations for smoothness.

- Xavier initialization helps with stable training.


---



In [33]:
def fhn_pde_loss(model, x, t, D, a):
    x = x.clone().detach().float().requires_grad_(True)
    t = t.clone().detach().float().requires_grad_(True)
    uv = model(x, t)
    u = uv[:, 0:1]

    # Compute gradients for PDE residual
    u_t = torch.autograd.grad(u, t, torch.ones_like(u), create_graph=True)[0]
    u_x = torch.autograd.grad(u, x, torch.ones_like(u), create_graph=True)[0]
    u_xx = torch.autograd.grad(u_x, x, torch.ones_like(u), create_graph=True)[0]

    # PDE residual for FitzHugh-Nagumo equation
    f_u = u_t - D * u_xx - (u - u ** 3 / 3)

    return (f_u ** 2).mean()

- This loss enforces the PDE by penalizing deviations of the network output from satisfying the equation.

- Uses automatic differentiation to compute derivatives w.r.t inputs.


---



In [34]:
def initial_condition_loss(model, x0, t0, u0):
    uv_pred = model(x0, t0)
    return ((uv_pred[:, 0:1] - u0) ** 2).mean()

- Penalizes difference between predicted and known initial condition values.



---



In [35]:
def boundary_condition_loss(model, xb, tb, ncase):
    xb = xb.clone().detach().float().requires_grad_(True)
    tb = tb.clone().detach().float()
    uv_pred = model(xb, tb)
    u = uv_pred[:, 0:1]

    if ncase == 1:
        # Case 1: Use analytical solution at boundaries
        u_left = torch.tensor(analytical_solution(xmin, tb[:len(tb)//2].numpy()), dtype=torch.float32).reshape(-1, 1)
        u_right = torch.tensor(analytical_solution(xmax, tb[len(tb)//2:].numpy()), dtype=torch.float32).reshape(-1, 1)
        u_target = torch.cat([u_left, u_right], dim=0)
        return ((u - u_target) ** 2).mean()
    elif ncase == 2:
        # Case 2: Fixed boundary values (example)
        left_vals = torch.ones_like(tb[:len(tb)//2])
        right_vals = torch.zeros_like(tb[len(tb)//2:])
        u_target = torch.cat([left_vals, right_vals]).reshape(-1, 1)
        return ((u - u_target) ** 2).mean()
    elif ncase == 3:
        #case 3 : Fixed Dirichlet BC
        u = uv_pred[:, 0:1]
        u_x = torch.autograd.grad(u, xb, torch.ones_like(u), create_graph=True)[0]
        u_x_left = u_x[:N_x]
        u_x_right = u_x[N_x:]
        tb_left_np = tb[:N_x].numpy()
        tb_right_np = tb[N_x:].numpy()
        ux_left_np = analytical_solution_derivative(xmin, tb_left_np, D, a)
        ux_right_np = analytical_solution_derivative(xmax, tb_right_np, D, a)
        ux_left = torch.tensor(ux_left_np, dtype=torch.float32).reshape(-1, 1)
        ux_right = torch.tensor(ux_right_np, dtype=torch.float32).reshape(-1, 1)
        return ((u_x_left - ux_left) ** 2).mean() + ((u_x_right - ux_right) ** 2).mean()
    elif ncase == 4:
        #case 4 : Neumann BC
        u = uv_pred[:, 0:1]
        u_x = torch.autograd.grad(u, xb, torch.ones_like(u), create_graph=True)[0]
        return (u_x ** 2).mean()
    elif ncase == 5:
        # case 5: No boundary condition loss
        return torch.tensor(0.0, device=uv_pred.device)
    else:
        raise ValueError(f"Unsupported ncase {ncase}")

This function computes the boundary condition loss for the PINN model, enforcing different types of boundary conditions depending on `ncase` having these different types:
- case 1: **Dirichlet boundary condition**, where the solution is fixed to known values at the edges.
- case 2: **Fixed Dirichlet BC**, where the left boundary is set to 1.0 and the right boundary to 0.0.
- case 3: applying the **Neumann boundary conditions** specifying the derivative of u at the boundaries.
- case 4: **Neumann BC with Zero Gradient** Sets the boundary values equal to their immediate neighbors inside the domain.
- case 5: **Free of boundary conditions**



In [36]:
def train_pinn(t_min, t_max, ncase, epochs=4000):
    layers = [2, 100, 100, 100, 2]  # Input: (x,t), Output: (u,v)
    model = FHN_PINN(layers)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    # Initial condition points
    x_tensor = torch.tensor(np.linspace(xmin, xmax, N_x), dtype=torch.float32).reshape(-1, 1)
    u0_np = analytical_solution(x_tensor.numpy().flatten(), t_min)
    u0 = torch.tensor(u0_np, dtype=torch.float32).reshape(-1, 1)
    t0 = torch.full_like(x_tensor, t_min)

    # Collocation points inside domain for PDE residual
    N_colloc = 5000
    x_colloc = torch.tensor(np.random.uniform(xmin, xmax, N_colloc), dtype=torch.float32).reshape(-1, 1)
    t_colloc = torch.tensor(np.random.uniform(t_min, t_max, N_colloc), dtype=torch.float32).reshape(-1, 1)

    # Boundary points for enforcing BC
    N_b = 200
    xb = torch.cat([
        torch.full((N_b, 1), xmin, dtype=torch.float32),
        torch.full((N_b, 1), xmax, dtype=torch.float32)
    ], dim=0)
    tb = torch.tensor(np.random.uniform(t_min, t_max, 2*N_b), dtype=torch.float32).reshape(-1, 1)

    for epoch in range(epochs):
        optimizer.zero_grad()
        loss_pde = fhn_pde_loss(model, x_colloc, t_colloc, D, a)
        loss_ic = initial_condition_loss(model, x_tensor, t0, u0)
        loss_bc = boundary_condition_loss(model, xb, tb, ncase)

        # Weighted sum of losses
        loss = 100.0 * loss_pde + 30.0 * loss_ic + 50.0 * loss_bc
        loss.backward()
        optimizer.step()

        if epoch % 500 == 0:
            print(f"Epoch {epoch}: Loss = {loss.item():.4e}")

    return model



---

This function trains a Physics-Informed Neural Network (PINN) to approximate the solution of a PDE on a time interval `[t_start, t_end]` with specified boundary conditions (`ncase`).

- It defines a fully connected neural network (`FHN_PINN`) with 3 hidden layers of 100 neurons each.
- Creates tensors for spatial points, initial conditions at `t_start`, and collocation points sampled randomly in space-time for enforcing the PDE.
- Defines boundary points at the spatial domain edges at time `t_end`.
- In each training epoch, it computes a weighted sum of three losses:
  - PDE residual loss at collocation points,
  - Initial condition loss at `t_start`,
  - Boundary condition loss at boundaries and `t_end`.
- They are weighted that way after trying to find the best balanced ratio to produce accurate results.  
- Uses Adam optimizer to minimize this composite loss.
- Prints loss every 500 epochs for a total of 4000 epochs.
- Returns the trained model after the specified number of epochs.



---

In [None]:
if __name__ == "__main__":
    # Train the PINN model
    print("Training PINN...")
    pinn_model = train_pinn(t_min, t_max, ncase)

    # Create grid for visualization
    t_grid = np.linspace(t_min, t_max, N_t)
    X, T = np.meshgrid(np.linspace(xmin, xmax, N_x), t_grid, indexing='ij')

    # Predict solution on the grid
    U_pred = np.zeros((N_x, N_t))
    for j, t_val in enumerate(t_grid):
        x_tensor = torch.tensor(np.linspace(xmin, xmax, N_x), dtype=torch.float32).reshape(-1, 1)
        t_tensor = torch.full_like(x_tensor, t_val)
        with torch.no_grad():
            U_pred[:, j] = pinn_model(x_tensor, t_tensor)[:, 0].numpy()

    # Plot heatmap of u(x,t)
    plt.figure(figsize=(10, 6))
    plt.pcolormesh(T, X, U_pred, shading='auto', cmap='viridis')
    plt.colorbar(label='u(x,t)')
    plt.xlabel('Time (t)')
    plt.ylabel('Space (x)')
    plt.title(f'FHN Equation Solution (PINN, Case {ncase})')
    plt.tight_layout()
    plt.show()

Training PINN...
Epoch 0: Loss = 2.0780e+01
Epoch 500: Loss = 6.6864e-01
Epoch 1000: Loss = 1.3532e-01
Epoch 1500: Loss = 7.6801e-02
Epoch 2000: Loss = 5.6128e-02




---
After training, evaluates the model on a mesh grid of space and time.

Uses pcolormesh to create a heatmap showing the solution evolution.

Labels axes and adds a colorbar for clarity.


---



