In [1]:
# import library

import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
from jax import random, vmap
import pickle

import numpyro
numpyro.enable_x64()

from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS
from numpyro.infer import initialization
import jax.numpy

from scipy import stats
from tqdm import tqdm

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
    print('Not connected to a GPU')
else:
    print(gpu_info)
!XLA_PYTHON_CLIENT_PREALLOCATE=false



gpu
Thu Nov 10 22:33:23 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.32.00    Driver Version: 455.32.00    CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 3090    On   | 00000000:01:00.0 Off |                  N/A |
|  0%   39C    P2    44W / 420W |    288MiB / 24265MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Pr

In [2]:
# get raw NGS counts
df2 = pd.read_csv('../Raw_NGS_count_tables/NGS_count_lib2.csv')
df2

Unnamed: 0,name,aa_seq,dna_seq,v2_T01,v2_T02,v2_T03,v2_T04,v2_T05,v2_T06,v2_T07,...,v2_C15,v2_C16,v2_C17,v2_C18,v2_C19,v2_C20,v2_C21,v2_C22,v2_C23,v2_C24
0,1A32.pdb,SAGGSPEVQIAILTEQINNLNEHLRVHKKDHHSRRGLLKMVGKRRR...,TCTGCGGGTGGCTCTCCAGAAGTTCAGATTGCGATCCTGACCGAAC...,554.0,581.0,423.0,479.0,380.0,383.0,478.0,...,325.0,370.0,244.0,156.0,62.0,9.0,7.0,8.0,12.0,21.0
1,1A32.pdb_A45D,SAGGSPEVQIAILTEQINNLNEHLRVHKKDHHSRRGLLKMVGKRRR...,TCTGCGGGTGGTTCTCCGGAAGTTCAGATCGCTATTCTGACCGAAC...,458.0,534.0,379.0,370.0,331.0,255.0,368.0,...,213.0,303.0,195.0,143.0,58.0,5.0,1.0,9.0,13.0,14.0
2,1A32.pdb_A45E,SAGGSPEVQIAILTEQINNLNEHLRVHKKDHHSRRGLLKMVGKRRR...,TCTGCGGGTGGTTCTCCGGAAGTTCAGATCGCTATTCTGACCGAAC...,373.0,411.0,292.0,311.0,198.0,226.0,312.0,...,218.0,252.0,182.0,123.0,96.0,9.0,2.0,6.0,10.0,23.0
3,1A32.pdb_A45F,SAGGSPEVQIAILTEQINNLNEHLRVHKKDHHSRRGLLKMVGKRRR...,TCCGCTGGCGGTTCTCCGGAAGTTCAGATCGCTATTCTGACCGAAC...,132.0,156.0,92.0,119.0,59.0,52.0,85.0,...,81.0,78.0,36.0,29.0,3.0,2.0,2.0,0.0,1.0,2.0
4,1A32.pdb_A45G,SAGGSPEVQIAILTEQINNLNEHLRVHKKDHHSRRGLLKMVGKRRR...,TCTGCTGGTGGCTCTCCGGAAGTTCAGATCGCTATTCTGACCGAAC...,153.0,226.0,144.0,152.0,115.0,83.0,132.0,...,107.0,113.0,62.0,51.0,10.0,3.0,2.0,3.0,7.0,6.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
661330,Lib2_dummy_84966,-,-,53.0,81.0,59.0,41.0,51.0,45.0,53.0,...,56.0,60.0,33.0,44.0,47.0,50.0,10.0,10.0,13.0,9.0
661331,Lib2_dummy_84967,-,-,151.0,199.0,152.0,153.0,103.0,113.0,182.0,...,136.0,160.0,91.0,141.0,151.0,222.0,128.0,189.0,76.0,26.0
661332,Lib2_dummy_84968,-,-,109.0,160.0,94.0,127.0,67.0,95.0,155.0,...,86.0,138.0,80.0,89.0,99.0,141.0,74.0,98.0,55.0,80.0
661333,Lib2_dummy_84969,-,-,13.0,17.0,19.0,8.0,7.0,18.0,8.0,...,15.0,17.0,10.0,8.0,13.0,6.0,2.0,2.0,2.0,0.0


In [3]:
###############################
#
#  Set up a model to get K50 from raw NGS counts
#
###############################

def count2K50_model(counts,protease_t,protease_c): 
    # counts: raw NGS counts [# of sequences, # of conditions (48; 12 concentrations x 2 replicates x 2 proteases)]
    # protease_t/c: protease concentration obtained in STEP1 [# of conditions (48)]
    
    # n: number of sequences
    n = len(counts[0,:])
    # total_count: total number of NGS counts for each condition [# of conditions (48)]
    total_count = np.array([int(x) for x in np.sum(counts, axis=1)])
    
    kmax_times_t = 10**0.65
    
    # log10_A0_xy: initial fraction for each sequence in log10 [# of sequences (n)]
    # x: trypsin (t) or chymotrypsin (c)
    # y: replicate (1 or 2)
    # sampled in normal distribution
    log10_A0_t1 = numpyro.sample("log10_A0_t1", dist.Normal(np.resize(np.log10(1/n),n), 1))
    log10_A0_t2 = numpyro.sample("log10_A0_t2", dist.Normal(np.resize(np.log10(1/n),n), 1))
    log10_A0_c1 = numpyro.sample("log10_A0_c1", dist.Normal(np.resize(np.log10(1/n),n), 1))
    log10_A0_c2 = numpyro.sample("log10_A0_c2", dist.Normal(np.resize(np.log10(1/n),n), 1))

    # log10_K50_t/c: log10 K50 values for each sequence [# of sequences], sampled in wide normal distribution
    log10_K50_t = numpyro.sample("log10_K50_t", dist.Normal(np.resize(0,n), 4) ) 
    log10_K50_c = numpyro.sample("log10_K50_c", dist.Normal(np.resize(0,n), 4) ) 

    # survival: relative ratio of each sequence for each condition to initial condition (no protease) [# of sequences (n), # of conditions (48)]
    # survival = exp(- kmax*t*[protease]/(K50+[protease]))
    survival=jax.numpy.concatenate([jax.numpy.exp(-jax.numpy.outer(kmax_times_t,protease_t)/((jax.numpy.resize(10.0**log10_K50_t,(24,n)).T)+jax.numpy.resize(protease_t,(n,24)))),
                                    jax.numpy.exp(-jax.numpy.outer(kmax_times_t, protease_c)/((jax.numpy.resize(10.0**log10_K50_c,(24,n)).T)+jax.numpy.resize(protease_c,(n,24))))]
                                   ,axis=1)
    
    # nonnorm_fraction: relative ratio of each sequence for each condition [# of sequences (n), # of conditions (48)]
    # nonnorm_fraction = initial ratio (A0) * survival
    survival = survival.T
    nonnorm_fraction = jax.numpy.concatenate([survival[0:12,:] * 10**log10_A0_t1,
                                              survival[12:24,:] * 10**log10_A0_t2,
                                              survival[24:36,:] * 10**log10_A0_c1,
                                              survival[36:48,:] * 10**log10_A0_c2])
    
    # fraction: normalized ratio of each sequence for each condition [# of sequences (n), # of conditions (48)]
    # fraction = nonnorm_fraction/sum(nonnorm_fraction)
    fraction=nonnorm_fraction / np.reshape(jax.numpy.sum(nonnorm_fraction, axis=1), (48, 1))
    
    # fitting paramters by assuming the observed NGS counts matched the multinomial distribution
    # obs_counts: observed NGS count number [# of sequences (n), # of conditions (48)]
    obs_counts = numpyro.sample("counts", dist.Multinomial(total_count = total_count,probs=fraction),obs=jax.numpy.array(counts))

In [4]:
# make dataframe including NGS counts and remove sequences with no counts in no protease samples

df = pd.concat([df2['name'],df2['dna_seq'],df2.iloc[:,3:51]],axis=1)
df = df[(df['v2_T01']>0)&(df['v2_T13']>0)&(df['v2_C01']>0)&(df['v2_C13']>0)].reset_index(drop=True)
df

Unnamed: 0,name,dna_seq,v2_T01,v2_T02,v2_T03,v2_T04,v2_T05,v2_T06,v2_T07,v2_T08,...,v2_C15,v2_C16,v2_C17,v2_C18,v2_C19,v2_C20,v2_C21,v2_C22,v2_C23,v2_C24
0,1A32.pdb,TCTGCGGGTGGCTCTCCAGAAGTTCAGATTGCGATCCTGACCGAAC...,554.0,581.0,423.0,479.0,380.0,383.0,478.0,431.0,...,325.0,370.0,244.0,156.0,62.0,9.0,7.0,8.0,12.0,21.0
1,1A32.pdb_A45D,TCTGCGGGTGGTTCTCCGGAAGTTCAGATCGCTATTCTGACCGAAC...,458.0,534.0,379.0,370.0,331.0,255.0,368.0,323.0,...,213.0,303.0,195.0,143.0,58.0,5.0,1.0,9.0,13.0,14.0
2,1A32.pdb_A45E,TCTGCGGGTGGTTCTCCGGAAGTTCAGATCGCTATTCTGACCGAAC...,373.0,411.0,292.0,311.0,198.0,226.0,312.0,340.0,...,218.0,252.0,182.0,123.0,96.0,9.0,2.0,6.0,10.0,23.0
3,1A32.pdb_A45F,TCCGCTGGCGGTTCTCCGGAAGTTCAGATCGCTATTCTGACCGAAC...,132.0,156.0,92.0,119.0,59.0,52.0,85.0,47.0,...,81.0,78.0,36.0,29.0,3.0,2.0,2.0,0.0,1.0,2.0
4,1A32.pdb_A45G,TCTGCTGGTGGCTCTCCGGAAGTTCAGATCGCTATTCTGACCGAAC...,153.0,226.0,144.0,152.0,115.0,83.0,132.0,99.0,...,107.0,113.0,62.0,51.0,10.0,3.0,2.0,3.0,7.0,6.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
634984,Lib2_dummy_84966,-,53.0,81.0,59.0,41.0,51.0,45.0,53.0,51.0,...,56.0,60.0,33.0,44.0,47.0,50.0,10.0,10.0,13.0,9.0
634985,Lib2_dummy_84967,-,151.0,199.0,152.0,153.0,103.0,113.0,182.0,226.0,...,136.0,160.0,91.0,141.0,151.0,222.0,128.0,189.0,76.0,26.0
634986,Lib2_dummy_84968,-,109.0,160.0,94.0,127.0,67.0,95.0,155.0,150.0,...,86.0,138.0,80.0,89.0,99.0,141.0,74.0,98.0,55.0,80.0
634987,Lib2_dummy_84969,-,13.0,17.0,19.0,8.0,7.0,18.0,8.0,6.0,...,15.0,17.0,10.0,8.0,13.0,6.0,2.0,2.0,2.0,0.0


In [5]:
# get protease concentrations calibrated in STEP1

with open('STEP1_out_protease_concentration_trypsin', 'rb') as p:
    protease_tryp = pickle.load(p)

with open('STEP1_out_protease_concentration_chymotrypsin', 'rb') as p:
    protease_chymo = pickle.load(p)

protease_con_v2t = protease_tryp['protease_v2']
protease_con_v2c = protease_chymo['protease_v2']

In [6]:
# run the model

rng_key = random.PRNGKey(1)
rng_key, rng_key_ = random.split(rng_key)

kernel = NUTS(count2K50_model, init_strategy=initialization.init_to_feasible())
mcmc = MCMC(kernel, num_warmup=100, num_samples=50, num_chains=1)

mcmc.run(rng_key_, counts=np.array(df.iloc[:,2:50].T),protease_t=protease_con_v2t,protease_c=protease_con_v2c)
samples1=mcmc.get_samples()

sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 150/150 [1:41:37<00:00, 40.65s/it, 1023 steps of size 3.95e-04. acc. prob=0.81]


In [7]:
###############################
#
#  Calculate expected counts based on K50 values
#
###############################

# counts: raw NGS counts [# of sequences, # of conditions (24)]
counts = np.array(df.iloc[:,2:50].T)

## trypsin challenge
# A_X:  initial fraction for each sequence [# of sequences]
A0_t1 = 10**np.percentile(samples1['log10_A0_t1'], 50, axis=0)
A0_t2 = 10**np.percentile(samples1['log10_A0_t2'], 50, axis=0)
# log10_K50_t: log10 K50 values for each sequence [# of sequences]
log10_K50_t = np.percentile(samples1['log10_K50_t'],50,axis=0)
#n : the nubmer of sequences
n = len(counts[0,:])

kmax_times_t = 10**0.65

# survival: relative ratio of each sequence for each condition to initial condition (no protease) [# of sequences (n), # of conditions (24)]
# survival = exp(- kmax*t*[protease]/(K50+[protease]))
survival=np.exp(-np.outer(kmax_times_t, protease_con_v2t)/(np.resize(10**log10_K50_t,(24,n)).T+np.resize(protease_con_v2t,(n,24))))

# nonnorm_fraction: relative ratio of each sequence for each condition [# of sequences (n), # of conditions (24)]
# nonnorm_fraction = initial ratio (A0) * survival
nonnorm_fraction = np.concatenate([survival[:,0:12].T * A0_t1,survival[:,12:24].T * A0_t2])

# fraction: normalized ratio of each sequence for each condition [# of sequences (n), # of conditions (24)]
# fraction = nonnorm_fraction/sum(nonnorm_fraction)
fraction=nonnorm_fraction / np.reshape(jax.numpy.sum(nonnorm_fraction, axis=1), (24, 1))

# count_expected_t: expected count number based on K50 values [# of sequences (n), # of conditions (24)]
count_expected_t=np.array([int(x) for x in np.sum(counts, axis=1)])[:24]*fraction.T


## chymotrypsin challenge
A0_c1 = 10**np.percentile(samples1['log10_A0_c1'], 50, axis=0)
A0_c2 = 10**np.percentile(samples1['log10_A0_c2'], 50, axis=0)
log10_K50_c = np.percentile(samples1['log10_K50_c'],50,axis=0)
survival=np.exp(-np.outer(kmax_times_t, protease_con_v2c)/(np.resize(10**log10_K50_c,(24,n)).T+np.resize(protease_con_v2c,(n,24))))
nonnorm_fraction = np.concatenate([survival[:,0:12].T * A0_c1,survival[:,12:24].T * A0_c2])
fraction=nonnorm_fraction / np.reshape(jax.numpy.sum(nonnorm_fraction, axis=1), (24, 1))
count_expected_c=np.array([int(x) for x in np.sum(counts, axis=1)])[24:]*fraction.T

In [8]:
###############################
#
#  Make dataframe including all info
#
###############################


dfsum = pd.DataFrame()
dfsum['name'] = df['name']
dfsum['dna_seq'] = df['dna_seq']

dfsum['log10_K50_t'] = np.percentile(samples1['log10_K50_t'],50,axis=0)
dfsum['log10_K50_t_95CI_high'] = np.percentile(samples1['log10_K50_t'],97.5,axis=0)
dfsum['log10_K50_t_95CI_low'] = np.percentile(samples1['log10_K50_t'],2.5,axis=0)
# fitting_error_t : absolute error between the observed counts and the expected counts for a given sequence (based on all model parameters related to trypsin), averaged over 24 conditions and normalized by the observed counts in the no-protease samples for that sequence
dfsum['fitting_error_t'] = [sum(np.abs(x))/(y[0]+y[12])/12 for x,y in zip(np.array(df.iloc[:,2:26])-count_expected_t[:,:],np.array(df.iloc[:,2:26]))]

dfsum['log10_K50_c'] = np.percentile(samples1['log10_K50_c'],50,axis=0)
dfsum['log10_K50_c_95CI_high'] = np.percentile(samples1['log10_K50_c'],97.5,axis=0)
dfsum['log10_K50_c_95CI_low'] = np.percentile(samples1['log10_K50_c'],2.5,axis=0)
dfsum['fitting_error_c'] = [sum(np.abs(x))/(y[0]+y[12])/12 for x,y in zip(np.array(df.iloc[:,26:50])-count_expected_c[:,:],np.array(df.iloc[:,26:50]))]

dfsum

Unnamed: 0,name,dna_seq,log10_K50_t,log10_K50_t_95CI_high,log10_K50_t_95CI_low,fitting_error_t,log10_K50_c,log10_K50_c_95CI_high,log10_K50_c_95CI_low,fitting_error_c
0,1A32.pdb,TCTGCGGGTGGCTCTCCAGAAGTTCAGATTGCGATCCTGACCGAAC...,0.258874,0.272459,0.237274,0.062285,-0.783398,-0.773568,-0.813558,0.047019
1,1A32.pdb_A45D,TCTGCGGGTGGTTCTCCGGAAGTTCAGATCGCTATTCTGACCGAAC...,0.231431,0.272045,0.207189,0.042640,-0.717615,-0.700997,-0.730519,0.040713
2,1A32.pdb_A45E,TCTGCGGGTGGTTCTCCGGAAGTTCAGATCGCTATTCTGACCGAAC...,0.468634,0.487053,0.434487,0.042349,-0.556989,-0.533996,-0.587870,0.052453
3,1A32.pdb_A45F,TCCGCTGGCGGTTCTCCGGAAGTTCAGATCGCTATTCTGACCGAAC...,0.012309,0.046098,-0.027893,0.074535,-1.110422,-1.079111,-1.182161,0.051027
4,1A32.pdb_A45G,TCTGCTGGTGGCTCTCCGGAAGTTCAGATCGCTATTCTGACCGAAC...,-0.008602,0.042370,-0.036318,0.054017,-0.992854,-0.954309,-1.013008,0.036031
...,...,...,...,...,...,...,...,...,...,...
634984,Lib2_dummy_84966,-,0.380952,0.447801,0.326435,0.110450,0.448683,0.493519,0.413246,0.093621
634985,Lib2_dummy_84967,-,1.826327,1.882690,1.762826,0.096576,1.466183,1.483905,1.444545,0.077107
634986,Lib2_dummy_84968,-,1.714145,1.735496,1.673833,0.112793,1.604773,1.623680,1.588466,0.132030
634987,Lib2_dummy_84969,-,-0.000222,0.186991,-0.120801,0.157124,0.126274,0.208306,0.051873,0.202526


In [9]:
dfsum.to_csv('STEP2_out_K50_lib2.csv',index=False)