In [3]:
import autograd 

import autograd.numpy as np
import autograd.scipy as sp
from numpy.polynomial.hermite import hermgauss

import numpy.testing as np_test

import sys
sys.path.insert(0, '../libraries/')
sys.path.insert(0, '../../BNP_modeling/')

import structure_model_lib
import modeling_lib
import data_utils

# Draw data

In [4]:
np.random.seed(25465)

In [5]:
n_obs = 10
n_loci = 5

# this is just done randomly at the moment
# a matrix of integers {0, 1, 2}
g_obs_int = np.random.choice(3, size = (n_obs, n_loci))

# one hot encoding
g_obs = data_utils.get_one_hot(g_obs_int, 3)
assert g_obs.shape == (n_obs, n_loci, 3)

# Get prior

In [6]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_paragami)

OrderedDict:
	[dp_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_beta] = NumericArrayPattern (1,) (lb=0.0, ub=inf)


# Get VB params 

In [7]:
k_approx = 12
_, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (5, 12, 2) (lb=0.0, ub=inf)
	[ind_mix_stick_propn_mean] = NumericArrayPattern (10, 11) (lb=-inf, ub=inf)
	[ind_mix_stick_propn_info] = NumericArrayPattern (10, 11) (lb=0.0001, ub=inf)


In [8]:
vb_params_dict = vb_params_paragami.random()

In [11]:
# get prior parameters
dp_prior_alpha = prior_params_dict['dp_prior_alpha']
allele_prior_alpha = prior_params_dict['allele_prior_alpha']
allele_prior_beta = prior_params_dict['allele_prior_beta']

# get vb parameters
ind_mix_stick_propn_mean = vb_params_dict['ind_mix_stick_propn_mean']
ind_mix_stick_propn_info = vb_params_dict['ind_mix_stick_propn_info']
pop_freq_beta_params = vb_params_dict['pop_freq_beta_params']

# expected log beta and expected log(1 - beta)
e_log_p, e_log_1mp = modeling_lib.get_e_log_beta(pop_freq_beta_params)

In [12]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(gh_deg)


In [13]:
def get_kl_from_z_nat_param(g_obs, vb_params_dict, prior_params_dict,
                                gh_loc, gh_weights,
                                z_nat_param):

    e_z = structure_model_lib.get_z_opt_from_loglik_cond_z(z_nat_param)

    return structure_model_lib.get_kl(g_obs, vb_params_dict, prior_params_dict,
                    gh_loc, gh_weights, e_z = e_z)


In [14]:
kl_z_nat_param_grad = autograd.grad(get_kl_from_z_nat_param, argnum = 5)

In [15]:
z_opt_nat_param = structure_model_lib.get_loglik_cond_z(g_obs, e_log_p, e_log_1mp,
                        ind_mix_stick_propn_mean, ind_mix_stick_propn_info,
                        gh_loc, gh_weights)

In [16]:
grad = kl_z_nat_param_grad(g_obs, vb_params_dict, prior_params_dict,
                                gh_loc, gh_weights,
                                z_opt_nat_param)

In [17]:
np.max(np.abs(grad)) < 1e-8

True