In [134]:
'''
Notebook for supplementary FIG?E - tuning the punisher to work with partial synthetic gene mutations
'''
# By Kirill Sechkar

# PACKAGE IMPORTS ------------------------------------------------------------------------------------------------------
import numpy as np
import jax
import jax.numpy as jnp
import functools
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController, SteadyStateEvent
import pandas as pd
import pickle
from bokeh import plotting as bkplot, models as bkmodels, layouts as bklayouts, palettes as bkpalettes, transform as bktransform
from math import pi
from bokeh import plotting as bkplot, models as bkmodels, layouts as bklayouts, io as bkio
from bokeh.colors import RGB as bkRGB
from contourpy import contour_generator as cgen
import time

# set up jax
from jax.lib import xla_bridge
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)
print(xla_bridge.get_backend().platform)

# set up bokeh
bkio.reset_output()
bkio.output_notebook() 

# OWN CODE IMPORTS -----------------------------------------------------------------------------------------------------
import synthetic_circuits as circuits
from cell_model import *
from get_steady_state import *
from Fig2.design_guidance_tools import *

cpu


  print(xla_bridge.get_backend().platform)


In [135]:
 # INITIALISE CELL MODEL, LOAD THE CIRCUIT

# initialise cell model
cellmodel_auxil = CellModelAuxiliary()  # auxiliary tools for simulating the model and plotting simulation outcomes
par = cellmodel_auxil.default_params()  # get default parameter values
init_conds = cellmodel_auxil.default_init_conds(par)  # get default initial conditions

# load synthetic gene circuit
ode_with_circuit, circuit_F_calc, circuit_eff_m_het_div_k_het,  \
    par, init_conds, circuit_genes, circuit_miscs, circuit_name2pos, circuit_styles, _ = cellmodel_auxil.add_circuit(
    circuits.punisher_b_initialise,
    circuits.punisher_b_ode,
    circuits.punisher_b_F_calc,
    circuits.punisher_sep_b_eff_m_het_div_k_het,
    par, init_conds)  # load the circuit

In [136]:
# SPECIFY THE CIRCUIT'S DEFAULT PARAMETERS

# BURDENSOME SYNTHETIC GENE
par['c_b'] = 1
par['a_b'] = 1e5

# PUNISHER
# switch gene conc
par['c_switch'] = 10.0  # gene concentration (nM)
par['a_switch'] = 400.0  # promoter strength (unitless)
par['d_switch']=0.01836
# integrase - expressed from the switch gene's operon, not its own gene => c_int, a_int irrelevant
par['k+_int'] = par['k+_switch']/80.0  # RBS weaker than for the switch gene
par['d_int'] = 0.0#0.01836 # rate of integrase degradation per protease molecule (1/nM/h)
# CAT (antibiotic resistance) gene
init_conds['cat_pb'] = 10.0  # gene concentration (nM) - INITIAL CONDITION< NOT PARAMETER as it can be cut out by the integrase
par['a_cat'] = 500.0  # promoter strength (unitless)
# synthetic protease gene
par['c_prot'] = 10.0  # gene concentration (nM)
par['a_prot'] = 25.0  # promoter strength (unitless)
init_conds['p_prot'] = 1500.0 # if zero at start, the punisher's triggered prematurely

# punisher's transcription regulation function
par['K_switch'] = 300.0  # Half-saturation constant for the self-activating switch gene promoter (nM)
par['eta_switch'] = 2 # Hill coefficient for the self-activating switch gene promoter (unitless)
par['baseline_switch'] = 0.025  # Baseline value of the switch gene's transcription activation function
par['baseline_switch_alt'] = 0
par['p_switch_ac_frac'] = 0.85  # active fraction of protein (i.e. share of molecules bound by the inducer)

# CULTURE MEDIUM
init_conds['s'] = 0.5
par['h_ext'] = 10.5 * (10.0 ** 3)

In [137]:
# get the cellular variables in steady state without burden
e, F_r, h, xis, chis = values_for_analytical(par, ode_with_circuit, init_conds,
                                             circuit_genes, circuit_miscs,
                                             circuit_name2pos,
                                             circuit_F_calc,
                                             circuit_eff_m_het_div_k_het)
# record the cellular variables
cellvars = {'e': e, 'F_r': F_r, # translation elongation rate and ribosome trnscription regulation
            'h': h, # intacellular chlorampenicol concentration
            # burden values
            'xi_a': xis['a'], 'xi_r': xis['r'], 'xi_other_genes': xis['other'], 'xi_cat': xis['cat'],
            'xi_switch_max': xis['switch (max)'], 'xi_int_max': xis['int (max)'], 'xi_prot': xis['prot'],
            # protein degradation correction factors for the switch protein and the integrase
            'chi_switch': chis['switch'], 'chi_int': chis['int']}

In [138]:
# DEFINE THE EXETENT OF PARTIAL MUTATION AND INDUCTION VALUES BEFORE AND AFTER TUNING
b_exp_reduction = 0.75
I_before = 0.87
I_after = 0.93

In [139]:
# DEFINE THE RANGES OF SWITCH/INTEGRASE/PROTEASE CONCS. AND PROPORTION OF INDUCER-BOUND SWITCH PROTEINS

# switch gene concentration range
c_range=np.linspace(7.5,15,80)

# bound (active) fraction range
ac_frac_range=np.linspace(0.8,1,80)

In [140]:
# CONCENTRATION VS ACTIVE FRACTION: FIND SWITCHING THRESHOLDS AND MINIMUM FOLD CHANGES

# get a mesh grid, then flatten its x and y coordinates into a single linear array
c_mesh, ac_frac_mesh = np.meshgrid(c_range, ac_frac_range)
c_mesh_ravel = c_mesh.ravel()
ac_frac_mesh_ravel = ac_frac_mesh.ravel()

# specify vmapping axes
par_vmapping_axes = {}
for key in par.keys():
    if (key == 'c_switch' or key == 'p_switch_ac_frac'):#or key == 'c_int'
        par_vmapping_axes[key] = 0
    else:
        par_vmapping_axes[key] = None
cellvars_vmapping_axes = {}
for key in cellvars.keys():
    if(key == 'xi_switch_max' or key == 'xi_int_max'):# or key == 'xi_prot'):
        cellvars_vmapping_axes[key] = 0
    else:
        cellvars_vmapping_axes[key] = None
    
# make a vmappable parameter dictionary
par_for_existence = par.copy()
par_for_existence['c_switch'] = jnp.array(c_mesh_ravel)
# par_for_existence['c_int'] = jnp.array(c_mesh_ravel)
# par_for_existence['c_prot'] = jnp.array(c_mesh_ravel)
par_for_existence['p_switch_ac_frac'] = jnp.array(ac_frac_mesh_ravel)
# make a vmappable cellular variable dictionary
cellvars_for_existence = cellvars.copy()
cellvars_for_existence['xi_switch_max'] = (cellvars['xi_switch_max']/par['c_switch']) * jnp.array(c_mesh_ravel)
cellvars_for_existence['xi_int_max'] = (cellvars['xi_int_max']/par['c_switch']) * jnp.array(c_mesh_ravel)
# cellvars_for_existence['xi_prot'] = (cellvars['xi_prot']/par['c_prot']) * jnp.array(c_mesh_ravel)

# make the checking function vmappable
vmapped_check_if_threshold_exists = jax.jit(jax.vmap(check_if_threshold_exists,
                                                     in_axes=(par_vmapping_axes, cellvars_vmapping_axes)))

# find for which parameter combinations the switching threshold exists
threshold_exists = np.array(vmapped_check_if_threshold_exists(par_for_existence, cellvars_for_existence))

# from now on, only consider parameter combinations where the threshold bifurcation point exists
indices_where_threshold_exists = []
for i in range(0, len(threshold_exists)):
    if (threshold_exists[i]):
        indices_where_threshold_exists.append(i)
c_mesh_ravel_exists = c_mesh_ravel[indices_where_threshold_exists]
ac_frac_mesh_ravel_exists = ac_frac_mesh_ravel[indices_where_threshold_exists]

# make a vmappable parameter dictionary
par_for_threshold_mfchanges = par.copy()
par_for_threshold_mfchanges['c_switch'] = jnp.array(c_mesh_ravel_exists)
# par_for_threshold_mfchanges['c_int'] = jnp.array(c_mesh_ravel_exists)
# par_for_threshold_mfchanges['c_prot'] = jnp.array(c_mesh_ravel_exists)
par_for_threshold_mfchanges['p_switch_ac_frac'] = jnp.array(ac_frac_mesh_ravel_exists)
# make a vmappable cellular variable dictionary
cellvars_for_threshold_mfchanges = cellvars.copy()
cellvars_for_threshold_mfchanges['xi_switch_max'] = (cellvars['xi_switch_max']/par['c_switch']) * jnp.array(c_mesh_ravel_exists)
cellvars_for_threshold_mfchanges['xi_int_max'] = (cellvars['xi_int_max']/par['c_switch']) * jnp.array(c_mesh_ravel_exists)
# cellvars_for_threshold_mfchanges['xi_prot'] = (cellvars['xi_prot']/par['c_prot']) * jnp.array(c_mesh_ravel_exists)

# make the threshold and minimum fold change retrieval function vmappable
vmapped_threshold_mfchanges = jax.jit(jax.vmap(threshold_mfchanges,
                                               in_axes=(par_vmapping_axes, cellvars_vmapping_axes)))

# find switching thresholds and minimum fold changes for the parameter combinations where the switching threshold exists
thresholds_mfchanges = np.array(vmapped_threshold_mfchanges(par_for_threshold_mfchanges, cellvars_for_threshold_mfchanges))
xi_thresholds = thresholds_mfchanges[:, 1]
mfchange_intact = thresholds_mfchanges[:, 4]

In [141]:
# CONCENTRATION VS ACTIVE FRACTION: FIND BURDEN CONTOURS

# fill the points where no threshold exists with INFS
xi_thresholds_for_contour_ravel = np.zeros(threshold_exists.shape)  # initialise
last_index_in_exist_list = 0
for i in range(0, len(xi_thresholds_for_contour_ravel)):
    if (i == indices_where_threshold_exists[last_index_in_exist_list]):
        xi_thresholds_for_contour_ravel[i] = xi_thresholds[last_index_in_exist_list]
        if (last_index_in_exist_list < len(indices_where_threshold_exists) - 1):
            last_index_in_exist_list += 1
    else:
        xi_thresholds_for_contour_ravel[i] = np.inf
xi_thresholds_for_contour = xi_thresholds_for_contour_ravel.reshape(len(ac_frac_range),
                                                                    len(c_range)).T

# create a contour generator
threshold_cgen = cgen(x=ac_frac_range, y=c_range,
                      z=xi_thresholds_for_contour)

# contours to be found: 1) all synth. genes functional; 2) just the CAT and protease genes functional
xi_native_prot_cat = cellvars['xi_a'] + cellvars['xi_r'] + cellvars['xi_cat'] + cellvars['xi_prot']
xi_with_all_genes = xi_native_prot_cat + cellvars['xi_other_genes']
xi_partial_mut = xi_native_prot_cat + cellvars['xi_other_genes']*(1-b_exp_reduction)
xi_contours = {'values': [xi_with_all_genes, xi_native_prot_cat, xi_partial_mut],
               'legends': ['Burdensome\ngene present', 'No burdensome\ngene', 'Partially disabling\nmutation'],
               'dashes': ['dashed','solid','dotted']}

# find burden contour lines
xi_contours['contour lines'] = []
for i in range(0, len(xi_contours['values'])):
    xi_contours['contour lines'].append(threshold_cgen.lines(xi_contours['values'][i]))

In [142]:
# CONCENTRATIONS VS ACTIVE FRACTIONS: PLOT

rect_widths_along_x_axis = np.zeros(len(ac_frac_range))
rect_widths_along_x_axis[0] = ac_frac_range[1] - ac_frac_range[0]
for i in range(1, len(ac_frac_range)):
    rect_widths_along_x_axis[i] = ((ac_frac_range[i] - ac_frac_range[i - 1]) -
                                   rect_widths_along_x_axis[i - 1] / 2) * 2
rect_heights_along_y_axis = np.zeros(len(c_range))
rect_heights_along_y_axis[0] = c_range[1] - c_range[0]
for i in range(1, len(c_range)):
    rect_heights_along_y_axis[i] = ((c_range[i] - c_range[i - 1]) - rect_heights_along_y_axis[i - 1] / 2) * 2
rect_widths_ravel_exists = np.zeros(ac_frac_mesh_ravel_exists.shape)
rect_heights_ravel_exists = np.zeros(c_mesh_ravel_exists.shape)
for i in range(0, len(ac_frac_mesh_ravel_exists)):
    baseline_where = np.argwhere(
        ac_frac_range == ac_frac_mesh_ravel_exists[i])  # locate the baseline value in the baseline range
    rect_widths_ravel_exists[i] = rect_widths_along_x_axis[baseline_where[0][0]]*1.25
    eta_where = np.argwhere(c_range == c_mesh_ravel_exists[i])  # locate the eta value in the eta range
    rect_heights_ravel_exists[i] = rect_heights_along_y_axis[eta_where[0][0]]*1.25

# make a dataframe for the heatmap of minimum fold changes
heatmap_df = pd.DataFrame({'c_s=c_int': c_mesh_ravel_exists, 'I': ac_frac_mesh_ravel_exists,
                           'mfchange_intact': mfchange_intact,
                           'rect_width': rect_widths_ravel_exists, 'rect_height': rect_heights_ravel_exists})


mfchange_intact_figure = bkplot.figure(
    frame_width=240,
    frame_height=180,
    x_axis_label="I (share of switch proteins bound by inducer)",
    y_axis_label="cs=ci (switch & integrase gene conc.), nM",
    x_range=(min(ac_frac_range), max(ac_frac_range)),
    y_range=(min(c_range), max(c_range)),
    #title="Integrase activity GF-change",
    tools='pan,box_zoom,reset,save'
)
# svg backend
mfchange_intact_figure.output_backend= "svg"
# plot the heatmap
rects = mfchange_intact_figure.rect(x="I", y="c_s=c_int", source=heatmap_df,
                                 width='rect_width', height='rect_height',
                                    fill_color=bktransform.linear_cmap('mfchange_intact',
                                                                       bkpalettes.Plasma256,
                                                                       low=2.5e4,
                                                                       high=7.5e4),
                                 line_width=0,line_alpha=0)
# add colour bar
mfchange_intact_figure.add_layout(rects.construct_color_bar(
    major_label_text_font_size="8pt",
    ticker=bkmodels.FixedTicker(ticks=[2.5e4, 5e4, 7.5e4]),
    formatter=bkmodels.PrintfTickFormatter(format="%e"),
    label_standoff=6,
    border_line_color=None,
    padding=5
), 'right')

# mark the line of same DNA concentrations as the default parameter
mfchange_intact_figure.line(x=[min(ac_frac_range), max(ac_frac_range)], y=[par['c_switch'], par['c_switch']],
            line_width=2, line_color='white', line_dash='solid')

# plot the burden contours
for i in range(0,len(xi_contours['values'])):
    for j in range(0,len(xi_contours['contour lines'][i])):
        mfchange_intact_figure.line(xi_contours['contour lines'][i][j][:, 0], xi_contours['contour lines'][i][j][:, 1],
                    line_dash=xi_contours['dashes'][i],
                    legend_label=xi_contours['legends'][i],
                    line_width=2, line_color='black')

# mark the point where the point after tuning lays
mfchange_intact_figure.scatter(marker='x',x=[I_after], y=[par['c_switch']], size=10, color='black',line_width=4)

# mark where the point before tuning lays
mfchange_intact_figure.scatter(marker='x',x=[I_before], y=[par['c_switch']], size=10, color=bkRGB(72,209,204),line_width=4)

# add and configure the legend
mfchange_intact_figure.legend.location = "top_right"
mfchange_intact_figure.legend.label_text_font_size = "8pt"
mfchange_intact_figure.legend.spacing = 5
mfchange_intact_figure.legend.padding = 2
mfchange_intact_figure.legend.margin = 2
mfchange_intact_figure.legend.glyph_width = 20
mfchange_intact_figure.legend.background_fill_alpha = 1

# font size
mfchange_intact_figure.xaxis.axis_label_text_font_size = "8pt"
mfchange_intact_figure.xaxis.major_label_text_font_size = "8pt"
mfchange_intact_figure.yaxis.axis_label_text_font_size = "8pt"
mfchange_intact_figure.yaxis.major_label_text_font_size = "8pt"

# show plot
bkplot.show(mfchange_intact_figure)
bkplot.output_file('partialmut_e.html')
bkplot.save(mfchange_intact_figure)

'/mnt/c/Users/ersat/CODE/punisher/partialmut/partialmut_e.html'