# PINN basics in JAX

## Import packages

In [None]:
import time
import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
import optax
import matplotlib.pyplot as plt

## Note on packages

1. JAX is supposed to be very fast owing to its features like just-in-time (jit) compilation and vectorized mapping of function (vmap). It is highly modular and has a steeper learning curve as compared to PyTorch
2. JAX differs from PyTorch in terms of many options being available to build NNs (equinox and linen) and optimize (optax and flax) them. On the contrary, PyTorch is a single unified multi-purpose and user friendly platform
3. Flax and Linen were designed to mimic the style of PyTorch
4. Here, we will be using equinox (neural nets focused) and optax (optimization focused)

## Problem definition

1D Poisson equation with homogeneous Dirichlet boundary conditions on the unit interval
$$ \\ $$
$$ \frac{\partial^2 u}{\partial x^2} = -f(x), x \in (0,1) $$
$$ \\ $$
$$ u(0) = u(1) = 0 $$

## Problem setup

By the "observed" universal approximation capabilities of the Neural Network (NN), we will approximate $u(x)$, the function that we would like to find with a NN parameterized by weights $\theta$ as $\hat{u}_{\theta}$.
$$ \\ $$
Analogous to the form of the PDE, we can cook up a loss function term for a point labelled by $x_{i}$ in the domain as:
$$ \mathcal{L}_{i}(\theta) = (\frac{\partial^2 \hat{u}_{\theta}}{\partial x^2} + f(x))^2 \vert_{x=x_{i}} $$
$$ \\ $$
The loss function part arising from the PDE can thus be written as:
$$ \mathcal{L}_{\text{PDE}}(\theta)  = \sum_{i=0}^{L} \mathcal{L}_{i}(\theta)  = \sum_{i=0}^{L} (\frac{\partial^2 \hat{u}_{\theta}}{\partial x^2} + f(x))^2 \vert_{x=x_{i}} $$
$$ \\ $$
We will also have a "conventional" loss function part arising from the data and since here only the boundary condition data is provided, this will be the boundary condition loss $\mathcal{L}_{\text{BC}}(\theta) $. This part of the loss function is simply the "vanilla" sums of squares given by: $$ \mathcal{L}_{\text{BC}}(\theta) = \frac{1}{2}(\hat{u}_{\theta}(0) - u(0))^2 + \frac{1}{2}(\hat{u}_{\theta}(1) - u(1))^2 $$
$$ \\ $$
The total loss $\mathcal{L}(\theta)$ with regularization $\lambda_{BC}$ is given by:
$$ \mathcal{L}(\theta) = \mathcal{L}_{\text{PDE}}(\theta) + \lambda_{BC} \mathcal{L}_{\text{BC}}(\theta) $$
$$ \\ $$
The optimization problem is then given by:
$$ \theta^{*} =\text{arg} \min_{\theta} \mathcal{L}(\theta) $$

## Set up hyperparameters

In [None]:
n_dof_fd = 100
L = 50
learning_rate = 1e-3
n_epochs = 5000
bc_loss_weight = 100

## Generate mesh

In [None]:
mesh_full = jnp.linspace(0.0, 1.0, n_dof_fd + 2)
mesh_interior = mesh_full[1:-1]

In [None]:
mesh_full

In [None]:
mesh_full.size

## Define our function f(x)

In [None]:
rhs_function = lambda x: jnp.where((x > 0.3) & (x < 0.5), 1.0, 0.0)

## Plot f(x)

In [None]:
plt.figure()
plt.plot(mesh_full[1:-1], rhs_function(mesh_full[1:-1]), label="Forcing function")
plt.xlabel("x")
plt.legend()

## Reproducibility using random key

In [None]:
key = jr.PRNGKey(42)

## Set up PINN - coordinate based NN

In [None]:
key, init_key = jr.split(key)
pinn = eqx.nn.MLP(
    in_size="scalar",
    out_size="scalar",
    width_size=10,
    depth=4,
    activation=jax.nn.sigmoid,
    key = key,
)

## Notes on PINN setup

1. We are setting up a network that maps a scalar to a scalar
2. The design choices of the network are arbitrary
3. We are using a shallow network with 4 layers
4. The number of neurons per layer is 10
5. We are using the sigmoid activation function

## Generate initial prediction at x = 0.2

In [None]:
pinn(0.2)

## Apply PINN on mesh

In [None]:
pinn(mesh_full)

1. We get an error since the PINN was directed to take a scalar input
2. This "scalar" issue can be fixed using jax's functionality called "vmap" that allows us to essentially take a function that is designed to act on scalars and apply it to vectors

## Vectorized map functionality in jax - vmap

In [None]:
jax.vmap(pinn)(mesh_full)

## PDE residuum

In [None]:
def pde_residuum(network, x):
    return jax.grad(jax.grad(network))(x) + rhs_function(x)

In [None]:
pde_residuum(pinn, 0.8)

In [None]:
jax.vmap(pde_residuum, in_axes=(None, 0))(pinn, mesh_interior)

## Notes on vmap applied to PDE residuum

1. The pde_residuum is a function that takes as input the pinn and a scalar x
2. Since we want to evaluate pde_residuum over the mesh which is a vector we need use vmap
3. Since we do not want to apply the vectorization to our pinn input, we set the first entry of in_axes to "None"
4. Since we want to apply the vectorization to the the first axes of the scalar input x, we set the second entry to the index corresponding to the first axes which is zero

## Total loss function

In [None]:
def loss_fn(network):
    pde_residuum_total = jax.vmap(pde_residuum, in_axes=(None, 0))(network, mesh_interior)
    pde_loss_total = 0.5 * jnp.mean(jnp.square(pde_residuum_total))
    bc_loss = 0.5 * jnp.square(network(0.0) - 0.0) + 0.5 * jnp.square(network(1.0) - 0.0)
    loss_total = pde_loss_total + bc_loss_weight * bc_loss
    return loss_total

## Check initial loss of PINN

In [None]:
loss_fn(pinn)

## Training loop

In [None]:
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(eqx.filter(pinn, eqx.is_array))

@eqx.filter_jit
def make_step(network, optimizer_state):
    loss, grad = eqx.filter_value_and_grad(loss_fn)(network)
    network_updates, new_optimizer_state = optimizer.update(grad, optimizer_state, network)
    new_network = eqx.apply_updates(network, network_updates)
    return new_network, new_optimizer_state, loss

start_time = time.time()
loss_history = []
for epoch in range(n_epochs):
    pinn, opt_state, loss = make_step(pinn, opt_state)
    loss_history.append(loss)
    if epoch % 100 == 0:
        print(f"Epoch: {epoch}, loss: {loss}")
end_time = time.time()

## Print Execution time

In [None]:
print("Execution time: %s seconds" % (end_time - start_time))

## Notes on Training loop

1. We first create the ADAM optimizer with the learning rate specified in the hyperparameters section
2. Next, the weights and biases of the pinn are initialized - equinox applies a fliter to extract the PINN parameters that are arrays and then these are initialized using the "optimizer.init" command
3. The make step function performs a single optimization step
4. In the make step function, we start with computing the loss and its gradient with respect to the "trainable" parameters of the network, again extracted via the "filter" part in "filter_value_and_grad"
5. Next, we use the gradients to compute updates that need to be applied to the network (updated weights and biases). In addition this line of code also provides the new state of the ADAM optimizer (new running averages and squared values of the gradient)
6. Finally, we apply the network update to the network (change weights and biases)
7. The loop at the end of the section simply performs the optimization iteratively using make step

## Plot loss history

In [None]:
plt.figure()
plt.plot(loss_history)
plt.yscale("log")
plt.title("loss history")
plt.xlabel("epoch")
plt.ylabel("loss")

## Function to compute reference solution

In [None]:
def compute_reference_solution(mesh_full):
    mesh_interior = mesh_full[1:-1]
    rhs_evaluated = rhs_function(mesh_interior)
    dx = mesh_interior[1] - mesh_interior[0]
    A = jnp.diag(jnp.ones(n_dof_fd - 1), -1) + jnp.diag(jnp.ones(n_dof_fd - 1), 1) - jnp.diag(2*jnp.ones(n_dof_fd), 0)
    A /= dx**2
    finite_difference_solution = jnp.linalg.solve(A, -rhs_evaluated)
    wrap_bc = lambda u: jnp.pad(u, (1, 1), mode="constant")
    reference_solution = wrap_bc(finite_difference_solution)
    return reference_solution

## Plot solutions post network training

In [None]:
plt.figure()
plt.plot(mesh_full, compute_reference_solution(mesh_full), 'r', label="Reference solution")
plt.plot(mesh_full, jax.vmap(pinn)(mesh_full), 'b*', label='Final PINN solution')
plt.legend()
plt.grid()
plt.title("PDE solutions")
plt.xlabel("x")
plt.ylabel("u(x)")