In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
from Bio.PDB import DSSP, PDBParser
import pandas as pd
from scipy import stats

from Bio import SeqIO, Phylo
import numpy as np

from matplotlib import pyplot as plt

import random
import weighting_methods

from collections import defaultdict

In [3]:
import sys
sys.path.append('../../Tree_rooting/Code/')
import rooting_methods

In [64]:
def total_ent_fxn(weights, seqs, base=2, filter_gaps=False):
    bin_counts = np.apply_along_axis(lambda x: np.bincount(x, weights=weights, minlength=21),\
                                         axis=1, arr=seqs)
    if filter_gaps:
        bin_counts = bin_counts[:,1:]
    all_ents = stats.entropy(bin_counts.T, base=base)
    return all_ents

def site_wise_entropy(fasta_records, weights=None, filter_gaps=False):
    seqs = np.array([list(str(record.seq)) for record in fasta_records])
    seqs_T = seqs.T
    initial_shape = seqs_T.shape
    flat_seqs = seqs_T.flatten()
    order, flat_array = np.unique(flat_seqs, return_inverse=True)
    if filter_gaps:
        assert order[0] == '-'
    replaced_seqs_T = flat_array.reshape(initial_shape)
    if not weights:
        weights = np.full(replaced_seqs_T[0].shape, fill_value=1./replaced_seqs_T[0].shape[0])
    entropies = total_ent_fxn(weights, replaced_seqs_T, base=20, filter_gaps=filter_gaps)
    return entropies

In [5]:
protein_id = '1aoeA'
pdb_loc = '../../Phylogenetic_couplings/Data/psicov150_aln_pdb/pdb/{}.pdb'.format(protein_id)
afa_loc = '../../Phylogenetic_couplings/Data/psicov150_aln_pdb/aln_fasta_max1k/{}.fasta'.format(protein_id)
tree_loc = '../../Phylogenetic_couplings/Data/psicov150_aln_pdb/ref_root_trees/{}.newick'.format(protein_id)
rates_loc = '../../Phylogenetic_couplings/Data/psicov150_aln_pdb/aln_fasta_max1k_iqtree/{}.fasta.rate'.format(protein_id)

In [6]:
p = PDBParser()
structure_ref = p.get_structure('protein_id', pdb_loc)
model_ref = structure_ref[0]
dssp_ref = DSSP(model_ref, pdb_loc,\
            dssp='/Users/adamhockenberry/workspace/xssp-3.0.5/mkdssp', acc_array='Wilke')
rsa_wilke = [i[3] for i in dssp_ref.property_list]

emp_df = pd.read_csv(rates_loc, sep='\t', skiprows=8)
emp_rates = list(emp_df['Rate'])
print(len(rsa_wilke), len(emp_rates), stats.spearmanr(emp_rates, rsa_wilke))

192 192 SpearmanrResult(correlation=0.43371292604002254, pvalue=3.2953996182491647e-10)


In [8]:
fasta_records = list(SeqIO.parse(afa_loc, 'fasta'))
entropies_basic = site_wise_entropy(fasta_records, weights=None, filter_gaps=True)
print(len(entropies_basic), stats.spearmanr(entropies_basic, rsa_wilke))
print(len(entropies_basic), stats.spearmanr(entropies_basic, emp_rates))



192 SpearmanrResult(correlation=0.29152269395360059, pvalue=4.0862607356287777e-05)
192 SpearmanrResult(correlation=0.9003199351665262, pvalue=1.4567880456100776e-70)


**Okay suppose you had a list of weights...**

In [9]:
tree = Phylo.read(tree_loc, 'newick')
tree = rooting_methods.mp_root_adhock(tree)
weights_dict_GSC = weighting_methods.GSC_adhock(tree)
# weights_dict_ACL, rooted_tree = weighting_methods.ACL_adhock(tree)
weights_dict_GSC_normed = weighting_methods.normalize_GSC_weights(weights_dict_GSC, tree)
weights_dict_HH = weighting_methods.HH_adhock(fasta_records)

In [21]:
ids = [record.id for record in fasta_records]

###For GSC
# weights = [weights_dict_GSC_normed[tree.find_any(seq_id)][-1] for seq_id in ids]
weights = [weights_dict_GSC[tree.find_any(seq_id)][-1] for seq_id in ids]

###For HH
# weights = [weights_dict_HH[seq_id] for seq_id in ids]

In [22]:
entropies_weighted = site_wise_entropy(fasta_records, weights, filter_gaps=True)
print(np.sum(entropies_weighted), 20**np.mean(entropies_weighted))
print(np.sum(entropies_basic), 20**np.mean(entropies_basic))

117.895739834 6.29339316296
112.605988506 5.79482682926


In [16]:
print(len(weighted_ents), stats.spearmanr(weighted_ents, rsa_wilke))
print(len(weighted_ents), stats.spearmanr(weighted_ents, emp_rates))
print(len(weighted_ents), stats.spearmanr(weighted_ents, entropies_basic))

192 SpearmanrResult(correlation=0.29948407754991035, pvalue=2.4436549080056641e-05)
192 SpearmanrResult(correlation=0.89205112859080393, pvalue=1.8852370615685438e-67)
192 SpearmanrResult(correlation=0.99148197379486214, pvalue=3.7121572848206839e-170)


In [18]:
np.sum(weights)

166.94666030637563

# Making things complicated

In [None]:
import sys
sys.path.append('../../Tree_rooting/Code/')

import rooting_methods

tree = Phylo.read(tree_loc, 'newick')
print(len(tree.get_terminals()), len(tree.root.clades[0].get_terminals()), len(tree.root.clades[1].get_terminals()))
tree = rooting_methods.mp_root_adhock(tree)
print(len(tree.get_terminals()), len(tree.root.clades[0].get_terminals()), len(tree.root.clades[1].get_terminals()))

**Getting weights on said tree with said root**

In [None]:
import glob
import sys
sys.path.append('../../Tree_rooting/Code/')
import rooting_methods

In [None]:
rsa_ent_basic = []
rsa_ent_weighted = []
rsa_sub = []
for pdb_loc in glob.glob('../../Phylogenetic_couplings/Data/psicov150_aln_pdb/pdb/*.pdb')[:]:
    protein_id = pdb_loc.split('/')[-1].strip('.pdb')
    if protein_id == '1vjkA':
        continue
    afa_loc = '../../Phylogenetic_couplings/Data/psicov150_aln_pdb/aln_fasta_max1k/{}.fasta'.format(protein_id)
    tree_loc = '../../Phylogenetic_couplings/Data/psicov150_aln_pdb/ref_root_trees/{}.newick'.format(protein_id)
    rates_loc = '../../Phylogenetic_couplings/Data/psicov150_aln_pdb/aln_fasta_max1k_iqtree/{}.fasta.rate'.format(protein_id)
    
    p = PDBParser()
    structure_ref = p.get_structure('protein_id', pdb_loc)
    model_ref = structure_ref[0]
    dssp_ref = DSSP(model_ref, pdb_loc,\
                dssp='/Users/adamhockenberry/workspace/xssp-3.0.5/mkdssp', acc_array='Wilke')
    rsa_wilke = [i[3] for i in dssp_ref.property_list]

    emp_df = pd.read_csv(rates_loc, sep='\t', skiprows=8)
    emp_rates = list(emp_df['Rate'])

    fasta_records = list(SeqIO.parse(afa_loc, 'fasta'))
    entropies_basic = site_wise_entropy(fasta_records, filter_gaps=True)

    weights_dict_hh = weighting_methods.HH_adhock(fasta_records)
    ids = [record.id for record in fasta_records]
    weights = [weights_dict_hh[seq_id] for seq_id in ids]
    entropies_weighted = site_wise_entropy(fasta_records, weights=weights, filter_gaps=True)

    rsa_ent_basic.append(stats.spearmanr(entropies_basic, rsa_wilke)[0])
    rsa_ent_weighted.append(stats.spearmanr(entropies_weighted, rsa_wilke)[0])
    rsa_sub.append(stats.spearmanr(emp_rates, rsa_wilke)[0])
    print('########')
    print(np.sum(weights))
    print(20**np.mean(entropies_basic), 20**np.mean(entropies_weighted))

In [None]:
print(np.mean(rsa_ent_basic),\
     np.mean(rsa_ent_weighted),\
     np.mean(rsa_sub))

In [None]:
stats.ttest_rel(rsa_sub, rsa_ent_basic)

In [None]:
stats.ttest_rel(rsa_ent_weighted, rsa_ent_basic)
# stats.wilcoxon(rsa_ent_weighted, rsa_ent_basic)

In [None]:
fig, ax = plt.subplots()
ax.hist(np.array(rsa_ent_weighted)-np.array(rsa_ent_basic), 20)

In [None]:
fig, ax = plt.subplots()
ax.hist(weights, 20);

In [69]:
# records = list(SeqIO.parse('../../Tree_rooting/Data/Tria_et_al_data/'
#                            'eukaryotes/ingroup/aln/KOG0018.faa.aln', 'fasta'))
# tree = Phylo.read('../../Tree_rooting/Data/Tria_et_al_data/'
#                   'eukaryotes/processed_trees/KOG0018.faa.aln.nwk.Rooted.MADAJH', 'newick')

# records = list(SeqIO.parse('../../Tree_rooting/Data/OMA_group_data/eukaryotes/aligned_OMA_groups/'
#                            'OMAGroup_833097.mafft.afa', 'fasta'))
# tree = Phylo.read('../../Tree_rooting/Data/OMA_group_data/eukaryotes/processed_OMA_trees/'
#                            'OMAGroup_833097.treefile.Rooted.MADAJH', 'newick')


records = list(SeqIO.parse('../../Phylogenetic_couplings/Data/psicov150_aln_pdb/'
                           'aln_fasta_max1k/1aoeA.fasta', 'fasta'))
tree = Phylo.read('../../Phylogenetic_couplings/Data/psicov150_aln_pdb/'
                  'mp_root_trees/1aoeA.newick', 'newick')

entropies_basic = site_wise_entropy(records, weights=None, filter_gaps=True)

In [77]:
20**np.mean(entropies_basic)

5.7948268292616385

In [78]:
ids = [record.id for record in records]
weights_dict_HH = weighting_methods.HH_adhock(records)
weights = [weights_dict_HH[seq_id] for seq_id in ids]
entropies_hh = site_wise_entropy(records, weights=weights, filter_gaps=True)

In [79]:
20**np.mean(entropies_hh)

6.3556819801933644

In [57]:
weights = [  1.11237791e-02,   4.08098069e-02,   6.06282603e-07,
         6.05743961e-07,   2.31976095e-04,   1.87718592e-03,
         7.70138229e-03,   4.02380357e-05,   1.08848637e-02,
         8.57081265e-03,   3.73401314e-03,   3.73401314e-03,
         6.05906983e-07,   4.03133942e-02,   1.94271297e-02,
         6.04188356e-07,   1.85553636e-02,   5.98658973e-07,
         2.96810725e-02,   3.94913406e-03,   3.94913406e-03,
         2.31553035e-04,   4.40487096e-02,   2.00777842e-02,
         6.08871816e-07,   6.01724558e-03,   6.08595677e-07,
         2.68214351e-04,   2.68214352e-04,   6.05925425e-07,
         6.04896644e-07,   6.06707521e-07,   5.99825741e-07,
         6.00355955e-07,   6.00349755e-07,   2.70077765e-03,
         6.05865986e-07,   7.70188351e-06,   2.28953833e-02,
         6.57131566e-03,   6.00722046e-07,   4.72817713e-03,
         5.69580642e-03,   5.97697507e-07,   5.97541336e-07,
         5.99100393e-07,   1.94363192e-04,   3.37188460e-06,
         3.18609530e-05,   6.03889231e-07,   5.97792603e-07,
         5.14009197e-04,   1.05037075e-05,   5.59695602e-03,
         6.81307108e-03,   1.96091657e-02,   5.15695439e-05,
         4.82801390e-03,   1.42955503e-05,   6.00421469e-07,
         5.99877257e-07,   6.04683582e-07,   1.94963763e-03,
         6.00898260e-07,   1.45835723e-02,   6.01115465e-07,
         5.98036989e-07,   6.08519249e-07,   3.97310163e-05,
         6.04439619e-07,   6.00563801e-07,   5.96742081e-07,
         6.01621625e-07,   5.99914836e-07,   5.99740812e-07,
         5.02555189e-04,   2.13142373e-04,   6.01053504e-07,
         5.12528716e-02,   6.00389441e-07,   5.99651091e-07,
         5.96507481e-07,   6.20997149e-03,   4.45215209e-03,
         4.34803821e-05,   8.09730728e-05,   1.04994239e-02,
         5.80231924e-03,   5.80231924e-03,   1.44247005e-02,
         6.02849804e-07,   6.01940786e-07,   5.96570208e-07,
         5.96740740e-07,   4.02232653e-04,   5.96834497e-07,
         5.97406980e-07,   5.98977874e-07,   5.96500002e-07,
         1.26729434e-03,   1.97195675e-05,   6.01566438e-07,
         5.96985651e-04,   2.92791096e-03,   3.36133766e-02,
         9.17171473e-05,   4.23494130e-03,   6.06058310e-07,
         1.62758507e-02,   1.23914498e-02,   5.97152914e-07,
         2.03403876e-04,   6.04841915e-07,   5.99773538e-07,
         1.64827575e-03,   6.00804643e-07,   1.44740169e-01,
         6.02034177e-07,   7.79359382e-04,   6.00147577e-07,
         6.00047339e-07,   1.22148977e-05,   6.00226568e-07,
         6.00000811e-07,   5.99222518e-07,   6.00287331e-07,
         1.91994774e-03,   3.35364740e-05,   5.97755111e-07,
         1.19946160e-02,   4.70022292e-03,   6.00249419e-07,
         5.63548633e-03,   2.18921363e-02,   6.29174283e-03,
         6.07901234e-07,   3.32447767e-04,   4.52036742e-04,
         9.58499999e-03,   1.77648300e-05,   1.84356911e-03,
         7.34244911e-03,   4.91302987e-02,   7.82753368e-04,
         7.82753378e-04,   9.62266468e-03,   7.74095686e-04,
         5.39878480e-02,   1.13731893e-02,   5.97040317e-07,
         1.94810293e-05,   5.99453853e-07,   5.98922248e-07,
         3.65553822e-05,   5.99089997e-07,   5.98686904e-07,
         1.22157737e-04,   6.06641815e-07,   3.73281027e-05,
         1.96272208e-03,   2.69763929e-02,   4.38084081e-04,
         6.00546589e-07,   5.99374788e-07,   3.18027759e-05,
         5.99715108e-07,   5.97619339e-07,   5.98530727e-07,
         6.00217329e-07,   9.62125620e-07,   1.71995780e-04,
         5.98629294e-07,   6.00306942e-07,   6.08274230e-07,
         6.00451730e-07,   5.98201855e-07,   5.99760388e-07,
         1.77261633e-05,   6.00796036e-07,   5.99954220e-07,
         6.01055062e-07,   5.99708246e-07,   5.99306764e-07,
         6.00593231e-07,   6.01654173e-07,   6.03286275e-07,
         4.44452712e-02,   6.06470546e-07,   5.97987535e-07,
         7.18276487e-03,   6.02300601e-07,   6.09717405e-07,
         4.52173063e-04,   6.00127990e-07,   5.97926590e-07,
         6.09468548e-07,   5.98162852e-07,   5.97494979e-07,
         6.01274832e-07,   5.99303071e-07,   6.08636410e-07,
         5.98790718e-07,   6.04627343e-07,   5.98509392e-07,
         6.06259295e-07,   6.06694897e-07,   5.98823250e-07,
         5.97499423e-07,   1.85839376e-03,   6.00583610e-07,
         6.06668617e-07,   2.67570216e-02,   5.97349273e-07,
         1.06739124e-04]

entropies_maxent = site_wise_entropy(records, weights=weights, filter_gaps=True)

['-' 'A' 'C' 'D' 'E' 'F' 'G' 'H' 'I' 'K' 'L' 'M' 'N' 'P' 'Q' 'R' 'S' 'T'
 'V' 'W' 'X' 'Y']


In [60]:
20**np.mean(entropies_maxent)

1.9732789218013993

# Trying a james-stein estimator for shrinkage

In [126]:
# js.estimate <- function(prob, ct) {
#   if(ct<=1) {
#     #basically if we only observe a count of 1
#     #the variance goes to infinity and we get the uniform distribution.
#     return(rep(1/length(prob), length(prob)))
#   }
#   # MLE of prob estimate
#   mlvar <- prob*(1-prob)/(ct-1)
#   unif <- rep(1/length(prob), length(prob)) 
  
#   # Deviation from uniform
#   deviation <- sum((prob-unif)^2)
  
#   #take care of special case,if no difference it doesn't matter
#   if(deviation==0) return(prob)
  
#   lambda <- sum(mlvar)/deviation
#   #if despite  our best efforts we ended up with an NaN number-just return the uniform distribution.
#   if(is.nan(lambda)) return(unif)
  
#   #truncate
#   if(lambda>1) lambda <- 1
#   if(lambda<0) lambda <- 0
  
#   #Construct shrinkage estimator as convex combination of the two
#   lambda*unif + (1 - lambda)*prob
# }


probs = np.array([0.75, 0.25])
possibilities = 2
counts = 1000
if counts <= 1:
    print('think i should return the uniform distribution since there is no variance')
mlvar = probs*(1-probs)/(counts-1)
uniform = np.array([1/possibilities for i in range(probs.shape[0])])
deviation = np.sum((probs-uniform)**2.)
if np.all(deviation) == 0:
    print('do something special')
lamb = np.sum(mlvar)/deviation
print('this value should be between zero and 1:', lamb)
new_probs = lamb*uniform + (1-lamb)*probs

this value should be between zero and 1: 0.003003003003


In [127]:
lamb

0.003003003003003003

In [128]:
uniform

array([ 0.5,  0.5])

In [129]:
lamb

0.003003003003003003

In [130]:
new_probs

array([ 0.74924925,  0.25075075])

In [195]:
def js_shrinkage_discrete(starting_freqs, prior_distribution, n_obs):
    assert np.sum(starting_freqs) == 1.
    assert len(starting_freqs) == len(prior_distribution)
    if n_obs <= 1:
        return prior_distribution
    mlvar = starting_freqs*(1-starting_freqs)/(n_obs-1)
    deviation = np.sum((starting_freqs-prior_distribution)**2.)
    if np.all(deviation) == 0.:
        return starting_freqs
    lamb = np.sum(mlvar)/deviation
    if lamb > 1.:
        lamb = 1.
    elif lamb < 0.:
        lamb = 0.
    final_freqs = lamb * prior_distribution + (1-lamb) * starting_freqs
    return final_freqs

In [196]:
seqs = np.array([list(str(record.seq)) for record in records])
seqs_T = seqs.T
initial_shape = seqs_T.shape
flat_seqs = seqs_T.flatten()
order, flat_array = np.unique(flat_seqs, return_inverse=True)
assert order[0] == '-'
replaced_seqs_T = flat_array.reshape(initial_shape)
bin_counts = np.apply_along_axis(lambda x: np.bincount(x, minlength=21),\
                                     axis=1, arr=replaced_seqs_T)
bin_counts = bin_counts[:,1:]
#     all_ents = stats.entropy(bin_counts.T, base=base)

In [205]:
starting = bin_counts[1]/np.sum(bin_counts[1])
uniform = np.array([1/len(starting) for i in range(len(starting))])
final = js_shrinkage_discrete(starting, uniform, np.sum(bin_counts[1]))

array([ 0.0766082 ,  0.0091508 ,  0.03913186,  0.00165553,  0.00165553,
        0.01664606,  0.00165553,  0.0316366 ,  0.0428795 ,  0.0541224 ,
        0.14781323,  0.04662713,  0.02788896,  0.0091508 ,  0.01289843,
        0.18528956,  0.26398986,  0.02788896,  0.00165553,  0.00165553])

In [188]:
probs = bin_counts[1]/np.sum(bin_counts[1])

In [189]:
20**stats.entropy(probs, base=20)

9.1134519803483016

In [193]:
possibilities = 20
counts = np.sum(bin_counts[1])
if counts <= 1:
    print('think i should return the uniform distribution since there is no variance')
mlvar = probs*(1-probs)/(counts-1)
uniform = np.array([1/possibilities for i in range(probs.shape[0])])
deviation = np.sum((probs-uniform)**2.)
if np.all(deviation) == 0:
    print('do something special')
lamb = np.sum(mlvar)/deviation
# lamb = 1
print('this value should be between zero and 1:', lamb)
new_probs = lamb*uniform + (1-lamb)*probs

this value should be between zero and 1: 0.0331106181129


In [194]:
20**stats.entropy(new_probs, base=20)

9.7584136874264242

In [180]:
new_probs

array([ 0.0766082 ,  0.0091508 ,  0.03913186,  0.00165553,  0.00165553,
        0.01664606,  0.00165553,  0.0316366 ,  0.0428795 ,  0.0541224 ,
        0.14781323,  0.04662713,  0.02788896,  0.0091508 ,  0.01289843,
        0.18528956,  0.26398986,  0.02788896,  0.00165553,  0.00165553])

In [181]:
probs

array([ 0.07751938,  0.00775194,  0.03875969,  0.        ,  0.        ,
        0.01550388,  0.        ,  0.03100775,  0.04263566,  0.05426357,
        0.15116279,  0.04651163,  0.02713178,  0.00775194,  0.01162791,
        0.18992248,  0.27131783,  0.02713178,  0.        ,  0.        ])

In [182]:
dirichlet = bin_counts[1]+5

In [184]:
20**stats.entropy(dirichlet, base=20)

13.549581922734784