In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
from sciope.utilities.summarystats import auto_tsfresh
from sciope.utilities.priors import uniform_prior
from sciope.inference.abc_inference import ABC
from sciope.utilities.distancefunctions import naive_squared
from tsfresh.feature_extraction.settings import MinimalFCParameters
from sklearn.metrics import mean_absolute_error
from gillespy2.solvers.numpy import NumPySSASolver
import gillespy2
import time

## Define gillespy2 model

In [2]:
class ToggleSwitch(gillespy2.Model):
    """ Gardner et al. Nature (1999)
    'Construction of a genetic toggle switch in Escherichia coli'
    """
    def __init__(self, parameter_values=None):
        # Initialize the model.
        gillespy2.Model.__init__(self, name="toggle_switch")
        # Parameters
        alpha1 = gillespy2.Parameter(name='alpha1', expression=1)
        alpha2 = gillespy2.Parameter(name='alpha2', expression=1)
        beta = gillespy2.Parameter(name='beta', expression="2.0")
        gamma = gillespy2.Parameter(name='gamma', expression="2.0")
        mu = gillespy2.Parameter(name='mu', expression=1.0)
        self.add_parameter([alpha1, alpha2, beta, gamma, mu])

        # Species
        U = gillespy2.Species(name='U', initial_value=10)
        V = gillespy2.Species(name='V', initial_value=10)
        self.add_species([U, V])

        # Reactions
        cu = gillespy2.Reaction(name="r1",reactants={}, products={U:1},
                propensity_function="alpha1/(1+pow(V,beta))")
        cv = gillespy2.Reaction(name="r2",reactants={}, products={V:1},
                propensity_function="alpha2/(1+pow(U,gamma))")
        du = gillespy2.Reaction(name="r3",reactants={U:1}, products={},
                rate=mu)
        dv = gillespy2.Reaction(name="r4",reactants={V:1}, products={},
                rate=mu)
        self.add_reaction([cu,cv,du,dv])
        self.timespan(np.linspace(0,50,101))

toggle_model = ToggleSwitch()

# Define simulator function
def set_model_parameters(params, model):
    """ params - array, needs to have the same order as
        model.listOfParameters """
    for e, (pname, p) in enumerate(model.listOfParameters.items()):
        model.get_parameter(pname).set_expression(params[e])
    return model

# Here we use gillespy2 numpy solver
def simulator(params, model):

    model_update = set_model_parameters(params, model)
    num_trajectories = 1  # TODO: howto handle ensembles

    res = model_update.run(solver=NumPySSASolver, show_labels=False,
                           number_of_trajectories=num_trajectories)
    tot_res = np.asarray([x.T for x in res]) # reshape to (N, S, T)  
    tot_res = tot_res[:,1:, :] # should not contain timepoints
    
    return tot_res

#wrapper, simulator function to abc should only take one argument (the parameter point)
def simulator2(x):
    return simulator(x, model=toggle_model)

# Set up the prior
default_param = np.array(list(toggle_model.listOfParameters.items()))[:,1] #take default from model as reference
bound = []
for exp in default_param:
    bound.append(float(exp.expression))
    
#set the bounds
bound = np.array(bound)
dmin = bound * 0.1
dmax = bound * 2.0

#Here we use uniform prior
uni_prior = uniform_prior.UniformPrior(dmin, dmax)

In [3]:
#generate some fixed(observed) data based on default parameters of model 
fixed_data = toggle_model.run(solver=NumPySSASolver, number_of_trajectories=100, show_labels=False)

In [4]:
#reshape data to (n_points,n_species,n_timepoints)
fixed_data = np.asarray([x.T for x in fixed_data])
#and remove timepoints array
fixed_data = fixed_data[:,1:, :]

In [5]:

#function to generate summary statistics 
summ_func = auto_tsfresh.SummariesTSFRESH()

#distance
ns = naive_squared.NaiveSquaredDistance()

#start abc instance
abc = ABC(fixed_data, sim=simulator2, prior_function=uni_prior, summaries_function=summ_func.compute, distance_function=ns)

In [6]:
#first compute the fixed (observed) mean 
abc.compute_fixed_mean(chunk_size=2)

# run in multiprocessing mode
tic = time.time()
res = abc.infer(num_samples=100, batch_size=10, chunk_size=2)
toc = time.time() - tic
toc

1.4781534671783447

In [7]:
true_params = bound
mae_inference = mean_absolute_error(true_params, abc.results['inferred_parameters'])
mae_inference

0.15405426658826327

In [8]:
abc.results

{'accepted_samples': [array([0.42264517, 1.60126453, 3.82637908, 2.87611469, 0.50141053]),
  array([1.12710902, 0.99433467, 2.86800125, 1.63835487, 1.23474087]),
  array([1.96269661, 0.76164626, 1.33447611, 2.05317688, 1.2373525 ]),
  array([0.4872649 , 0.11811028, 3.8111812 , 3.46402187, 1.14853157]),
  array([0.37531391, 0.11532176, 2.35999917, 1.42360544, 0.63735148]),
  array([1.09498762, 1.38035601, 1.80949211, 2.46592874, 0.81031725]),
  array([1.21346181, 1.92796342, 3.54428444, 1.8380288 , 1.57460927]),
  array([0.66952272, 1.19958311, 1.55311969, 2.27341387, 1.29677108]),
  array([0.95864775, 0.78839037, 0.34814486, 2.60510545, 1.60836774]),
  array([1.11766771, 0.86288675, 1.24923151, 3.07657179, 0.57502609]),
  array([0.2960871 , 0.29467484, 3.20542349, 3.90353867, 1.76779826]),
  array([1.78420243, 1.99022219, 2.55099795, 3.73393436, 1.9975284 ]),
  array([0.22616766, 1.99773313, 2.56009043, 2.93703942, 1.47079032]),
  array([1.36493307, 0.18834989, 3.53277985, 1.16122458, 

In [9]:
#Setup local cluster (dask client)
from dask.distributed import Client
c = Client()
c

0,1
Client  Scheduler: tcp://127.0.0.1:52545  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 4  Cores: 4  Memory: 16.83 GB


In [None]:
# run in local cluster mode
res = abc.infer(num_samples=200, batch_size=20, chunk_size=10)

In [11]:
mae_inference = mean_absolute_error(true_params, abc.results['inferred_parameters'])
mae_inference

0.13784505637461192

In [12]:
#see results in res or abc.results
res

{'accepted_samples': [array([1.32208828, 1.14534649, 0.4986438 , 1.20983794, 1.18827748]),
  array([1.91812249, 1.64061583, 0.46538878, 0.75388476, 1.89918401]),
  array([1.23287685, 0.85513311, 0.60576717, 3.45204693, 1.82118477]),
  array([1.9016429 , 0.48101515, 2.78249477, 3.19669539, 1.02484516]),
  array([0.48146076, 0.57475249, 1.37114835, 0.67247573, 0.77390968]),
  array([0.27135556, 0.52451269, 1.64252337, 0.88952863, 0.48671847]),
  array([0.44404986, 1.17052035, 2.66822734, 1.43435953, 0.70134477]),
  array([1.23287685, 0.85513311, 0.60576717, 3.45204693, 1.82118477]),
  array([1.9016429 , 0.48101515, 2.78249477, 3.19669539, 1.02484516]),
  array([0.27135556, 0.52451269, 1.64252337, 0.88952863, 0.48671847]),
  array([0.44404986, 1.17052035, 2.66822734, 1.43435953, 0.70134477]),
  array([1.32208828, 1.14534649, 0.4986438 , 1.20983794, 1.18827748]),
  array([1.91812249, 1.64061583, 0.46538878, 0.75388476, 1.89918401]),
  array([0.48146076, 0.57475249, 1.37114835, 0.67247573, 