In [1]:
import autograd.numpy as np
import autograd.scipy as sp
from numpy.polynomial.hermite import hermgauss

import numpy.testing as np_test

import sys
sys.path.insert(0, '../libraries/')
sys.path.insert(0, '../../BNP_modeling/')

import structure_model_lib
import modeling_lib
import data_utils

# Draw data

In [2]:
n_obs = 10
n_loci = 5

# this is just done randomly at the moment
# a matrix of integers {0, 1, 2}
g_obs_int = np.random.choice(3, size = (n_obs, n_loci))

# one hot encoding
g_obs = data_utils.get_one_hot(g_obs_int, 3)
assert g_obs.shape == (n_obs, n_loci, 3)

# Get prior

In [3]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_paragami)

OrderedDict:
	[dp_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_beta] = NumericArrayPattern (1,) (lb=0.0, ub=inf)


# Get VB params 

In [4]:
k_approx = 12
vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (5, 12, 2) (lb=0.0, ub=inf)
	[ind_mix_stick_propn_mean] = NumericArrayPattern (10, 11) (lb=-inf, ub=inf)
	[ind_mix_stick_propn_info] = NumericArrayPattern (10, 11) (lb=0.0001, ub=inf)


# Check that get_e_log_logitnormal works with arrays

In [5]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [6]:
dp_prior_alpha = prior_params_dict['dp_prior_alpha']
ind_mix_stick_propn_mean = vb_params_dict['ind_mix_stick_propn_mean']
ind_mix_stick_propn_info = vb_params_dict['ind_mix_stick_propn_info']

In [7]:
e_log_v, e_log_1mv = \
        structure_model_lib.ef.get_e_log_logitnormal(
                                    lognorm_means = ind_mix_stick_propn_mean,
                                    lognorm_infos = ind_mix_stick_propn_info,
                                    gh_loc = gh_loc,
                                    gh_weights = gh_weights)

In [8]:
# sample from logit-normal
num_draws = 10**5
samples = np.random.normal(ind_mix_stick_propn_mean,
                1/np.sqrt(ind_mix_stick_propn_info), size = (num_draws, n_obs, k_approx - 1))
logit_norm_samples = sp.special.expit(samples)

In [9]:
e_log_samples = np.log(logit_norm_samples) 
e_log_1m_samples = np.log(1 - logit_norm_samples) 

In [10]:
diff1 = np.abs(e_log_v - e_log_samples.mean(axis = 0))
diff2 = np.abs(e_log_1mv - e_log_1m_samples.mean(axis = 0)) 

In [11]:
assert(np.sum(diff1 > 3 * e_log_samples.std(axis = 0)) == 0)
assert(np.sum(diff2 > 3 * e_log_1m_samples.std(axis = 0)) == 0)

In [12]:
print(np.max(diff1))

0.003003696622445262


In [13]:
print(np.max(diff2))

0.003917942978577393


# Test dp prior

In [14]:
# get dp prior
e_dp_prior = modeling_lib.get_e_logitnorm_dp_prior(ind_mix_stick_propn_mean,
                                            ind_mix_stick_propn_info,
                                            dp_prior_alpha, gh_loc, gh_weights)

In [15]:
# sample from logit-normal
num_draws = 10**5
samples = np.random.normal(ind_mix_stick_propn_mean,
                1/np.sqrt(ind_mix_stick_propn_info), size = (num_draws, n_obs, k_approx - 1))
logit_norm_samples = sp.special.expit(samples)

In [16]:
# samples of the dp_prior
dp_prior_samples = (np.log(1 - logit_norm_samples) * (dp_prior_alpha - 1)).sum(axis = 2).sum(axis = 1)

In [17]:
dp_prior_samples_mean = np.mean(dp_prior_samples)
dp_prior_samples_std = np.std(dp_prior_samples)

In [18]:
print(e_dp_prior)
print(dp_prior_samples_mean)
print(dp_prior_samples_std / np.sqrt(num_draws))

[-230.94815387]
-230.92738325796037
0.03267212975421389


In [19]:
np_test.assert_allclose(
            dp_prior_samples_mean,
            e_dp_prior,
            atol = 3 * dp_prior_samples_std / np.sqrt(num_draws))

# Test stick-breaking entropy

In [20]:
modeling_lib.get_stick_breaking_entropy(ind_mix_stick_propn_mean,
                                            ind_mix_stick_propn_info,
                                            gh_loc, gh_weights)

-47.319134678422415

# E log cluster probabilities

In [21]:
e_log_cluster_prob = modeling_lib.get_e_log_cluster_probabilities(ind_mix_stick_propn_mean,
                                            ind_mix_stick_propn_info,
                                            gh_loc, gh_weights)

In [22]:
e_log_cluster_prob.shape

(10, 12)

# Beta entropy term

In [23]:
pop_freq_beta_params = vb_params_dict['pop_freq_beta_params']

In [24]:
lk = pop_freq_beta_params.shape[0] * pop_freq_beta_params.shape[1]
beta_entropy = structure_model_lib.ef.beta_entropy(tau = pop_freq_beta_params.reshape((lk, 2)))

In [25]:
beta_entropy

-7.576814396426089

In [26]:
pop_freq_beta_params.shape

(5, 12, 2)

In [27]:
beta_samples = np.random.beta(a = pop_freq_beta_params[:, :, 0], b = pop_freq_beta_params[:, :, 1], 
               size = (num_draws, n_loci, k_approx))

In [28]:
beta_entropy_samples = -sp.stats.beta.logpdf(beta_samples, 
                          pop_freq_beta_params[:, :, 0], 
                          pop_freq_beta_params[:, :, 1]).sum(axis = 2).sum(axis = 1)

In [29]:
print(beta_entropy)
print(beta_entropy_samples.mean())
print(beta_entropy_samples.std() / np.sqrt(num_draws))

-7.576814396426089
-7.557031179979996
0.009450833993950044


In [30]:
np_test.assert_allclose(
            beta_entropy_samples.mean(),
            beta_entropy,
            atol = 3 * beta_entropy_samples.std() / np.sqrt(num_draws))

In [31]:
# check e_log_beta 

In [32]:
e_log_beta, e_log_1mbeta = modeling_lib.get_e_log_beta(tau = pop_freq_beta_params)

In [33]:
e_log_beta.shape

(5, 12)

In [34]:
e_log_1mbeta.shape

(5, 12)

In [35]:
e_log_beta_samples = np.log(beta_samples)
e_log_1mbeta_samples = np.log(1 - beta_samples)

In [36]:
diff1 = np.abs(e_log_beta - e_log_beta_samples.mean(axis = 0))
diff2 = np.abs(e_log_1mbeta - e_log_1mbeta_samples.mean(axis = 0)) 

In [37]:
assert(np.sum(diff1 > 3 * e_log_beta_samples.std(axis = 0)) == 0)
assert(np.sum(diff2 > 3 * e_log_1mbeta_samples.std(axis = 0)) == 0)

In [38]:
print(np.max(diff1))

0.009431124266855995


In [39]:
print(np.max(diff2))

0.007277247168444134
