In [1]:
# some initial imports

import jax
import jax.numpy as jnp
import jax.lax

# Using the JAX-powered resource-aware cell modelling package _rc_e_coli_jax_ for gene circuit simulations

This Jupyter notebook provides a step-by-step guide for simulating how a gene circuit behaves in the context of the host cell and other competing synthetic genes. Underlying these simulations is a resource-aware coarse-grained cell model published in [Sechkar et al., 2024](https://doi.org/10.1038/s41467-024-46410-9), which here has been implemented in JAX to enable efficient parallelised simulations on the GPU.

As an example case, we consider the case of a gene _ta_ that encodes a transcription activation factor which, upon chemical modulation, promotes the expression of the output gene _x_. In addition to these and the host cell's own genes, we consider another disturbing synthetic gene _dist_ that competes with _ta_ and _x_ for the host's gene expression resources. Each gene has two modelled variables associated with it: its mRNA concentration in the cell _m_ and its protein level _p_. The resultant circuit is depicted in the figure below (adapted from [Sechkar et al., 2024](https://doi.org/10.1038/s41467-024-46410-9)).

<div>
<img src="example.png" width="750"/>
</div>

## Circuit model definition

We start by defining the functions that allow to model the circuit of interest -- namely, the initialiser _initialise()_, the gene regulation function _F_calc()_, the deterministic ordinary differential equation (ODE) model function _ode()_, and - if needed - a reaction propensity function _v()_ for stochastic simulations. 

In the file _one_constit.py_ they are specified for simulating a single constituive gene present in the cell. For any new gene circuit, one can copy and paste this file's contents into a new file - then modify the circuit-specific code fragments, which have all been highighted as follows:

````{python}
    # -------- SPECIFY [X] FROM HERE...
    # -------- ...TO HERE
````

### Circuit initialisation
In our case this means that in the initialiser function, the genes specified are _ta_, _x_ and _dist_. For our case study, we assume that their DNA concentrations $c$ and promoter strengths $\alpha$ are not the universal default values 1 nM and 100, so we re-specify them below. We additionally set the parameters of the Hill functions that describe the chemical inducer's binding to the transcription activation factors and the subsequent binding between the output gene's DNA and the inducer-bound transcription activation factor.

In [2]:
def initialise_taxdist():
    # -------- SPECIFY CIRCUIT COMPONENTS FROM HERE...
    genes = ['ta','x','dist']  # names of genes in the circuit
    miscs = []  # names of miscellaneous species involved in the circuit (e.g. metabolites)
    # -------- ...TO HERE

    # for convenience, one can refer to the species' concs. by names instead of positions in x
    # e.g. x[name2pos['m_xtra']] will return the concentration of mRNA of the gene 'xtra'
    name2pos = {}
    for i in range(0, len(genes)):
        name2pos['m_' + genes[i]] = 8 + i  # mRNA
        name2pos['p_' + genes[i]] = 8 + len(genes) + i  # protein
    for i in range(0, len(miscs)):
        name2pos[miscs[i]] = 8 + len(genes) * 2 + i  # miscellaneous species
    for i in range(0, len(genes)):
        name2pos['k_' + genes[i]] =  i  # effective mRNA-ribosome dissociation constants (in k_het, not x!!!)
    for i in range(0, len(genes)):
        name2pos['F_' + genes[i]] =  i  # transcription regulation functions (in F, not x!!!)
    for i in range(0, len(genes)):
        name2pos['mscale_' + genes[i]] =  i  # mRNA count scaling factors (in mRNA_count_scales, not x!!!)

    # default gene parameters to be imported into the main model's parameter dictionary
    default_par = {}
    for gene in genes: # gene parameters
        default_par['func_' + gene] = 1.0  # gene functionality - 1 if working, 0 if mutated
        default_par['c_' + gene] = 1.0  # copy no. (nM)
        default_par['a_' + gene] = 100.0  # promoter strength (unitless)
        default_par['b_' + gene] = 6.0  # mRNA decay rate (/h)
        default_par['k+_' + gene] = 60.0  # ribosome binding rate (/h/nM)
        default_par['k-_' + gene] = 60.0  # ribosome unbinding rate (/h)
        default_par['n_' + gene] = 300.0  # protein length (aa)
        default_par['d_' + gene] = 0.0  # rate of active protein degradation - zero by default (/h)

    # default initial conditions
    default_init_conds = {}
    for gene in genes:
        default_init_conds['m_' + gene] = 0
        default_init_conds['p_' + gene] = 0
    for misc in miscs:
        default_init_conds[misc] = 0

    # -------- DEFAULT VALUES OF CIRCUIT-SPECIFIC PARAMETERS CAN BE SPECIFIED FROM HERE...
    
    # gene concentrations (nM) and promoter strengths (unitless)
    default_par['c_ta']=100
    default_par['c_x']=100
    default_par['a_ta']=50
    default_par['a_x']=50
    
    # binging between the chemical inducer and the transcription activation factor
    default_par['K_ta-f']=1000  # Half-saturation constant (nM)
    
    # binding between the output gene's DNA and the inducer-bound transcription activation factor
    default_par['K_dna(x)-taf']=700  # Half-saturation constant (nM)
    default_par['eta_dna(x)-taf']=2 # Hill coefficient/cooperativity of binding
    default_par['baseline']=0.1 # baseline output gene promoter activity in abscence of binding
    
    # time of the inducer's addition to the medium (h)
    default_par['t_add']=15
    default_par['f_added']=1000  # added inducer concentration (nM)
    # -------- ...TO HERE

    # default palette and dashes for plotting (5 genes + misc. species max)
    default_palette = ["#0072BD", "#D95319", "#4DBEEE", "#A2142F", "#FF00FF"]
    default_dash = ['solid']
    # match default palette to genes and miscellaneous species, looping over the five colours we defined
    circuit_styles = {'colours': {}, 'dashes': {}}  # initialise dictionary
    for i in range(0, len(genes)):
        circuit_styles['colours'][genes[i]] = default_palette[i % len(default_palette)]
        circuit_styles['dashes'][genes[i]] = default_dash[i % len(default_dash)]
    for i in range(len(genes), len(genes) + len(miscs)):
        circuit_styles['colours'][miscs[i - len(genes)]] = default_palette[i % len(default_palette)]
        circuit_styles['dashes'][miscs[i - len(genes)]] = default_dash[i % len(default_dash)]

    # --------  YOU CAN RE-SPECIFY COLOURS FOR PLOTTING FROM HERE...
    # -------- ...TO HERE

    return default_par, default_init_conds, genes, miscs, name2pos, circuit_styles

### Gene transcription regulation

We now proceed to define the gene transcription regulation function _F_calc_taxdist()_. Where finding the function's value requires a certain state variable, the variable with a given name can be retrieved from the state vector using the name2pos mapper, e.g. _x[name2pos['m_ta']]_ will return the value of the mRNA concentration of the gene _ta_.

As mentioned above, the genes _ta_ and _dist_ are constitutive, so the value of the gene transcription regulation function _F_ will be 1 at all times. For the output gene _x_, regulation happens as follows. First, the inducer with the concentration $f$ binds the transcription activation protein $p_{ta}$, hence the concentration of the active inducer-bound factor being given by a Hill function:
$$p_{ta}^{act}=p_{ta} \frac{f}{f+K_{ta-f}}$$
Second, the active transcription activation factor binds the output gene's DNA, hence the output gene's promoter activity being given by another Hill function:
$$F_{x}=baseline+\frac{(p_{ta}^{act})^{\eta_{dna(x)-taf}}}{(p_{ta}^{act})^{\eta_{dna(x)-taf}} + (K_{dna(x)-taf})^{\eta_{dna(x)-taf}}$$

The inducer may not be originally present in the culture medium, but rather than added to it at some time $t_{add}$ as a pulse input (see below). This can be recreated by using the _jax.lax.select()_ function in _F_calc()_.
$$f(t)=\begin{cases}0 & \text{ if } t<t_{add} \\ f_{added} & \text{ if } t\geq t_{add}\end{cases}$$


In [3]:
def F_calc_taxdist(t ,x, par, name2pos):
    # --------  SPECIFY THE TRANSCRIPTIOPN REGULATION FUNCTION FROM HERE...
    F_ta = 1 # ta gene is constitutive
    F_dist = 1 # dist gene is constitutive
    
    # get the time-dependent inducer concentration
    f = jax.lax.select(t<par['t_add'], 0, par['f_added'])
    
    # get the concentration of the transcription activation factor using the 'variable name-position in state vector' mapper
    p_ta = x[name2pos['p_ta']] 
    
    # binding between the chemical inducer and the transcription activation factor
    p_ta_act = p_ta * f/(f+par['K_ta-f'])
    
    # binding between the output gene's DNA and the inducer-bound transcription activation factor
    F_x = par['baseline']+(1-par['baseline']) * (p_ta_act**par['eta_dna(x)-taf'])/(p_ta_act**par['eta_dna(x)-taf']+par['K_dna(x)-taf']**par['eta_dna(x)-taf'])
    
    # returning the regulation function values in the same order as we specified the genes in the 'initialise()' function
    return jnp.array([F_ta, F_x, F_dist])
    # -------- ...TO HERE

### Deterministic ODE model definition

We now define the deterministic ODE model function _ode()_, based on Equations (S128) to (S133) of the Supplementary Information to [Sechkar et al., 2024](https://doi.org/10.1038/s41467-024-46410-9):

\begin{align}
    \dot{m_{ta}} &= F_{ta} c_{ta} \alpha_{ta} \lambda(\epsilon,B) - (\beta_{ta} + \lambda(\epsilon,B))m_{ta}
    \\
    \dot{m_{x}} &= F_{x}(f,p_{ta}) \cdot c_{x} \alpha_{x} \lambda(\epsilon,B) - (\beta_{x} + \lambda(\epsilon,B))m_{x}
    \\
    \dot{m}_{dist} &= F_{dist} c_{dist} \alpha_{dist} \lambda(\epsilon,B) - (\beta_{dist} + \lambda(\epsilon,B))m_{dist}
    \\  
    \dot{p_{ta}} &= \frac{\epsilon(t^c)}{n_{ta}} \cdot 
    \frac{m_{ta} / k_{ta}}{D} R - (\delta_{ta} + \lambda(\epsilon,B)) \cdot p_{ta}
    \\
    \dot{p_{x}} &= \frac{\epsilon(t^c)}{n_{x}} \cdot 
    \frac{m_{x} / k_{x}}{D} R - (\delta_{x} + \lambda(\epsilon,B)) \cdot p_{x}
    \\
    \dot{p}_{dist} &= \frac{\epsilon(t^c)}{n_{dist}} \cdot 
    \frac{m_{dist} / k_{dist}}{D} R - (\delta_{dist} + \lambda(\epsilon,B)) \cdot p_{dist}
\end{align}

Here, $\lambda$ is the cell's growth rate, $R$ is the concentration of ribosomes in the cell, $\epsilon$ is the translation elongation rate, and $D$ is the 'resource competition denomionator' capturing the extent of competition for gene expression resources (i.e. ribosomes) in the cell. For gene $i$, $F_i$ its transcription regulation function (implemented in _F_calc_taxdist()_ as discussed above), $c_i$ is its DNA concentration, $\alpha_i$ is its promoter strength, $\beta_i$ is its mRNA degradation rate, $\delta_i$ is the protein degradation rate, $n_i$ is the number of amino acids in the encoded protein, and $k_i$ is the effective mRNA-ribosome dissociation constant. A more detailed discussion of the model equations can be found in [Sechkar et al., 2024](https://doi.org/10.1038/s41467-024-46410-9).

In our JAX implementation _ode_taxdist()_, each of the ODEs above is simply typed into the corresponding entry in the array dxdt, where the order of the genes is the same as specified in _initialise_taxdist()_ and all mRNA level ODEs come before those for the protein concentrations. Once again, the variable with a given name can be retrieved from the state vector using the name2pos mapper.

In [4]:
def ode_taxdist(F_calc,     # calculating the transcription regulation functions
            t,  x,  # time, cell state, external inputs
            e, l, # translation elongation rate, growth rate
            R, # ribosome count in the cell, resource
            k_het, D, # effective mRNA-ribosome dissociation constants for synthetic genes, resource competition denominator
            par,  # system parameters
            name2pos  # name to position decoder
            ):
    # GET REGULATORY FUNCTION VALUES
    F = F_calc(t, x, par, name2pos)

    # --------  SPECIFY THE ODEs FROM HERE...
    return [# mRNAs
            F[name2pos['F_ta']] * par['c_ta'] * par['a_ta'] * l - (par['b_ta'] + l) * x[name2pos['m_ta']],  # m_ta
            F[name2pos['F_x']] * par['c_x'] * par['a_x'] * l - (par['b_x'] + l) * x[name2pos['m_x']],  # m_x
            F[name2pos['F_dist']] * par['c_dist'] * par['a_dist'] * l - (par['b_dist'] + l) * x[name2pos['m_dist']],  # m_dist
            # proteins
            (e / par['n_ta']) * (x[name2pos['m_ta']] / k_het[name2pos['k_ta']] / D) * R - (l + par['d_ta']) * x[name2pos['p_ta']],  # p_ta
            (e / par['n_x']) * (x[name2pos['m_x']] / k_het[name2pos['k_x']] / D) * R - (l + par['d_x']) * x[name2pos['p_x']],  # p_x
            (e / par['n_dist']) * (x[name2pos['m_dist']] / k_het[name2pos['k_dist']] / D) * R - (l + par['d_dist']) * x[name2pos['p_dist']]  # p_dist
            ]
    # -------- ...TO HERE

### Stochastic model definition

To account for the stochasticity of gene expression, our package also allows to perform hybrid simulations of gene circuit performance, where the host cell variables are still treated deterministically (since they are coarse-grained variables representing the mean dynamics of many different variables, whose fluctuations are averaged out)and the synthetic gene variables are treated stochastically. 

Stochastic simulation of synthetic gene expression requires to define the reaction propensity function _v_taxdist()_. This can be done simply by putting each ODE term, in order of appearance in the ODE array above, into a separate entry in the propensity vector, with a few caveats. 

First, due to the way that the translation of a single mRNA by several ribosomes is modelled, all terms involving mRNAs must be scaled by a factor found in the _mRNA_count_scales_ vector, which can be accessed using the same decoder _name2pos_. 

Second, the model allows to consider mRNA removal due to antibiotic action, which we do not consider in this case as we assume that the culture medium contains no antibiotic. However, to maintain the correct order of the propensities, we include an additional term, set to zero, for all mRNAs. Should it be needed, the details of how to model antibiotic action can be found in the Supplementary Information to [Sechkar et al., 2024](https://doi.org/10.1038/s41467-024-46410-9).

In [5]:
def v_taxdist(F_calc,     # calculating the transcription regulation functions
            t,  x,  # time, cell state, external inputs
            e, l, # translation elongation rate, growth rate
            R, # ribosome count in the cell, resource
            k_het, D, # effective mRNA-ribosome dissociation constants for synthetic genes, resource competition denominator
            mRNA_count_scales, # scaling factors for mRNA counts
            par,  # system parameters
            name2pos
            ):
    # GET REGULATORY FUNCTION VALUES
    F = F_calc(t, x, par, name2pos)

    # --------  SPECIFY THE PROPENSITIES FROM HERE...
    return [
            # synthesis, degradation, dilution of ta gene mRNA - note the scaling factor added
            F[name2pos['F_ta']] * par['c_ta'] * par['a_ta'] * l / mRNA_count_scales[name2pos['mscale_ta']],
            par['b_ta'] * x[name2pos['m_ta']] / mRNA_count_scales[name2pos['mscale_ta']],
            l * x[name2pos['m_ta']] / mRNA_count_scales[name2pos['mscale_ta']],
            # mRNA removal due to chloramphenicol action - set to zero
            0,
            # synthesis, degradation, dilution of x gene mRNA - note the scaling factor added
            F[name2pos['F_x']] * par['c_x'] * par['a_x']* l / mRNA_count_scales[name2pos['mscale_x']],
            par['b_x'] * x[name2pos['m_x']] / mRNA_count_scales[name2pos['mscale_x']],
            l * x[name2pos['m_x']] / mRNA_count_scales[name2pos['mscale_x']],
            # mRNA removal due to chloramphenicol action - set to zero
            0,
            # synthesis, degradation, dilution of dist gene mRNA - note the scaling factor added
            F[name2pos['F_dist']] * par['c_dist'] * par['a_dist'] * l / mRNA_count_scales[name2pos['mscale_dist']],
            par['b_dist'] * x[name2pos['m_dist']] / mRNA_count_scales[name2pos['mscale_dist']],
            l * x[name2pos['m_dist']] / mRNA_count_scales[name2pos['mscale_dist']],
            # mRNA removal due to chloramphenicol action - set to zero
            0,
            # synthesis, degradation, dilution of ta gene protein
            (e / par['n_ta']) * (x[name2pos['m_ta']] / k_het[name2pos['k_ta']] / D) * R,
            par['d_ta'] * x[name2pos['p_ta']],
            l * x[name2pos['p_ta']],
            # synthesis, degradation, dilution of x gene protein
            (e / par['n_x']) * (x[name2pos['m_x']] / k_het[name2pos['k_x']] / D) * R,
            par['d_x'] * x[name2pos['p_x']],
            l * x[name2pos['p_x']],
            # synthesis, degradation, dilution of dist gene protein
            (e / par['n_dist']) * (x[name2pos['m_dist']] / k_het[name2pos['k_dist']] / D) * R,
            par['d_dist'] * x[name2pos['p_dist']],
            l * x[name2pos['p_dist']]
    ]
    # -------- ...TO HERE

## Gene circuit simulation

Now that the circuit has been defined, we can proceed to simulate its behaviour.

### Setting up the simulation

We start by initialising the cell model simulator and loading the gene circuit model functions defined above. 

Note that instead of defining these circuit functions in the same script, one can define them in a separate file and then import them, which we do in _sim_script.ipynb_ and the antithetic integral feedback controller. This can be useful when one wants to run several scripts simulating the same synthetic gene circuit in different scenarios.

In [6]:
jax.config.update('jax_platform_name', 'gpu')   # make JAX use the GPU
jax.config.update("jax_enable_x64", True)   # enable 64-bit precision for JAX arrays
print(jax.lib.xla_bridge.get_backend().platform)    # check that the GPU is being used

from jax_cell_simulator import *    # import the cell model simulator an auxiliary functions
from het_modules.one_constit import *          # import the synthetic gene circuit model functions

# initialise cell model
cellmodel_auxil = CellModelAuxiliary()  # auxiliary tools for simulating the model and plotting simulation outcomes
par = cellmodel_auxil.default_params()  # get default parameter values
init_conds = cellmodel_auxil.default_init_conds(par)  # get default initial conditions

# load the synthetic gene circuit define above
(
    # OUTPUTS
    ode_with_circuit,  # deterministic ODE model function with the circuit loaded
    circuit_F_calc,    # transcription regulation function for the circuit
    par, init_conds,   # updated parameters and initial conditions with the circuit loaded
    circuit_genes,     # list of the circuit's gene names
    circuit_miscs,     # list of the circuit's miscellaneous species names
    circuit_name2pos,  # name to position decoder for the circuit
    circuit_styles,    # plotting styles for the circuit
    circuit_v          # stochastic reaction propensity calculation function for hybrid simulations
 ) = cellmodel_auxil.add_circuit(
    # oneconstit_init, oneconstit_ode, oneconstit_F_calc, par, init_conds, oneconstit_v)
    # INPUTS
    initialise_taxdist, ode_taxdist, F_calc_taxdist,    # circuit model functions as defined above
    par, init_conds,                                    # host cell model parameters and initial conditions   
    v_taxdist                                           # propensity calculation function for hybrid simulations as defined above - input None if only doing deterministic simulations
)

For a particular scenario, circuit and cell model parameters can be re-specified by updating the _par_ dictionary, which we do here as an example.

In [7]:
# regulated gene
par['c_ta']=100             # transcription activation factor gene concentration (nM)
par['c_x']=100              # output gene concentration (nM)
par['a_ta']=50              # transcription activation factor gene promoter strength (unitless)
par['a_x']=50               # output gene promoter strength (unitless)
par['K_ta-f']=1000          # transcription activation factor-inducer half-saturation constant (nM)
par['K_dna(x)-taf']=700     # transcription activation factor-DNA binding half-saturation constant
par['eta_dna(x)-taf']=2     # transcription activation factor-DNA binding cooperativity
par['baseline']=0.1         # baseline output gene promoter activity without transcription activation factor bound

# inducer addition
par['t_add']=35             # time of inducer addition to the medium (h)

### Deterministic simulation

We now proceed to deterministically simulate the circuit's behaviour in the host cell using the jax-powered ODE simulation package [diffrax](https://github.com/patrick-kidger/diffrax), described in [Kidger 2022](https://doi.org/10.48550/arXiv.2202.02435). 

Note that from this point onwards, none of the variables here are circuit-specific (e.g. the complete ODE for the system will *always* be _ode_with_circuit_ once you load your particular cicruit as above). Hence, the code from now on can be reused in future simulation script with minimal changes.

In [8]:
# specify the simulation characteristics
tf = (0, 50)  # simulation time frame
rtol = 1e-6  # relative tolerance for the ODE solver
atol = 1e-6  # absolute tolerance for the ODE solver
savetimestep = 0.01  # time step at which the simulation will be saved
t_save = np.arange(tf[0], tf[1] + savetimestep / 2, savetimestep)  # time points at which the simulation will be recorded

# run the simulation, get a diffrax solution object
sol = ode_sim(
    # INPUTS
    par,                # model parameters
    ode_with_circuit,   # ODE function for the cell with synthetic circuit
    cellmodel_auxil.x0_from_init_conds(init_conds, circuit_genes, circuit_miscs),   # initial condition VECTOR
    len(circuit_genes), len(circuit_miscs),     # number of synthetic genes and miscellaneous species in the circuit
    circuit_name2pos,                           # species name to vector position decoder for the synthetic gene circuit
    cellmodel_auxil.synth_gene_params_for_jax(par, circuit_genes),                  # jax.numpy array of certain synthetic gene parameters for efficient calculations
    tf, t_save,     # simulation time frame and time points at which the simulation will be recorded
    rtol=rtol,      # relative tolerance for the ODE solver
    atol=atol       # absolute tolerance for the ODE solver
)

# convert the diffrax solution's jax.numpy arrays into regular numpy arrays
xs = np.array(sol.ys)
ts = np.array(sol.ts)   # recorded simulation time points are a 2D array, where the first dimension is the time point and the second is the position in the state vector

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


The auxiliary function collection _CellModelAuxiliary_ provides a number of tools for visualising the simulation results.

Having imported and set up the _Bokeh_ package for plotting, we start by visualising the simulation for the cell at large.

In [9]:
# bokeh imports and setup
from bokeh import plotting as bkplot, models as bkmodels, layouts as bklayouts, io as bkio
from bokeh.colors import RGB as bkRGB
bkio.reset_output()
bkplot.output_notebook()

# figure for the cell protein mass breakdown - blue patch stands for all heterologous (i.e. synthetic) proteins
mass_fig=cellmodel_auxil.plot_protein_masses(ts,xs,par,circuit_genes) 

# figures for the native mRNA, protein and tRNA level figures, as well as the intracellular antibiotic level (if present in the medium)
(
    nat_mrna_fig, nat_prot_fig, nat_trna_fig,
    h_fig   # intracellular antibiotic (chloramphenicol) level figure - will be always zero here as none present in the medium
) = cellmodel_auxil.plot_native_concentrations(
    ts, xs,             # recorded simulation time points and states
    par, circuit_genes  # model parameters and synthetic gene names
)  

# figures for:
(
    l_figure,      # cell growth
    e_figure,      # translation elongation rate
    Fr_figure,     # ribosomal gene transcription regulation function (between 0 and 1)
    ppGpp_figure,  # concentration of the alarmone ppGpp, which determines resource allocation in the cell by regulating ribosome synthesis
    nu_figure,     # tRNA aminoacylation rate
    D_figure       # resource competition denominator - captures the extent of competition for ribosomes in the cell
 ) = cellmodel_auxil.plot_phys_variables(
    ts, xs,                         # recorded simulation time points and states
    par,                            # model parameters
    circuit_genes, circuit_miscs,   # synthetic gene and miscellaneous species names
    circuit_name2pos               # name to position decoder for the synthetic gene circuit
)

# show plots
bkplot.show(bklayouts.grid([[mass_fig, nat_mrna_fig, nat_prot_fig],
                            [nat_trna_fig, h_fig, l_figure],
                            [e_figure, Fr_figure, D_figure]]))

Now, we visualise the synthetic gene circuit's behaviour in the host cell. Note the step increase in the output gene's transcription regulation function at _par['t_add']_=35 h, when the inducer is added to the medium.

In [10]:
# figures for:
(
    het_mrna_fig, # synthetic gene mRNA levels
    het_prot_fig, # synthetic gene protein levels
    misc_fig      # synthetic circuit's miscellaneous species levels (none present in this particular circuit, hence an empty plot returned)
) = cellmodel_auxil.plot_circuit_concentrations(
    ts, xs,                         # recorded simulation time points and states    
    par,                            # model parameters
    circuit_genes, circuit_miscs,   # synthetic gene and miscellaneous species names
    circuit_name2pos,               # name to position decoder for the synthetic gene circuit
    circuit_styles                  # plotting styles for the synthetic gene circuit
)

# synthetic gene regulation
F_fig = cellmodel_auxil.plot_circuit_regulation(ts, xs, circuit_F_calc, par, circuit_genes, circuit_miscs, circuit_name2pos, circuit_styles)

# show plots
bkplot.show(bklayouts.grid([[het_mrna_fig, het_prot_fig, misc_fig],
                            [F_fig, None, None]]))

### Stochastic simulation

Finally, we simulate the circuit's stochastic behaviour using a hybrid tau-leap simulation algorithm (see [Sechkar et al. 2024](https://doi.org/10.1038/s41467-024-46410-9) for more details). As stochastic simulations are computationally more expensive, it usually makes sense to first obtain a circuit's deterministic steady state and then use it as the initial condition for stochastic simulations.

In [11]:
# specify the deterministic simulation time for steady state retrieval - 24h usually enough
tf_det = (0, 24)  # simulation time frame

sol = ode_sim(
    par,                # model parameters
    ode_with_circuit,   # ODE function for the cell with synthetic circuit
    cellmodel_auxil.x0_from_init_conds(init_conds, circuit_genes, circuit_miscs),   # initial condition VECTOR
    len(circuit_genes), len(circuit_miscs),     # number of synthetic genes and miscellaneous species in the circuit
    circuit_name2pos,                           # species name to vector position decoder for the synthetic gene circuit
    cellmodel_auxil.synth_gene_params_for_jax(par, circuit_genes),                  # jax.numpy array of certain synthetic gene parameters for efficient calculations
    tf_det, (tf_det[-1],),     # simulation time frame; time points at which the simulation will be recorded (empty as only the final steady state is needed)
    rtol=rtol,      # relative tolerance for the ODE solver
    atol=atol       # absolute tolerance for the ODE solver
)

det_steady_state = sol.ys[-1,:]  # get the steady state from the last time point of the deterministic simulation (no need to convert to numpy array)

The initial condition should be fed into the function _tauleap_sim_prep_ that initialises the variables needed for the tau-leap simulation.

Another important argument of _tauleap_sim_prep_ is the _key_seed_ array (almost any format, from simple list to _numpy.array_, is acceptable), which initialises the random number generators for the stochastic trajectories that will all be simulated in parallel, leveraging the GPU's parallelisation capacity. The number of simulated trajectories will be the same as the number of entries in _key_seed_. Crucially, pseudo-random number generation in JAX is [always identical for the same key seed value](https://jax.readthedocs.io/en/latest/random-numbers.html), so all _key_seed_ entries should be different.

For example, if in our case we want to get _num_traj=10_ trajectories, it is a good idea to have _key_seed_ be _jax.numpy.arange(10)_.


In [12]:
# specify how many different stochastic trajectories we want to simulate
num_traj = 2

(
    # OUTPUTS
    mRNA_count_scales,          # scaling factors for mRNA counts
    S,                          # stochastic reaction stoichiometry matrix
    x0_tauleap,                 # initial condition for the tau-leap simulation 
    circuit_synpos2genename,    # position to gene name decoder for the synthetic gene circuit
    keys0                       # random number generator seeds for the stochastic trajectories
) = tauleap_sim_prep(
    # INPUTS
    par,                                    # model parameters
    len(circuit_genes), len(circuit_miscs), # number of synthetic genes and miscellaneous species in the circuit
    circuit_name2pos,                       # species name to vector position decoder for the synthetic gene circuit
    det_steady_state,                       # deterministic steady state
    key_seeds=np.arange(num_traj)          # random number generator seeds for the stochastic trajectories
)

With all preparations for stochastic simulations complete, we can now run the tau-leaping algorithm.

In [13]:
# tau-leaping simulation step (time over which the deterministic and the stochastic reactions are treated independently before being brought together)
tau = 1e-6

# deterministic ODE integration step (should be smaller than tau and fit within it an integer number of times)
tau_odestep=1e-7

# stochastic simulation time frame and recording points
tf_stoch = (tf_det[-1],40)  # simulation time frame
tau_savetimestep = 0.01  # time step at which the simulation will be saved

# run the simulation, get recorded times and states as jnp arrays; also get latest key seeds if you want to carry on the simulation afterwards
ts_stoch_jnp, xs_stoch_jnp, keys_stoch = tauleap_sim(
    par,            # model parameters
    circuit_v,      # circuit reaction propensity calculator
    x0_tauleap,     # initial condition for the tau-leap simulation as returned by tauleap_sim_prep()
    len(circuit_genes), len(circuit_miscs),                         # number of synthetic genes and miscellaneous species in the circuit
    circuit_name2pos,                                               # species name to vector position decoder for the synthetic gene circuit
    cellmodel_auxil.synth_gene_params_for_jax(par, circuit_genes),  # jax.numpy array of certain synthetic gene parameters for efficient calculations
    tf_stoch,                   # simulation time frame
    tau, tau_odestep,           # tau leap step size, deterministic ODE integration step size
    tau_savetimestep,           # time step at which the simulation will be saved
    mRNA_count_scales,          # scaling factors for mRNA counts
    S,                          # stochastic reaction stoichiometry matrix
    circuit_synpos2genename,    # position to gene name decoder for the synthetic gene circuit
    keys0,                      # random number generator seeds for the stochastic trajectories
    avg_dynamics=False          
)

# convert the jax.numpy arrays into regular numpy arrays
ts_stoch = np.array(ts_stoch_jnp)
xs_stoch = np.array(xs_stoch_jnp)

All of the simulation visualisation tools for deterministic simulations, save for the cell mass breakdown plot, are available for stochastic simulations as well. In these plots, semi-transparent lines represent individual stochastic trajectories, while the solid line represents the mean of all trajectories.

For the host cell's state, we get the following:

In [16]:
# figures for the native mRNA, protein and tRNA level figures, as well as the intracellular antibiotic level (if present in the medium)
(
    nat_mrna_fig, nat_prot_fig, nat_trna_fig, 
    h_fig   # intracellular antibiotic (chloramphenicol) level figure - will be always zero here as none present in the medium
) = cellmodel_auxil.plot_native_concentrations_multiple(
    ts_stoch, xs_stoch, # recorded simulation time points and states
    par,                # model parameters
    circuit_genes,      # synthetic gene names
    tspan=tf_stoch,     # time span of the simulation to be plotted - not interested in deterministic steady-state determination part
    simtraj_alpha=0.1   # transparency of the individual stochastic trajectory plot
)

# figures for:
(
    l_figure,       # cell growth rate
    e_figure,       # translation elongation rate
    Fr_figure,      # ribosomal gene transcription regulation function (between 0 and 1)
    ppGpp_figure,   # concentration of the alarmone ppGpp, which determines resource allocation in the cell by regulating ribosome synthesis
    nu_figure,      # tRNA aminoacylation rate
    D_figure        # resource competition denominator - captures the extent of competition for ribosomes in the cell
) = cellmodel_auxil.plot_phys_variables_multiple(
    ts_stoch, xs_stoch,             # recorded simulation time points and states
    par,                            # model parameters
    circuit_genes, circuit_miscs,   # synthetic gene and miscellaneous species names
    circuit_name2pos,               # name to position decoder for the synthetic gene circuit
    simtraj_alpha=0.1               # transparency of the individual stochastic trajectory plot
)
# show plots
bkplot.show(bklayouts.grid([[nat_mrna_fig, nat_prot_fig, None],
                            [nat_trna_fig, h_fig, l_figure],
                            [e_figure, Fr_figure, D_figure]]))

Finally, for the synthetic gene circuit plots, we get the following:

In [15]:
# figures for:
(
    het_mrna_fig,   # synthetic gene mRNA levels
    het_prot_fig,   # synthetic gene protein levels
    misc_fig        # synthetic circuit's miscellaneous species levels (none present in this particular circuit, hence an empty plot returned)
)= cellmodel_auxil.plot_circuit_concentrations_multiple(
    ts_stoch, xs_stoch,             # recorded simulation time points and states
    par,                            # model parameters
    circuit_genes, circuit_miscs,   # synthetic gene and miscellaneous species names
    circuit_name2pos,               # name to position decoder for the synthetic gene circuit
    circuit_styles,                 # plotting styles for the synthetic gene circuit
    simtraj_alpha=0.1               # transparency of the individual stochastic trajectory plot
)

# synthetic gene regulation
F_fig = cellmodel_auxil.plot_circuit_regulation_multiple(
    ts_stoch, xs_stoch,             # recorded simulation time points and states
    par,                            # model parameters
    circuit_F_calc,                 # transcription regulation function for the circuit
    circuit_genes, circuit_miscs,   # synthetic gene and miscellaneous species names
    circuit_name2pos,               # name to position decoder for the synthetic gene circuit
    circuit_styles,                 # plotting styles for the synthetic gene circuit
    simtraj_alpha=0.1               # transparency of the individual stochastic trajectory plot
)

# show plots
bkplot.show(bklayouts.grid([[het_mrna_fig, het_prot_fig, misc_fig],
                            [F_fig, None, None]]))