<a href="https://colab.research.google.com/github/aidancrilly/AIMSLecture/blob/main/2026/Project_DPNDE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Neural Ordinary Differential Equations for Chaotic Systems

The double pendulum is a physical system which exhibits chaotic behaviour (for some initial conditions). As the name suggests, it consists of two masses connected by fixed rods which swing under gravity.

Below is a diagram of the double pendulum system:

![](https://dassencio.org/assets/double-pendulum.483b215e.svg)

We will consider the simplified system where $m_1 = m_2$ and $l_1 = l_2$.

The dynamics of this system are described by the equations of motion:

$$ \mathbf{y} = \begin{bmatrix} \theta_1 \\ \theta_2 \\ \dot{\theta}_1 \\ \dot{\theta}_2 \end{bmatrix} \ \ , \ \ \frac{d \mathbf{y}}{dt} = \begin{bmatrix} \dot{\theta}_1 \\ \dot{\theta}_2 \\ h_1(\theta_1,\theta_2,\dot{\theta}_1,\dot{\theta}_2) \\ h_2(\theta_1,\theta_2,\dot{\theta}_1,\dot{\theta}_2)  \end{bmatrix} $$

Where $\theta_1$ and $\theta_2$ are the angles of the pendula (see diagram) and dotted variables denote time derivatives.

Note $h_1$ and $h_2$ are the non-linear angular force functions, which we will learn via NDE methods.

## Problem statement

Your task is to:

 - Write a neural differential equation system for the double pendulum system.
 - Train this model on data produced by direct numerical solutions to the true double pendulum system.
 - Understand the training and behaviour of NDEs and investigate their behaviour for chaotic systems.

In [None]:
!pip install equinox optax diffrax
!git clone https://github.com/aidancrilly/AIMSLecture.git

import sys
sys.path.append('./AIMSLecture/2026/')

import jax
jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
import jax.nn as jnn
import diffrax
import time
import copy
import optax
import equinox as eqx

# Note this should be found via the git clone + sys.path.append above
from DoublePendulumSolver import DoublePendulum as DPS
from DoublePendulumSolver import g
from DoublePendulumSolver import ODE_kwargs as NDE_args

L = 1.0

### Training data creation

Here we create training data for the double pendulum problem, running Nsamples simulations for t between 0 and tmax. We start the pendula from rest, $\omega$ range = [0,0].

In [None]:
def generate_training_data(Nsamples, theta_range, omega_range, tmax, L = L, Nt = 1000, key = jax.random.PRNGKey(42)):
    key,subkey = jax.random.split(key)
    y0s = jax.random.uniform(subkey,shape=(Nsamples,4))
    y0s = y0s.at[:,:2].set(theta_range[0] + (theta_range[1]-theta_range[0])*y0s[:,:2])
    y0s = y0s.at[:,2:].set(omega_range[0] + (omega_range[1]-omega_range[0])*y0s[:,2:])

    ts = jnp.linspace(0.0,tmax,Nt)
    DP = DPS(ts=ts)

    init_cond = {
        'theta1' : y0s[:,0],
        'theta2' : y0s[:,1],
        'theta1dot' : y0s[:,2],
        'theta2dot' : y0s[:,3],
        }

    args = {'L' : L}

    ys = jax.jit(jax.vmap(DP.__call__,in_axes=[0,None]))(init_cond, args)
    return ts, ys

# No need to change these
theta_range = [-jnp.pi/2.0,jnp.pi/2.0]
omega_range = [0.0,0.0]
Nsamples = 1200
tmax = 10.0

ts, full_ys = generate_training_data(Nsamples, theta_range, omega_range, tmax)

These data contain chaotic trajectories (we shall discuss later), for training purposes we wish to remove these, to do this we do a crude cut of data based on the range of angles in the trajectories.

In [None]:
def split_ys(full_ys, theta_cut = jnp.pi):
    # Applying a crude filter for chaotic trajectories
    mask = jnp.any((abs(full_ys['theta1']) > theta_cut) | (abs(full_ys['theta2']) > theta_cut), axis=1)
    nonchaotic_ys = jax.tree.map(lambda m: m[~mask], full_ys)
    chaotic_ys = jax.tree.map(lambda m: m[mask], full_ys)
    return nonchaotic_ys, chaotic_ys

training_ys, chaotic_ys = split_ys(full_ys)

### Visualisation

The trajectories of the double pendula can be visualised in some interesting ways. For example, the histogram below.

These can be used to understand the dynamics, but also make very nice visuals!

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure(dpi=200,figsize=(6,2.5))
ax1 = fig.add_subplot(121)
plt.hist2d(training_ys['theta1'].flatten(),training_ys['theta2'].flatten(),
           bins=[jnp.linspace(-jnp.pi,jnp.pi,200),jnp.linspace(-jnp.pi,jnp.pi,200)],
           vmax=500,cmap='cubehelix')
plt.colorbar()
plt.xlabel(r'$\theta_1$')
plt.ylabel(r'$\theta_2$')
plt.title('Non-Chaotic')

ax2 = fig.add_subplot(122)
plt.hist2d(chaotic_ys['theta1'].flatten(),chaotic_ys['theta2'].flatten(),
           bins=[jnp.linspace(-jnp.pi,jnp.pi,200),jnp.linspace(-jnp.pi,jnp.pi,200)],
           vmax=10,cmap='cubehelix')
plt.colorbar()
plt.xlabel(r'$\theta_1$')
plt.ylabel(r'$\theta_2$')
plt.title('Chaotic')
fig.tight_layout()

### NDE implementation

Below we will define our Neural Differential Equation module using jax and libraries (most notably equinox and diffrax).

- [Diffrax Documentation](https://docs.kidger.site/diffrax/)
- [Equinox Documentation](https://docs.kidger.site/equinox/)

**We will use a Lagrangian network approach to deriving the equations of motion**

- [Euler-Lagrange equation](https://en.wikipedia.org/wiki/Euler%E2%80%93Lagrange_equation#Statement)

The Lagrangian approach is as follows,

Firstly, we define a scalar function called the "Lagrangian", $\mathcal{L}(\theta_1,\theta_2,\dot{\theta}_1,\dot{\theta}_2)$, which we will approximate using a scalar and a bilinear term:

$$ \mathcal{L}(\theta_1,\theta_2,\dot{\theta}_1,\dot{\theta}_2) = \begin{bmatrix} \dot{\theta}_1 \\ \dot{\theta}_2 \end{bmatrix}^T \mathbf{M}(\theta_1,\theta_2) \begin{bmatrix} \dot{\theta}_1 \\ \dot{\theta}_2 \end{bmatrix} + \mathcal{N}(\theta_1,\theta_2) $$

Where $M$ is a positive definite matrix. $N$ will be determined from a neural network and expression for $M$ is provided from theory.

Secondly, we can use AD to form the equations of motion from derivatives of the Lagrangian.

This involves two Jacobian matrices:

$$ J_{\dot{\theta}\dot{\theta}} = \nabla_{\dot{\theta}} \nabla_{\dot{\theta}}^T \mathcal{L} $$
$$ J_{\theta\dot{\theta}} = \nabla_{\theta} \nabla_{\dot{\theta}}^T \mathcal{L} $$

We then solve the following linear system for the required terms in our equations of motion (see introduction):

$$ J_{\dot{\theta}\dot{\theta}} \cdot \begin{bmatrix} h_1 \\ h_2 \end{bmatrix} = \begin{bmatrix} \frac{\partial \mathcal{L}}{\partial \theta_1} \\ \frac{\partial \mathcal{L}}{\partial \theta_2} \end{bmatrix} - J_{\theta\dot{\theta}} \cdot \begin{bmatrix} \dot{\theta_1} \\ \dot{\theta_2} \end{bmatrix} $$

Finally, we can use these equations of motion to solve for the trajectory of the system in time

In the code block below, we set up this approach to NDEs for the double pendulum system.

In [None]:
class NeuralODE_Lagrangian(eqx.Module):
    mlp_N: eqx.nn.MLP

    # Neural ODE based on Lagrangian mechanics for the double pendulum
    def __init__(self, width_size, depth, activation, *, key, **kwargs):
        super().__init__(**kwargs)

        # To be completed
        # Set up MLP
        # Hint: no final bias aids stability and training
        self.mlp_N =

        # Initialize the final linear layer to zero
        where = lambda m: m.layers[-1].weight
        self.mlp_N = eqx.tree_at(where, self.mlp_N, jnp.zeros_like(self.mlp_N.layers[-1].weight))

    def fourier_features(self, th):
        # Convert angles to fourier features
        # Embed periodicity into inputs explicity
        return jnp.array([jnp.cos(th[0]), jnp.sin(th[0]), jnp.cos(th[1]), jnp.sin(th[1])])

    def compute_mass_matrix(self, th):
        # Ensures M is positive definite
        x = jnp.cos(th[1]-th[0])
        M = L**2*jnp.array([[1.0,0.5*x],[0.5*x,0.5]])
        return M

    def Lagrangian(self, th, tdot):
        # Convert to fourier features for inputs
        x_N = self.fourier_features(th)

        # Add explicit gravity and length dependence
        # This helps the NN generalise
        N = g*L**2*self.mlp_N(x_N)

        M = self.compute_mass_matrix(th)

        return tdot.T @ (M @ tdot) + N

    def grad_L(self, th, thdot):
        """ Computes all the required gradients of the Lagrangian """
        # To be completed, define gradient functions here
        grad_L_theta =
        grad_L_thdot =

        # Compute required Jacobians
        J_ddot = jax.jacrev(grad_L_thdot,argnums=1)(th,thdot)
        J_dot = jax.jacrev(grad_L_thdot,argnums=0)(th,thdot)

        grad_L_theta = grad_L_theta(th,thdot)

        return grad_L_theta, J_ddot, J_dot

    def __call__(self, t, y, args):
        theta_dot_array = jnp.array([y['theta1dot'], y['theta2dot']])
        theta_array = jnp.array([y['theta1'], y['theta2']])

        grad_L_theta, J_ddot, J_dot = self.grad_L(theta_array, theta_dot_array)

        # To be completed
        # Compute Right Hand Side of equations of motion
        # RHS = dL/dtheta - J_{theta,theta_dot} @ theta_dot
        RHS =

        # Solve for angular accelerations
        hs = jnp.linalg.solve(J_ddot, RHS)

        return {
        'theta1' : y['theta1dot'],
        'theta2' : y['theta2dot'],
        'theta1dot' : hs[0],
        'theta2dot' : hs[1],
        }

class NeuralODE(eqx.Module):
    func: NeuralODE_Lagrangian

    def __init__(self, width_size, depth, activation, *, key, **kwargs):
        super().__init__(**kwargs)
        self.func = NeuralODE_Lagrangian(width_size, depth, activation, key=key)

    def __call__(self, ts, y0, args):
        """
        Similar to our examples in exercise, we set up diffeqsolve
        but now our ODETerm uses our NDE equinox Module.
        """
        # Set up and solve ODE
        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Tsit5(),
            t0=ts[0],
            t1=ts[-1],
            dt0=ts[1] - ts[0],
            y0=y0,
            args=args,
            stepsize_controller=diffrax.PIDController(rtol=args['rtol'], atol=args['atol']),
            saveat=diffrax.SaveAt(ts=ts),
            max_steps=int(1e6)
        )
        return solution.ys

### Training loop

Here we define a function to perform training on our NDE model based on input training data. This training function involves 3 key strategies:

1. Batching of NDE solves over training data
2. A stepwise learning rate schedule
3. A stepwise training data truncation schedule

In [None]:
from jax.flatten_util import ravel_pytree

def MSE_pytrees(pred,truth):
    return jnp.mean((jax.flatten_util.ravel_pytree(pred)[0]-jax.flatten_util.ravel_pytree(truth)[0])**2)

def train_DPNDE(ts,ys,args,
    model,
    optimiser,
    lr_strategy,
    steps_strategy,
    length_strategy,
    grad_clip,
    batch_size,
    print_every=50,
    reset_state=True,
):
    @eqx.filter_value_and_grad
    def grad_loss(model, ti, yi):
        """
        Compute loss and gradients
        Total loss is MSE with respect to both trajectories and equations of motion
        """
        # Handle batch
        batched_model = jax.vmap(model,in_axes=(None,0,None))
        y0 = jax.tree.map(lambda v : v[:,0], yi)
        y_pred = batched_model(ti, y0, args)
        # Error with respect to trajectories
        MSE = MSE_pytrees(y_pred,yi)

        # Batched Equations of Motion (EoM)
        EoM_pred = lambda y : jax.vmap(model.func,in_axes=(0,0,None))(ti,y,args)
        batched_EoM_pred = jax.vmap(EoM_pred,in_axes=0)(y_pred)
        EoM_truth = lambda y : jax.vmap(DPS.DP_equations,in_axes=(0,0,None))(ti,y,{'L' : L})
        batched_EoM_truth = jax.vmap(EoM_truth,in_axes=0)(y_pred)
        # Error with respect to equations of motion
        MEoM = MSE_pytrees(batched_EoM_pred,batched_EoM_truth)

        return MSE + MEoM

    @eqx.filter_jit
    def make_step(ti, yi, model, opt_state):
        loss, grads = grad_loss(model, ti, yi)
        flat_grad, _ = ravel_pytree(grads)
        grad_norm = jnp.linalg.norm(flat_grad)

        # To be completed
        # Compute parameter updates and apply to model
        # See NDE example in exercise
        updates, opt_state =
        model =

        return loss, grad_norm, model, opt_state

    history = {'loss' : [], 'grad_norm' : []}
    count = 0
    batch_key = jax.random.PRNGKey(5445)

    # Set up optimiser, gradient clipping and learning rate schedule
    if(len(lr_strategy) > 1):
        lr_schedule = optax.schedules.piecewise_constant_schedule(lr_strategy[0],{k:v for k,v in zip(steps_strategy[:-1],lr_strategy[1:])})
        optim = optax.chain(
            optax.clip_by_global_norm(grad_clip),
            optimiser(learning_rate=lr_schedule)
        )
    else:
        optim = optax.chain(
            optax.clip_by_global_norm(grad_clip),
            optimiser(learning_rate=lr_strategy[0])
        )

    # Set up optimiser state
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

    length_size = ts.shape[0]
    length_prev = 0
    for steps, length in zip(steps_strategy, length_strategy):
        if reset_state and length_prev != length:
            opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
        length_prev = 1.0*length

        # To be completed
        # Find index respective to length
        Nts =

        _ts = ts[:Nts]
        for step in range(steps):
            batch_key, _ = jax.random.split(batch_key)
            bidx = jax.random.choice(batch_key, jnp.arange(ys['theta1'].shape[0]), shape=(batch_size,), replace=False)
            _ys = jax.tree.map(lambda v : v[bidx,:Nts], ys)

            count += 1
            start = time.time()
            loss, grad_norm, model, opt_state = make_step(_ts, _ys, model, opt_state)
            history['loss'].append(loss)
            history['grad_norm'].append(grad_norm)
            end = time.time()
            if (step % print_every) == 0 or step == steps - 1:
                print(f"Step: {step}, Loss: {loss}, Computation time: {end - start}")

    return model, history

Here we can define our base model and learning strategy:

In [None]:
# Model hyperparameters
# To be chosen, only a small network is needed, depth ~ 3
WIDTH_SIZE =
DEPTH =
# Consider smooth activations, tanh, swish, gelu, etc.
ACT_FUNC =

base_model = NeuralODE(width_size=WIDTH_SIZE, depth=DEPTH, activation=ACT_FUNC , key=jax.random.PRNGKey(420))

trained_model = copy.deepcopy(base_model)

In [None]:
# Training hyperparameters
optimiser = optax.adam
# Gradient clipping value, learning rates, steps and lengths to be chosen
# To check gradient clipping value, check grad_norm returned during training (in history)
grad_clip = 1e30 # Arbitrarily large value

lr_strategy = [5e-3,2.5e-3,1e-3]
steps_strategy = [300,500,500]
length_strategy = [0.5,0.75,1.0]
batch_size = 32

## Problems

- Try training a model on a single training data example
    - Consider: How the model trains, what complexity of network is needed, short term vs long term prediction
    - What happens if you try to train on full time series (length_strategy = 1.0) from epoch 0?
- Perform a test-train (e.g. roughly 80/20) split and train the model and evaluate its performance.
- Can you improve this model via modifying the hyper parameters?

_Warning: Training can take a long time for large training sets (~100s of ms per step for batch_size=32), experiment on smaller sections of data to understand NDE behaviours before attempting more extensive training._

In [None]:
split_idx = 1

training_data = jax.tree.map(lambda v : v[:split_idx,:], training_ys)

trained_model, history = train_DPNDE(
    ts,training_data,NDE_args,trained_model,
    optimiser,lr_strategy,steps_strategy,length_strategy,
    grad_clip,batch_size,print_every=50,reset_state=False
    )

### Saving your model

Code below can be used to save/serialise your trained model

In [None]:
eqx.tree_serialise_leaves(f"DPNDE_W{WIDTH_SIZE}_D{DEPTH}_act{str(ACT_FUNC).split(' ')[-3]}_trained.eqx", trained_model)

### Plotting loss history

In [None]:
fig = plt.figure(dpi=200)
ax1 = fig.add_subplot(211)
ax2 = fig.add_subplot(212)

ax1.semilogy(history['loss'])
smoothed_loss = jnp.convolve(jnp.array(history['loss']),jnp.ones(batch_size)/batch_size,mode='valid')
ax1.semilogy(smoothed_loss)
ax2.semilogy(history['grad_norm'])
ax1.set_ylabel('Loss')
ax2.set_ylabel('Gradient Norm')
ax2.set_xlabel('Training Step')

Use the below to compute a MSE on a trained model using test data

In [None]:
def compute_loss_from_trained_model(model,ts,test_data):
    batched_model = jax.vmap(model,in_axes=(None,0,None))
    y0 = jax.tree.map(lambda v : v[:,0], test_data)
    y_pred = batched_model(ts,y0,NDE_args)
    return MSE_pytrees(y_pred,test_data)

test_data = jax.tree.map(lambda v : v[split_idx:,:], training_ys)

print(compute_loss_from_trained_model(trained_model,ts,test_data))

Use below to plot trajectories for comparison

In [None]:
test_data = jax.tree.map(lambda v : v[2:3,:], training_ys)

plt.plot(test_data['theta1'].T)
plt.plot(test_data['theta2'].T)
y0 = jax.tree.map(lambda v : v[0,0], test_data)
y_pred = trained_model(ts,y0,NDE_args)

plt.plot(y_pred['theta1'],c='k',ls='--')
plt.plot(y_pred['theta2'],c='k',ls=':')

plt.ylim(-2*jnp.pi,2*jnp.pi)

### Loading pre-trained model

Code above can save your model such that you can re-load it later. Here is how you might load from the saved model file.

```python

model_original = NeuralODE(in_size=IN_SIZE, out_size=OUT_SIZE, width_size=WIDTH_SIZE, depth=DEPTH, activation=ACT_FUNC , key=jax.random.PRNGKey(420))
# Note that all sizes and depths must match between model_original and model_loaded
model_loaded = eqx.tree_deserialise_leaves("<model_to_be_loaded>.eqx", model_original)

```

### Chaotic trajectories

For some values of initial conditions, the double pendulum can show chaotic trajectories. In other words, small perturbations to the initial conditions lead to large variation in trajectories. Let us look at a specific example:

In [None]:
ts_chaos = jnp.linspace(0.0,20.0,2000)

DP_chaos = DPS(ts=ts_chaos)

args = {'L' : 1.0}

init_cond = {
    'theta1' : 1.0*jnp.pi/3.0,
    'theta2' : 1.0*jnp.pi/3.0,
    'theta1dot' : 0.0,
    'theta2dot' : 0.0,
    }

init_cond_perturb = {k : (1+1e-5)*v for k,v in init_cond.items()}

ys_nochaos_1 = DP_chaos(init_cond, args)
ys_nochaos_2 = DP_chaos(init_cond_perturb, args)

init_cond = {
    'theta1' : 2.0*jnp.pi/3.0,
    'theta2' : 1.0*jnp.pi/3.0,
    'theta1dot' : 0.0,
    'theta2dot' : 0.0,
    }

init_cond_perturb = {k : (1+1e-5)*v for k,v in init_cond.items()}

ys_chaos_1 = DP_chaos(init_cond, args)
ys_chaos_2 = DP_chaos(init_cond_perturb, args)

fig = plt.figure(dpi=200)
ax1 = fig.add_subplot(211)
ax2 = fig.add_subplot(212,sharex=ax1)

ax1t = ax1.twinx()
ax2t = ax2.twinx()

ax1.plot(ts_chaos,ys_nochaos_1['theta1'],label='Theta 1 - Unperturbed',color='blue')
ax1.plot(ts_chaos,ys_nochaos_2['theta1'],label='Theta 1 - Perturbed',color='orange',ls='--')

ax1t.plot(ts_chaos,ys_nochaos_1['theta2'],label='Theta 2 - Unperturbed',color='red')
ax1t.plot(ts_chaos,ys_nochaos_2['theta2'],label='Theta 2 - Perturbed',color='green',ls='--')

ax2.plot(ts_chaos,ys_chaos_1['theta1'],label='Theta 1 - Unperturbed',color='blue')
ax2.plot(ts_chaos,ys_chaos_2['theta1'],label='Theta 1 - Perturbed',color='orange',ls='--')

ax2t.plot(ts_chaos,ys_chaos_1['theta2'],label='Theta 2 - Unperturbed',color='red')
ax2t.plot(ts_chaos,ys_chaos_2['theta2'],label='Theta 2 - Perturbed',color='green',ls='--')

ax1.set_title('Non-Chaotic Trajectories')
ax2.set_title('Chaotic Trajectories')

ax1.set_ylabel(r'$\theta_1$')
ax2.set_ylabel(r'$\theta_1$')

ax1t.set_ylabel(r'$\theta_2$')
ax2t.set_ylabel(r'$\theta_2$')

ax1.set_xlim([ts_chaos[0],ts_chaos[-1]])
fig.tight_layout()

## Problems:

- How does your trained model perform on a chaotic trajectory (select a number from the chaotic_ys paritioned at the beginning of the project)? Why is the performance likely to be poor?
- What issue will likely arise if a model is only trained on a short timeframe (e.g. [0,10]) given what is shown in the example above?


## Extension problems:

- Include the chaotic trajectories in your training data and see how the NDE model trains.
- What implications does this have more ML models of chaotic systems e.g. weather?