In [1]:
import numpy as np
import pandas as pd
import seaborn as sns

import petab
import pypesto
import pypesto.petab


In [2]:
# rng seed
np.random.seed(500)

Import petab files and define the problem

In [3]:
petab_yaml = 'petab_files/age_of_infection_model.yaml'
petab.validate(petab_yaml)
petab_problem = petab.Problem.from_yaml(petab_yaml)

problem = pypesto.petab.PetabImporter(
        petab_problem,
        hierarchical=False,
        model_name=f"Age_of_Infection_Model",
    ).create_problem(force_compile=True, verbose=False)
problem.objective.amici_model.setAllStatesNonNegative()

Compiling amici model to folder /Users/yuhongliu/Documents/OV/model/age_of_infection_model/amici_models/0.27.0/Age_of_Infection_Model.


# Gradient Check

In [4]:
# obtain random startpoints
startpoints = problem.get_startpoints(n_starts=2)

In [5]:
problem.objective.check_grad(
    x = startpoints[0],
    eps = 1e-5,  # default
    verbosity = 0,
)

Unnamed: 0,grad,fd_f,fd_b,fd_c,fd_err,abs_err,rel_err
rho,8311.124286,8345.590983,8276.392723,8310.991853,69.19826,0.132433,1.593473e-05
kappa,5964.70415,5987.552814,5947.242734,5967.397774,40.31008,2.693624,0.0004513901
psi,-207.244997,-219.736412,-196.191357,-207.963885,23.545055,0.718888,0.003456792
phi,-574.908675,-587.731662,-561.673381,-574.702522,26.058281,0.206153,0.000358713
beta,-9.4e-05,-13.393312,13.397471,0.002079,26.790783,0.002173,1.039971
alpha,-2.1732,-14.198518,8.79826,-2.700129,22.996779,0.526929,0.1951504
delta,27.638521,15.992469,41.587793,28.790131,25.595324,1.15161,0.04000014
sigma_a,-6748.10079,-6761.428395,-6734.773305,-6748.10085,26.655089,6e-05,8.855918e-09
sigma_b,-3048.476873,-3061.892249,-3035.060152,-3048.4762,26.832097,0.000673,2.207325e-07


In [6]:
problem.objective.check_grad(
    x = problem.get_reduced_vector(startpoints[0]),
    eps = 1e-5,  # default
    verbosity = 0,
)

Unnamed: 0,grad,fd_f,fd_b,fd_c,fd_err,abs_err,rel_err
rho,8311.124286,8345.590983,8276.392723,8310.991853,69.19826,0.132433,1.593473e-05
kappa,5964.70415,5987.552814,5947.242734,5967.397774,40.31008,2.693624,0.0004513901
psi,-207.244997,-219.736412,-196.191357,-207.963885,23.545055,0.718888,0.003456792
phi,-574.908675,-587.731662,-561.673381,-574.702522,26.058281,0.206153,0.000358713
beta,-9.4e-05,-13.393312,13.397471,0.002079,26.790783,0.002173,1.039971
alpha,-2.1732,-14.198518,8.79826,-2.700129,22.996779,0.526929,0.1951504
delta,27.638521,15.992469,41.587793,28.790131,25.595324,1.15161,0.04000014
sigma_a,-6748.10079,-6761.428395,-6734.773305,-6748.10085,26.655089,6e-05,8.855918e-09
sigma_b,-3048.476873,-3061.892249,-3035.060152,-3048.4762,26.832097,0.000673,2.207325e-07


In [7]:
gc = problem.objective.check_grad_multi_eps(
    x=startpoints[0],
    verbosity=0,
    label='rel_err',  # default
)
gc

Unnamed: 0,grad,fd_f,fd_b,fd_c,fd_err,abs_err,rel_err,eps
rho,8311.124286,8309.378309,8312.889634,8311.133972,3.511325,0.009685193,1.165327e-06,0.001
kappa,5964.70415,5967.86859,5962.088125,5964.978358,5.780466,0.2742076,4.596957e-05,0.001
psi,-207.244997,-207.507447,-206.980502,-207.243974,0.526945,0.001022514,4.933887e-06,0.001
phi,-574.908675,-576.069756,-573.766012,-574.917884,2.303744,0.009209621,1.601905e-05,0.001
beta,-9.4e-05,-0.001445,0.001256,-9.4e-05,0.002701,8.93082e-07,8.939259e-06,0.1
alpha,-2.1732,-2.302367,-2.021702,-2.162034,0.280665,0.01116574,0.005166852,0.001
delta,27.638521,27.553397,27.750685,27.652041,0.197288,0.01351967,0.0004889036,0.001
sigma_a,-6748.10079,-8087.393644,-5408.808033,-6748.100839,2678.585611,4.846016e-05,7.181303e-09,1e-07
sigma_b,-3048.476873,-136977.822876,130880.869619,-3048.476628,267858.692496,0.0002447758,8.029448e-08,1e-09


In [8]:
def highlight_value_above_threshold(x, threshold=10):
    return ['color: darkorange' if xi > threshold else None for xi in x]

gc.style.apply(
    highlight_value_above_threshold, subset=["fd_err"],
).background_gradient(
    cmap=sns.light_palette("purple", as_cmap=True), subset=["abs_err"],
).background_gradient(
    cmap=sns.light_palette("red", as_cmap=True), subset=["rel_err"],
)

Unnamed: 0,grad,fd_f,fd_b,fd_c,fd_err,abs_err,rel_err,eps
rho,8311.124286,8309.378309,8312.889634,8311.133972,3.511325,0.009685,1e-06,0.001
kappa,5964.70415,5967.86859,5962.088125,5964.978358,5.780466,0.274208,4.6e-05,0.001
psi,-207.244997,-207.507447,-206.980502,-207.243974,0.526945,0.001023,5e-06,0.001
phi,-574.908675,-576.069756,-573.766012,-574.917884,2.303744,0.00921,1.6e-05,0.001
beta,-9.4e-05,-0.001445,0.001256,-9.4e-05,0.002701,1e-06,9e-06,0.1
alpha,-2.1732,-2.302367,-2.021702,-2.162034,0.280665,0.011166,0.005167,0.001
delta,27.638521,27.553397,27.750685,27.652041,0.197288,0.01352,0.000489,0.001
sigma_a,-6748.10079,-8087.393644,-5408.808033,-6748.100839,2678.585611,4.8e-05,0.0,0.0
sigma_b,-3048.476873,-136977.822876,130880.869619,-3048.476628,267858.692496,0.000245,0.0,0.0
