# Hierarchical optimization

In this notebook we illustrate how to do hierarchical optimization in pyPESTO.

A frequent problem occuring in parameter estimation for dynamical systems is that the objective function takes a form

$$ J(\theta, s, b, \sigma^2) = \sum_i \left[\log(2\pi\sigma_i^2) + \frac{(\bar y_i - (s_iy_i(\theta) + b_i))^2}{\sigma_i^2}\right] $$

with data $\bar y_i$, parameters $\eta = (\theta,s,b,\sigma^2)$, and ODE simulations $y(\theta)$. Here, we consider a Gaussian noise model, but also others (e.g. Laplace) are possible. The point of interest here is that we can split up the parameter vector $\eta$ into "dynamic" parameters $\theta$ which are required for simulating the ODE, and "static" parameters $s,b,\sigma^2$ only required to scale the simulations and formulate the objective function. As usually simulating the ODE is the time-critical part, one can exploit this separation of parameters by formulating an outer optimization problem in which $\theta$ is optimized, and an inner optimization problem in which $s,b,\sigma^2$ are optimized conditioned on $\theta$. This approach has shown to have superior performance to the classic aproach of jointly optimizting $\eta$. 

In pyPESTO, we have implemented the algorithms developed in [Loos et al.; Hierarchical optimization for the efficient parametrization of ODE models; Bioinformatics 2018](https://academic.oup.com/bioinformatics/article/34/24/4266/5053308) (covering Gaussian and Laplace noise models with gradients computed via forward sensitivity analysis) and [Schmiester et al.; Efficient parameterization of large-scale dynamic models based on relative measurements; Bioinformatics 2019](https://academic.oup.com/bioinformatics/article/36/2/594/5538985) (extending to offset parameters and adjoint sensitivity analysis).

In [None]:
import os
import time

import amici
import matplotlib.pyplot as plt
import numpy as np
import petab
from matplotlib.colors import to_rgba

import pypesto
from pypesto.hierarchical.solver import (
    AnalyticalInnerSolver,
    NumericalInnerSolver,
)
from pypesto.optimize.options import OptimizeOptions
from pypesto.petab import PetabImporter

We consider a version of the [Boehm et al.; Journal of Proeome Research 2014] model, modified to include scalings $s$, offsets $b$, and noise parameters $\sigma^2$.

In [None]:
# get the PEtab problem
# requires installation of
from pypesto.testing.examples import (
    get_Boehm_JProteomeRes2014_hierarchical_petab,
)

petab_problem = get_Boehm_JProteomeRes2014_hierarchical_petab()

The PEtab observable table contains placeholders for scaling parameters $s$ (`observableParameter1_{pSTAT5A_rel,pSTAT5B_rel,rSTAT5A_rel}`), offsets $b$ (`observableParameter2_{pSTAT5A_rel,pSTAT5B_rel,rSTAT5A_rel}`), and noise parameters $\sigma^2$ (`noiseParameter1_{pSTAT5A_rel,pSTAT5B_rel,rSTAT5A_rel}`) that are overridden by the `{observable,noise}Parameters` column in the measurement table.

In [None]:
from pandas import option_context

with option_context('display.max_colwidth', 400):
    display(petab_problem.observable_df)

Parameters to be optimized in the inner problem are selected via the PEtab parameter table by setting a value in the non-standard column `parameterType` (`offset` for offset parameters, `scaling` for scaling parameters, and `sigma` for sigma parameters):

In [None]:
petab_problem.parameter_df

In [None]:
# Create pypesto Objectives with and without hierarchical optimization
importer = PetabImporter(petab_problem, hierarchical=True)
objective = importer.create_objective()
problem = importer.create_problem(objective)
problem.objective.amici_solver.setSensitivityMethod(
    amici.SensitivityMethod_adjoint
)

importer2 = PetabImporter(petab_problem, hierarchical=False)
objective2 = importer2.create_objective()
problem2 = importer2.create_problem(objective2)
problem2.objective.amici_solver.setSensitivityMethod(
    amici.SensitivityMethod_adjoint
)

# Set the same starting point for both
n_starts = 3
startpoints = pypesto.startpoint.latin_hypercube(
    n_starts=n_starts, lb=problem2.lb_full, ub=problem2.ub_full
)
outer_indices = [problem2.x_names.index(x_id) for x_id in problem.x_names]
problem.set_x_guesses(startpoints[:, outer_indices])
problem2.set_x_guesses(startpoints)

options = OptimizeOptions(allow_failed_starts=False)

In [None]:
# Run hierarchical optimization using NumericalInnerSolver
start_time = time.time()
problem.objective.calculator.inner_solver = NumericalInnerSolver()
problem.objective.calculator.inner_solver.n_starts = 1
engine = pypesto.engine.MultiProcessEngine(n_procs=6)
result_num = pypesto.optimize.minimize(
    problem,
    n_starts=n_starts,
    engine=engine,
    options=options,
)
print(f"{result_num.optimize_result.get_for_key('fval')=}")
time_num = time.time() - start_time
print(f"{time_num=}")

In [None]:
# Run hierarchical optimization using AnalyticalInnerSolver
start_time = time.time()
problem.objective.calculator.inner_solver = AnalyticalInnerSolver()
engine = pypesto.engine.MultiProcessEngine(n_procs=6)
result_ana = pypesto.optimize.minimize(
    problem, n_starts=n_starts, engine=engine, options=options
)
print(f"{result_ana.optimize_result.get_for_key('fval')=}")
time_ana = time.time() - start_time
print(f"{time_ana=}")

In [None]:
# Waterfall plot - analytical vs numerical inner solver
pypesto.visualize.waterfall(
    [result_num, result_ana],
    legends=['Numerical', 'Analytical'],
    size=(15, 6),
    order_by_id=True,
    colors=np.array(list(map(to_rgba, ('green', 'purple')))),
)
plt.savefig("num_ana.png")

In [None]:
# Time comparison - analytical vs numerical inner solver
ax = plt.bar(x=[0, 1], height=[time_ana, time_num], color=['purple', 'green'])
ax = plt.gca()
ax.set_xticks([0, 1])
ax.set_xticklabels(['Analytical', 'Numerical'])
ax.set_ylabel('Time [s]')
plt.savefig("num_ana_time.png")

In [None]:
# Run standard optimization
start_time = time.time()
engine = pypesto.engine.MultiProcessEngine(n_procs=6)
result_ord = pypesto.optimize.minimize(
    problem2, n_starts=n_starts, engine=engine
)
print(f"{result_ord.optimize_result.get_for_key('fval')=}")
time_ord = time.time() - start_time
print(f"{time_ord=}")

In [None]:
# Waterfall plot - hierarchical optimization with analytical inner solver vs standard optimization
pypesto.visualize.waterfall(
    [result_ana, result_ord],
    legends=['Analytical', 'Non-Hierarchical'],
    order_by_id=True,
    colors=np.array(list(map(to_rgba, ('purple', 'orange')))),
    size=(15, 6),
)
plt.savefig("ana_ord.png")

In [None]:
# Time comparison - hierarchical optimization with analytical inner solver vs standard optimization
import matplotlib.pyplot as plt

ax = plt.bar(x=[0, 1], height=[time_ana, time_ord], color=['purple', 'orange'])
ax = plt.gca()
ax.set_xticks([0, 1])
ax.set_xticklabels(['Analytical', 'Non-Hierarchical'])
ax.set_ylabel('Time [s]')
plt.savefig("ana_ord_time.png")

In [None]:
# Run hierarchical optimization with analytical inner solver and forward sensitivities
start_time = time.time()
problem.objective.calculator.inner_solver = AnalyticalInnerSolver()
problem.objective.amici_solver.setSensitivityMethod(
    amici.SensitivityMethod_forward
)
engine = pypesto.engine.MultiProcessEngine(n_procs=6)
result_ana_fw = pypesto.optimize.minimize(
    problem, n_starts=n_starts, engine=engine
)
print(f"{result_ana_fw.optimize_result.get_for_key('fval')=}")
time_ana_fw = time.time() - start_time
print(f"{time_ana_fw=}")

In [None]:
# Waterfall plot - compare all scenarios
pypesto.visualize.waterfall(
    [result_ana, result_ana_fw, result_num, result_ord],
    legends=['Analytical', 'Analytical forward', 'Numerical', 'Standard'],
    colors=np.array(list(map(to_rgba, ('purple', 'blue', 'green', 'orange')))),
    order_by_id=True,
    size=(15, 6),
)
plt.savefig("all.png")

In [None]:
# Time comparison of all scenarios
import matplotlib.pyplot as plt

ax = plt.bar(
    x=[0, 1, 2, 3],
    height=[time_ana, time_ana_fw, time_num, time_ord],
    color=['purple', 'blue', 'green', 'orange'],
)
ax = plt.gca()
ax.set_xticks([0, 1, 2, 3])
ax.set_xticklabels(
    ['Analytical', 'Analytical forward', 'Numerical', 'Non-Hierarchical']
)
ax.set_ylabel('Time [s]')
plt.savefig("all_time.png")