In [24]:
'''
FIND_SWITCHING_TIMES_TOHIGH.PY: Find times taken by the switch to transition between states
'''
# By Kirill Sechkar

# PACKAGE IMPORTS 
import numpy as np
import jax
import jax.numpy as jnp
import functools
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController, SteadyStateEvent
import pandas as pd
from bokeh import plotting as bkplot, models as bkmodels, layouts as bklayouts, io as bkio
from bokeh.colors import RGB as bkRGB
import time

# set up jax
from jax.lib import xla_bridge
jax.config.update('jax_platform_name', 'gpu')
jax.config.update("jax_enable_x64", True)
print(xla_bridge.get_backend().platform)

# set up bokeh
bkio.reset_output()
bkio.output_notebook() 

# OWN CODE IMPORTS
import synthetic_circuits as circuits
from cell_model import *
from get_steady_state import values_for_analytical
from Fig2.design_guidance_tools import F_real_calc,F_req_calc, pint_from_pswitch_and_xi
from switching_time_estimation_tools import *

cpu


In [25]:
# INITIALISE CELL MODEL, LOAD THE CIRCUIT

# initialise cell model
cellmodel_auxil = CellModelAuxiliary()  # auxiliary tools for simulating the model and plotting simulation outcomes
par = cellmodel_auxil.default_params()  # get default parameter values
init_conds = cellmodel_auxil.default_init_conds(par)  # get default initial conditions

# load synthetic gene circuit
ode_with_circuit, circuit_F_calc, circuit_eff_m_het_div_k_het,\
    par, init_conds, circuit_genes, circuit_miscs, circuit_name2pos, circuit_styles, circuit_v = cellmodel_auxil.add_circuit(
    circuits.punisher_b_initialise,
    circuits.punisher_b_ode,
    circuits.punisher_b_F_calc,
    circuits.punisher_b_eff_m_het_div_k_het,
    par, init_conds,
    circuit_v=circuits.punisher_b_v)  # load the circuit

In [26]:
# PARAMETERISE THE CIRCUIT

# BURDENSOME SYNTHETIC GENE
par['c_b'] = 1
par['a_b'] = 1e5

# PUNISHER
# switch gene conc
par['c_switch'] = 10.0  # gene concentration (nM)
par['a_switch'] = 400.0  # promoter strength (unitless)
par['d_switch']=0.01836
# integrase - expressed from the switch gene's operon, not its own gene => c_int, a_int irrelevant
par['k+_int'] = par['k+_switch']/80.0  # RBS weaker than for the switch gene
par['d_int'] = 0.0#0.01836 # rate of integrase degradation per protease molecule (1/nM/h)
# CAT (antibiotic resistance) gene
init_conds['cat_pb'] = 10.0  # gene concentration (nM) - INITIAL CONDITION< NOT PARAMETER as it can be cut out by the integrase
par['a_cat'] = 500.0  # promoter strength (unitless)
# synthetic protease gene
par['c_prot'] = 10.0  # gene concentration (nM)
par['a_prot'] = 25.0  # promoter strength (unitless)
init_conds['p_prot'] = 1500.0 # if zero at start, the punisher's triggered prematurely

# punisher's transcription regulation function
par['K_switch'] = 300.0  # Half-saturation constant for the self-activating switch gene promoter (nM)
par['eta_switch'] = 2 # Hill coefficient for the self-activating switch gene promoter (unitless)
par['baseline_switch'] = 0.025  # Baseline value of the switch gene's transcription activation function
par['baseline_switch_alt'] = 0
par['p_switch_ac_frac'] = 0.85  # active fraction of protein (i.e. share of molecules bound by the inducer)

# CULTURE MEDIUM
init_conds['s'] = 0.5
par['h_ext'] = 10.5 * (10.0 ** 3)

# LOOKING AT TRANSITIONS WITH CAT AWLAYS PRESENT => INTEGRASE INACTIVE
par['k_sxf']=0.0

In [27]:
# SET DETERMINISTIC SIMULATION PARAMETERS

# diffrax simulator
savetimestep = 0.1  # save time step
rtol = 1e-6  # relative tolerance for the ODE solver
atol = 1e-6  # absolute tolerance for the ODE solver

In [28]:
# SET TAU-LEAP SIMULATION PARAMETERS

tau = 5e-7  # simulation time step
tau_odestep = 5e-7  # number of ODE integration steps in a single tau-leap step (smaller than tau)
tau_savetimestep = 1e-2  # save time step a multiple of tau
key_seeds=np.arange(0,30000,1) # random number generator seeds - NUMBER OF KEYS DEFINES NUMBER OF TRAJECTORIES
key_seeds_no_b=key_seeds.copy()  # same number of trajectories for the case with the xtra gene absent
tau_leap_sim_duration=10 # duration of the tau-leap simulation (h)

In [29]:
# get the cellular variables in steady state without burden
e, F_r, h, xis, Ps = values_for_analytical(par, ode_with_circuit, init_conds,
                                           circuit_genes, circuit_miscs,
                                           circuit_name2pos,
                                           circuit_F_calc,
                                           circuit_eff_m_het_div_k_het)
# record the cellular variables
cellvars = {'e': e, 'F_r': F_r,  # translation elongation rate and ribosome trnscription regulation
            'h': h,  # intacellular chlorampenicol concentration
            # burden values
            'xi_a': xis['a'], 'xi_r': xis['r'], 'xi_other_genes': xis['other'], 'xi_cat': xis['cat'],
            'xi_switch_max': xis['switch (max)'], 'xi_int_max': xis['int (max)'], 'xi_prot': xis['prot'],
            # protein degradation correction factors for the switch protein and the integrase
            'chi_switch': Ps['switch'], 'chi_int': Ps['int']}
xi_total = xis['a'] + xis['r'] + xis['other'] + xis['cat'] + xis['switch (max)'] + xis['int (max)'] + xis['prot'] # total burden
xi_no_b = xi_total - xis['other'] # burden without the burdensome gene

In [30]:
# DEFINE SWITCH STATES FOR ALL GENES PRESENT
# greatest possible p_switch level - for zero extra burden so that it's max across all conditions
p_switch_upper_bound = np.ceil(cellvars['xi_switch_max'] * (1/(1+cellvars['chi_switch'])) / (
        cellvars['xi_switch_max'] + cellvars['xi_int_max'] + cellvars['xi_prot'] + cellvars['xi_cat'] + cellvars['xi_a'] + cellvars['xi_r']
        ) * par['M'] * (1 - par['phi_q']) / par['n_switch'])  # upper bound for p_switch (to get the high equilibrium)

# axis of p_switch values - there can only be an integer number of proteins
p_switches=np.arange(0,p_switch_upper_bound+0.1,1)

# define function for calculating the squared difference between the required and actual F values
sqdiff = lambda p_switch: (F_real_calc(p_switch, par) - F_req_calc(p_switch, xi_total, par, cellvars)) ** 2

# find the squared differences
sqdiffs=np.zeros_like(p_switches)
for i in range(0,len(p_switches)):
    sqdiffs[i]=sqdiff(p_switches[i])

# find the local minima and maxima in squared differences
loc_minima=[]
loc_maxima=[]
for i in range(1,len(p_switches)-1):
    # detect a local minimum
    if(sqdiffs[i]<=sqdiffs[i-1] and sqdiffs[i]<=sqdiffs[i+1]):
        loc_minima.append(p_switches[i])
    # detect a local maximum
    elif(sqdiffs[i]>=sqdiffs[i-1] and sqdiffs[i]>=sqdiffs[i+1]):
        loc_maxima.append(p_switches[i])

# print results
print('Local minima in squared difference:',loc_minima)
print('Local maxima in squared difference:',loc_maxima)

# based on our prior findings, with burden present we expect two stable equilibria and one unstable equilibrium between them
# this corresponds to three local minima and two local maxima in the squared difference between F values
if (len(loc_minima)==3 and len(loc_maxima)==2):
    p_switch_low=loc_minima[0]  # low state standing for low-concentration stable equilibrium
    p_switch_high=loc_minima[2] # high state standing for high-concentration stable equilibrium
    p_switch_boundary=loc_minima[1] # boundary between states standing for unstable equilibrium
    print('Low state representative p_switch:',p_switch_low)
    print('High state representative p_switch:',p_switch_high)
    print('Boundary between states:',p_switch_boundary)
else:
    print('Something weird is going on')

Local minima in squared difference: [35.0, 102.0, 878.0]
Local maxima in squared difference: [67.0, 461.0]
Low state representative p_switch: 35.0
High state representative p_switch: 878.0
Boundary between states: 102.0


In [31]:
# SIMULATE FOR XTRA GENE PRESENT

# GET LOW-EXPRESSION DETERMINISTIC STEADY STATE FOR ALL GENES PRESENT
init_conds['p_switch'] = 0 # to get low equilibrium, start at zero
# initial simulation to get the steady state without gene expression loss
tf = (0, 50)  # simulation time frame
sol=ode_sim(par,    # dictionary with model parameters
            ode_with_circuit,   #  ODE function for the cell with synthetic circuit
            cellmodel_auxil.x0_from_init_conds(init_conds,circuit_genes,circuit_miscs),  # initial condition VECTOR
            len(circuit_genes), len(circuit_miscs), circuit_name2pos, # dictionaries with circuit gene and miscellaneous specie names, species name to vector position decoder
            cellmodel_auxil.synth_gene_params_for_jax(par,circuit_genes), # synthetic gene parameters for calculating k values
            tf, jnp.arange(tf[0], tf[1]+savetimestep/2, savetimestep), # time frame and time axis for saving the system's state
            rtol, atol)    # relative and absolute tolerances
ts_det=np.array(sol.ts)
xs_det=np.array(sol.ys)
# get the steady state from the deterministic simulation
det_steady_x = xs_det[-1, :]

In [32]:
# INTEGRASE CONCENTRATIONS FROM SWITCH PROTEIN CONCENTRATION
p_int_low=pint_from_pswitch_and_xi(p_switch_low,xis['a'] + xis['r'] + xis['other'] + xis['cat'] + xis['prot'],par, cellvars)
p_int_high=pint_from_pswitch_and_xi(p_switch_high,xis['a'] + xis['r'] + xis['other'] + xis['cat'] + xis['prot'],par, cellvars)
print('Integrase concentration at low state:',p_int_low)
print('Integrase concentration at high state:',p_int_high)

Integrase concentration at low state: 3.5086100980161605
Integrase concentration at high state: 87.91830437347966


In [33]:
# SWITCH CONDITION DEFINITION

# condition for RISING to a higher state
def switch_condition_rise(x,next_x,switch_boundary,p_pos):
    return jnp.logical_and(x[p_pos]<switch_boundary,
                           next_x[p_pos]>=switch_boundary)

In [34]:
# MEMORY-EFFICIENT SWITCH TIME DETERMINATION

# we consider state switched when we come within  10% OF CHARACTERISTIC PROTEIN CONCENTRATION
sc_low_to_high=lambda x, next_x: switch_condition_rise(x,next_x,0.9*p_int_high,circuit_name2pos['p_int']) # high to low switching
sc_unnecessary=lambda x, next_x: jnp.zeros_like(x[circuit_name2pos['p_int']],dtype=bool) # second switch condition check unnecessary here

In [35]:
tf_hybrid = (ts_det[-1], ts_det[-1] + tau_leap_sim_duration)  # simulation time frame
mRNA_count_scales, S, x0_tauleap, circuit_synpos2genename, keys0 = tauleap_sim_prep(par, len(circuit_genes),
                                                                                    len(circuit_miscs),
                                                                                    circuit_name2pos, det_steady_x,
                                                                                    key_seeds=key_seeds)

low_to_high_switch_times_since_stochstart, _, \
            final_keys_hybrid = tauleap_sim_switch(par,  # dictionary with model parameters
                                                    circuit_v,  # circuit reaction propensity calculator
                                                    circuit_eff_m_het_div_k_het, # calculate the total effective mRNA conc./k value for all synthetic genes
                                                    x0_tauleap, # initial condition VECTOR (processed to make sure random variables are appropriate integers)
                                                    len(circuit_genes), len(circuit_miscs),
                                                    circuit_name2pos,
                                                    cellmodel_auxil.synth_gene_params_for_jax(par,
                                                                                       circuit_genes), # synthetic gene parameters for calculating k values
                                                    tf_hybrid, tau, tau_odestep, tau_savetimestep, # simulation parameters: time frame, tau leap step size, number of ode integration steps in a single tau leap step
                                                    mRNA_count_scales, S, circuit_synpos2genename, # mRNA count scaling factor, stoichiometry matrix, synthetic gene number in list of synth. genes to name decoder
                                                    keys0, # starting random number genereation key
                                                    sc_low_to_high, sc_unnecessary
                                                    )  

# record switching times in numpy arrays
low_to_high_switch_times=np.array(low_to_high_switch_times_since_stochstart) # we need times of falling to zero since REACHING THE LOW STATE

# print results
print('Low-to-high switch times:',low_to_high_switch_times)

# clean memory
del final_keys_hybrid

Low-to-high switch times: [4.25 3.93 0.   5.22 0.   9.05 0.   0.   9.19 0.  ]


In [36]:
# ESTIMATE SWITCHING TIMES AND CONFIDENCE INTERVALS

# low-to-high switch times
N=len(key_seeds) # number of samples
T=tau_leap_sim_duration # total simulation time
print('Number of samples:',N)
print('')
print('Sampling time interval for low-to-high switch times:',T)
high_to_low_mle=find_mle(low_to_high_switch_times,N,T) # maximum likelihood estimate
high_to_low_leftconfint, high_to_low_rightconfint=np.array(find_confint(high_to_low_mle,N,T,0.05)) # 95% confidence interval
print('Low-to-high switch times MLE:',high_to_low_mle)
print('Low-to-high switch times 95% confidence interval:',high_to_low_leftconfint, high_to_low_rightconfint)

Number of samples: 10

Sampling time interval for low-to-high switch times: 10
Low-to-high switch times MLE: 16.328
Low-to-high switch times 95% confidence interval: 7.645510221141436 42.532143005737794


In [37]:
# PLOT SWITCH TRAJECTORIES
hist_bins=100

# HISTOGRAM OF FALL TIMES: HIGH TO LOW
falltime_htl_hist_withburden=bkplot.figure(title='Distribution of low-to-high switch times',
                                   x_axis_label='Fall time [h]',
                                   y_axis_label='Frequency',
                                   width=400,height=300)
# get numpy histogram
ft_htl_hist_withburden, ftl_htl_edges_withburden=np.histogram(low_to_high_switch_times,bins=hist_bins,density=False)
# plot histogram
falltime_htl_hist_withburden.quad(top=ft_htl_hist_withburden, bottom=0,
                              left=ftl_htl_edges_withburden[:-1], right=ftl_htl_edges_withburden[1:],
                            color='magenta')
# mark mean fall time
falltime_htl_hist_withburden.add_layout(bkmodels.Span(location=high_to_low_mle, dimension='height',
                                              line_dash='dashed', line_color='navy', line_width=2, line_alpha=0.5))
# mark 95% confidence interval for fall times
falltime_htl_hist_withburden.add_layout(bkmodels.BoxAnnotation(left=high_to_low_leftconfint, right=high_to_low_rightconfint,
                                                       fill_color='navy', fill_alpha=0.1)) # range

# show plot
bkplot.show(falltime_htl_hist_withburden)

In [38]:
# THE CASE WITH XTRA GENE MUTATED

par_no_b=par.copy() # parameters for the case with the xtra gene absent
par_no_b['func_b']=0.0 # xtra gene concentration

cellvars_no_b=cellvars.copy() # cellular variables for the case with the xtra gene absent
cellvars_no_b['xi_b']=0 # no  burdensome gene

In [39]:
# GET THE HIGH STATE WITH THE XTRA GENE ABSENT (LOW STATE REUSED)

# axis of p_switch values - there can only be an integer number of proteins
p_switches=np.arange(0,p_switch_upper_bound+0.1,1)

# define function for calculating the squared difference between the required and actual F values
sqdiff_no_b = lambda p_switch: (F_real_calc(p_switch, par_no_b) - F_req_calc(p_switch, xi_no_b, par_no_b, cellvars_no_b)) ** 2

# find the squared differences
sqdiffs=np.zeros_like(p_switches)
for i in range(0,len(p_switches)):
    sqdiffs[i]=sqdiff_no_b(p_switches[i])

# find the local minima and maxima in squared differences
loc_minima=[]
loc_maxima=[]
for i in range(1,len(p_switches)-1):
    # detect a local minimum
    if(sqdiffs[i]<=sqdiffs[i-1] and sqdiffs[i]<=sqdiffs[i+1]):
        loc_minima.append(p_switches[i])
    # detect a local maximum
    elif(sqdiffs[i]>=sqdiffs[i-1] and sqdiffs[i]>=sqdiffs[i+1]):
        loc_maxima.append(p_switches[i])

# print results
print('Local minima in squared difference:',loc_minima)
print('Local maxima in squared difference:',loc_maxima)

# based on our prior findings, with burden present we expect one stable equilibrium, one pseudo-equilibrium and one local maximum in squared differences
# this corresponds to three local minima and two local maxima in the squared difference between F values
if (len(loc_minima)==2 and len(loc_maxima)==1):
    p_switch_high_no_b=loc_minima[1] # REDEFINE high state standing for high-concentration stable equilibrium
    print('High state representative p_switch:',p_switch_high_no_b)
else:
    print('Something weird is going on')

Local minima in squared difference: [58.0, 1044.0]
Local maxima in squared difference: [501.0]
High state representative p_switch: 1044.0


In [40]:
# INTEGRASE CONCENTRATIONS FROM SWITCH PROTEIN CONCENTRATIONS

p_int_low_no_b=p_int_low # low state integrase concentration is the same as the low state is reused
p_int_high_no_b=pint_from_pswitch_and_xi(p_switch_high_no_b,xis['a'] + xis['r'] + xis['cat'] + xis['prot'],par_no_b, cellvars_no_b)
print('Integrase concentration at low state:',p_int_low_no_b)
print('Integrase concentration at high state:',p_int_high_no_b)

Integrase concentration at low state: 4.010225208880506
Integrase concentration at high state: 104.63004474352397


In [41]:
# GET DETERMINISTIC STEADY STATE

# taken from the case with burden
ts_det_no_b=ts_det.copy()
xs_det_no_b=xs_det.copy()
det_steady_x_no_b=det_steady_x.copy()

In [42]:
# MEMORY-EFFICIENT SWITCH TIME DETERMINATION

# we consider state switched when we come within  10% OF CHARACTERISTIC PROTEIN CONCENTRATION
# IMPORTANT: switch states assumed to be the same between the cases with and without the xtra gene
sc_low_to_high_no_b=lambda x, next_x: switch_condition_rise(x,next_x,0.9*p_int_high_no_b,circuit_name2pos['p_int']) # low to (new) high switching

tf_hybrid_no_b = (ts_det_no_b[-1], ts_det_no_b[-1] + tau_leap_sim_duration)  # simulation time frame
mRNA_count_scales_no_b, S_no_b, x0_tauleap_no_b, circuit_synpos2genename_no_b, keys0_no_b = tauleap_sim_prep(par_no_b, len(circuit_genes),
                                                                                    len(circuit_miscs),
                                                                                    circuit_name2pos, det_steady_x_no_b,
                                                                                    key_seeds=key_seeds_no_b)   

low_to_high_switch_times_no_b_since_stochastart_no_b, _, \
            final_keys_hybrid_no_b = tauleap_sim_switch(par_no_b,  # dictionary with model parameters
                                                    circuit_v,  # circuit reaction propensity calculator
                                                    circuit_eff_m_het_div_k_het, # calculate the total effective mRNA conc./k value for all synthetic genes
                                                    x0_tauleap_no_b, # initial condition VECTOR (processed to make sure random variables are appropriate integers)
                                                    len(circuit_genes), len(circuit_miscs),
                                                    circuit_name2pos,
                                                    cellmodel_auxil.synth_gene_params_for_jax(par_no_b,
                                                                                       circuit_genes), # synthetic gene parameters for calculating k values
                                                    tf_hybrid_no_b, tau, tau_odestep, tau_savetimestep, # simulation parameters: time frame, tau leap step size, number of ode integration steps in a single tau leap step
                                                    mRNA_count_scales_no_b, S_no_b, circuit_synpos2genename_no_b, # mRNA count scaling factor, stoichiometry matrix, synthetic gene number in list of synth. genes to name decoder
                                                    keys0_no_b, # starting random number genereation key
                                                    sc_low_to_high_no_b, sc_unnecessary
                                                    )

# record switching times in numpy arrays
low_to_high_switch_times_no_b=np.array(low_to_high_switch_times_no_b_since_stochastart_no_b) # we need times of falling to zero since REACHING THE LOW STATE

print('Low-to-high switch times:',low_to_high_switch_times_no_b)

del final_keys_hybrid_no_b

Low-to-high switch times: [3.52 3.86 0.   2.93 5.93 3.75 3.23 0.   4.33 2.84]


In [43]:
# ESTIMATE SWITCHING TIMES AND CONFIDENCE INTERVALS

# low-to-high switch times
N=len(key_seeds) # number of samples
T=tau_leap_sim_duration # total simulation time
print('Number of samples:',N)
print('')
print('Sampling time interval for low-to-high switch times:',T)
high_to_low_mle_no_b=find_mle(low_to_high_switch_times_no_b,N,T) # maximum likelihood estimate
high_to_low_leftconfint_no_b, high_to_low_rightconfint_no_b=np.array(find_confint(high_to_low_mle_no_b,N,T,0.05)) # 95% confidence interval
print('Low-to-high switch times MLE:',high_to_low_mle_no_b)
print('Low-to-high switch times 95% confidence interval:',high_to_low_leftconfint_no_b, high_to_low_rightconfint_no_b)

Number of samples: 10

Sampling time interval for low-to-high switch times: 10
Low-to-high switch times MLE: 6.29875
Low-to-high switch times 95% confidence interval: 3.6345914578149463 13.889287482647369


In [44]:
# PLOT SWITCH TIMES
hist_bins=100

# HISTOGRAM OF FALL TIMES: HIGH TO LOW
falltime_htl_hist_noburden=bkplot.figure(title='Distribution of low-to-high switch times',
                                   x_axis_label='Fall time [h]',
                                   y_axis_label='Frequency',
                                   width=400,height=300)
# get numpy histogram
ft_htl_hist_noburden, ftl_htl_edges_noburden=np.histogram(low_to_high_switch_times_no_b,bins=hist_bins,density=False)
# plot histogram
falltime_htl_hist_noburden.quad(top=ft_htl_hist_noburden, bottom=0,
                              left=ftl_htl_edges_noburden[:-1], right=ftl_htl_edges_noburden[1:],
                            color='magenta')
# mark mean fall time
falltime_htl_hist_noburden.add_layout(bkmodels.Span(location=high_to_low_mle_no_b, dimension='height',
                                              line_dash='dashed', line_color='navy', line_width=2, line_alpha=0.5))
# mark 95% confidence interval for fall times
falltime_htl_hist_noburden.add_layout(bkmodels.BoxAnnotation(left=high_to_low_leftconfint_no_b, right=high_to_low_rightconfint_no_b,
                                                       fill_color='navy', fill_alpha=0.1)) # range

# show plot
bkplot.show(falltime_htl_hist_noburden)

In [45]:
# SAVE SWITCH TIMES

# save the switching times
saved_switchtimes={'low_to_high_switch_times_withburden':low_to_high_switch_times,
                   'low_to_high_switch_times_no_b':low_to_high_switch_times_no_b}
np.save('low_to_high_switch_times_withburden.npy',saved_switchtimes)

In [46]:
# CLEAR CACHES TO PREVENT CPU/GPU ERRORS
# jax.clear_caches()