# Getting Started

## Overview

This notebook outlines how to set up and run a PEtab SciML simulation, using [AMICI](https://amici.readthedocs.io/en/latest/index.html). This guide will highlight key aspects of the PEtab SciML format specification that are relevant for implementing these types of simulations.  Some familiarity with the PEtab format is assumed, see [PEtab](https://amici.readthedocs.io/en/latest/index.html) for a refresher.

The environment and example petab files to run this notebook are provided in the [petab sciml repo](https://github.com/PEtab-dev/petab_sciml).

## Model Specification

As an example case we model the Lotka-Voltera system, which is described by:

$$\frac{\mathrm{d} \text{prey}}{\mathrm{d} t} = \alpha \cdot \text{prey} - \beta \cdot \text{prey} \cdot \text{predator}$$

$$\frac{\mathrm{d} \text{predator}}{\mathrm{d} t} = \gamma \cdot \text{prey} \cdot \text{predator} - \delta \cdot \text{predator}$$

We will replace the two interactive terms ($\beta$ and $\gamma$) with outputs from a neural network. 
The $\text{prey}$ and $\text{predator}$ variables in the model will be used as the inputs to this network.

$$\frac{\mathrm{d} \text{prey}}{\mathrm{d} t} = \alpha \cdot \text{prey} - \text{NN}(\text{prey}, \text{predator})[1]$$

$$\frac{\mathrm{d} \text{predator}}{\mathrm{d} t} = \text{NN}(\text{prey}, \text{predator})[2] - \delta \cdot \text{predator}$$

## Environment

Support for petab sciml models is under active development. A ``requirements.txt`` file is provided with these docs which checks out branches of the external dependencies where support has been implemented. The key dependencies to note are:
- Petab: https://github.com/PEtab-dev/libpetab-python/tree/sciml
- AMICI: https://github.com/AMICI-dev/AMICI/tree/jax_sciml

## Loading the PEtab problem

Our first step is to load the petab problem.  This is done using the petab python library.

In [74]:
from petab.v2 import Problem
from yaml import safe_load

with open("problem.yaml") as f:
    petab_yaml = safe_load(f)

petab_yaml["format_version"] = "2.0.0"
petab_problem = Problem.from_yaml(petab_yaml)

The petab problem now contains some key sciml elements. Some are new elements of PEtab v2.0.0 and some are expansions of existing tables. We will look at these in more depth now.

### Measurements, Observables and Conditions tables

These tables describe the measurements collected, which experimental conditions they were collected under and links those measured entities to the model. In this case we have two observables that have been measured ``prey_o`` and ``predator_o``.  Refer to the [PEtab tutorial](https://petab.readthedocs.io/en/latest/v1/tutorial/tutorial.html) and [specification](https://petab.readthedocs.io/en/latest/v1/documentation_data_format.html) for more details about these files.

In [80]:
petab_problem.measurement_df

Unnamed: 0,observableId,simulationConditionId,measurement,time
0,prey_o,cond1,0.173017,1.0
1,prey_o,cond1,0.489177,2.0
2,prey_o,cond1,1.643996,3.0
3,prey_o,cond1,5.451963,4.0
4,prey_o,cond1,2.977522,5.0
5,prey_o,cond1,0.181663,6.0
6,prey_o,cond1,0.348112,7.0
7,prey_o,cond1,0.937919,8.0
8,prey_o,cond1,3.11324,9.0
9,prey_o,cond1,8.863933,10.0


The observables table links the model entities to the measured values. A noise model can also be defined in this table.

In [83]:
petab_problem.observable_df

Unnamed: 0_level_0,observableFormula,noiseFormula,observableTransformation,noiseDistribution
observableId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
prey_o,prey,0.05,lin,normal
predator_o,predator,0.05,lin,normal


The conditions table in intended to fully describe the experimental conditions the measurements were collected under. It can have any number of columns, but in this example there is only one experimental condition and no additional information about it.

In [81]:
petab_problem.condition_df

cond1


### Mapping file

The mapping table is new to PEtab v2.0.0.  Its purpose is to define PEtab compatible identifiers for model entities that would otherwise not be valid PEtab. Inputs, outputs and parameters of our neural network are all examples of model entities that would not be [valid PEtab identifiers](https://petab.readthedocs.io/en/latest/v2/documentation_data_format.html#v2-identifiers), so we define PEtab ids for these in this file.

Take the first row for example.  The valid PEtab identifier ``net1_input1`` is mapped to the model id ``net1.inputs[0][0]``. This model id denotes the first index of the first input to the neural network. 

``net1.parameters`` refers to all the trainable parameters in the network i.e. weights and biases of all the layers. If we wanted to refer to more specific parameters within the network we could write, for example ``net1.parameters[layer1].weight`` to select the weights of the layer with id ``layer1``.

In [79]:
petab_problem.mapping_df

Unnamed: 0_level_0,modelEntityId
petabEntityId,Unnamed: 1_level_1
net1_input1,net1.inputs[0][0]
net1_input2,net1.inputs[0][1]
net1_output1,net1.outputs[0][0]
net1_output2,net1.outputs[0][1]
net1_ps,net1.parameters


### Parameters file

This file contains information on all model parameters, which in our case, includes parameters of the neural network.  The file specifies the scale, bounds and nominal values of all the parameters. We want our neural network parameters to be freely optimised during training, so we set the bounds to ``[-inf, inf]`` and leave the nominal value blank (translates to a NaN when loaded into the petab problem).

In [62]:
petab_problem.parameter_df

Unnamed: 0_level_0,parameterScale,lowerBound,upperBound,nominalValue,estimate
parameterId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
alpha,lin,0.0,15.0,1.3,1
delta,lin,0.0,15.0,1.8,1
net1_ps,lin,-inf,inf,,1


## Building the JAX model

At this point we can build the model from the petab problem in order to see these mappings and parameters in action, and to demonstrate how the neural network inputs and outputs are inserted into the ODE system.

In [None]:
from amici.petab import import_petab_problem
from amici.jax import (
    JAXProblem,
    run_simulations,
)

# Create AMICI model for the petab problem
jax_model = import_petab_problem(
    petab_problem,
    model_output_dir="model",
    compile_=True,
    jax=True
)

# Create the JAXProblem - wrapper for the AMICI model
jax_problem = JAXProblem(jax_model, petab_problem)

In [85]:
jax_problem._petab_problem.hybridization_df

Unnamed: 0_level_0,targetValue
targetId,Unnamed: 1_level_1
net1_input1,prey
net1_input2,predator
beta,net1_output1
gamma,net1_output2


### Hybridization

The JAXProblem stores the petab problem, including the contents of the hybridization file.  The hybridization defines where our neural network inputs and outputs get used in the model.  

In the example, the inputs to the network are the ``prey`` and ``predator`` values and the outputs of the network provide $\beta$ and  $\gamma$.  Note that we refer to the inputs and outputs of the network by their petab identifiers, as defined in the mapping file.

### Network YAML file

This file defines the network architecture.  An extract is shown here but the full file is included in the repo.  The provided example defines a network with three ``Linear`` layers and a ``tanh`` activation function.  See the page on supported layers and activation functions for a complete list.

```yaml
    nn_model_id: net1
    inputs:
    - input_id: input0
    layers:
    - layer_id: layer1
        layer_type: Linear
        args:
            in_features: 2
            out_features: 5
            bias: true
...
```

Note that where network entities are referenced in the ``mapping.tsv`` file, their ``modelEntityId`` must match the ``net_model_id`` given in this YAML.  Likewise if any specific layers or parameters are referenced in the ``mapping.tsv``, it should refer to them by the layer ids in this file.

### YAML file

The ``problem.yaml`` file describes how the problem should be constructed from the PEtab files.

```yaml
    format_version: 2
    parameter_file: "parameters.tsv"
    problems:
        - model_files:
            lv:
                location: "lv.xml"
                language: "sbml"
            measurement_files:
            - "measurements.tsv"
            observable_files:
            - "observables.tsv"
            condition_files:
            - "conditions.tsv"
            mapping_files:
            - "mapping.tsv"
    extensions:
        sciml:
            array_files:
            - "net1_ps.hdf5"
            hybridization_files:
            - "hybridization.tsv"
            neural_nets:
            net1:
                location: "net1.yaml"
                static: false
                format: "YAML"
```

### Array inputs

In this example, the array file is named in the problem YAML file under ``array_files`` and gives the initialisation parameters for the neural network.
The structure inside the HDF5 file is as follows.

```
   arrays.hdf5
   ├── metadata
   │   └── perm                                                                     
   └── parameters
       └── net1
           ├── layer1              
           │   ├── weight
           │   └── bias
           ├── layer2
           │   ├── weight
           │   └── bias
           └── layer3
               ├── weight
               └── bias
```

Note the network id ``net1`` matches the identifier in the problem and network YAML files, and the ``modelEntityId`` in the ``mapping.tsv``.  Likewise the layer names (``layer1``, ``layer2``, ``layer3``) in the HDF5 are identifiers and must match the layer names in the network YAML and mapping files.

## Running Simulations

We are now ready to run a simulation using the jax problem.  AMICI's `run_simulations` will simulate all conditions provided.  In this example though, there is only one.

In [67]:
# Run simulations - results in llh - metrics in r
llh, r = run_simulations(jax_problem)

# Run simulations in gradient mode - gradients given in sllh
sllh, rgrad = run_simulations(
    jax_problem,
    is_grad_mode=True,
)

From this point, we could implement a minimal training loop. For an AMICI based example of such model training, see [this guide](https://amici.readthedocs.io/en/latest/examples/example_jax_petab/ExampleJaxPEtab.html#Model-training).