Our task now is to use a neural network $f_{\theta}$ to approximate the simulator $\mathcal{P}$ created with the Runge-Kutta (RK4) method. The properties and functionality of the simulator can be found in the `lorenz63.py` file. We will use a Multilayer Perceptron (MLP) neural network. The main task is to learn the parameter(s) $\theta$ that approxiates $\mathcal{P}$: $$f_{\theta} \approx  \mathcal{P}$$ The idea is to try to emulate/approximate the Runge-Kutta scheme using a neural network. In RK4, we used previous states to explicitly compute the next state. Now, blind to the scheme, we will try to learn the parameters $\theta$ of our neural network from the trajectory by having the NN predict the next state given the current state and learned parameters. I will try both cases: with and without noise in the Lorenz simulator. 

For practitioners, what I am trying to do is learning the discrete flow map using an MLP based emulator. Lorenz '63 is a first order autonomous ODE (no mention of time in the equations).

In [6]:
# import libraries 
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax.tree_util as jtu

# deep learning framework, we'll use their MLP architechtrue 
import equinox as eqx

# gradient processing library, for adam optimizer 
import optax

# training monitoring 
from tqdm import tqdm

# generate data
from Study.lorenz63JAX import rollout


In [None]:
u_0_set = jax.random.normal(jax.random.PRNGKey(0), (9, 3))
lorenzStepper = LorenzSimulatorK4()
iterations = 7000

rollout_func = rollout(lorenzStepper, iterations, include_init=False)
dataset = jax.vmap(rollout_func)(u_0_set)

There needs to be a more robust way to check this but the initial phase of the trajectory is quite turbulent according to my experience, susceptible to initial conditions. So we will let go off the first 2000 datapoints. 

In [None]:
dataset = jax.vmap(rollout_func)(u_0_set)[:, 2000:]

# divide up into training and testing (we use the first 6 trajectories for training and leave the remaining for testing)
training_set, testing_set = dataset[:6], dataset[6:]

$$\renewcommand{\cP}{\mathcal{P}}$$

### Online learning

Recall how the simulator/stepper works $\cP(u_n) = u_{n+1}$. Our emulator is also a stepper function. $$f_{\theta}(u_n) = \hat{u_n}$$ Because of our dataset above, we already have $$\text{Dataset}: \{ u_0, \cP(u_0),\cdots, \cP(u_{n-1})\}$$ Our MLP based emulator will first make a prediction on $u_0$ and get $f_\theta(u_0) = \hat{u_1}$. Next, it will compute the error $$ e = ||\hat{u_1} -   \cP(u_0)||^2/n = ||\hat{u_1} -  u_1||^2/n$$ Once we have the error, we can gradient descent $\theta$ to minimize the error. Once that is done we have a new theta $$\theta \text{ gradient descent} \to \hat{\theta}$$ Now that we have a theta (which is hopefully better) we can do prediction on $u_1$ and get $f_\theta(u_1) = \hat{u_2}$ and do the whole thing again. Now note that this is technically a time series so we have a create a moving window that takes an input $u_n$ to $f_{theta}$ given the current best parameter $\theta$ and check the prediction $f_{theta}(u_n) = \hat{u_{n+1}}$ against $u_{n+1}$. So each window must have two data points from our training set $(u_n, u_{n+1})$ 

A few things I will need to keep in mind:

- Normalization: scaling with mean 0, variance 1
- Batching: instead of doing the process like data is arriving one at a time, for the sake of time, we can batch the whole thing and run one gradient step. 
- A residual network that learns the change: $$u_{n+1} = u_n + \Delta t \cdot f_\theta(u_n)$$