purpose of this guide is combination of amici with differentiable programming in jax. 
for this guide, we will demonstrate how to implement custom parameter transformations.

In [1]:
import jax
import jax.numpy as jnp

key to differentiable programming is to have a custom jvp. this allows computation of derivatives using the chain rule. in jax gradients can be computed using the grad function. to interface amici with jax we will use the custom_jvp function to define how to compute the jacobian vector product for simulation results

In [2]:
from jax import custom_jvp, value_and_grad

for native jax support, we would need to implement lax primitive for amici simulation, but would require quite a bit of engineering and writing C code.
Instead support will be enabled by an experimental jax feature called `host_callback`. 
this means that amici code will only run on CPU, but AMICI code is anyways not amenable to GPU vectorization.

In [3]:
import jax.experimental.host_callback as hcb

another important tool that we will use here is the function `partial` from the functools package. `partial` can be used as function decorator to apply arguments to other decorator functions.

to get started we will import petab definition. will use benchmark collection [insert ref] for that, for more details see petab notebook [ref].

In [4]:
!git clone --depth 1 https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab.git tmp/benchmark-models || (cd tmp/benchmark-models && git pull)
from pathlib import Path
folder_base = Path('.') / "tmp" / "benchmark-models" / "Benchmark-Models"

Cloning into 'tmp/benchmark-models'...
remote: Enumerating objects: 336, done.[K
remote: Counting objects: 100% (336/336), done.[K
remote: Compressing objects: 100% (285/285), done.[K
remote: Total 336 (delta 88), reused 216 (delta 39), pack-reused 0[K
Receiving objects: 100% (336/336), 2.11 MiB | 11.02 MiB/s, done.
Resolving deltas: 100% (88/88), done.


now we can import boehm model

In [5]:
import petab
model_name = "Boehm_JProteomeRes2014"
yaml_file = folder_base / model_name / (model_name + ".yaml")
petab_problem = petab.Problem.from_yaml(yaml_file)

parameter scaling is defined in the parameter table. for the boehm model, all estimated parameters (`petab.ESTIMATE` column equal to `1`) have a `petab.LOG10` as parameter scaling.

In [6]:
petab_problem.parameter_df

Unnamed: 0_level_0,parameterName,parameterScale,lowerBound,upperBound,nominalValue,estimate
parameterId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Epo_degradation_BaF3,"EPO_{degradation,BaF3}",log10,1e-05,100000,0.026983,1
k_exp_hetero,"k_{exp,hetero}",log10,1e-05,100000,1e-05,1
k_exp_homo,"k_{exp,homo}",log10,1e-05,100000,0.00617,1
k_imp_hetero,"k_{imp,hetero}",log10,1e-05,100000,0.016368,1
k_imp_homo,"k_{imp,homo}",log10,1e-05,100000,97749.379402,1
k_phos,k_{phos},log10,1e-05,100000,15766.50702,1
ratio,ratio,lin,-5.0,5,0.693,0
sd_pSTAT5A_rel,"\sigma_{pSTAT5A,rel}",log10,1e-05,100000,3.852612,1
sd_pSTAT5B_rel,"\sigma_{pSTAT5B,rel}",log10,1e-05,100000,6.591478,1
sd_rSTAT5A_rel,"\sigma_{rSTAT5A,rel}",log10,1e-05,100000,3.152713,1


now we import both petab problem using `amici.petab_import`

In [7]:
from amici.petab_import import import_petab_problem

In [8]:
amici_model = import_petab_problem(petab_problem)

2023-02-16 10:46:20.089 - amici.petab_import - INFO - Importing model ...
2023-02-16 10:46:20.090 - amici.petab_import - INFO - Validating PEtab problem ...
2023-02-16 10:46:20.396 - amici.petab_import - INFO - Model name is 'Boehm_JProteomeRes2014'.
Writing model code to '/Users/fabian/Documents/projects/AMICI/documentation/amici_models/Boehm_JProteomeRes2014'.
2023-02-16 10:46:20.397 - amici.petab_import - INFO - Species: 8
2023-02-16 10:46:20.397 - amici.petab_import - INFO - Global parameters: 9
2023-02-16 10:46:20.398 - amici.petab_import - INFO - Reactions: 9
2023-02-16 10:46:20.406 - amici.petab_import - INFO - Observables: 3
2023-02-16 10:46:20.407 - amici.petab_import - INFO - Sigmas: 3
2023-02-16 10:46:20.410 - amici.petab_import - DEBUG - Adding output parameters to model: ['noiseParameter1_pSTAT5A_rel', 'noiseParameter1_pSTAT5B_rel', 'noiseParameter1_rSTAT5A_rel']
2023-02-16 10:46:20.411 - amici.petab_import - DEBUG - Adding initial assignments for []
2023-02-16 10:46:20.41

2023-02-16 10:46:20.752 - amici.ode_export - DEBUG - Finished computing w                       ++++ (8.81E-03s)
2023-02-16 10:46:20.765 - amici.ode_export - DEBUG - Finished running smart_jacobian            ++++ (9.75E-03s)
2023-02-16 10:46:20.771 - amici.ode_export - DEBUG - Finished simplifying dwdp                  ++++ (4.64E-03s)
2023-02-16 10:46:20.772 - amici.ode_export - DEBUG - Finished computing dwdp                     +++ (2.96E-02s)
2023-02-16 10:46:20.775 - amici.ode_export - DEBUG - Finished writing dwdp.cpp                    ++ (3.47E-02s)
2023-02-16 10:46:20.790 - amici.ode_export - DEBUG - Finished running smart_jacobian            ++++ (8.31E-03s)
2023-02-16 10:46:20.795 - amici.ode_export - DEBUG - Finished simplifying dwdx                  ++++ (3.34E-03s)
2023-02-16 10:46:20.795 - amici.ode_export - DEBUG - Finished computing dwdx                     +++ (1.58E-02s)
2023-02-16 10:46:20.798 - amici.ode_export - DEBUG - Finished writing dwdx.cpp                  

2023-02-16 10:46:21.076 - amici.ode_export - DEBUG - Finished writing x0.cpp                      ++ (4.98E-03s)
2023-02-16 10:46:21.081 - amici.ode_export - DEBUG - Finished simplifying x0_fixedParameters    ++++ (2.91E-04s)
2023-02-16 10:46:21.082 - amici.ode_export - DEBUG - Finished computing x0_fixedParameters       +++ (1.98E-03s)
2023-02-16 10:46:21.083 - amici.ode_export - DEBUG - Finished writing x0_fixedParameters.cpp      ++ (4.22E-03s)
2023-02-16 10:46:21.088 - amici.ode_export - DEBUG - Finished running smart_jacobian            ++++ (7.87E-04s)
2023-02-16 10:46:21.090 - amici.ode_export - DEBUG - Finished simplifying sx0                   ++++ (3.49E-05s)
2023-02-16 10:46:21.090 - amici.ode_export - DEBUG - Finished computing sx0                      +++ (4.25E-03s)
2023-02-16 10:46:21.091 - amici.ode_export - DEBUG - Finished writing sx0.cpp                     ++ (5.90E-03s)
2023-02-16 10:46:21.095 - amici.ode_export - DEBUG - Finished running smart_jacobian            

running AmiciInstall
hdf5.h found in /opt/homebrew/Cellar/hdf5/1.12.2_2/include
libhdf5.a found in /opt/homebrew/Cellar/hdf5/1.12.2_2/lib
running build_ext
Changed extra_compile_args for unix to ['-std=c++14']
Building model extension in /Users/fabian/Documents/projects/AMICI/documentation/amici_models/Boehm_JProteomeRes2014
building 'Boehm_JProteomeRes2014._Boehm_JProteomeRes2014' extension
Testing SWIG executable swig4.0... FAILED.
Testing SWIG executable swig3.0... FAILED.
Testing SWIG executable swig... SUCCEEDED.
swigging swig/Boehm_JProteomeRes2014.i to swig/Boehm_JProteomeRes2014_wrap.cpp
swig -python -c++ -modern -outdir Boehm_JProteomeRes2014 -I/Users/fabian/Documents/projects/AMICI/python/sdist/amici/swig -I/Users/fabian/Documents/projects/AMICI/python/sdist/amici/include -o swig/Boehm_JProteomeRes2014_wrap.cpp swig/Boehm_JProteomeRes2014.i
Deprecated command line option: -modern. Ignored, this option is now always on.
creating build
creating build/temp.macosx-13-arm64-cpytho

2023-02-16 10:46:33.666 - amici.petab_import - INFO - Finished Importing PEtab model                (1.36E+01s)
2023-02-16 10:46:33.676 - amici.petab_import - INFO - Successfully loaded model Boehm_JProteomeRes2014 from /Users/fabian/Documents/projects/AMICI/documentation/amici_models/Boehm_JProteomeRes2014.


now everything is ready to actually start with the jax implementation. as first step we define a jax function that runs an amici simulation.

In [16]:
from amici.petab_objective import simulate_petab
import amici
import numpy as np
amici_solver = amici_model.getSolver()
amici_solver.setSensitivityOrder(amici.SensitivityOrder.first)

def amici_hcb_base(parameters: jnp.array):
    return simulate_petab(
        petab_problem, 
        amici_model, 
        problem_parameters=dict(zip(petab_problem.x_free_ids, parameters)), 
        scaled_parameters=True,
        solver=amici_solver,
    )
    
def amici_hcb_llh(parameters: jnp.array):
    return amici_hcb_base(parameters)['llh']

def amici_hcb_sllh(parameters: jnp.array):
    sllh = amici_hcb_base(parameters)['sllh']
    return jnp.asarray(tuple(
        sllh[par_id] for par_id in petab_problem.x_free_ids
    ))

@custom_jvp
def jax_objective(parameters: jnp.array):
    return hcb.call(
        amici_hcb_llh,
        parameters,
        result_shape=jax.ShapeDtypeStruct((), np.float64),
    )


@jax_objective.defjvp
def jax_objective_jvp(primals: jnp.array, tangents: jnp.array):
    (parameters,) = primals
    (x_dot,) = tangents
    llh = jax_objective(parameters)
    sllh = hcb.call(
        amici_hcb_sllh,
        parameters,
        result_shape=jax.ShapeDtypeStruct((petab_problem.parameter_df.estimate.sum(),), np.float64),
    )
    return llh, sllh.dot(x_dot)

In [17]:
petab_problem.x_nominal_free

[0.026982514033029,
 1.00067973851508e-05,
 0.006170228086381,
 0.0163679184468,
 97749.3794024716,
 15766.5070195731,
 3.85261197844677,
 6.59147818673419,
 3.15271275648527]

In [18]:
jax_objective(jnp.asarray(petab_problem.x_nominal_free_scaled))

Array(-138.222, dtype=float32)

In [19]:
simulate_petab(petab_problem, amici_model)['llh']

-138.2219962156317

In [20]:
parameter_scales = petab_problem.parameter_df.loc[petab_problem.x_free_ids, petab.PARAMETER_SCALE].values

@jax.jit
@value_and_grad
def jax_objective_with_parameter_transform(parameters: jnp.array):
    par_scaled = jnp.asarray(tuple(
        value if scale == petab.LIN
        else jnp.log(value) if scale == petab.LOG
        else jnp.log10(value)
        for value, scale in zip(parameters, parameter_scales)
    ))
    return jax_objective(par_scaled)
    

In [21]:
jax_objective_with_parameter_transform(petab_problem.x_nominal_free)

(Array(-138.222, dtype=float32),
 [Array(-0.36403936, dtype=float32),
  Array(-2401.0105, dtype=float32),
  Array(-0.41067627, dtype=float32),
  Array(-0.16390301, dtype=float32),
  Array(2.0064123e-10, dtype=float32),
  Array(-2.089803e-07, dtype=float32),
  Array(-0.00122289, dtype=float32),
  Array(-0.00158087, dtype=float32),
  Array(-0.00264136, dtype=float32)])

In [22]:
simulate_petab(petab_problem, amici_model, solver=amici_solver)

{'llh': -138.22199662450979,
 'sllh': {'Epo_degradation_BaF3': -0.022031291993031152,
  'k_exp_hetero': -0.05532275416950131,
  'k_exp_homo': -0.005787886630252937,
  'k_imp_hetero': -0.005400220655104336,
  'k_imp_homo': 4.515958094583564e-05,
  'k_phos': -0.007914030504748332,
  'sd_pSTAT5A_rel': -0.010783057977445385,
  'sd_pSTAT5B_rel': -0.02403937268176315,
  'sd_rSTAT5A_rel': -0.019192198115317298},
 'rdatas': [<ReturnDataView(<amici.amici.ReturnData; proxy of <Swig Object of type 'amici::ReturnData *' at 0x2887b82a0> >)>]}

In [None]:
amici.SensitivityOrder(amici_solver.getSensitivityOrder())

In [23]:
petab_problem.x_nominal_free

[0.026982514033029,
 1.00067973851508e-05,
 0.006170228086381,
 0.0163679184468,
 97749.3794024716,
 15766.5070195731,
 3.85261197844677,
 6.59147818673419,
 3.15271275648527]