In [1]:
'''
Notebook for FIGS8 - validation of the assumptions required to make analytical derivations
'''
# By Kirill Sechkar

# PACKAGE IMPORTS ------------------------------------------------------------------------------------------------------
import numpy as np
import jax
import jax.numpy as jnp
import functools
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController, SteadyStateEvent
import pandas as pd
import pickle
from bokeh import plotting as bkplot, models as bkmodels, layouts as bklayouts, palettes as bkpalettes, transform as bktransform
from math import pi
from bokeh import plotting as bkplot, models as bkmodels, layouts as bklayouts, io as bkio
from bokeh.colors import RGB as bkRGB
import time

# set up jax
from jax.lib import xla_bridge
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)
print(xla_bridge.get_backend().platform)

# set up bokeh
bkio.reset_output()
bkio.output_notebook() 

# OWN CODE IMPORTS -----------------------------------------------------------------------------------------------------
import synthetic_circuits as circuits
from cell_model import *
from get_steady_state import get_steady_state

  print(xla_bridge.get_backend().platform)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


cpu


In [2]:
 # INITIALISE CELL MODEL, LOAD THE CIRCUIT

# set up the circuit
cellmodel_auxil = CellModelAuxiliary()  # auxiliary tools for simulating the model and plotting simulation outcomes
par = cellmodel_auxil.default_params()  # get default parameter values
init_conds = cellmodel_auxil.default_init_conds(par)  # get default initial conditions

ode_with_circuit, circuit_F_calc, circuit_eff_m_het_div_k_het,\
    par, init_conds, circuit_genes, circuit_miscs, circuit_name2pos, circuit_styles, _ = cellmodel_auxil.add_circuit(
    circuits.oneconstitutive_cat_initialise,
    circuits.oneconstitutive_cat_ode,
    circuits.oneconstitutive_cat_F_calc,
    circuits.oneconstitutive_eff_m_het_div_k_het,
    par, init_conds)  # load the circuit


In [3]:
# PARAMETERISE CHLORAMPHENICOL RESISTANCE

init_conds['cat_pb'] = 10.0  # gene concentration (nM) - INITIAL CONDITION< NOT PARAMETER as it can be cut out by the integrase
par['a_cat'] = 500.0  # promoter strength (unitless)
par['n_cat'] = 300.0 # number of amino acids in CAT

# chloramphenicol concentration in the medium
par['h_ext'] = 10.5 * (10.0 ** 3)

In [4]:
# DEFINE PARAMETER RANGES TO CONSIDER

# define range of heterologous gene expression rates
num_c_bs = 15

c_b_range = jnp.linspace(0,2.5e5/par['a_b'],num_c_bs)    # plasmid concentration range

# define range of medium nutrient qualities
num_ss = 15
s_range = jnp.linspace(0.08,0.5,num_ss)    # protease concentration range

In [5]:
# DEFINE SIMULATION PARAMETERS

# define simulation parameters
t_to_steady = 50 # simulation time frame - assume that the cell is close to steady state after 50h
rtol = 1e-6; atol = 1e-6  # relative and absolute tolerances for the ODE solver

In [6]:
# PREPARE FOR VMAPPED SIMULATIONS

# make a mesh grid of all possible parameter combinations
s_mesh, c_b_mesh = jnp.meshgrid(s_range, c_b_range)
s_mesh_ravel = s_mesh.ravel()
c_b_mesh_ravel = c_b_mesh.ravel()

# vmap axes for parameters
par_vmap_axes = {}
for parameter in par.keys():
    if(parameter=='c_b'):
        par_vmap_axes[parameter]=0
    else:
        par_vmap_axes[parameter]=None
init_cond_vmap_axes = {}
for init_cond in init_conds.keys():
    if(init_cond=='s'):
        init_cond_vmap_axes[init_cond]=0
    else:
        init_cond_vmap_axes[init_cond]=None
        
# set parameter values to be vectors for vmapping
default_s = init_conds['s']
default_c_b = par['c_b']
init_conds['s'] = s_mesh_ravel
par['c_b'] = c_b_mesh_ravel

In [7]:
# RUN SIMULATIONS

# repackage synthetic gene parameters into jax arrays for simulation
sgp4j=cellmodel_auxil.synth_gene_params_for_jax(par,circuit_genes) 

get_steadystate_for_par_and_init_conds = lambda par, init_conds: get_steady_state(par,  # dictionary with model parameters
                                                        ode_with_circuit,  # ODE function for the cell with synthetic circuit
                                                        cellmodel_auxil.x0_from_init_conds(init_conds,circuit_genes,circuit_miscs),  # initial condition vector
                                                        len(circuit_genes), len(circuit_miscs), circuit_name2pos,
                                                        # dictionaries with circuit gene and miscellaneous specie names, species name to vector position decoder
                                                        sgp4j,# synthetic gene parameters for calculating k values
                                                        t_to_steady, # simulation time until steday state assumed to be reached
                                                        rtol,atol)  # simulation parameters: when to save the system's state, relative and absolute tolerances)   # simulation parameters: time frame, save time step, relative and absolute tolerances
vmapped_get_steadystate_for_par_and_init_conds = jax.vmap(get_steadystate_for_par_and_init_conds, in_axes=(par_vmap_axes,init_cond_vmap_axes))
sols = vmapped_get_steadystate_for_par_and_init_conds(par,init_conds)
ts=np.array(sols.ts)
xs=np.array(sols.ys)


In [8]:
# GET PHYSIOLOGICAL VARIABLES FOR SIMULATION OUTCOMES

es_ravel=np.zeros_like(np.array(s_mesh_ravel))
F_rs_ravel=np.zeros_like(es_ravel)
Hs_ravel=np.zeros_like(es_ravel)
for i in range(0, sols.ts.shape[0]):
    par_i= par.copy()
    init_conds_i = init_conds.copy()
    init_conds_i['s'] = np.array(s_mesh_ravel)[i]
    par_i['c_b'] = np.array(c_b_mesh_ravel)[i]
    es_array, _, F_rs_array, _, _, Ts_array, Ds_array, Ds_nodeg_array = cellmodel_auxil.get_e_l_Fr_nu_psi_T_D_Dnodeg(
        ts[i,:], xs[i,:,:], par_i,
        circuit_genes, circuit_miscs, circuit_name2pos,
        circuit_eff_m_het_div_k_het)
    es_ravel[i] = es_array[-1]
    F_rs_ravel[i] = F_rs_array[-1]
    # chloramphenicol concentration
    h=xs[i,-1,7]
    Hs_ravel[i]=par_i['K_D']/(h+par_i['K_D'])
    
# record the obtained values
es=es_ravel.reshape(len(c_b_range), len(s_range)).T
F_rs=F_rs_ravel.reshape(len(c_b_range), len(s_range)).T
Hs=Hs_ravel.reshape(len(c_b_range), len(s_range)).T

# find percentage changes in physiological variables relative to no burden
es_changes=np.zeros_like(es)
F_rs_changes=np.zeros_like(F_rs)
Hs_changes=np.zeros_like(Hs)
for i in range(0, len(s_range)):
    for j in range(0, len(c_b_range)):
        es_changes[i,j]=np.abs(100*(es[i,j]-es[i,0])/es[i,0])
        F_rs_changes[i,j]=np.abs(100*(F_rs[i,j]-F_rs[i,0])/F_rs[i,0])
        Hs_changes[i,j]=np.abs(100*(Hs[i,j]-Hs[i,0])/Hs[i,0])

In [18]:
# PLOT HEATMAPS OF CHANGES IN E, T, F_r

# specify heatmap settings
dimensions = (345,290)  # dimensions of the plot (width, height)
colourbar_range = (0, 10)

# consider changes in different variables in turns
var_changes = [es_changes, F_rs_changes, Hs_changes]
var_descriptions = ['ε', 'F_r', 'H']

# make axis labels
prot_labels = []
for i in range(0, len(s_range)):
    if(i==0 or i==int(len(s_range+1)/2) or i==len(s_range)-1):
        prot_labels.append(str(s_range[i]))
    else:
        prot_labels.append('')
xtra_labels = []
for i in range(0, len(c_b_range)):
    if(i==0 or i==int((len(c_b_range)+1)/2) or i==len(c_b_range)-1):
        xtra_labels.append(str(c_b_range[i]))
    else:
        xtra_labels.append('')

# plot
figs = []
for var_cntr in range(0, len(var_changes)):   
    # make a dataframe for plotting
    str_b=[np.format_float_scientific(x*par['a_b'], precision=2) for x in c_b_range]
    df_2d = pd.DataFrame(var_changes[var_cntr], columns=str_b)
    df_2d.columns = df_2d.columns.astype('str')
    df_2d['s'] = [np.format_float_scientific(x,precision=2) for x in s_range]
    df_2d['s'] = df_2d['s'].astype('str')
    df_2d = df_2d.set_index('s')
    df_2d.columns.name = 'a_b'
    changes_df = pd.DataFrame(df_2d.stack(), columns=[var_descriptions[var_cntr]]).reset_index()

    # set up the graph
    figs.append(bkplot.figure(
        x_axis_label='σ (medium nutr. qual.)',
        y_axis_label='αb (synth. gene transc. rate), nM',
        x_range=list(df_2d.index),
        y_range=list(df_2d.columns),
        width=dimensions[0], height=dimensions[1],
        tools="box_zoom,pan,hover,reset,save",
        tooltips=[('σ = ','@{s}'),
                  ('αb = ','@{a_b}'),
                  (var_descriptions[var_cntr]+' change=', '@'+var_descriptions[var_cntr])],
        title=var_descriptions[var_cntr]+' % change',
    ))
    # svg backend
    figs[-1].output_backend = "svg"
    figs[-1].grid.grid_line_color = None
    figs[-1].axis.axis_line_color = None
    figs[-1].axis.major_tick_line_color = None
    figs[-1].axis.major_label_text_font_size = "8pt"
    figs[-1].axis.major_label_standoff = 0
    figs[-1].xaxis.major_label_orientation = pi/2

    # plot the heatmap
    rects = figs[-1].rect(x="s", y="a_b", source=changes_df,
                       width=1, height=1,
                       fill_color=bktransform.linear_cmap(var_descriptions[var_cntr], bkpalettes.Plasma256, low=colourbar_range[0],high=colourbar_range[1]),
                       line_color=None)
    # add colour bar
    figs[-1].add_layout(rects.construct_color_bar(
        major_label_text_font_size="8pt",
        ticker=bkmodels.BasicTicker(desired_num_ticks=3),
        formatter=bkmodels.PrintfTickFormatter(format="%.0f%%"),
        label_standoff=6,
        border_line_color=None,
        padding=5
    ), 'right')

# show plots
bkplot.show(bklayouts.row(figs))

In [15]:
# save one plot for the colour bar

bkplot.output_file('hmap_for_bar.html')
bkplot.save(figs)

'/mnt/c/Users/ersat/CODE/punisher/FigS8/hmap_for_bar.html'