In [1]:
'''
D_VS_A_D.IPYNB - Invetsigate the dependence of the Resource Competition Denominator D on synthetic protein expression and degradation
'''
# By Kirill Sechkar

# PACKAGE IMPORTS ------------------------------------------------------------------------------------------------------
import numpy as np
import jax
import jax.numpy as jnp
import functools
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController, SteadyStateEvent
import pandas as pd
import pickle
from bokeh import plotting as bkplot, models as bkmodels, layouts as bklayouts, palettes as bkpalettes, transform as bktransform
from math import pi
from bokeh import plotting as bkplot, models as bkmodels, layouts as bklayouts, io as bkio
from bokeh.colors import RGB as bkRGB
import time

# set up jax
from jax.lib import xla_bridge
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)
print(xla_bridge.get_backend().platform)

# set up bokeh
bkio.reset_output()
bkio.output_notebook() 

# OWN CODE IMPORTS -----------------------------------------------------------------------------------------------------
import synthetic_circuits as circuits
from cell_model import *

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


cpu


In [2]:
# DEFINE VMAPPABLE FUNCTION FOR GETTING STEADY STATES

def get_steadystates(par,    # dictionary with model parameters
                        ode_with_circuit,  # ODE function for the cell with the synthetic gene circuit
                        x0,  # initial condition vector
                        num_circuit_genes, num_circuit_miscs, circuit_name2pos, sgp4j, # dictionaries with circuit gene and miscellaneous specie names, species name to vector position decoder, relevant synthetic gene parameters in jax.array form
                        t_to_steady, rtol, atol    # simulation parameters: time until steady state, when to save the system's state, relative and absolute tolerances
                        ):
    # define the ODE term
    vector_field = lambda t, y, args: ode_with_circuit(t, y, args)
    term = ODETerm(vector_field)

    # define arguments of the ODE term
    args = (
        par,  # model parameters
        circuit_name2pos,  # gene name - position in circuit vector decoder
        num_circuit_genes, num_circuit_miscs,  # number of genes and miscellaneous species in the circuit
        sgp4j  # relevant synthetic gene parameters in jax.array form
    )

    # define the solver
    solver = Kvaerno3()

    # define the timestep controller
    stepsize_controller = PIDController(rtol=rtol, atol=atol)

    # define the steady-state termination conditions
    steady_state_stop = SteadyStateEvent(rtol=0.001,atol=0.001)  # stop simulation prematurely if steady state is reached

    # solvew the ODE
    sol = diffeqsolve(term, solver,
                      args=args,
                      t0=0, t1=t_to_steady, dt0=0.1, y0=x0,
                      max_steps=None,
                      stepsize_controller=stepsize_controller)

    return sol

In [3]:
# SET UP AND PARAMETERISE THE SYSTEM

cellmodel_auxil = CellModelAuxiliary()  # auxiliary tools for simulating the model and plotting simulation outcomes
par = cellmodel_auxil.default_params()  # get default parameter values
init_conds = cellmodel_auxil.default_init_conds(par)  # get default initial conditions
ode_with_circuit, circuit_F_calc, par, init_conds, circuit_genes, circuit_miscs, circuit_name2pos, circuit_colours, _ = cellmodel_auxil.add_circuit(
    circuits.oneconstitutive_cat_prot_initialise,
    circuits.oneconstitutive_cat_prot_ode,
    circuits.oneconstitutive_cat_prot_F_calc,
    par, init_conds)  # load the circuit (WITH CAT - comment out the unused one)

# culture medium nutrient quality
init_conds['s'] = 0.05

# rate of burdensome synthetic protein degradation by the protease
par['d_xtra'] = 0.01836  # protease activity (1/h/nM)

# cat gene parameters and chloramphenicol levels
par['c_cat'] = 10.0  # gene concentration (nM)
par['a_cat'] = 500.0  # promoter strength (unitless)
par['h_ext'] = 10.5 * (10 ** 3)  # chloramphenicol concentration in thet culture medium (nM)

# default synthetic gene expression rate
par['c_xtra'] = 1.0  # gene concentration (nM)
par['a_xtra'] = 1e5  # promoter strength (unitless)

# default protease expression rate
par['c_prot'] = 10.0  # gene concentration (nM)
par['a_prot'] = 70.0  # promoter strength (unitless)

In [4]:
# DEFINE PARAMETER RANGES TO CONSIDER
# define range of heterologous gene expression rates
num_c_xtras = 20
# c_xtra_range = jnp.linspace(0,par['a_a']*2/par['a_xtra'],num_c_xtras)    # plasmid concentration range
c_xtra_range = jnp.linspace(0,2000/par['a_xtra'],num_c_xtras)    # plasmid concentration range

# define range of protease expression rates -up to 2 times the default value
num_c_prots = 20
c_prot_range = jnp.linspace(0,2,num_c_prots)    # protease concentration range

In [5]:
# DEFINE SIMULATION PARAMETERS

# define simulation parameters
t_to_steady = 50 # simulation time frame - assume that the cell is close to steady state after 50h
rtol = 1e-6; atol = 1e-6  # relative and absolute tolerances for the ODE solver

In [6]:
# PREPARE FOR VMAPPED SIMULATIONS

# make a mesh grid of all possible parameter combinations
c_prot_mesh, c_xtra_mesh = jnp.meshgrid(c_prot_range, c_xtra_range)
c_prot_mesh_ravel = c_prot_mesh.ravel()
c_xtra_mesh_ravel = c_xtra_mesh.ravel()

# vmap axes for parameters
par_vmap_axes = {}
for parameter in par.keys():
    if(parameter=='c_xtra'):
        par_vmap_axes[parameter]=0
    elif(parameter=='c_prot'):
        par_vmap_axes[parameter]=0
    else:
        par_vmap_axes[parameter]=None
        
# set parameter values to be vectors for vmapping
default_c_prot = par['c_prot']
default_c_xtra = par['c_xtra']
par['c_prot'] = c_prot_mesh_ravel
par['c_xtra'] = c_xtra_mesh_ravel

In [7]:
# RUN SIMULATIONS

# repackage synthetic gene parameters into jax arrays for simulation
sgp4j=cellmodel_auxil.synth_gene_params_for_jax(par,circuit_genes) 

get_steadystates_for_par = lambda par: get_steadystates(par,  # dictionary with model parameters
                                                        ode_with_circuit,  # ODE function for the cell with synthetic circuit
                                                        cellmodel_auxil.x0_from_init_conds(init_conds,circuit_genes,circuit_miscs),  # initial condition vector
                                                        len(circuit_genes), len(circuit_miscs), circuit_name2pos,
                                                        # dictionaries with circuit gene and miscellaneous specie names, species name to vector position decoder
                                                        sgp4j,# synthetic gene parameters for calculating k values
                                                        t_to_steady, # simulation time until steday state assumed to be reached
                                                        rtol,atol)  # simulation parameters: when to save the system's state, relative and absolute tolerances)   # simulation parameters: time frame, save time step, relative and absolute tolerances
vmapped_get_steadystates_for_par = jax.vmap(get_steadystates_for_par, in_axes=(par_vmap_axes,))
sols = vmapped_get_steadystates_for_par(par)
ts=np.array(sols.ts)
xs=np.array(sols.ys)


In [8]:
# GET PHYSIOLOGICAL VARIABLES FOR SIMULATION OUTCOMES

es_ravel=np.zeros_like(np.array(c_prot_mesh_ravel))
F_rs_ravel=np.zeros_like(es_ravel)
Ts_ravel=np.zeros_like(es_ravel)
Ds_ravel=np.zeros_like(es_ravel)
Ds_nodeg_ravel=np.zeros_like(es_ravel)
for i in range(0, sols.ts.shape[0]):
    par_i= par.copy()
    par_i['c_prot'] = np.array(c_prot_mesh_ravel)[i]
    par_i['c_xtra'] = np.array(c_xtra_mesh_ravel)[i]
    es_array, _, F_rs_array, _, _, Ts_array, Ds_array, Ds_nodeg_array = cellmodel_auxil.get_e_l_Fr_nu_psi_T_D_Dnodeg(
        ts[i,:], xs[i,:,:], par_i,
        circuit_genes, circuit_miscs, circuit_name2pos)
    es_ravel[i] = es_array[-1]
    F_rs_ravel[i] = F_rs_array[-1]
    Ts_ravel[i] = Ts_array[-1]
    Ds_ravel[i] = Ds_array[-1]
    Ds_nodeg_ravel[i] = Ds_nodeg_array[-1]
    
# record the obtained values
es=es_ravel.reshape(len(c_xtra_range), len(c_prot_range)).T
Ts=Ts_ravel.reshape(len(c_xtra_range), len(c_prot_range)).T
F_rs=F_rs_ravel.reshape(len(c_xtra_range), len(c_prot_range)).T
Ds=Ds_ravel.reshape(len(c_xtra_range), len(c_prot_range)).T
Ds_nodeg=Ds_nodeg_ravel.reshape(len(c_xtra_range), len(c_prot_range)).T

# find percentage changes in physiological variables relative to no burden and no protease action
es_changes=np.abs((1-es/es[0,0])*100)
Ts_changes=np.abs((1-Ts/Ts[0,0])*100)
F_rs_changes=np.abs((1-F_rs/F_rs[0,0])*100)
Ds_changes=np.abs((1-Ds/Ds[0,0])*100)

In [9]:
# PLOT HEATMAPS OF CHANGES IN E, T, F_r

# specify heatmap settings
dimensions = (480,420)  # dimensions of the plot (width, height)
colourbar_range = (0, 5)

# consider changes in different variables in turns
var_changes = [es_changes, Ts_changes, F_rs_changes]
var_descriptions = ['e', 'T', 'F_r']

# make axis labels
prot_labels = []
for i in range(0, len(c_prot_range)):
    if(i==0 or i==int(len(c_prot_range+1)/2) or i==len(c_prot_range)-1):
        prot_labels.append(str(c_prot_range[i]))
    else:
        prot_labels.append('')
xtra_labels = []
for i in range(0, len(c_xtra_range)):
    if(i==0 or i==int((len(c_xtra_range)+1)/2) or i==len(c_xtra_range)-1):
        xtra_labels.append(str(c_xtra_range[i]))
    else:
        xtra_labels.append('')

# plot
figs = []
for var_cntr in range(0, len(var_changes)):   
    # make a dataframe for plotting
    str_xtra=[np.format_float_scientific(x*par['a_xtra'], precision=2) for x in c_xtra_range]
    df_2d = pd.DataFrame(var_changes[var_cntr], columns=str_xtra)
    df_2d.columns = df_2d.columns.astype('str')
    df_2d['c_prot a_prot'] = [np.format_float_scientific(x * par['a_prot'],precision=2) for x in c_prot_range]
    df_2d['c_prot a_prot'] = df_2d['c_prot a_prot'].astype('str')
    df_2d = df_2d.set_index('c_prot a_prot')
    df_2d.columns.name = 'c_xtra a_xtra'
    changes_df = pd.DataFrame(df_2d.stack(), columns=[var_descriptions[var_cntr]]).reset_index()

    # set up the graph
    figs.append(bkplot.figure(
        x_axis_label='Protease gene transcription rate',
        y_axis_label='Burdensome gene transcription rate',
        x_range=list(df_2d.index),
        y_range=list(df_2d.columns),
        width=dimensions[0], height=dimensions[1],
        tools="box_zoom,pan,hover,reset,save",
        tooltips=[('Prot. transc. rate = ','@{c_prot a_prot}'),
                  ('Burd. transc. rate = ','@{c_xtra a_xtra}'),
                  (var_descriptions[var_cntr]+' change=', '@'+var_descriptions[var_cntr])],
        title=var_descriptions[var_cntr]+' abs. change from no synth. gene exp. burden',
    ))
    # svg backend
    figs[-1].output_backend = "svg"
    figs[-1].grid.grid_line_color = None
    figs[-1].axis.axis_line_color = None
    figs[-1].axis.major_tick_line_color = None
    figs[-1].axis.major_label_text_font_size = "8pt"
    figs[-1].axis.major_label_standoff = 0
    figs[-1].xaxis.major_label_orientation = pi/2

    # plot the heatmap
    rects = figs[-1].rect(x="c_prot a_prot", y="c_xtra a_xtra", source=changes_df,
                       width=1, height=1,
                       fill_color=bktransform.linear_cmap(var_descriptions[var_cntr], bkpalettes.Plasma256, low=colourbar_range[0],high=colourbar_range[1]),
                       line_color=None)
    # add colour bar
    figs[-1].add_layout(rects.construct_color_bar(
        major_label_text_font_size="8pt",
        ticker=bkmodels.BasicTicker(desired_num_ticks=3),
        formatter=bkmodels.PrintfTickFormatter(format="%.3f%%"),
        label_standoff=6,
        border_line_color=None,
        padding=5
    ), 'right')

# show plots
bkplot.show(bklayouts.row(figs))

In [10]:
# PLOT EFFECT OF PROTEIN DEGRADATION ON D

D_deg_changes=(1-Ds/Ds_nodeg)*100

# make axis labels
prot_labels = []
for i in range(0, len(c_prot_range)):
    if(i==0 or i==int(len(c_prot_range+1)/2) or i==len(c_prot_range)-1):
        prot_labels.append(str(c_prot_range[i]))
    else:
        prot_labels.append('')
xtra_labels = []
for i in range(0, len(c_xtra_range)):
    if(i==0 or i==int((len(c_xtra_range)+1)/2) or i==len(c_xtra_range)-1):
        xtra_labels.append(str(c_xtra_range[i]))
    else:
        xtra_labels.append('')

# make a dataframe for plotting
str_xtra=[np.format_float_scientific(x*par['a_xtra'], precision=2) for x in c_xtra_range]
df_2d = pd.DataFrame(D_deg_changes, columns=str_xtra)
df_2d.columns = df_2d.columns.astype('str')
df_2d['c_prot a_prot'] = [np.format_float_scientific(x * par['a_prot'],precision=2) for x in c_prot_range]
df_2d['c_prot a_prot'] = df_2d['c_prot a_prot'].astype('str')
df_2d = df_2d.set_index('c_prot a_prot')
df_2d.columns.name = 'c_xtra a_xtra'
changes_df = pd.DataFrame(df_2d.stack(), columns=['D deg changes']).reset_index()

# set up the graph
D_fig = bkplot.figure(
    x_axis_label='Protease gene transcription rate',
    y_axis_label='Burdensome gene transcription rate',
    x_range=list(df_2d.index),
    y_range=list(df_2d.columns),
    width=dimensions[0], height=dimensions[1],
    tools="box_zoom,pan,hover,reset,save",
    tooltips=[('Prot. transc. rate = ','@{c_prot a_prot}'),
              ('Burd. transc. rate = ','@{c_xtra a_xtra}'),
              ('D change due to deg.=', '@{D deg changes}')],
    title='D change compared to no degradation',
)
# svg backend
D_fig.output_backend = "svg"
D_fig.grid.grid_line_color = None
D_fig.axis.axis_line_color = None
D_fig.axis.major_tick_line_color = None
D_fig.axis.major_label_text_font_size = "8pt"
D_fig.axis.major_label_standoff = 0
D_fig.xaxis.major_label_orientation = pi/2

# plot the heatmap
rects = D_fig.rect(x="c_prot a_prot", y="c_xtra a_xtra", source=changes_df,
                   width=1, height=1,
                   fill_color=bktransform.linear_cmap('D deg changes', bkpalettes.Plasma256, low=colourbar_range[0],high=colourbar_range[1]),
                   line_color=None)
# mark maximum switch gene expression on the heatmap
transcrate_treshold=2000.0
for i in range(0, len(c_prot_range)):
    if(c_xtra_range[i]*par['a_xtra']>=transcrate_treshold):
        D_fig.add_layout(bkmodels.Span(location=i, dimension='width', line_color='white', line_width=5,line_dash='dashed'))
        break
# add colour bar
D_fig.add_layout(rects.construct_color_bar(
    major_label_text_font_size="8pt",
    ticker=bkmodels.BasicTicker(desired_num_ticks=3),
    formatter=bkmodels.PrintfTickFormatter(format="%.3f%%"),
    label_standoff=6,
    border_line_color=None,
    padding=5
), 'right')

# show plot
bkplot.show(D_fig)