# How To: Deep Mechanistic Models

This is the case where the outputs from a neural network are used as inputs to the ODE.  In this case, using the Lokta-Voltera system as an example, the output from the network replaces the parameter $\gamma$.

$$\frac{\mathrm{d} \text{prey}}{\mathrm{d} t} = \alpha \cdot \text{prey} - \beta \cdot \text{prey} \cdot \text{predator}$$

$$\frac{\mathrm{d} \text{predator}}{\mathrm{d} t} = \text{NN}[0] \cdot \text{prey} \cdot \text{predator} - \delta \cdot \text{predator}$$

This example also demonstrates how array input data to the neural network for multiple conditions is handled.

## Loading the PEtab problem

Let's load the PEtab problem, build our model and define the overall hybrid problem as a ``JAXProblem``.

In [None]:
from amici.petab import import_petab_problem
from amici.jax import (
    JAXProblem,
    run_simulations,
)
from petab.v2 import Problem

# Create the PEtab problem
petab_problem = Problem.from_yaml("problem.yaml")

# Create AMICI model for the petab problem
jax_model = import_petab_problem(
    petab_problem,
    model_output_dir="model",
    compile_=True,
    jax=True
)

# Create the JAXProblem - wrapper for the AMICI model
jax_problem = JAXProblem(jax_model, petab_problem)

By looking at the hybridization, parameters and mapping tables we can see how this deep mechanistic modelling problem has been defined.

In [3]:
jax_problem._petab_problem.hybridization_df

Unnamed: 0_level_0,targetValue
targetId,Unnamed: 1_level_1
gamma,net3_output1


The ``gamma`` species from the model is mapped to a parameter with PEtab identifier ``net3_output1``. Looking at the mapping table, we can see that that PEtab id is defined as the first output from the neural network.

In [4]:
petab_problem.mapping_df

Unnamed: 0_level_0,modelEntityId
petabEntityId,Unnamed: 1_level_1
input0,net3.inputs[0]
net3_output1,net3.outputs[0][0]
net3_ps,net3.parameters


Finally, in the parameters table, the network parameters are listed with infinite bounds and no nominal value, so that they are free to be optimised. Also note that $\gamma$ does not appear in the parameters table, because we will set it with the neural network instead.

In [5]:
petab_problem.parameter_df

Unnamed: 0_level_0,parameterScale,lowerBound,upperBound,nominalValue,estimate
parameterId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
alpha,lin,0.0,15.0,1.3,1
delta,lin,0.0,15.0,1.8,1
beta,lin,0.0,15.0,0.9,1
net3_ps,lin,-inf,inf,,1


### Array input data

The PEtab problem specified under ``extensions_config`` that there were two files of array data to be added to the problem:
- ``net3_ps.hdf5`` to set the values of the network parameters
- ``net3_input2.hdf5`` to provide the inputs into the network

In [12]:
import json # for pretty printing only

print(json.dumps(petab_problem.extensions_config, indent=4))

{
    "sciml": {
        "array_files": [
            "net3_ps.hdf5",
            "net3_input2.hdf5"
        ],
        "hybridization_files": [
            "hybridization.tsv"
        ],
        "neural_nets": {
            "net3": {
                "location": "net3.yaml",
                "static": true,
                "format": "YAML"
            }
        }
    }
}


The nested structure of the neural network input file can be seen below. The input data to the neural network is under ``inputs/input0`` and data for two conditions is specified. The keys for the different conditions match those defined in the conditions table.

In [13]:
import h5py

# Convenience function to show the nested structure of an HDF5 file
def show_h5_struct(file):
    file.visit(lambda x: print("  " * (len(x.split("/")) - 1), x.split("/")[-1]))

file = h5py.File("net3_input2.hdf5")
show_h5_struct(file)

 inputs
   input0
     cond1
     cond2
 metadata
   perm


In [14]:
petab_problem.condition_df

cond1
cond2


Also note the ``static: true`` setting in the neural network definition of the extensions config. This means the input to the neural network is not expected to depend on the model. This keyword will indicate to PEtab SciML importers that the network precedes the ODE, as opposed to being inside it (i.e. a UDE) or in one of the observable formulae.

### Network Architecture

The PyTorch snippet below shows how a network architecture would be defined and exported to YAML format using PEtab SciML.  The predefined YAML file is also provided in the PEtab SciML repo for completeness.  We have used a convolution architecture here to indicate how the DMM problem set up could enable inclusion of information from high dimensional inputs in the mechanistic model. 

In [15]:
from petab_sciml.standard.nn_model import Input, NNModel, NNModelStandard
import torch
from torch import nn
import torch.nn.functional as F

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = torch.nn.Conv2d(3, 1, (5, 5), stride=(1, 1), padding=(0, 0), dilation=(1, 1))
        self.layer2 = torch.nn.Flatten()
        self.layer3 = torch.nn.Linear(36, 1)

    def forward(self, net_input):
        x = self.layer1(net_input)
        x = self.layer2(x)
        x = self.layer3(x)
        x = F.relu(x)
        return x

net1 = NeuralNetwork()
nn_model1 = NNModel.from_pytorch_module(
    module=net1, nn_model_id="net3", inputs=[Input(input_id="input0")]
)
NNModelStandard.save_data(
    data=nn_model1, filename="net3_from_pytorch.yaml"
)

### Network inputs as parameters

An alternative to defining network inputs as array data in files, is to define them in the parameters table, with appropriate bounds and nominal values. 

| parameterId     | parameterScale | lowerBound | upperBound | nominalValue | estimate |
|-----------------|----------------|------------|------------|--------------|----------|
| net1_input_pre1 | lin            | -inf       | inf        | 1            | 0        |
| net1_input_pre2 | lin            | -inf       | inf        | 1            | 0        |
|                 |                |            |            |              |          |

And the corresponding mapping table to define the model entities for those PEtab identifiers.

| petabEntityId   | modelEntityId     |
|-----------------|-------------------|
| net1_input_pre1 | net1.inputs[0][0] |
| net1_input_pre2 | net1.inputs[0][1] |