In [28]:
import choice_probabilities_mcmc as cpmcmc
import numpy as np
import seaborn as sns
import scipy.stats as scps
import matplotlib.pyplot as plt
import pandas as pd
import statsmodels as statm
from statsmodels.tsa.stattools import acf

In [29]:
cl = cpmcmc.choice_probabilities_analytic_mh()

In [30]:
# Model parameters
cl.model_num = 0
cl.model_time = '09_03_18_17_28_21'
cl.model_signature = '_choice_probabilities_analytic_'
cl.model_checkpoint = 'final'

In [31]:
# Data simulation parameters
cl.data_sim_params['v'] = -1
cl.data_sim_params['a'] = 1
cl.data_sim_params['n_samples'] = 1000

In [32]:
# Make paths
cl.make_checkpoint_path()
cl.make_model_path()
cl.model_path

'/home/afengler/git_repos/nn_likelihoods/keras_models/dnnregressor_choice_probabilities_analytic_09_03_18_17_28_21/model_0'

In [33]:
# Make dataset
cl.make_data_set()

datapoint 0 generated
label 0 generated


In [34]:
# Get predictor
cl.get_dnn_keras()

In [35]:
cl.get_log_posterior()

array([[-508.76459733]])

In [36]:
cl.mcmc_params['n_samples'] = 1000
cl.mcmc_params['cov_init'] = np.array([[0.1, 0.0], [0, 0.1]])
my_chain, acc_samples = cl.metropolis_hastings_custom(method = 'dnn',
                                                      variance_scale_param = 0.4,
                                                      variance_epsilon = 0.05,
                                                      write_to_file = True,
                                                      print_steps = False)

0


In [37]:
cl.chain.loc[cl.chain['log_posterior'].idxmax()]

v               -0.991700
a                1.106986
log_posterior   -6.470191
Name: 51, dtype: float64

In [38]:
# Get autocorrelation
chain_autocorrelations = acf(cl.chain['a'], nlags = 80)
n_eff_samples = cl.mcmc_params['n_samples'] / (1 + 2 * np.sum(chain_autocorrelations))

# N effective samples
n_eff_samples

20.349684518467303

In [None]:
def v_a_curve(x = 0.5, sign = 1):
    curve = pd.DataFrame(np.zeros((999, 2)), columns = ['v', 'a'])
    cnt = 0
    for v_tmp in np.arange(0.01 * sign, 10 * sign, 0.01 * sign):
        a_tmp = np.log((1 - x) / x) / v_tmp
        curve.loc[cnt] = [v_tmp, a_tmp]
        cnt += 1
    return curve

def v_a_curve_prime(x = 0.5, v_star = 1):
    curve_prime = pd.DataFrame(np.zeros((999, 2)), columns = ['v', 'a'])
    cnt = 0
    for v_tmp in np.arange(0.01 * np.sign(v_star), 10 * np.sign(v_star), 0.01 * np.sign(v_star)):
        a_tmp = - (np.log((1 - x) / x) / np.power(v_star, 2)) * v_tmp + (2 * (np.log((1-x)/x)) / v_star)
        curve_prime.loc[cnt] = [v_tmp, a_tmp]
        cnt += 1
        
    return curve_prime

In [None]:
chain_nn = cl.chain.copy()
chain_nn['id'] = chain_nn.index.get_values()

curve = v_a_curve(x = cl.data_sim['n_choice_lower'] / cl.data_sim['n_samples'], 
                  sign = -1)
curve = curve.loc[curve['a'] < np.max(chain_nn['a'])].copy()
curve = curve.loc[curve['a'] > np.min(chain_nn['a'])].copy()
curve = curve.loc[curve['v'] < np.max(chain_nn['v'])].copy()
curve = curve.loc[curve['v'] > np.min(chain_nn['v'])].copy()


#curve_prime = v_a_curve_prime(x = cl.data_sim['n_choice_lower'] / cl.data_sim['n_samples'], v_star = 1)
#curve_prime = curve_prime.loc[curve_prime['a'] < np.max(chain_nn['a'])].copy()
#curve_prime = curve_prime.loc[curve_prime['v'] < np.max(chain_nn['v'])].copy()

g = sns.jointplot('v', 'a', data = chain_nn, kind = 'kde', space = 0, color = 'g')
x0, x1 = g.ax_joint.get_xlim()
y0, y1 = g.ax_joint.get_ylim()
lims = [max(x0, y0), min(x1, y1)]
g.ax_joint.plot(curve['v'], curve['a'], 'r-')
#g.ax_joint.plot(curve_prime['v'], curve_prime['a'], 'b-')
plt.show()

In [None]:
# trace plot v
ax = sns.lineplot(x = 'id', y = 'v', data = chain_nn)

In [None]:
# trace plot a
ax = sns.lineplot(x = 'id', y = 'a', data = chain_nn)

In [None]:
# Now sample from actual model
cl.priors = {'v': scps.norm(loc = 0, scale = 10),
             'a': scps.uniform(loc = 0, scale = 10)}

cl.mcmc_params['n_samples'] = 50000
cl.mcmc_params['cov_init'] = np.array([[0.1, 0.0], [0, 0.1]])
cl.metropolis_hastings(method = 'wfpt')

In [None]:
chain_wfpt = cl.chain.copy()
chain_wfpt['id'] = chain_wfpt.index.get_values()

curve = v_a_curve(x = cl.data_sim['n_choice_lower'] / cl.data_sim['n_samples'])
curve = curve.loc[curve['a'] < np.max(chain_wfpt['a'])].copy()
curve = curve.loc[curve['v'] < np.max(chain_wfpt['v'])].copy()

# Drawing posterior plot with 
g = sns.jointplot('v', 'a', data = chain_wfpt, kind = 'kde', space = 0, color = 'g')
x0, x1 = g.ax_joint.get_xlim()
y0, y1 = g.ax_joint.get_ylim()
lims = [max(x0, y0), min(x1, y1)]
g.ax_joint.plot(curve['v'], curve['a'], 'r-')
plt.show()

In [None]:
cl.chain.shape

In [None]:
# trace plot v
ax = sns.lineplot(x = 'id', y = 'v', data = chain_wfpt)

In [None]:
# trace plot a
ax = sns.lineplot(x = 'id', y = 'a', data = chain_wfpt)

In [None]:
cl.chain.loc[cl.chain['log_posterior'].idxmax()][]

In [44]:
my_chain.loc[3]

v                  0.074192
a                  0.354625
log_posterior   -143.343700
Name: 3, dtype: float64

In [None]:
# Run experiment: Parameter recovery with MAP for DNN vs. NF_Likelihood

# Make sampler instance
cl2 = cpmcmc.choice_probabilities_analytic_mh()

# Model parameters
cl2.model_num = 0
cl2.model_time = '09_03_18_17_28_21'
cl2.model_signature = '_choice_probabilities_analytic_'
cl2.model_checkpoint = 'final'

# Make paths
cl2.make_checkpoint_path()
cl2.make_model_path()
cl2.model_path

# Attach DNN
cl2.get_dnn_keras()

In [None]:
# Main experiment 

# Main specification of experiment parameters
n_experiments = 150
cl.data_sim_params['n_samples'] = 5000
cl.mcmc_params['n_samples'] = 50000
cnt = 0
model_types = ['dnn', 'wfpt']

# Storage data
exp_chains = []
exp_data = pd.DataFrame(columns = ['experiment_id', 
                                   'data_v', 
                                   'data_a',
                                   'dnn_n_eff_samples',
                                   'dnn_map_loglik', 
                                   'dnn_map_a', 
                                   'dnn_map_v', 
                                   'nf_n_eff_samples',
                                   'nf_map_loglik',
                                   'nf_map_a',
                                   'nf_map_v', 
                                   ])

data_id = 0
while cnt < n_experiments:
        # Sample parameters for simulation
        v_tmp = np.random.uniform([-2, 2])
        a_tmp = np.random.uniform([0.5, 3])

        # Print info:
        print('Experiment: ', cnt)
        print('Model Type: ', model)
        print('v: ', v_tmp)
        print('a: ', a_tmp)

        # Data simulation parameters
        cl.data_sim_params['v'] = v_tmp
        cl.data_sim_params['a'] = a_tmp
        cl.data_sim_params['n_samples'] = n_samples

        # Make dataset
        cl.make_data_set()

        chain_dnn, _ = cl.metropolis_hastings_custom(method = 'dnn',
                                                              variance_scale_param = 0.4,
                                                              variance_epsilon = 0.05,
                                                              write_to_file = True,
                                                              print_steps = False)
        
        chain_wfpt, _ = cl.metropolis_hastings_custom(method = 'wfpt',
                                                               variance_scale_param = 0.4,
                                                               variance_epsilon = 0.05,
                                                               write_to_file = True,
                                                               print_steps = False)
        
        # Get number of effective samples
        # ------------------------------------------
        
        # dnn
        chain_autocorrelations_dnn = acf(chain_dnn['a'], nlags = 80)
        n_eff_samples_dnn = cl.mcmc_params['n_samples'] / (1 + 2 * np.sum(chain_autocorrelations_dnn))

        # wfpt
        chain_autcorrelations_wfpt = acf(chain_wfpt['a'], nlags = 80)
        n_eff_samples_wfpt = cl.mcmc_params['n_samples'] / (1 + 2 * np.sum(chain_autocorrelations_wfpt))
        
        
        # ------------------------------------------
        
        # Compute map
        # ------------------------------------------
        
        # dnn
        map_dnn = chain_dnn.loc[chain_dnn['log_posterior'].idxmax()]
        
        # wfpt
        map_wfpt = chain_wfpt.loc[chain_wfpt['log_posterior'].idxmax()]
        
        # ------------------------------------------
        
        # Store data
        # ------------------------------------------
        exp_data.loc[data_id]['experiment_id'] = cnt
        exp_data.loc[data_id]['data_v'] = cl.data_sim_params['v']
        exp_data.loc[data_id]['data_a'] = cl.data_sim_params['a']
        exp_data.loc[data_id]['dnn_n_eff_samples'] = n_eff_samples_dnn
        exp_data.loc[data_id]['dnn_map_loglik'] = map_dnn['log_posterior']
        exp_data.loc[data_id]['dnn_map_a'] = map_dnn['a'] 
        exp_data.loc[data_id]['dnn_map_v'] = map_dnn['v']
        exp_data.loc[data_id]['nf_n_eff_samples'] = n_eff_samples_wfpt
        exp_data.loc[data_id]['nf_map_loglik'] = map_wfpt['log_posterior']
        exp_data.loc[data_id]['nf_map_a'] = map_wfpt['a']
        exp_data.loc[data_id]['nf_map_v'] = map_wfpt['v']
        # ------------------------------------------
        
        # Store data
        exp_data.to_csv()
        chain_dnn.to_csv()
        chain_wfpt.to_csv()
        data_id += 1
        
    cnt += 1
    

In [46]:
list(my_chain.loc[my_chain['log_posterior'].idxmax()])
        

[-0.9916995913404392, 1.1069856964425269, -6.47019096120014]

In [47]:
my_chain.loc[my_chain['log_posterior'].idxmax()]

v               -0.991700
a                1.106986
log_posterior   -6.470191
Name: 51, dtype: float64

In [48]:
my_chain.loc[1]['v']  = 10

In [49]:
my_chain

Unnamed: 0,v,a,log_posterior
0,0.100000,1.000000,-164.629249
1,10.000000,1.000000,-164.629249
2,0.100000,1.000000,-164.629249
3,0.074192,0.354625,-143.343700
4,-0.285528,0.247464,-119.384262
5,-0.285528,0.247464,-119.384262
6,-0.317472,0.811239,-80.916285
7,-0.449896,1.398384,-28.784087
8,-0.557401,2.341446,-10.849456
9,-1.297820,0.926523,-7.536741
