purpose of this guide is combination of amici with differentiable programming in jax. 
for this guide, we will demonstrate how to implement custom parameter transformations.

In [1]:
import jax
import jax.numpy as jnp

key to differentiable programming is to have a custom jvp. this allows computation of derivatives using the chain rule. in jax gradients can be computed using the grad function. to interface amici with jax we will use the custom_jvp function to define how to compute the jacobian vector product for simulation results

In [2]:
from jax import custom_jvp, grad

for native jax support, we would need to implement lax primitive for amici simulation, but would require quite a bit of engineering and writing C code.
Instead support will be enabled by an experimental jax feature called `host_callback`. 
this means that amici code will only run on CPU, but AMICI code is anyways not amenable to GPU vectorization.

In [3]:
import jax.experimental.host_callback as hcb

another important tool that we will use here is the function `partial` from the functools package. `partial` can be used as function decorator to apply arguments to other decorator functions.

to get started we will import petab definition. will use benchmark collection [insert ref] for that, for more details see petab notebook [ref].

In [4]:
!git clone --depth 1 https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab.git tmp/benchmark-models || (cd tmp/benchmark-models && git pull)
from pathlib import Path
folder_base = Path('.') / "tmp" / "benchmark-models" / "Benchmark-Models"

fatal: destination path 'tmp/benchmark-models' already exists and is not an empty directory.
Already up to date.


now we can import boehm model

In [5]:
import petab
model_name = "Boehm_JProteomeRes2014"
yaml_file = folder_base / model_name / (model_name + ".yaml")
petab_problem = petab.Problem.from_yaml(yaml_file)

parameter scaling is defined in the parameter table. for the boehm model, all estimated parameters (`petab.ESTIMATE` column equal to `1`) have a `petab.LOG10` as parameter scaling.

In [6]:
petab_problem.parameter_df

Unnamed: 0_level_0,parameterName,parameterScale,lowerBound,upperBound,nominalValue,estimate
parameterId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Epo_degradation_BaF3,"EPO_{degradation,BaF3}",log10,1e-05,100000,0.026983,1
k_exp_hetero,"k_{exp,hetero}",log10,1e-05,100000,1e-05,1
k_exp_homo,"k_{exp,homo}",log10,1e-05,100000,0.00617,1
k_imp_hetero,"k_{imp,hetero}",log10,1e-05,100000,0.016368,1
k_imp_homo,"k_{imp,homo}",log10,1e-05,100000,97749.379402,1
k_phos,k_{phos},log10,1e-05,100000,15766.50702,1
ratio,ratio,lin,-5.0,5,0.693,0
sd_pSTAT5A_rel,"\sigma_{pSTAT5A,rel}",log10,1e-05,100000,3.852612,1
sd_pSTAT5B_rel,"\sigma_{pSTAT5B,rel}",log10,1e-05,100000,6.591478,1
sd_rSTAT5A_rel,"\sigma_{rSTAT5A,rel}",log10,1e-05,100000,3.152713,1


to compare results to native scaling, we copy the problem and set parameter scaling in the copy to `petab.LIN` (linear) so we can implement scaling in jax.

In [7]:
import copy
petab_problem_unscaled = copy.deepcopy(petab_problem)
petab_problem_unscaled.parameter_df[petab.PARAMETER_SCALE] = petab.LIN

now we import both petab problem using `amici.petab_import`

In [8]:
from amici.petab_import import import_petab_problem

In [11]:
amici_model = import_petab_problem(petab_problem, model_name='Boehm_scaled')
amici_model_unscaled = import_petab_problem(petab_problem_unscaled, model_name='Boehm_unscaled')

ValueError: Cannot compile to /Users/fabian/Documents/projects/AMICI/documentation/amici_models/Boehm_JProteomeRes2014: not empty. Please assign a different target or set `force_compile`.