In [33]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from math import gamma, factorial
from scipy.special import loggamma, logsumexp
from scipy.spatial.distance import pdist
sns.set()
from time import time
%matplotlib

Using matplotlib backend: Qt5Agg


In [34]:
#def stirling(x):
   # return (x-0.5)*np.log(x+1)-x+1.9189385332+(1/(12*(x+1)+0.5))

def dirichlet_multinomial_old(x,alpha):
    '''
    where x is the count vector of tetranuclotides (256,1)
    alpha is the hyper parameter (256,1)
    alpha_0 is the strength of the pseudo-counts (sum of alphas)
    
    calculations are in log space
    formula: https://en.wikipedia.org/wiki/Dirichlet-multinomial_distribution
    '''

    n = np.sum(x)
    alpha_0 = np.sum(alpha)
    
    log_probability = loggamma(n+1)+loggamma(alpha_0)-loggamma(n+alpha_0)

    for k in range(len(x)):
        log_probability += loggamma(x[k]+alpha[k])-loggamma(x[k]+1)-loggamma(alpha[k])
        
    return log_probability


def bayes_null_old(x,y,alpha):
    return dirichlet_multinomial_old(x,alpha)+dirichlet_multinomial_old(y,alpha)   


def bayes_model_old(x,y,alpha):
    dirich = dirichlet_multinomial_old(x+y,alpha)
    Lx = np.sum(x)
    Ly = np.sum(y)
    correction = loggamma(Lx+1) + loggamma(Ly+1) - loggamma(Lx+Ly+1)
    for k in range(len(x)):
        correction += loggamma(x[k]+y[k]+1) - loggamma(x[k]+1) -loggamma(y[k]+1)
    return correction + dirich

def b_metric_old(x,y, alpha):
    BN = bayes_null_old(x, y, alpha)
    BM = bayes_model_old(x, y, alpha)    
    return BM-BN


def dirichlet_multinomial(x,alpha):
    '''
    where x is the count vector of tetranuclotides (256,1)
    alpha is the hyper parameter (256,1)
    alpha_0 is the strength of the pseudo-counts (sum of alphas)
    
    calculations are in log space
    formula: https://en.wikipedia.org/wiki/Dirichlet-multinomial_distribution
    '''

    n = np.sum(x)
    alpha_0 = np.sum(alpha)
    
    log_probability = loggamma(n+1)+loggamma(alpha_0)-loggamma(n+alpha_0)

    log_probability += np.sum(loggamma(x+alpha)-loggamma(x+1)-loggamma(alpha))
        
    return log_probability


def bayes_null(x,y,alpha):
    return dirichlet_multinomial(x,alpha)+dirichlet_multinomial(y,alpha)    


def bayes_model(x,y,alpha):
    dirich = dirichlet_multinomial(x+y,alpha)
    Lx = np.sum(x)
    Ly = np.sum(y)
    correction = loggamma(Lx+1) + loggamma(Ly+1) - loggamma(Lx+Ly+1)
    correction += np.sum(loggamma(x+y+1) - loggamma(x+1) -loggamma(y+1))
    return correction + dirich


def b_metric(x,y, alpha):
    BN = bayes_null(x, y, alpha)
    BM = bayes_model(x, y, alpha)    
    return BM-BN


In [35]:
#Mean TNF in 10 ref genomes as alpha
all_genomes = pd.read_csv('./genomes_tn_conts.csv')
alpha = all_genomes.mean().values

In [36]:
data = pd.read_csv('./v_cholerae_conts.csv')

In [37]:
group = data.loc[data['contamination'] == 0].values
labels = group[:,-1]
group = group[:,:-1]

# Speed benchmark

In [38]:
def timing(fun):
    t1 = time()
    fun()
    t2 = time()
    print('time of execution: ', t2 - t1)


In [39]:

def new():
    x = group[1]
    n = np.sum(x)
    alpha_0 = np.sum(alpha)

    log_probability = loggamma(n+1)+loggamma(alpha_0)-loggamma(n+alpha_0)

    log_probability += np.sum(loggamma(x+alpha)-loggamma(x+1)-loggamma(alpha))

    log_probability


In [40]:
def old():
    x =group[1]
    n = np.sum(x)
    alpha_0 = np.sum(alpha)

    log_probability = loggamma(n+1)+loggamma(alpha_0)-loggamma(n+alpha_0)

    for k in range(len(x)):
        log_probability += loggamma(x[k]+alpha[k])-loggamma(x[k]+1)-loggamma(alpha[k])

    log_probability


In [41]:
timing(old)

time of execution:  0.0019412040710449219


In [42]:
timing(new)

time of execution:  0.00015664100646972656


In [43]:
def test_old():
    group = data.loc[data['contamination'] == 0].values
    labels = group[:100,-1]
    group = group[:100,:-1]


    d = b_metric(group[0],group[1], alpha)
    distances_old = pdist(group, lambda u, v: b_metric_old(u,v, alpha))
    
def test_new():
    group = data.loc[data['contamination'] == 0].values
    labels = group[:100,-1]
    group = group[:100,:-1]


    d = b_metric(group[0],group[1], alpha)
    distances = pdist(group, lambda u, v: b_metric(u,v, alpha))

In [44]:
timing(test_old)

time of execution:  41.28519034385681


In [45]:
timing(test_new)

time of execution:  1.4435486793518066
