# The AdoptODE Cookbook
## 0. Install AdoptODE and JAX

Make sure JAX and if you want to use GPUs a supported CUDA driver is installed, as well as AdoptODE and its dependencies. An installation guide is provided in the git-repository, https://gitlab.gwdg.de/sherzog3/adoptODE.git.

## 1. Define your System
Our example system is $\frac{d}{dt} pop=a\cdot pop + b$, where $pop$ is some scalar population and $a$ and $b$ are the parameters we want to find. We assume the initial population, $a$ and $b$ to be bounded below by zero and above by some maximum specified in $\texttt{kwargs\_sys}$.

In [1]:
import numpy as np
import jax.numpy as jnp
from jax import jit
def define_system(**kwargs_sys):
    p_max = kwargs_sys['p_max']
    a_max = kwargs_sys['a_max']
    b_max = kwargs_sys['b_max']
    
    def gen_y0():
        ini_pop = np.random.rand()*p_max
        return {'population':ini_pop}
    
    def gen_params():
        a = np.random.rand()*a_max
        b = np.random.rand()*b_max
        return {'a':a, 'b':b}, {}, {}
        
    @jit
    def eom(y, t, params, iparams, exparams):
        pop = y['population']
        a, b = params['a'], params['b']
        return {'population':a*pop+b}

    @jit
    def loss(ys, params, iparams, 
                    exparams, targets):
        pop = ys['population']
        t_pop = targets['population']
        return jnp.mean((pop-t_pop)**2)

    return eom, loss, gen_params, gen_y0, {}

The second and third dictionary of $\texttt{gen\_params}$ are $\texttt{iparams}$ and $\texttt{exparams}$ we do not have in this simple example. The first two functions can be arbitrary, the $\texttt{eom}$ and $\texttt{loss}$ functions have to be implemented using the jax libraries.

## 2. Set up a simulation
To set up a simulation we define the dictionaries $\texttt{kwargs\_sys}$ and $\texttt{kwargs\_NODE}$ as well as the times $\texttt{t\_evals}$ at which we assume to observe our system. The keyword $\texttt{N\_sys}$ gives the number of copies in terms of multi-experiment fitting, here we consider only one system.

In [2]:
from adoptODE import simple_simulation, train_adoptODE
kwargs_sys = {'p_max': 2,
              'a_max': 1,
              'b_max': 3,
              'N_sys': 1}
kwargs_adoptODE = {'lr':3e-2, 'epochs':200}
t_evals = np.linspace(0,5,10)
dataset = simple_simulation(define_system,
                                t_evals,
                                kwargs_sys,
                                kwargs_adoptODE)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In real-life applications, these simulations not only help as an easy test environment, but also to test the reliability of parameter recovery! The simulation automatically generated some parameters, and also a (wrong) initial guess for the parameter recovery, both based on the previously define $\texttt{gen\_params}$ function:

In [3]:
print('The true parameters used to generate the data: ', dataset.params)
print('The inial gues of parameters for the recovery: ', dataset.params_train )

The true parameters used to generate the data:  {'a': 0.8683463889134698, 'b': 0.3448656189148058}
The inial gues of parameters for the recovery:  {'a': 0.7324235064410349, 'b': 1.935051710240188}


## 3. Train a simulation
The easy following command trains our simulation and prints the true params in comparison to the found ones:

In [4]:
_ = train_adoptODE(dataset)
print('True params: ', dataset.params)
print('Found params: ', dataset.params_train)

Epoch 000:  Loss: 4.6e+02,  Params Err.: 1.6e+00, y0 error: 0.0e+00, Params Norm: 2.1e+00, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 020:  Loss: 1.7e+01,  Params Err.: 1.4e+00, y0 error: 0.0e+00, Params Norm: 1.9e+00, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 040:  Loss: 1.1e+01,  Params Err.: 1.1e+00, y0 error: 0.0e+00, Params Norm: 1.6e+00, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 060:  Loss: 4.8e+00,  Params Err.: 6.8e-01, y0 error: 0.0e+00, Params Norm: 1.2e+00, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 080:  Loss: 2.0e+00,  Params Err.: 3.9e-01, y0 error: 0.0e+00, Params Norm: 1.1e+00, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 100:  Loss: 9.8e-01,  Params Err.: 2.5e-01, y0 error: 0.0e+00, Params Norm: 1.0e+00, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 120:  Loss: 5.5e-01,  Params Err.: 1.5e-01, y0 error: 0.0e+00, Params Norm: 9.7e-01, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 140:  Loss: 2.8e-01,  Params

For more accurate results, try to manipulate the learing rate or the number of epochs!

## 4. Including Data
To include data, we bring it in the same form as the shape of the state given by $\texttt{gen\_y0()}$, but with two additional leading axes. The first counts the different experiments, and has length one here, the second runs over time points and has the same length as $\texttt{t\_evals}$.

Training can now be performed as before, with the difference that no error of the parameters can be given as the original parameters are unknown:

In [47]:
from adoptODE import dataset_adoptODE
data = np.array([ 0.86, 1.66, 2.56, 3.59, 4.75, 6.08, 7.58, 9.28, 11.21, 13.40]) # Observation of population, shape (10,)
targets = {'population':data.reshape((1,10))}
dataset2 = dataset_adoptODE(define_system,
                                targets,
                                t_evals,
                                kwargs_sys,
                                kwargs_adoptODE)

In [48]:
_ = train_adoptODE(dataset2)
print('Found params: ', dataset2.params_train)

Epoch 000:  Loss: 1.9e+02,  Params Err.: nan, y0 error: nan, Params Norm: 1.6e+00, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 020:  Loss: 1.0e-01,  Params Err.: nan, y0 error: nan, Params Norm: 1.2e+00, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 040:  Loss: 1.1e-03,  Params Err.: nan, y0 error: nan, Params Norm: 1.2e+00, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 060:  Loss: 1.1e-03,  Params Err.: nan, y0 error: nan, Params Norm: 1.2e+00, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 080:  Loss: 3.9e-03,  Params Err.: nan, y0 error: nan, Params Norm: 1.2e+00, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 100:  Loss: 9.6e-03,  Params Err.: nan, y0 error: nan, Params Norm: 1.2e+00, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 120:  Loss: 1.9e-03,  Params Err.: nan, y0 error: nan, Params Norm: 1.2e+00, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 140:  Loss: 9.8e-03,  Params Err.: nan, y0 error: nan, Params Norm: 1.2e+00, iParams

We hope this notebook provides a helpful starting point. A number of more advanced notebooks, showing the implementation of the problems discussed in the paper (DOI) are available in the git repository!