In [1]:
# import libaries

import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
from jax import random, vmap
import numpyro
numpyro.enable_x64()
import pickle
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS
import jax.numpy

from scipy import stats
from Bio.Seq import Seq
from tqdm import tqdm

import warnings
warnings.filterwarnings('ignore')

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
    print('Not connected to a GPU')
else:
    print(gpu_info)
!XLA_PYTHON_CLIENT_PREALLOCATE=false

from sklearn.preprocessing import OneHotEncoder




gpu
Tue Oct 11 09:07:59 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.32.00    Driver Version: 455.32.00    CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 3090    On   | 00000000:01:00.0  On |                  N/A |
|  0%   46C    P2    46W / 420W |    325MiB / 24260MiB |     77%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Pr

In [None]:
###############################
#
#  Set up a function to calculate K50,Unfolded 
#
###############################

In [None]:
def caluculation_unfoldedK(df):
###############################
# Get parameters obtained in STEP3
###############################
    with open('./STEP3_unfolded_model_prams', 'rb') as p:
        param = pickle.load(p)
    chain_num = 0
    median_k_folded_TC,median_logstic_center_TC,median_min_k_unfolded_TC,median_max_k_unfolded_TC,median_protease_filter = np.array([[3., 2.]]),np.median(param['logstic_center_TC'][chain_num],axis=0),np.median(param['min_K50unfolded_TC'][chain_num],axis=0),np.median(param['max_K50unfolded_TC'][chain_num],axis=0),np.median(param['protease_filter'][chain_num],axis=0)
    

###############################
# Get one-hot-encoded amino acid sequences
###############################
    # convert DNA sequences into amino acid sequences, then padded the sequences with 'X'
    df['aa_seq'] = ['GGG'+str(Seq(x).translate())+'GGG' for x in df['dna_seq']]
    aa_list = list(df['aa_seq'])

    for i in range(len(aa_list)):
        add_aa_num = 86-len(aa_list[i])
        if add_aa_num % 2 == 0:
            aa_list[i] = 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX'[:int(add_aa_num/2)] + aa_list[i] + 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX'[:int(add_aa_num/2)]
        if add_aa_num % 2 == 1:
            n = int(add_aa_num/2-0.5)
            aa_list[i] = 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX'[:n] + aa_list[i] + 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX'[:n+1]
    
    # define one-hot encoder
    aas='ACDEFGHIKLMNPQRSTVWXY'
    enc=OneHotEncoder()
    enc.fit(np.array([aa for aa in aas]).reshape(-1, 1))
    
    # make onehot vectors from amino acid sequences   
    seqs = np.array([[(x) for x in (s)] for s in aa_list])
    seq_input = []
    for s in tqdm(seqs):
        seq_input.append(enc.transform(s.reshape([-1,1])).toarray().T)
    seq_input = np.array(seq_input,dtype='float32')
    # seq_input: onehot vector representing amino acid sequences [# of scrambles, # of 20 amino acids +' X' (21), length of amino acids +6x Gly (86)]
    
###############################
# Calculate K50,U from one-hot-encoded amino acid sequences
###############################
    # filtered_seqs: filtered onehot vectors [# of scrambles, trypsin/chymotrypsin(2) , length of amino acids +6x Gly (86) - length of filter (9) + 1 (78)]
    filtered_seqs=jax.lax.conv_general_dilated(seq_input, median_protease_filter, [1], "VALID")
    
    # saturated_sites: local protease sensitivity with satulation [# of scrambles, trypsin/chymotrypsin(2) , length of amino acids +6x Gly - length of window + 1 (78)]
    # saturated_sites = logistic(PSSM(aa_site,site))
    saturated_sites=jax.scipy.special.expit(filtered_seqs)
    
    # sum_saturated_sites: sum of local protease sensitivity [# of scrambles, trypsin/chymotrypsin(2)]
    # sum_saturated_sites = sum(saturated_sites)
    sum_saturated_sites = jax.numpy.sum(saturated_sites,axis=2)
    
    #bin_num: the number of logistic functions for the final activation
    bin_num = 10
    # change the shape of sum_saturated_sites
    # sum_saturated_sites: sum of local protease sensitivity [# of scrambles, trypsin/chymotrypsin(2), # of bins (10)]     
    sum_saturated_sites = jax.numpy.transpose(jax.numpy.resize(saturated_sites,(bin_num,len(saturated_sites),2)),(1,2,0))

    # K50unfolded_TC: log10 K50 unfolded values for trypsin/chymotrypsin [# of scrambles, trypsin/chymotrypsin(2)]
    # K50unfolded = max_K50,U - Scale * sum_saturated_sites
    k_unfolded_TC = median_max_k_unfolded_TC + (median_min_k_unfolded_TC-median_max_k_unfolded_TC)*jax.numpy.sum(jax.scipy.special.expit(sum_saturated_sites-median_logstic_center_TC),axis=2)/10

###############################
#  Create dataframe with K50,U 
###############################
    df['unfolded_K50t'] = [x[0] for x in k_unfolded_TC]
    df['unfolded_K50c'] = [x[1] for x in k_unfolded_TC]
    df_unfolded = pd.concat([df['name'],df['dna_seq'],df['unfolded_K50t'],df['unfolded_K50c']],axis=1)
    return df_unfolded