The purpose of this guide is to showcase how AMICI can be combined with differentiable programming in JAX. We will do so by reimplementing the parameter transformations available in AMICI in JAX and comparing it to the native implementation.

In [1]:
import jax
import jax.numpy as jnp

# Preparation

To get started we will import a model using the [petab](https://petab.readthedocs.io). To this end, we will use the [benchmark collection](https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab), which features a variety of different models. For more details about petab import, see the respective notebook petab [notebook](https://amici.readthedocs.io/en/latest/petab.html).

In [2]:
!git clone --depth 1 https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab.git tmp/benchmark-models || (cd tmp/benchmark-models && git pull)
from pathlib import Path

folder_base = Path(".") / "tmp" / "benchmark-models" / "Benchmark-Models"

Cloning into 'tmp/benchmark-models'...
remote: Enumerating objects: 336, done.[K
remote: Counting objects: 100% (336/336), done.[K
remote: Compressing objects: 100% (285/285), done.[K
remote: Total 336 (delta 88), reused 216 (delta 39), pack-reused 0[K
Receiving objects: 100% (336/336), 2.11 MiB | 7.48 MiB/s, done.
Resolving deltas: 100% (88/88), done.


From the benchmark collection, we now import the Boehm model.

In [3]:
import petab

model_name = "Boehm_JProteomeRes2014"
yaml_file = folder_base / model_name / (model_name + ".yaml")
petab_problem = petab.Problem.from_yaml(yaml_file)

The petab problem includes information about parameter scaling in it's the parameter table. For the boehm model, all estimated parameters (`petab.ESTIMATE` column equal to `1`) have a `petab.LOG10` as parameter scaling.

In [4]:
petab_problem.parameter_df

Unnamed: 0_level_0,parameterName,parameterScale,lowerBound,upperBound,nominalValue,estimate
parameterId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Epo_degradation_BaF3,"EPO_{degradation,BaF3}",log10,1e-05,100000,0.026983,1
k_exp_hetero,"k_{exp,hetero}",log10,1e-05,100000,1e-05,1
k_exp_homo,"k_{exp,homo}",log10,1e-05,100000,0.00617,1
k_imp_hetero,"k_{imp,hetero}",log10,1e-05,100000,0.016368,1
k_imp_homo,"k_{imp,homo}",log10,1e-05,100000,97749.379402,1
k_phos,k_{phos},log10,1e-05,100000,15766.50702,1
ratio,ratio,lin,-5.0,5,0.693,0
sd_pSTAT5A_rel,"\sigma_{pSTAT5A,rel}",log10,1e-05,100000,3.852612,1
sd_pSTAT5B_rel,"\sigma_{pSTAT5B,rel}",log10,1e-05,100000,6.591478,1
sd_rSTAT5A_rel,"\sigma_{rSTAT5A,rel}",log10,1e-05,100000,3.152713,1


We now import the petab problem using [`amici.petab_import`](https://amici.readthedocs.io/en/latest/generated/amici.petab_import.html#amici.petab_import.import_petab_problem).

In [5]:
from amici.petab.petab_import import import_petab_problem

amici_model = import_petab_problem(petab_problem, compile_=True)

2023-02-16 12:37:18.049 - amici.petab_import - INFO - Importing model ...
2023-02-16 12:37:18.050 - amici.petab_import - INFO - Validating PEtab problem ...
2023-02-16 12:37:18.343 - amici.petab_import - INFO - Model name is 'Boehm_JProteomeRes2014'.
Writing model code to '/Users/fabian/Documents/projects/AMICI/documentation/amici_models/Boehm_JProteomeRes2014'.
2023-02-16 12:37:18.344 - amici.petab_import - INFO - Species: 8
2023-02-16 12:37:18.344 - amici.petab_import - INFO - Global parameters: 9
2023-02-16 12:37:18.344 - amici.petab_import - INFO - Reactions: 9
2023-02-16 12:37:18.353 - amici.petab_import - INFO - Observables: 3
2023-02-16 12:37:18.353 - amici.petab_import - INFO - Sigmas: 3
2023-02-16 12:37:18.357 - amici.petab_import - DEBUG - Adding output parameters to model: ['noiseParameter1_pSTAT5A_rel', 'noiseParameter1_pSTAT5B_rel', 'noiseParameter1_rSTAT5A_rel']
2023-02-16 12:37:18.357 - amici.petab_import - DEBUG - Adding initial assignments for []
2023-02-16 12:37:18.36

2023-02-16 12:37:18.707 - amici.ode_export - DEBUG - Finished computing w                       ++++ (8.93E-03s)
2023-02-16 12:37:18.719 - amici.ode_export - DEBUG - Finished running smart_jacobian            ++++ (9.60E-03s)
2023-02-16 12:37:18.726 - amici.ode_export - DEBUG - Finished simplifying dwdp                  ++++ (4.59E-03s)
2023-02-16 12:37:18.726 - amici.ode_export - DEBUG - Finished computing dwdp                     +++ (2.93E-02s)
2023-02-16 12:37:18.730 - amici.ode_export - DEBUG - Finished writing dwdp.cpp                    ++ (3.43E-02s)
2023-02-16 12:37:18.743 - amici.ode_export - DEBUG - Finished running smart_jacobian            ++++ (8.10E-03s)
2023-02-16 12:37:18.749 - amici.ode_export - DEBUG - Finished simplifying dwdx                  ++++ (3.24E-03s)
2023-02-16 12:37:18.749 - amici.ode_export - DEBUG - Finished computing dwdx                     +++ (1.54E-02s)
2023-02-16 12:37:18.753 - amici.ode_export - DEBUG - Finished writing dwdx.cpp                  

2023-02-16 12:37:19.040 - amici.ode_export - DEBUG - Finished writing x0.cpp                      ++ (5.34E-03s)
2023-02-16 12:37:19.046 - amici.ode_export - DEBUG - Finished simplifying x0_fixedParameters    ++++ (3.45E-04s)
2023-02-16 12:37:19.046 - amici.ode_export - DEBUG - Finished computing x0_fixedParameters       +++ (2.33E-03s)
2023-02-16 12:37:19.047 - amici.ode_export - DEBUG - Finished writing x0_fixedParameters.cpp      ++ (4.79E-03s)
2023-02-16 12:37:19.053 - amici.ode_export - DEBUG - Finished running smart_jacobian            ++++ (9.02E-04s)
2023-02-16 12:37:19.055 - amici.ode_export - DEBUG - Finished simplifying sx0                   ++++ (3.90E-05s)
2023-02-16 12:37:19.055 - amici.ode_export - DEBUG - Finished computing sx0                      +++ (4.79E-03s)
2023-02-16 12:37:19.056 - amici.ode_export - DEBUG - Finished writing sx0.cpp                     ++ (6.54E-03s)
2023-02-16 12:37:19.061 - amici.ode_export - DEBUG - Finished running smart_jacobian            

running AmiciInstall
hdf5.h found in /opt/homebrew/Cellar/hdf5/1.12.2_2/include
libhdf5.a found in /opt/homebrew/Cellar/hdf5/1.12.2_2/lib
running build_ext
Changed extra_compile_args for unix to ['-std=c++14']
Building model extension in /Users/fabian/Documents/projects/AMICI/documentation/amici_models/Boehm_JProteomeRes2014
building 'Boehm_JProteomeRes2014._Boehm_JProteomeRes2014' extension
Testing SWIG executable swig4.0... FAILED.
Testing SWIG executable swig3.0... FAILED.
Testing SWIG executable swig... SUCCEEDED.
swigging swig/Boehm_JProteomeRes2014.i to swig/Boehm_JProteomeRes2014_wrap.cpp
swig -python -c++ -modern -outdir Boehm_JProteomeRes2014 -I/Users/fabian/Documents/projects/AMICI/python/sdist/amici/swig -I/Users/fabian/Documents/projects/AMICI/python/sdist/amici/include -o swig/Boehm_JProteomeRes2014_wrap.cpp swig/Boehm_JProteomeRes2014.i
Deprecated command line option: -modern. Ignored, this option is now always on.
creating build
creating build/temp.macosx-13-arm64-cpytho

2023-02-16 12:37:31.673 - amici.petab_import - INFO - Finished Importing PEtab model                (1.36E+01s)
2023-02-16 12:37:31.684 - amici.petab_import - INFO - Successfully loaded model Boehm_JProteomeRes2014 from /Users/fabian/Documents/projects/AMICI/documentation/amici_models/Boehm_JProteomeRes2014.


# JAX implementation

For full jax support, we would have to implement a new [primitive](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html), which would require quite a bit of engineering, and in the end wouldn't add much benefit since AMICI can't run on GPUs. Instead, we will interface AMICI using the experimental jax module [`host_callback`](https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html). 

To do so, we define a base function that only takes a single argument (the parameters) and runs simulation using petab via [`simulate_petab`](https://amici.readthedocs.io/en/latest/generated/amici.petab_objective.html#amici.petab_objective.simulate_petab). To enable gradient computation later on, we create a solver object and set the sensitivity order to first order and pass it to `simulate_petab`. Moreover, `simulate_petab` expects a dictionary of parameters, so we create a dictionary using the free parameter ids from the petab problem. As we want to implement parameter transformation in JAX, we disable parameter scaling in petab by passing `scaled_parameters=True`.

In [6]:
from amici.petab.simulations import simulate_petab
import amici

amici_solver = amici_model.getSolver()
amici_solver.setSensitivityOrder(amici.SensitivityOrder.first)


def amici_hcb_base(parameters: jnp.array):
    return simulate_petab(
        petab_problem,
        amici_model,
        problem_parameters=dict(zip(petab_problem.x_free_ids, parameters)),
        scaled_parameters=True,
        solver=amici_solver,
    )

Now we can use this base function to create two functions separate functions that compute the log-likelihood (`llh`) and it's gradient (`sllh`) in two individual routines. Note that, as we are using the same base function here, the log-likelihood computation will also run with sensitivities which is not necessary and will add some overhead. This is only out of convenience and should be fixed in an application where efficiency is important.

In [7]:
def amici_hcb_llh(parameters: jnp.array):
    return amici_hcb_base(parameters)["llh"]


def amici_hcb_sllh(parameters: jnp.array):
    sllh = amici_hcb_base(parameters)["sllh"]
    return jnp.asarray(
        tuple(sllh[par_id] for par_id in petab_problem.x_free_ids)
    )

Now we can finally define the JAX function that runs amici simulation using the host callback. We add a `custom_jvp` decorator so that we can define a custom jacobian vector product function in the next step. More details about custom jacobian vector product functions can be found in the [JAX documentation](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html)

In [8]:
import jax.experimental.host_callback as hcb
from jax import custom_jvp

import numpy as np


@custom_jvp
def jax_objective(parameters: jnp.array):
    return hcb.call(
        amici_hcb_llh,
        parameters,
        result_shape=jax.ShapeDtypeStruct((), np.float64),
    )

Now we define the function that implement the jacobian vector product. This effectively just returns the objective function value (computed using the previously defined `jax_objective`) as well as the inner product of the gradient (computed using a host callback to the previously defined `amici_hcb_sllh`) and the tangents vector. Note that this implementation performs two simulation runs, one for the function value and one for the gradient, which is inefficient and could be avoided by caching solutions.

In [9]:
@jax_objective.defjvp
def jax_objective_jvp(primals: jnp.array, tangents: jnp.array):
    (parameters,) = primals
    (x_dot,) = tangents
    llh = jax_objective(parameters)
    sllh = hcb.call(
        amici_hcb_sllh,
        parameters,
        result_shape=jax.ShapeDtypeStruct(
            (petab_problem.parameter_df.estimate.sum(),), np.float64
        ),
    )
    return llh, sllh.dot(x_dot)

As last step, we implement the parameter transformation in jax. This effectively just extracts parameter scales from the petab problem, implements rescaling in jax and then passes the scaled parameters to the previously objective function we previously defined. We add the `value_and_grad` decorator such that the generated jax function returns both function value and function gradient in a tuple. Moreover, we add the `jax.jit` decorator such that the function is [just in time compiled](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) upon the first function call.

In [10]:
from jax import value_and_grad

parameter_scales = petab_problem.parameter_df.loc[
    petab_problem.x_free_ids, petab.PARAMETER_SCALE
].values


@jax.jit
@value_and_grad
def jax_objective_with_parameter_transform(parameters: jnp.array):
    par_scaled = jnp.asarray(
        tuple(
            value
            if scale == petab.LIN
            else jnp.log(value)
            if scale == petab.LOG
            else jnp.log10(value)
            for value, scale in zip(parameters, parameter_scales)
        )
    )
    return jax_objective(par_scaled)

# Testing

We can now run the function to compute the log-likelihood and the gradient. 

In [11]:
llh_jax, sllh_jax = jax_objective_with_parameter_transform(
    petab_problem.x_nominal_free
)

As a sanity check, we compare the computed value to native parameter transformation in amici. 

In [12]:
r = simulate_petab(petab_problem, amici_model, solver=amici_solver)
# TODO remove me as soon as sllh in simulate_petab is fixed
sllh = {
    name: value / (np.log(10) * par_value)
    for (name, value), par_value in zip(
        r["sllh"].items(), petab_problem.x_nominal_free
    )
}

In [13]:
import pandas as pd

pd.Series(dict(amici=r["llh"], jax=float(llh_jax)))

amici   -138.221997
jax     -138.222000
dtype: float64

In [14]:
pd.DataFrame(
    index=sllh.keys(), data=dict(amici=sllh.values(), jax=np.asarray(sllh_jax))
)

Unnamed: 0,amici,jax
Epo_degradation_BaF3,-0.3546026,-0.3640394
k_exp_hetero,-2401.005,-2401.01
k_exp_homo,-0.4073832,-0.4106763
k_imp_hetero,-0.1432855,-0.163903
k_imp_homo,2.006412e-10,2.006412e-10
k_phos,-2.17995e-07,-2.089803e-07
sd_pSTAT5A_rel,-0.001215545,-0.001222887
sd_pSTAT5B_rel,-0.001583889,-0.00158087
sd_rSTAT5A_rel,-0.002643776,-0.002641361


We see quite some differences in the gradient calculation. The primary reason is that running JAX in default configuration will use float32 precision for the parameters that are passed to AMICI, which uses float64, and the derivative of the parameter transformation 
As AMICI simulations that run on the CPU are the most expensive operation, there is barely any tradeoff for using float32 vs float64 in JAX. Therefore, we configure JAX to use float64 instead and rerun simulations.

In [15]:
jax.config.update("jax_enable_x64", True)
llh_jax, sllh_jax = jax_objective_with_parameter_transform(
    petab_problem.x_nominal_free
)

We can now evaluate the results again and see that differences between pure AMICI and AMICI/JAX implementations are now much smaller.

In [16]:
pd.Series(dict(amici=r["llh"], jax=float(llh_jax)))

amici   -138.221997
jax     -138.221997
dtype: float64

In [17]:
pd.DataFrame(
    index=sllh.keys(), data=dict(amici=sllh.values(), jax=np.asarray(sllh_jax))
)

Unnamed: 0,amici,jax
Epo_degradation_BaF3,-0.3546026,-0.3546504
k_exp_hetero,-2401.005,-2401.005
k_exp_homo,-0.4073832,-0.4074248
k_imp_hetero,-0.1432855,-0.1433139
k_imp_homo,2.006412e-10,2.006412e-10
k_phos,-2.17995e-07,-2.179076e-07
sd_pSTAT5A_rel,-0.001215545,-0.001215596
sd_pSTAT5B_rel,-0.001583889,-0.001583805
sd_rSTAT5A_rel,-0.002643776,-0.002643703
