In [2]:
import choice_probabilities_mcmc as cpmcmc
import numpy as np
import seaborn as sns
import scipy.stats as scps
import matplotlib.pyplot as plt
import pandas as pd
import statsmodels as statm

In [3]:
cl = cpmcmc.choice_probabilities_analytic_mh()

In [4]:
# Model parameters
cl.model_num = 0
cl.model_time = '09_03_18_17_28_21'
cl.model_signature = '_choice_probabilities_analytic_'
cl.model_checkpoint = 'final'

In [5]:
# Data simulation parameters
cl.data_sim_params['v'] = -1
cl.data_sim_params['a'] = 1
cl.data_sim_params['n_samples'] = 1000

In [6]:
# Make paths
cl.make_checkpoint_path()
cl.make_model_path()
cl.model_path

'/home/afengler/git_repos/nn_likelihoods/keras_models/dnnregressor_choice_probabilities_analytic_09_03_18_17_28_21/model_0'

In [7]:
# Make dataset
cl.make_data_set()

datapoint 0 generated
label 0 generated


In [8]:
# Get predictor
cl.get_dnn_keras()

In [9]:
cl.get_log_posterior()

array([[-500.40198737]])

In [None]:
cl.mcmc_params['n_samples'] = 10000
cl.mcmc_params['cov_init'] = np.array([[0.1, 0.0], [0, 0.1]])
cl.metropolis_hastings_custom(method = 'dnn',
                              variance_scale_param = 0.5,
                              variance_epsilon = 0.05,
                              write_to_file = True,
                              print_steps = False)

0
1000
2000
3000
4000
5000


In [11]:
cl.chain_stats

{'acc_cnt': 1401}

In [None]:
def v_a_curve(x = 0.5, sign = 1):
    curve = pd.DataFrame(np.zeros((999, 2)), columns = ['v', 'a'])
    cnt = 0
    for v_tmp in np.arange(0.01 * sign, 10 * sign, 0.01 * sign):
        a_tmp = np.log((1 - x) / x) / v_tmp
        curve.loc[cnt] = [v_tmp, a_tmp]
        cnt += 1
    return curve

def v_a_curve_prime(x = 0.5, v_star = 1):
    curve_prime = pd.DataFrame(np.zeros((999, 2)), columns = ['v', 'a'])
    cnt = 0
    for v_tmp in np.arange(0.01 * np.sign(v_star), 10 * np.sign(v_star), 0.01 * np.sign(v_star)):
        a_tmp = - (np.log((1 - x) / x) / np.power(v_star, 2)) * v_tmp + (2 * (np.log((1-x)/x)) / v_star)
        curve_prime.loc[cnt] = [v_tmp, a_tmp]
        cnt += 1
        
    return curve_prime

In [None]:
chain_nn = cl.chain.copy()
chain_nn['id'] = chain_nn.index.get_values()

curve = v_a_curve(x = cl.data_sim['n_choice_lower'] / cl.data_sim['n_samples'], 
                  sign = -1)
curve = curve.loc[curve['a'] < np.max(chain_nn['a'])].copy()
curve = curve.loc[curve['a'] > np.min(chain_nn['a'])].copy()
curve = curve.loc[curve['v'] < np.max(chain_nn['v'])].copy()
curve = curve.loc[curve['v'] > np.min(chain_nn['v'])].copy()


#curve_prime = v_a_curve_prime(x = cl.data_sim['n_choice_lower'] / cl.data_sim['n_samples'], v_star = 1)
#curve_prime = curve_prime.loc[curve_prime['a'] < np.max(chain_nn['a'])].copy()
#curve_prime = curve_prime.loc[curve_prime['v'] < np.max(chain_nn['v'])].copy()

g = sns.jointplot('v', 'a', data = chain_nn, kind = 'kde', space = 0, color = 'g')
x0, x1 = g.ax_joint.get_xlim()
y0, y1 = g.ax_joint.get_ylim()
lims = [max(x0, y0), min(x1, y1)]
g.ax_joint.plot(curve['v'], curve['a'], 'r-')
#g.ax_joint.plot(curve_prime['v'], curve_prime['a'], 'b-')
plt.show()

In [None]:
# trace plot v
ax = sns.lineplot(x = 'id', y = 'v', data = chain_nn)

In [None]:
# trace plot a
ax = sns.lineplot(x = 'id', y = 'a', data = chain_nn)

In [None]:
# Now sample from actual model
cl.priors = {'v': scps.norm(loc = 0, scale = 10),
             'a': scps.uniform(loc = 0, scale = 10)}

cl.mcmc_params['n_samples'] = 50000
cl.mcmc_params['cov_init'] = np.array([[0.1, 0.0], [0, 0.1]])
cl.metropolis_hastings(method = 'wfpt')

In [None]:
chain_wfpt = cl.chain.copy()
chain_wfpt['id'] = chain_wfpt.index.get_values()

curve = v_a_curve(x = cl.data_sim['n_choice_lower'] / cl.data_sim['n_samples'])
curve = curve.loc[curve['a'] < np.max(chain_wfpt['a'])].copy()
curve = curve.loc[curve['v'] < np.max(chain_wfpt['v'])].copy()

# Drawing posterior plot with 
g = sns.jointplot('v', 'a', data = chain_wfpt, kind = 'kde', space = 0, color = 'g')
x0, x1 = g.ax_joint.get_xlim()
y0, y1 = g.ax_joint.get_ylim()
lims = [max(x0, y0), min(x1, y1)]
g.ax_joint.plot(curve['v'], curve['a'], 'r-')
plt.show()

In [None]:
cl.chain.shape

In [None]:
# trace plot v
ax = sns.lineplot(x = 'id', y = 'v', data = chain_wfpt)

In [None]:
# trace plot a
ax = sns.lineplot(x = 'id', y = 'a', data = chain_wfpt)