# How To: Neural ODE

In this guide we will cover how to set up a Neural ODE problem using the PEtab SciML format and utility functions from the `petab_sciml` Python library.  We assume some familiarity with the getting started tutorial, which examines an entire PEtab SciML problem, while this guide focuses on the parts that are relevant to the Neural ODE use case. 

In the Neural ODE case, the whole right-hand-side of the ODE model is replaced with a neural network. As an example, we use the Lotka-Voltera system,

$$\frac{\mathrm{d} \text{prey}}{\mathrm{d} t} = \alpha \cdot \text{prey} - \beta \cdot \text{prey} \cdot \text{predator}$$

$$\frac{\mathrm{d} \text{predator}}{\mathrm{d} t} = \gamma \cdot \text{prey} \cdot \text{predator} - \delta \cdot \text{predator}$$

which simply becomes,

$$\frac{\mathrm{d} \text{prey}}{\mathrm{d} t} = \text{NN}(\text{prey}, \text{predator})[0]$$

$$\frac{\mathrm{d} \text{predator}}{\mathrm{d} t} = \text{NN}(\text{prey}, \text{predator})[1]$$

to configure it as a Neural ODE problem.

## Defining the network architecture

In [1]:
from petab_sciml.standard.nn_model import Input, NNModel, NNModelStandard
import torch
from torch import nn
import torch.nn.functional as F

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = torch.nn.Linear(2, 10)
        self.layer2 = torch.nn.Linear(10, 10)
        self.layer3 = torch.nn.Linear(10, 2)

    def forward(self, net_input):
        x = self.layer1(net_input)
        x = F.tanh(x)
        x = self.layer2(x)
        x = F.tanh(x)
        x = self.layer3(x)
        return x

net1 = NeuralNetwork()
nn_model1 = NNModel.from_pytorch_module(
    module=net1, nn_model_id="net1", inputs=[Input(input_id="input0")]
)
NNModelStandard.save_data(
    data=nn_model1, filename="net1.yaml"
)

The network architecture in this example is kept simple for demonstration purposes.  Refer to the page on supported layers and activation functions for more inspiration, but note that PEtab SciML and its importers currently only support networks with vector outputs. 

## Generating the PEtab files

The PEtab SciML Python package provides utility functions to generate the model and PEtab files for a neural ODE case.  The names of the species in the ODE system are required to generate the model. The utility functions will generate hybridization, mapping and parameter files.

In [3]:
from petab_sciml.problem_utils.neural_ode import (
    create_neural_ode, 
    create_neural_ode_problem
)

create_neural_ode(["prey", "predator"], model_filename="lv.xml")

In order to completely define the PEtab problem, the measurement, observable and array input files need to be supplied by the user. There is then a utility function to generate the `problem.yaml` file and reference all the PEtab files in it. Example files are included in the docs as a demonstration.

In [4]:
create_neural_ode_problem(
    "lv.xml", 
    "measurements.tsv", 
    "observables.tsv", 
    "net1.yaml", 
    ["net1_ps.hdf5"]
)

## Loading the PEtab problem

In [None]:
from amici.petab import import_petab_problem
from amici.jax import (
    JAXProblem,
    run_simulations,
)
from petab.v2 import Problem

# Create the PEtab problem
petab_problem = Problem.from_yaml("problem.yaml")

# Create AMICI model for the petab problem
jax_model = import_petab_problem(
    petab_problem,
    model_output_dir="model",
    compile_=True,
    jax=True
)

# Create the JAXProblem - wrapper for the AMICI model
jax_problem = JAXProblem(jax_model, petab_problem)

The hybridization and mapping tables show us how the neural network inputs and outputs are mapped to the model.

In [9]:
jax_problem._petab_problem.hybridization_df

Unnamed: 0_level_0,targetValue
targetId,Unnamed: 1_level_1
net1_input0,prey
net1_input1,predator
prey_param,net1_output0
predator_param,net1_output1


In [10]:
petab_problem.mapping_df

Unnamed: 0_level_0,modelEntityId
petabEntityId,Unnamed: 1_level_1
net1_input0,net1.inputs[0][0]
net1_input1,net1.inputs[0][1]
net1_output0,net1.outputs[0][0]
net1_output1,net1.outputs[0][1]
net1_ps,net1.parameters


The inputs to the neural network are given by the ``prey`` and ``predator`` species.  $\alpha$, the ``prey`` amount over time as defined by the SBML model, is given by the first output of the network. $\delta$, the ``predator`` amount over time, is given by the second output of the network.

It is also worth showing that the parameter table only has the network parameters defined in it. Unlike previous examples, there are no other parameters to be estimated in the problem. We will optimise the network parameters and then the outputs of the network will give us the solution to our ODEs.

In [11]:
petab_problem.parameter_df

Unnamed: 0_level_0,parameterScale,lowerBound,upperBound,nominalValue,estimate
parameterId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
net1_ps,lin,-inf,inf,,1
