In [1]:
import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
from jax import random, vmap
import numpyro
numpyro.enable_x64()

from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS
import jax.numpy

from scipy import stats
from tqdm import tqdm

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
    print('Not connected to a GPU')
else:
    print(gpu_info)
!XLA_PYTHON_CLIENT_PREALLOCATE=false

import pickle
from numpyro.infer import initialization

gpu




Fri Nov  4 10:18:48 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.32.00    Driver Version: 455.32.00    CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 3090    On   | 00000000:01:00.0  On |                  N/A |
|  0%   36C    P2    22W / 420W |    647MiB / 24260MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
###############################
#
#  Set up a model to get K50 and deltaG from NGS counts
#
###############################

def multi_k_delg_model(counts,log10_K50_unfolded_t,log10_K50_unfolded_c,protease_t,protease_c):
    # counts: raw NGS counts [# of sequences, # of conditions (48; 12 concentrations x 2 replicates x 2 proteases)]
    # log10_K50_unfolded_t/c: Predicted log10 K50 unfolded [# of sequences]
    # protease_t/c: protease concentration [# of conditions (48)]
    
    # n: number of sequences
    n = len(counts[0,:])
    # total_count: total number of NGS counts for each condition [# of conditions (24)]    
    total_count = np.array([int(x) for x in np.sum(counts, axis=1)])
    
    kmax_times_t = 10**0.65

    # log10_A0_xy: initial fraction for each sequence in log10 [# of sequences (n)]
    # x: trypsin (t) or chymotrypsin (c)
    # y: replicate (1 or 2)
    # sampled in normal distribution
    log10_A0_t1 = numpyro.sample("log10_A0_t1", dist.Normal(np.resize(np.log10(1/n),n), 1)) 
    log10_A0_t2 = numpyro.sample("log10_A0_t2", dist.Normal(np.resize(np.log10(1/n),n), 1))
    log10_A0_c1 = numpyro.sample("log10_A0_c1", dist.Normal(np.resize(np.log10(1/n),n), 1))
    log10_A0_c2 = numpyro.sample("log10_A0_c2", dist.Normal(np.resize(np.log10(1/n),n), 1))

###############################
#  Sample K50_trypsin and K50_chymotrypsin
###############################
    # log10_K50_t/c: log10 K50 values for each sequence [# of sequences], sampled in wide normal distribution
    log10_K50_t = numpyro.sample("log10_K50_t", dist.Normal(np.resize(0,n), 4) ) 
    log10_K50_c = numpyro.sample("log10_K50_c", dist.Normal(np.resize(0,n), 4) )    
    
    # survival_TC: relative ratio of each sequence for each condition to initial condition (no protease) [# of sequences (n), # of conditions (48)]
    # survival = exp(- kmax*t*[protease]/(K50+[protease]))
    survival_TC=jax.numpy.concatenate([jax.numpy.exp(-jax.numpy.outer(kmax_times_t, protease_t)/((jax.numpy.resize(10.0**log10_K50_t,(24,n)).T)+jax.numpy.resize(protease_t,(n,24)))),
                                       jax.numpy.exp(-jax.numpy.outer(kmax_times_t, protease_c)/((jax.numpy.resize(10.0**log10_K50_c,(24,n)).T)+jax.numpy.resize(protease_c,(n,24))))]
                                      ,axis=1)
    
    # nonnorm_fraction: relative ratio of each sequence for each condition [# of sequences (n), # of conditions (48)]
    # nonnorm_fraction = initial ratio (A0) * survival
    survival_TC = survival_TC.T
    nonnorm_fraction = jax.numpy.concatenate([survival_TC[0:12,:] * 10**log10_A0_t1,
                                              survival_TC[12:24,:] * 10**log10_A0_t2,
                                              survival_TC[24:36,:] * 10**log10_A0_c1,
                                              survival_TC[36:48,:] * 10**log10_A0_c2])

    # fraction: normalized ratio of each sequence for each condition [# of sequences (n), # of conditions (48)]
    # fraction = nonnorm_fraction/sum(nonnorm_fraction)
    fraction=nonnorm_fraction / np.reshape(jax.numpy.sum(nonnorm_fraction, axis=1), (48, 1))
    
    ## fitting paramters by assuming the observed NGS counts matched the multinomial distribution
    # obs_counts: observed NGS count number [# of sequences (n), # of conditions (48)]
    obs_counts = numpyro.sample("counts", dist.Multinomial(total_count = total_count,probs=fraction),obs=jax.numpy.array(counts))
    
    
###############################
#  Sample deltaG
###############################
    # fixed K50 folded values
    log10_K50_folded_t,log10_K50_folded_c = 2.25,1.75 
    # log10_K50_folded_tc: combined K50 folded values [-,trypsin/chymoptrysin (2)]
    log10_K50_folded_tc = jax.numpy.array([[log10_K50_folded_t, log10_K50_folded_c]])
    
    # deltaG: folding stability shared between trypsin and chymotrypsin challege [# of sequences (n)]; sampled in wide normal distribution
    deltaG = numpyro.sample("deltaG", dist.Normal(np.resize(0, (n, 1)), 6))
    
    # fraction_unfolded: fraction of unfolded state calculated by deltaG [# of scrambles]
    # fraction_unfolded = 1/(1+ΔG/RT) 
    fraction_unfolded = 1.0 / (1.0 + jax.numpy.exp(deltaG / 0.58)) 
    
    # log10_theoretical_K50_tc: theoretical log10 K50 values for trypin/chymotrypsin computed fraction_unfolded (i.e. deltaG) [# of scrambles, trypsin/chymotrypsin(2)]
    # 1/K50 = fraction_unfolded/K50,U + (1-fraction_unfolded)/K50,F 
    log10_theoretical_K50_tc = - jax.numpy.log10( ((10.0 ** -( np.stack((log10_K50_unfolded_t, log10_K50_unfolded_c)).T    )) * fraction_unfolded) +  (10.0 ** -log10_K50_folded_tc) * (1-fraction_unfolded) )
    
    # survival_deltaG: relative ratio of each sequence for each condition to initial condition (no protease) [# of sequences (n), # of conditions (48)]; based on deltaG values sampled above
    survival_deltaG=jax.numpy.concatenate([jax.numpy.exp(-jax.numpy.outer(kmax_times_t, protease_t)/((jax.numpy.resize(10.0**log10_theoretical_K50_tc[:,0],(24,n)).T)+jax.numpy.resize(protease_t,(n,24))))
                                           ,jax.numpy.exp(-jax.numpy.outer(kmax_times_t, protease_c)/((jax.numpy.resize(10.0**log10_theoretical_K50_tc[:,1],(24,n)).T)+jax.numpy.resize(protease_c,(n,24))))]
                                          ,axis=1)
    
    # nonnorm_fraction_deltaG: relative ratio of each sequence for each condition [# of sequences (n), # of conditions (48)]
    survival_deltaG = survival_deltaG.T
    nonnorm_fraction_deltaG = jax.numpy.concatenate([survival_deltaG[0:12,:] * 10**log10_A0_t1,
                                                     survival_deltaG[12:24,:] * 10**log10_A0_t2,
                                                     survival_deltaG[24:36,:] * 10**log10_A0_c1,
                                                     survival_deltaG[36:48,:] * 10**log10_A0_c2])
    
    # fraction_deltaG: normalized ratio of each sequence for each condition [# of sequences (n), # of conditions (48)]
    fraction_deltaG=nonnorm_fraction_deltaG / np.reshape(jax.numpy.sum(nonnorm_fraction_deltaG, axis=1), (48, 1))
    
    # fitting paramters by assuming the observed NGS counts matched the multinomial distribution
    # obs_counts_deltaG: observed NGS count number [# of sequences (n), # of conditions (48)]
    obs_counts_deltaG = numpyro.sample("counts_deltaG", dist.Multinomial(total_count = total_count,probs=fraction_deltaG),obs=jax.numpy.array(counts))
    
    

In [3]:
###############################
#
#  Set up a model to get K50 and deltaG from NGS counts
#
###############################

def obtain_samples(lib_name,df):
    # get protease concentrations calibrated in STEP1
    with open('STEP1_out_protease_concentration_trypsin', 'rb') as f:
         tryp_cons = pickle.load(f)
    with open('STEP1_out_protease_concentration_chymotrypsin', 'rb') as f:
         chymo_cons = pickle.load(f)

    protease_t, protease_c = tryp_cons['protease_'+lib_name], chymo_cons['protease_'+lib_name] 

    # run the model
    rng_key = random.PRNGKey(1)
    rng_key, rng_key_ = random.split(rng_key)
    kernel = NUTS(multi_k_delg_model, init_strategy=initialization.init_to_feasible())
    mcmc = MCMC(kernel, num_warmup=100, num_samples=50, num_chains=1)
    mcmc.run(rng_key_, counts=np.array(df.iloc[:,3:51].T),log10_K50_unfolded_t=np.array(df['unfolded_K50t']),log10_K50_unfolded_c=np.array(df['unfolded_K50c']),protease_t=protease_t,protease_c=protease_c)
    samples1=mcmc.get_samples()
    
    return samples1


In [4]:
###############################
#
#  Summarize all data into one dataframe
#
###############################

def samples2sumdf(samples1,lib_name,df):
    # get protease concentrations calibrated in STEP1
    with open('STEP1_out_protease_concentration_trypsin', 'rb') as f:
         tryp_cons = pickle.load(f)
    with open('STEP1_out_protease_concentration_chymotrypsin', 'rb') as f:
         chymo_cons = pickle.load(f)
    protease_t, protease_c = tryp_cons['protease_'+lib_name], chymo_cons['protease_'+lib_name] 
    
    # get NGS counts
    counts=np.array(df.iloc[:,3:51].T)
    
###############################
#  Calculate expected counts based on K50 values
###############################
    ## trypsin challenge
    # A_X:  initial fraction for each sequence [# of sequences]
    A0_t1 = 10**np.percentile(samples1['log10_A0_t1'], 50, axis=0)
    A0_t2 = 10**np.percentile(samples1['log10_A0_t2'], 50, axis=0)
    # log10_K50_t: log10 K50 values for each sequence [# of sequences]
    log10_K50_t = np.percentile(samples1['log10_K50_t'],50,axis=0)
    #n : the nubmer of sequences
    n = len(counts[0,:])
    
    kmax_times_t = 10**0.65

    # survival: relative ratio of each sequence for each condition to initial condition (no protease) [# of sequences (n), # of conditions (24)]
    # survival = exp(- kmax*t*[protease]/(K50+[protease]))
    survival=np.exp(-np.outer(kmax_times_t, protease_t)/(np.resize(10**log10_K50_t,(24,n)).T+np.resize(protease_t,(n,24))))

    # nonnorm_fraction: relative ratio of each sequence for each condition [# of sequences (n), # of conditions (24)]
    # nonnorm_fraction = initial ratio (A0) * survival
    nonnorm_fraction = np.concatenate([survival[:,0:12].T * A0_t1,survival[:,12:24].T * A0_t2])

    # fraction: normalized ratio of each sequence for each condition [# of sequences (n), # of conditions (24)]
    # fraction = nonnorm_fraction/sum(nonnorm_fraction)
    fraction=nonnorm_fraction / np.reshape(jax.numpy.sum(nonnorm_fraction, axis=1), (24, 1))

    # count_expected_t: expected count number based on K50 values [# of sequences (n), # of conditions (24)]
    count_expected_t=np.array([int(x) for x in np.sum(counts, axis=1)])[:24]*fraction.T


    ## chymotrypsin challenge
    A0_c1 = 10**np.percentile(samples1['log10_A0_c1'], 50, axis=0)
    A0_c2 = 10**np.percentile(samples1['log10_A0_c2'], 50, axis=0)
    log10_K50_c = np.percentile(samples1['log10_K50_c'],50,axis=0)
    survival=np.exp(-np.outer(kmax_times_t, protease_c)/(np.resize(10**log10_K50_c,(24,n)).T+np.resize(protease_c,(n,24))))
    nonnorm_fraction = np.concatenate([survival[:,0:12].T * A0_c1,survival[:,12:24].T * A0_c2])
    fraction=nonnorm_fraction / np.reshape(jax.numpy.sum(nonnorm_fraction, axis=1), (24, 1))
    count_expected_c=np.array([int(x) for x in np.sum(counts, axis=1)])[24:]*fraction.T
    
###############################
#  Summarize data related to K50_trypsin
###############################
    dfsum = pd.DataFrame()
    dfsum['name'] = df['name']
    dfsum['dna_seq'] = df['dna_seq']
    dfsum['log10_K50_t'] = np.percentile(samples1['log10_K50_t'],50,axis=0)
    dfsum['log10_K50_t_95CI_high'] = np.percentile(samples1['log10_K50_t'],97.5,axis=0)
    dfsum['log10_K50_t_95CI_low'] = np.percentile(samples1['log10_K50_t'],2.5,axis=0)
    dfsum['log10_K50_t_95CI'] = dfsum['log10_K50_t_95CI_high'] - dfsum['log10_K50_t_95CI_low']
    # fitting_error_t : absolute error between the observed counts and the expected counts for a given sequence (based on all model parameters related to trypsin), averaged over 24 conditions and normalized by the observed counts in the no-protease samples for that sequence
    dfsum['fitting_error_t'] = [sum(np.abs(x))/(y[0]+y[12])/12 for x,y in zip(np.array(df.iloc[:,3:27])-count_expected_t[:,:],np.array(df.iloc[:,3:27]))] 
    dfsum['log10_K50unfolded_t'] = np.array(df['unfolded_K50t'])


    deltaG_t_list,deltaG_t_95CI_high_list,deltaG_t_95CI_low_list = [],[],[]
    k_folded_T,k_folded_C = 2.5,2

    for i in tqdm(range(len(df))):
    # calculate deltaG_t_95CI_high (top 2.5%ile of deltaG trypsin)
    # ΔG = -RT ln((1/K50-1/K50,F)/(1/K50,U-1/K50))
        if (dfsum['log10_K50_t_95CI_high'][i] > k_folded_T): deltaG_t_95CI_high_list.append(25) # in case K50 is too high
        elif (dfsum['log10_K50unfolded_t'][i] < dfsum['log10_K50_t_95CI_high'][i] < k_folded_T):
            deltaG_t_95CI_high_list.append(0.58*jax.numpy.log((10.0 ** -dfsum['log10_K50unfolded_t'][i] - 10** -k_folded_T )/(10** -dfsum['log10_K50_t_95CI_high'][i] - 10.0 ** -k_folded_T)-1)) # in case K50 is in dynamic range
        else: deltaG_t_95CI_high_list.append(-5) # in case K50 is too low
            
    # calculate deltaG_t (Median of deltaG trypsin)
        if (dfsum['log10_K50_t'][i] > k_folded_T): deltaG_t_list.append(15) # in case K50 is too high
        elif (dfsum['log10_K50unfolded_t'][i] < dfsum['log10_K50_t'][i] < k_folded_T):
            deltaG_t_list.append(0.58*jax.numpy.log((10.0 ** -dfsum['log10_K50unfolded_t'][i] - 10** -k_folded_T )/(10** -dfsum['log10_K50_t'][i] - 10.0 ** -k_folded_T)-1)) # in case K50 is in dynamic range
        else: deltaG_t_list.append(-15) # in case K50 is too low
            
    # calculate deltaG_t_95CI_low (top 97.5%ile of deltaG trypsin)
        if (dfsum['log10_K50_t_95CI_low'][i] > k_folded_T): deltaG_t_95CI_low_list.append(5) # in case K50 is too high
        elif (dfsum['log10_K50unfolded_t'][i] < dfsum['log10_K50_t_95CI_low'][i] < k_folded_T):
            deltaG_t_95CI_low_list.append(0.58*jax.numpy.log((10.0 ** -dfsum['log10_K50unfolded_t'][i] - 10** -k_folded_T )/(10** -dfsum['log10_K50_t_95CI_low'][i] - 10.0 ** -k_folded_T)-1)) # in case K50 is in dynamic range
        else: deltaG_t_95CI_low_list.append(-25) # in case K50 is too low

    dfsum['deltaG_t'] = deltaG_t_list
    dfsum['deltaG_t_95CI_high'] = deltaG_t_95CI_high_list
    dfsum['deltaG_t_95CI_low'] = deltaG_t_95CI_low_list
    dfsum['deltaG_t_95CI'] = np.array(deltaG_t_95CI_high_list) - np.array(deltaG_t_95CI_low_list)
    
###############################
#  Summarize data related to K50_chymotrypsin
###############################
    dfsum['log10_K50_c'] = np.percentile(samples1['log10_K50_c'],50,axis=0)
    dfsum['log10_K50_c_95CI_high'] = np.percentile(samples1['log10_K50_c'],97.5,axis=0)
    dfsum['log10_K50_c_95CI_low'] = np.percentile(samples1['log10_K50_c'],2.5,axis=0)
    dfsum['log10_K50_c_95CI'] = dfsum['log10_K50_c_95CI_high'] - dfsum['log10_K50_c_95CI_low']
    dfsum['fitting_error_c'] = [sum(np.abs(x))/(y[0]+y[12])/12 for x,y in zip(np.array(df.iloc[:,27:51])-count_expected_c[:,:],np.array(df.iloc[:,27:51]))] 
    dfsum['log10_K50unfolded_c'] = np.array(df['unfolded_K50c'])
    
    
    deltaG_c_list,deltaG_c_95CI_high_list,deltaG_c_95CI_low_list = [],[],[]


    for i in tqdm(range(len(df))):
        if (dfsum['log10_K50_c_95CI_high'][i] > k_folded_C): deltaG_c_95CI_high_list.append(25)
        elif (dfsum['log10_K50unfolded_c'][i] < dfsum['log10_K50_c_95CI_high'][i] < k_folded_C):
            deltaG_c_95CI_high_list.append(0.58*jax.numpy.log((10.0 ** -dfsum['log10_K50unfolded_c'][i] - 10** -k_folded_C )/(10** -dfsum['log10_K50_c_95CI_high'][i] - 10.0 ** -k_folded_C)-1))
        else: deltaG_c_95CI_high_list.append(-5)

        if (dfsum['log10_K50_c'][i] > k_folded_C): deltaG_c_list.append(15)
        elif (dfsum['log10_K50unfolded_c'][i] < dfsum['log10_K50_c'][i] < k_folded_C):
            deltaG_c_list.append(0.58*jax.numpy.log((10.0 ** -dfsum['log10_K50unfolded_c'][i] - 10** -k_folded_C )/(10** -dfsum['log10_K50_c'][i] - 10.0 ** -k_folded_C)-1))
        else: deltaG_c_list.append(-15)

        if (dfsum['log10_K50_c_95CI_low'][i] > k_folded_C): deltaG_c_95CI_low_list.append(5)
        elif (dfsum['log10_K50unfolded_c'][i] < dfsum['log10_K50_c_95CI_low'][i] < k_folded_C):
            deltaG_c_95CI_low_list.append(0.58*jax.numpy.log((10.0 ** -dfsum['log10_K50unfolded_c'][i] - 10** -k_folded_C )/(10** -dfsum['log10_K50_c_95CI_low'][i] - 10.0 ** -k_folded_C)-1))
        else: deltaG_c_95CI_low_list.append(-25)

    dfsum['deltaG_c'] = deltaG_c_list
    dfsum['deltaG_c_95CI_high'] = deltaG_c_95CI_high_list
    dfsum['deltaG_c_95CI_low'] = deltaG_c_95CI_low_list
    dfsum['deltaG_c_95CI'] = np.array(deltaG_c_95CI_high_list) - np.array(deltaG_c_95CI_low_list)

###############################
#  Summarize data related to deltaG (combined)
###############################
    dfsum['deltaG'] = np.percentile(samples1['deltaG'],50,axis=0)
    dfsum['deltaG_95CI_high'] = np.percentile(samples1['deltaG'],97.5,axis=0)
    dfsum['deltaG_95CI_low'] = np.percentile(samples1['deltaG'],2.5,axis=0)
    dfsum['deltaG_95CI'] = dfsum['deltaG_95CI_high'] - dfsum['deltaG_95CI_low']
    
    return dfsum