<a href="https://colab.research.google.com/github/aidancrilly/MiniCourse-DifferentiableSimulation/blob/main/02_DifferentiableSimulatorsAndOptimisation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Interactive Exercise

Resources:

- JAX [documentation](https://jax.readthedocs.io/en/latest/quickstart.html)
- Patrick Kidger "On Neural Differential Equations" [ArXiv link](https://arxiv.org/abs/2202.02435)

Note: need to install diffrax and optax libraries as not installed by default on colab

In [None]:
!pip install diffrax optax
import diffrax
import optax
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax

# Part 1

For the following exercises we are going to introduce two JAX libraries which will allow us to easily write and train differentiable simulators:

1. diffrax - https://docs.kidger.site/diffrax/
2. optax - https://optax.readthedocs.io/en/latest/getting_started.html

These libraries allow for the numerical forward and adjoint solution to differential equations (diffrax) and the solving of optimisation problems (optax).

First, we introduce diffrax. The key features of diffrax we need to use are:

1. ODETerms, these wrap python functions of the form:
```
def dydt(t : float, y : JAX array, args : dict):
    return dydt_val : JAX array
```
which return the right hand side of your differential equation.

2. Solvers, these implement different numerical methods for evolving in time. For example, Euler's method is available as well as other higher order methods.

3. Step size controllers, the solvers we will use are adaptive (they vary the time step used) so these controllers are used to decide what step size to take to obtain a prescribed accuracy

4. SaveAts, these simply provide a number of points in time when the numerical solution should be saved and returned to the user.

Using these features, diffrax numerically solves the following equation:

$$
y(t_1) = y(t_0) + \int_{t_0}^{t_1} \frac{dy}{dt} dt
$$

and thus we must also specify the time interval $[t_0,t_1]$ and the intitial conditions $y(t_0) \equiv y_0$.

## Part a: Solving differential equations in JAX

We will first learn to use the key features of diffrax using the exponential decay ODE example from the previous exercise.


In [None]:
def dydt(t,y,args):
  return -y/args['tau']

def diffrax_solve(dydt,t0,t1,Nt,rtol=1e-5,atol=1e-5):
  """
  Here we wrap the diffrax diffeqsolve function such that we can run with
  different y0s and taus over the same time interval easily
  """
  # We convert our python function to a diffrax ODETerm
  term = diffrax.ODETerm(dydt)
  # We chose a solver (time-stepping) method from within diffrax library
  # Heun's method (https://en.wikipedia.org/wiki/Heun%27s_method)
  solver = diffrax.Heun()

  # At what time points you want to save the solution
  saveat = diffrax.SaveAt(ts=jnp.linspace(t0,t1,Nt))
  # Diffrax uses adaptive time stepping to gain accuracy within certain tolerances
  stepsize_controller = diffrax.PIDController(rtol=rtol, atol=atol)

  return lambda y0,tau : diffrax.diffeqsolve(term, solver,
                         y0=y0, args = {'tau' : tau},
                         t0=t0, t1=t1, dt0=(t1-t0)/Nt,
                         saveat=saveat, stepsize_controller=stepsize_controller)

t0 = 0.0
t1 = 1.0
Nt = 100

ODE_solve = diffrax_solve(dydt,t0,t1,Nt)

# Solve for specific y0 and tau
y0 = 1.0
tau = 0.5
sol = ODE_solve(y0,tau)

plt.plot(sol.ts,sol.ys)
plt.plot(sol.ts,y0*jnp.exp(-sol.ts/tau),'k--')
plt.legend(['numerical','exact'])
plt.show()

Diffrax solutions are differentiable by construction - see https://docs.kidger.site/diffrax/api/adjoints/ for details.

We can therefore very easily solve the adjoint state problem using diffrax and JAX AD:

In [None]:
def loss(inputs):
  y0 = inputs['y0']
  tau = inputs['tau']
  sol = ODE_solve(y0,tau)
  return sol.ys[-1]

inputs = {'y0' : y0, 'tau' : tau}
# Returns gradient of loss with respect to all inputs, i.e. dLdtau and dLdy0
jax.grad(loss)(inputs)

## Part b: Optimisation in JAX

As means as of an introduction to optax, we will solve a very simple optimisation problem:

$$
\mathrm{argmin}_p f(p) = \underline{p}^T \cdot \underline{\underline{d}} \cdot \underline{p}
$$

which is trivially solvable with $\underline{p} = \underline{0}$. We will also use this oppurtunity to make use of JAX's random number generators.

Below is some code to set up the problem for randomly generated $\underline{\underline{d}}$ and starting location for the optimisation problem $\underline{p}_0$.

In [None]:
# RNG initialisation
key = jax.random.key(0)

def example_loss(p,d):
  return jnp.dot(p.T,jnp.dot(d,p))

# Dimension of input space
Np = 10
# Convex shape of input space
d = jax.random.normal(key,shape=(Np,Np))
# Random positive semi-definite
d = jnp.dot(d,d.T)
# Random starting location
key, subkey = jax.random.split(key)
p0 = jax.random.normal(subkey,shape=(Np,))
print(f'Starting loss: {example_loss(p0,d)}')


Optax defines the optimisation workflow through two important components:

1. The optimizer: this defines the optimizer algorithm and uses the gradients of the loss (combined with the optimizer hyperparameters) to update the trainable parameters
2. The optimizer state: this defines the trainable parameters

In the code below, we define these optax components for our simple convex optimisation problem. We also define the gradient of the loss via AD. Finally, we introduce a training loop which iteratively updates the parameters via the adam optimiser.

In [None]:
# Initialize parameters of the model + optimizer.
learning_rate = 1e-1
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(p0)

# A simple update loop
Nepoch = 200
grad_loss = jax.value_and_grad(example_loss)
p = p0.copy()
history = []
for _ in range(Nepoch):
  loss,grads = grad_loss(p, d)
  # Optax optimizer uses the gradients to update the parameters and the optimiser state
  updates, opt_state = optimizer.update(grads, opt_state)
  p = optax.apply_updates(p, updates)
  history.append(loss)

Plotting the loss history, we see it converging towards the global optimum at 0.

In [None]:
plt.semilogy(history)
plt.xlabel('Epoch')
plt.ylabel('Loss')
print(f'Final p mag: {jnp.linalg.norm(p)}')

# Neural Differential Equations

In the most simple form, a neural differential equation (NDE) has the following structure:

$$
\frac{d y}{d t} = \mathcal{N}_\theta(y,t)
$$

where $\mathcal{N}_\theta(y,t)$ is a neural network with parameters $\theta$.

The numerical solution can be computed using traditional techniques used for ODEs and PDEs. The adjoint solution can then be used to train the neural network to improve the solution. Due to the flexibility of neural networks, NDEs can describe a wide variety of systems in a time-continuous manner.

In the following we will consider a simple example where we will use a NDE to describe the response of an electrical circuit for which we have collected data but do not know its internal structure.

For neural networks in JAX we will use equinox:

- equinox [documentation](https://docs.kidger.site/equinox/)

and then diffrax and optax can be used in the same way as before to solve both the forward and adjoint problems.

In [None]:
!pip install equinox
import equinox as eqx
import jax.nn as jnn
import jax.random as jrandom

Given the theory of electical circuits, we propose a simple NDE of the form:

$$ \frac{d}{dt} \begin{bmatrix}I \\ h\end{bmatrix} = \begin{bmatrix}h \\ \mathcal{N}_\theta(I,h)\end{bmatrix} + \begin{bmatrix}0 \\ \frac{dV_{ext.}}{dt}\end{bmatrix}$$

Where $I$ is the electical current and $V_{ext.}$ is the externally applied voltage. We introduce a hidden state $h$ which linearly responds to the temporal gradient of the applied voltage. The neural network then computes the temporal response of the hidden state given the solution values at time $t$.

First, we load in the training data, a data set of ($t$, $I$) for a oscillating applied voltage $V(t,\omega,T)$ at various $\omega$ values. The currents ($I$) have been corrupted by noise and therefore our NDE model must be robust to this noise to be useful.

In [None]:
T = 10.0 #s for all data sets

# Upload the data provided in course git repo to the sample_data folder
# <- Open the colab file browser by the folder icon of the left
omegas = jnp.load('./sample_data/omegas.npy')
ts = jnp.load('./sample_data/ts.npy')
Is = jnp.load('./sample_data/Is.npy')

# Save only one data set for testing in this simple example
ts_train, Is_train, omegas_train = ts[:], Is[:-1,:], omegas[:-1]
ts_test, Is_test, omegas_test    = ts[:], Is[-1:,:], omegas[-1:]

# Applied voltage function
def V(t,omega,T):
    """
    Applied voltage function, frequency increases quadratically in time
    """
    omega0 = omega*(2*t/T)
    return jnp.sin(omega0*t)

# To be completed
# Create a gradient function of V, dVdt
# Also create a vmapped version, which maps over the time dimension (see in_axes argument of jax.vmap)
dVdt =
dVdt_vmapped =

Lets quickly view the data and the forcing term (dV/dt):

In [None]:
fig = plt.figure(dpi=100)
ax1 = fig.add_subplot(211)
ax2 = fig.add_subplot(212,sharex=ax1)
ax1.plot(ts,Is.T,'k')
for omega in omegas:
  ax2.plot(ts,dVdt_vmapped(ts,omega,T),'r')
ax1.set_xlim(ts[0],ts[-1])
ax2.set_xlabel('t')
ax1.set_ylabel('I')
ax2.set_ylabel('dVdt')

Next, we set up an equinox Module that contains an MultiLayer Perceptron (MLP) for $\mathcal{N}_\theta(I,h)$ and its call method evaluated the right hand side of the full NDE.

In [None]:
# NDE solution
# Following https://docs.kidger.site/diffrax/examples/neural_ode/
class NeuralODE_RHS(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, in_size, out_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=in_size,
            out_size=out_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.leaky_relu,
            key=key,
            use_bias=False,
            use_final_bias=False
        )

    def __call__(self, t, y, args):
        # To be completed
        dhdt =
        dIdt =
        # Stack the temporal responses into dydt
        dydt =
        return dydt

class NeuralODE(eqx.Module):
    in_size: int
    func: NeuralODE_RHS

    def __init__(self, in_size, out_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.in_size = in_size
        self.func = NeuralODE_RHS(in_size, out_size, width_size, depth, key=key)

    def __call__(self, ts, omega, args):
        """
        Similar to our examples above, we set up diffeqsolve
        but now our ODETerm uses our NeuralODE_RHS equinox Module.
        """
        # Add relevant data to the args dictionary
        args['omega'] = omega
        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Heun(),
            t0=ts[0],
            t1=ts[-1],
            dt0=ts[1] - ts[0],
            y0=jnp.zeros(self.in_size),
            args=args,
            stepsize_controller=diffrax.PIDController(rtol=1e-2, atol=1e-4),
            saveat=diffrax.SaveAt(ts=ts),
            max_steps=int(1e6)
        )
        return solution.ys

Now we will train our NDE using the adjoint method and the available training data of ($t$,$I$,$\omega$).

The training process will look very similar to what you implemented above in the diffusion example. The only thing to notice is that break the learning up over increasingly long sections of the data. So we train on the first x time steps for n epochs and then the next y time steps for m epochs, etc.. This is a standard trick to avoid getting caught in a local minimum.

There are a few lines to complete in order to run the training of the NDE.

In [None]:
import time

def train_NDE(
    ts,ys,omegas,args,
    width_size,depth,
    optimiser,
    lr_strategy,
    steps_strategy,
    length_strategy,
    seed=420,
    print_every=50
):
    """
    ts : jax array of data times
    ys : jax array of data currents
    omegas : jax array of data drive frequencies
    width_size : number of neurons in hidden layers
    depth : number of hidden layers
    optimiser : optax optimiser to be used
    lr_strategy : learning rate schedule in tuple
    steps_strategy : number of training steps in tuple
    length_strategy : fraction of training data used in training in tuple
    seed : PRNG seed value
    print_every : print every n steps of training
    """
    key = jrandom.PRNGKey(seed)
    __, model_key = jrandom.split(key)

    length_size = ts.shape[0]

    # To be completed, what is the shape of our input and output layer
    in_size =
    out_size =

    model = NeuralODE(in_size, out_size, width_size, depth, key=model_key)

    @eqx.filter_value_and_grad
    def grad_loss(model, ti, yi, omegas):
        # Handle batch
        batched_model = jax.vmap(model,in_axes=(None,0,None))
        y_pred = batched_model(ti, omegas, args)
        # Only compute the MSE using the NDE current (I) prediction
        # To be completed
        MSE =
        return MSE

    @eqx.filter_jit
    def make_step(ti, yi, omegas, model, opt_state):
        loss, grads = grad_loss(model, ti, yi, omegas)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    count = 0
    for lr, steps, length in zip(lr_strategy, steps_strategy, length_strategy):
        optim = optimiser(lr)
        opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
        _ts = ts[: int(length_size * length)]
        _ys = ys[:, : int(length_size * length)]
        for step in range(steps):
            count += 1
            start = time.time()
            loss, model, opt_state = make_step(_ts, _ys, omegas, model, opt_state)
            end = time.time()
            if (step % print_every) == 0 or step == steps - 1:
                print(f"Step: {step}, Loss: {loss}, Computation time: {end - start}")

    return model

We will make a very small neural network model in this case, both for quick training and for regularisation such that we do not overfit the noisy data.

In [None]:
# MLP hyperparameters
width_size = 4  # Number of neurons in hidden layers
depth = 2       # Number of hidden layers (including output layer)

optimiser = optax.adabelief

NDE_args = {'dVdt' : dVdt, 'T' : T}

# Training should take just a 1-2 mins
trained_model = train_NDE(ts_train,Is_train,omegas_train,
                          NDE_args,
                          width_size,depth,
                          optimiser,
                          lr_strategy=(1e-2,1e-2,1e-3),
                          steps_strategy=(200,200,400),
                          length_strategy=(0.5,1.0,1.0))

Now, we have trained our NDE we can test its results at an unseen value of $\omega$.

Below, write code to compute the NDE solution at all omegas (including the test data) and compare the predicted and true current values.

In [None]:
# Only omega changes between data samples
compute_trained_model_pred = lambda omega : trained_model(ts, omega, NDE_args)

# Plot trained model against train and test data

With the tuned hyperparameters above, this NDE can capture the behaviour of the electrical circuit with high accuracy on both training and test data.

# Physics-Informed Neural Network

The following demonstrates the use of a Physics-Informed Neural Network (PINN) to solve the 1D diffusion equation: 

$$ \frac{\partial u}{\partial t} = D \frac{\partial^2 u}{\partial x^2} $$

This is a common example in physics and engineering, and PINNs provide a way to solve such differential equations using neural networks, by encoding the equation itself into the loss function of the network.

Key concepts demonstrated:
- A neural network is used to approximate the solution u(x, t).
- The loss function is composed of two parts:
    1. Mean Squared Error (MSE) loss: This ensures the solution fits any available "ground truth" data points.
    2. Physics loss: This ensures the solution obeys the diffusion equation. This loss is calculated from the residual of the PDE.
- Automatic differentiation (as provided by JAX) is used to compute the derivatives of the neural network's output with respect to its inputs (x and t),which is essential for calculating the physics loss.


First, we define the analytic diffusion solution from which we will draw training data

In [None]:
# Diffusivity constant for the diffusion equation.
D = 1.0

def diffusion_solution(t, x):
    """
    Analytical solution to the 1D diffusion equation for a Dirac delta
    function initial condition at t=0. This serves as our ground truth for
    generating training data and for comparison.

    Args:
        t (jax.numpy.ndarray): Time coordinates.
        x (jax.numpy.ndarray): Spatial coordinates.

    Returns:
        jax.numpy.ndarray: The value of the solution u(x, t).
    """
    return jnp.exp(-x**2 / (4 * D * t)) / jnp.sqrt(4 * jnp.pi * D * t)

Next, we define our PINN with a MLP

In [None]:
class PINN(eqx.Module):
    """
    The Physics-Informed Neural Network model.

    This is a simple Multi-Layer Perceptron (MLP) that takes spatial (x) and
    temporal (t) coordinates as input and outputs the predicted value of the
    solution u(x, t).
    """
    mlp: eqx.nn.MLP

    def __init__(self, in_size, out_size, width_size, depth, *, key, **kwargs):
        """
        Initializes the MLP.

        Args:
            in_size (int): Input size (2 for x and t).
            out_size (int): Output size (1 for u).
            width_size (int): Number of neurons in each hidden layer.
            depth (int): Number of hidden layers.
            key (jax.random.PRNGKey): JAX random key for initialization.
        """
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=in_size,
            out_size=out_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.tanh,
            key=key,
        )

    def __call__(self, x, t):
        """
        Performs a forward pass of the network.

        Args:
            x (float): Spatial coordinate.
            t (float): Temporal coordinate.

        Returns:
            float: The predicted value of the solution u(x, t).
        """
        # The inputs are concatenated to form a single input vector for the MLP.
        input_vec = jnp.array([x, t])
        y_PINN = self.mlp(input_vec)
        return y_PINN.reshape(())


Finally, we define our training function which includes our definition of both the loss terms.

In [None]:
def train_PINN(
    training_ts, training_xs, training_sol,
    lr_strategy=(1e-3,),
    steps_strategy=(2000,),
    width_size=32,
    depth=3,
    seed=5678,
    print_every=50,
):
    """
    Trains the PINN model.

    Args:
        training_ts (jax.numpy.ndarray): Time coordinates for training data.
        training_xs (jax.numpy.ndarray): Spatial coordinates for training data.
        training_sol (jax.numpy.ndarray): Solution values at training points.
        lr_strategy (tuple): Tuple of learning rates for the optimizer.
        steps_strategy (tuple): Tuple of number of training steps for each learning rate.
        width_size (int): Width of the neural network.
        depth (int): Depth of the neural network.
        seed (int): Random seed.
        print_every (int): How often to print loss values.

    Returns:
        PINN: The trained model.
    """
    key = jrandom.PRNGKey(seed)
    __, model_key = jrandom.split(key)

    in_size = 
    out_size =
    model = PINN(in_size, out_size, width_size, depth, key=model_key)

    @eqx.filter_jit
    def PDE_loss(model, ti, xi):
        """
        Calculates the physics-based loss.

        This function computes the residual of the diffusion equation.
        The goal of the training is to minimize this residual, effectively
        forcing the neural network to satisfy the physics of the problem.
        """
        # Use jax.grad to compute the derivatives of the model's output.
        # `jax.vmap` is used to apply the function over the batch of inputs.

        # ∂²u/∂x²
        d2ydx2 = jax.vmap(jax.grad(jax.grad(model, argnums=0), argnums=0))
        # ∂u/∂t
        dydt = 

        # The residual of the PDE: ∂u/∂t - D * ∂²u/∂x²
        # For a perfect solution, this would be zero.
        g_PDE = dydt(xi, ti) - D * d2ydx2(xi, ti)

        # We return the mean squared residual.
        return jnp.mean(g_PDE**2)

    @eqx.filter_jit
    def MSE_loss(model, ti, xi, yi):
        """
        Calculates the Mean Squared Error (MSE) loss.

        This is the "data-driven" part of the loss. It measures how well the
        network's prediction matches the provided training data.
        """
        y_pred = jax.vmap(model)(xi, ti)
        MSE = 
        return MSE

    @eqx.filter_value_and_grad
    def grad_loss(model, ti, xi, yi):
        """
        Calculates the total loss and its gradient.

        The total loss is a sum of the MSE loss and the PDE loss. This is the
        core idea of a PINN: the model learns to satisfy both the data and the
        underlying physics simultaneously.
        """
        mse_loss = MSE_loss(model, ti, xi, yi)
        pde_loss = PDE_loss(model, ti, xi)
        return mse_loss + pde_loss

    @eqx.filter_jit
    def make_step(ti, xi, yi, model, opt_state):
        """
        Performs a single optimization step.
        """
        loss, grads = grad_loss(model, ti, xi, yi)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    # Training loop
    history = []
    count = 0
    for lr, steps in zip(lr_strategy, steps_strategy):
        optim = optax.adabelief(lr)
        opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
        for step in range(steps):
            count += 1
            loss, model, opt_state = make_step(training_ts, training_xs, training_sol, model, opt_state)
            history.append(loss)

            if (step % print_every) == 0 or step == steps - 1:
                print(f"Step: {step}, MSE Loss: {MSE_loss(model, training_ts, training_xs, training_sol)}, PDE Loss: {PDE_loss(model, training_ts, training_xs)}")

    return model, history


Running the code below with initialise and train our PINN on data pulled from the analytic solution

In [None]:
# Define the spatial and temporal domain for plotting.
extent = 4.0

# Generate training data.
# We create a set of random points in space and time and use the analytical
# solution to get the "ground truth" values at these points.
data_seed = 404
key = jrandom.PRNGKey(data_seed)
Ntrain = 100

training_xs = jax.random.uniform(key, shape=(Ntrain,), minval=-extent, maxval=extent)
__, key = jrandom.split(key)
training_ts = jax.random.uniform(key, shape=(Ntrain,), minval=0.5, maxval=1.5)

# Get the solution at the training points.
training_ys = diffusion_solution(training_ts, training_xs)

# Train the PINN model.
NDE_model, history = train_PINN(training_ts, training_xs, training_ys)

In [None]:
# Plot training history
fig = plt.figure(dpi=100)
plt.semilogy(history)
plt.xlabel('Training Step')
plt.ylabel('Total Loss')
plt.title('Training History of PINN')

# Plot solution
plot_ts = jnp.linspace(0.01, 1.5, 100)
plot_xs = jnp.linspace(-extent, extent, 200)
plot_Ts, plot_Xs = jnp.meshgrid(plot_ts, plot_xs)
plot_sol = diffusion_solution(plot_Ts, plot_Xs)
plot_PINN_sol = jax.vmap(NDE_model)(plot_Xs.flatten(), plot_Ts.flatten()).reshape(plot_Xs.shape)

fig = plt.figure(dpi=100)
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)

c1 = ax1.pcolormesh(plot_Ts, plot_Xs, plot_sol, shading='auto')
fig.colorbar(c1, ax=ax1)
c2 = ax2.pcolormesh(plot_Ts, plot_Xs, plot_PINN_sol, shading='auto')
fig.colorbar(c2, ax=ax2)

# Key takeaways

- Differentiable simulators require the following components:
  1. A pre-defined ODE or PDE structure
  2. A number of trainable parameters which are not a-priori known
  3. A numerical scheme for solution (finite differencing, time-stepping, etc.)
  4. A loss function which defines an optimal model
  5. A means to minimise the loss (adjoint solver, optimiser, etc.)
- Diffrax is a JAX library for the numerical solution (both forward and adjoint) of differential equations
- Optax is a JAX library for optimisation which can be easily interfaced with diffrax
- One can construct differential equations which include neural networks. Equinox is a JAX library which can be used for this task.