<a href="https://colab.research.google.com/github/aidancrilly/AIMSLecture/blob/main/2026/Project_TFPINN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Physics Informed Neural Networks for Non-Linear Poisson Equation (Thomas Fermi Equation)


The Thomas-Fermi equation describes a compressed neutral atom semi-classically and is derived from Poisson's equation. It takes the form a second order, non-linear ordinary differential equation:

$$ \frac{d^2 y}{d x^2} = \frac{(y x_0)^{\frac{3}{2}}}{x^{\frac{1}{2}}} $$

Subject to the following boundary conditions:

$$ y(0) = 1 $$
$$ \frac{dy}{dx}(1) = y(1) $$

Where $x_0$ is a constant (the radius of the "ion-sphere").

We are interested in a PINN solution to the Thomas-Fermi equation.

First, we consider how to construct a PDE loss term.

It is clear from the Thomas-Fermi equation and the boundary condition at $x = 0$, that the second derivative of $y$ diverges at $x = 0$. Therefore, we choose a loss function to minimise the contribution from points approaching $x = 0$:

$$ \mathcal{L}(y,y'',x,x_0) = \left(\frac{x^{\frac{1}{2}+m}}{x_0^{\frac{3}{2}}}\frac{d^2 y}{d x^2}  - x^m y^{\frac{3}{2}}\right)^2 $$

Where $m$ is some positive real number, e.g. $\frac{1}{2}$.

Next, we consider how we can enforce the boundary conditions by construction - rather than including them as loss terms.

To do this, we consider transformations applied to our PINN which enforce the boundary conditions by construction. A sensible choice is of the form:

$$ w(x, x_0) = A(x) + B(x) \mathcal{N} + C(x) \mathcal{N}'$$
$$ y(x, x_0) = t(w) $$

Where $t(x)$ is some non-linear function to maintain positivity (we will use $\exp(x)$). A, B and C are polynomial functions of $x$ only - these can be derived from the boundary conditions. $\mathcal{N}$ and $\mathcal{N}'$ are the PINN and its derivative w.r.t. $x$.

Next, we know that the leading order term of the solution as we approach $x = 0$ is as follows:

$$ y(x) \approx 1 + \alpha_1 x + \alpha_2 x^{3/2} + \mathcal{O}(x^2) $$

Where $\alpha_i$ are unknown constants. Therefore, we will provide other powers of $x$ to the inputs of the PINN to include this behaviour.

Finally, we know that the boundary value of $y(1)$ is a function of $x_0$ only, and therefore we make our PINN model a product of models:

$$ \mathcal{N}(x,x_0) = \mathcal{N}_{PINN} (x^{3/2},x,x_0) \mathcal{N}_{BC}(x_0) $$

In the following project, you will implement the above PINN solution and explore its behaviour.

_Thanks to Guzman Sanchez Gonzalez for his work on PINNs for Thomas-Fermi which acted as basis for this project._

## Problem statement:

Your task is to create a PINN capable of solving the Thomas-Fermi equation for a range of $x_0$ values.

- We must derive suitable functions for $A(x)$, $B(x)$ and $C(x)$ to respect the boundary conditions
- We must construct suitable NN architectures for the PINN and BC components
- We must create a training method to train the PINN on varied $x$ and $x_0$ values

In [None]:
!pip install optax equinox
import jax
import jax.numpy as jnp
import jax.nn as jnn
import equinox as eqx
import matplotlib.pyplot as plt

### Deriving the boundary conditions

We wish to derive suitable functions that enforce our boundary conditions. To do this we use the following:

$$ w(x, x_0) = A(x) + B(x) \mathcal{N} + C(x) \mathcal{N}'$$
$$ y(x, x_0) = \exp(w) $$

With boundary conditions:

$$ y(0) = 1 $$
$$ \frac{dy}{dx}(1) = y(1) $$

For any $x_0$. Which for $w$ are:

$$ w(0) = 0 $$
$$ \frac{dw}{dx} = 1 $$

These lead to the following constraints on A, B and C:

$$ A(0) = 0, \ B(0) = 0, \ C(0) = 0 $$
$$ A'(1) = 1, \ B'(1) = 0, \ B(1) = -C'(1), \ C(1) = 0 $$

A suitable set of functions which statisfy all these constraints are:

$$ A(x) = \frac{1}{2}x^2 $$
$$ B(x) = 1 - (1 - x)^2 $$
$$ C(x) = x (1 - x) $$

### PINN model

Below you will complete the following PINN model implemented using jax and equniox. Reference the equinox docs (and course materials) for NN implementations:

- [Equinox Documentation](https://docs.kidger.site/equinox/)

In [None]:
class TFPINNModel(eqx.Module):
    mlp_PINN: eqx.nn.MLP
    mlp_BC: eqx.nn.MLP
    m: float

    def __init__(self, width_size, depth, m, *, key, **kwargs):
        super().__init__(**kwargs)
        self.m = m
        key, subkey = jax.random.split(key)

        # Create optimal NN architecture
        # To be completed as exercise
        # Use eqx MLP like in tutorial
        base_model =
        # Initialize the final linear layer to zero
        where = lambda m: m.layers[-1].weight
        self.mlp_PINN = eqx.tree_at(where, base_model, jnp.zeros_like(base_model.layers[-1].weight))

        key, subkey = jax.random.split(key)
        # To be completed as exercise
        # Use eqx MLP like in tutorial
        self.mlp_BC =

    def boundary_condition_constraints(self,x):
        # To be completed as exercise
        A =
        B =
        C =
        return A, B, C

    def power_input(self, x, x0):
        return jnp.sqrt(x)*x

    def PINN(self, x, x0):
        return self.mlp_PINN(jnp.array([self.power_input(x,x0), x, jnp.log10(x0)]))*self.mlp_BC(jnp.log10(x0))

    def dPINN(self, x, x0):
        # To be completed as exercise
        # Compute grad of PINN
        dPINNdx =
        return dPINNdx

    def y_NN(self, x, x0):
        A, B, C = self.boundary_condition_constraints(x)
        N = self.PINN(x, x0)
        Nprime = self.dPINN(x, x0)
        w = A + B * N + C * Nprime
        # Choose suitable transformation of w to ensure positivity, e.g. exponential
        # This needs to be consistent with the boundary conditions
        # To be completed as exercise
        y = # Some function of w
        return y

    def y(self, x, x0):
        return self.y_NN(x, x0)

    def dydx(self, x, x0):
        return jax.grad(self.y)(x, x0)

    def d2ydx2(self, x, x0):
        return jax.grad(jax.grad(self.y))(x, x0)

    def __call__(self, x, x0):
        return eqx.filter_jit(eqx.filter_vmap(self.y))(x,x0)

    def residual_loss(self, x, x0):
        # To be completed as exercise
        loss =
        return loss

### Training data creation

Training data simply consists of randomly sampled $x$ and $x_0$ values. Below is a simple implementation that performs uniform sampling - however other sampling schemes can be considered.

In [None]:
def create_training_data(Nsamples,x0_range,x_lower=1e-3, key = jax.random.PRNGKey(42)):
    key,subkey = jax.random.split(key)
    xs = x_lower + (1.0-x_lower)*jax.random.uniform(subkey,shape=(Nsamples,))
    key,subkey = jax.random.split(key)
    x0s = x0_range[0] + (x0_range[1]-x0_range[0])*jax.random.uniform(subkey,shape=(Nsamples,))
    return xs, x0s

### Training loop

Below you should complete the training loop for the PINN model which makes use of jax and optax - make refernece to the course materials and optax documentation.

- [Optax Documentation](https://optax.readthedocs.io/en/latest/)

In [None]:
import optax
from jax.flatten_util import ravel_pytree

def train_TFPINN(TFPINN, xs, x0s, batch_size, num_epochs, learning_rate, key = jax.random.PRNGKey(42), print_every=10):
    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(eqx.filter(TFPINN, eqx.is_inexact_array))

    def batch_loss_fn(TFPINN, xs, x0s):
        # Compute the mean residual loss over the batch
        batched_loss = eqx.filter_vmap(TFPINN.residual_loss)
        loss = jnp.mean(batched_loss(xs, x0s))
        return loss

    @eqx.filter_jit
    def update(TFPINN, opt_state, batch_xs, batch_x0s):
        # Compute loss and gradients
        loss, grads = eqx.filter_value_and_grad(batch_loss_fn)(TFPINN, batch_xs, batch_x0s)
        flat_grad, _ = ravel_pytree(grads)
        grad_norm = jnp.linalg.norm(flat_grad)
        updates, opt_state = optimizer.update(grads, opt_state)
        TFPINN = eqx.apply_updates(TFPINN, updates)
        return TFPINN, opt_state, loss, grad_norm

    history = []
    for epoch in range(num_epochs):
        loss = []
        grad_norm = []
        for _ in range(len(xs)//batch_size):
            # Set up batching, to be completed
            # Look at jax.random.choice for help
            key, subkey = jax.random.split(key)
            batch_xs =
            key, subkey = jax.random.split(key)
            batch_x0s =
            TFPINN, opt_state, _loss, _grad_norm = update(TFPINN, opt_state, batch_xs, batch_x0s)
            loss.append(_loss)
            grad_norm.append(_grad_norm)
        loss = jnp.mean(jnp.array(loss))
        grad_norm = jnp.mean(jnp.array(grad_norm))
        if epoch % print_every == 0:
            print(f"Epoch {epoch}, Loss: {loss}, Grad Norm: {grad_norm}")
        history.append(loss)

    return TFPINN, history


## Problems

- Train a PINN for $x_0 = 1$, what is the shape of the solution as function of $x$ and what is value at $x = 1$? What is the lowest loss value obtained?
- Train a PINN for $x_0$ in range $[0.5,5]$, how does the solution shape change with $x_0$? What is the lowest loss value obtained?

In [None]:
# Training data hyperparameters
Nsamples = # Number of samples to create, e.g. int(1e6)
x0_range = # List of two floats defining the range of x0 values

xs, x0s = create_training_data(Nsamples,x0_range,x_lower=1e-5)

In [None]:
# Model hyperparameters
# Pick suitable values (e.g. WIDTH_SIZE=16, DEPTH=3, M_LOSS=0.5)
WIDTH_SIZE =
DEPTH =
M_LOSS =

TFPINN = TFPINNModel(width_size=WIDTH_SIZE, depth=DEPTH, m=M_LOSS, key=jax.random.PRNGKey(1234))

In [None]:
# Training hyperparameters
# Note model evaluation is very cheap so batch_size can be large, with right model hyperparameters training will be very fast
BATCH_SIZE =
NUM_EPOCHS =
LEARNING_RATE =

TFPINN, history = train_TFPINN(TFPINN, xs, x0s, BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE, print_every=1)

## Problem:
- What pathological solutions can the PINN fall into? Try a large value of $m$ to find these. Why does this occur?
- Can you forsee an issue with large $x_0$ values? e.g. 10000?

## Extension problems:

- Can you train over a larger range of $x_0$ values? What changes to the model/training/data need to be made?
- (Difficult) The Thomas-Fermi-Dirac equation in an extension to Thomas-Fermi theory. The same boundary conditions apply but the equation is modified to the following:

$$ \frac{d^2 y}{d x^2} =  x \left(\epsilon x_0 + \sqrt{\frac{y x_0}{x}}\right)^3$$

Where $\epsilon$ is an additional constant (with value $\lesssim 0.2$). What modifications need to be made to the model/training/data? Show results for a range of $x_0$ and $\epsilon$ values.