# HDDM models results summary

Imports

In [None]:
from cmdstanpy import CmdStanModel
import os
import numpy as np
import pandas as pd
from datetime import datetime
import pickle
import cmdstanpy
from scipy import stats
import seaborn as sns
import arviz as az
import matplotlib.pyplot as plt
from scipy.stats import halfnorm, norm
import math

## Read the MCMC fit object from csv files

In [None]:
model = 'drift_boundary_pre2_tbb_normal'
path = '../plgrid_results/cond_models/sonata'
# path = '../plgrid_results/adapt/results'
# path = '../plgrid_results/ncond_models/stahl/acc'

In [None]:
fit = cmdstanpy.from_csv(path=f'{path}/{model}/')
inference_data = az.from_cmdstanpy(posterior=fit)

In [None]:
# print(fit.diagnose())

In [None]:
inference_data = az.from_cmdstanpy(posterior=fit)

## Results summary

### Save draws

In [None]:
# fit_df = fit.draws_pd()
# fit_df.to_csv(f'{path}/{model}/results/{model}_samples.csv')

### Create and save the summary

mean, stdv, percentiles, n_effects, r_hats, bayes factors

In [None]:
def calculate_bayes_factor_participant_level(
        fit_df, 
        parameters_list, 
        prior_distribution, 
        hierarchical=False
):
    prior_kde_0 = calculate_hierarchical_prior_kde(prior_distribution, N=100000)(0)
    
    participants_bf_01 = []

    for participants_parameter in parameters_list:
        participant_bf_01 = calculate_bayes_factor(fit_df, participants_parameter, prior_kde_0)
        participants_bf_01.append(participant_bf_01)

    return np.array(participants_bf_01).flatten()

In [None]:
def calculate_hierarchical_prior_kde(hyper_prior_distributions, N = 100000):
    if isinstance(hyper_prior_distributions, list):
        prior_samples = []
        for hyper_prior in hyper_prior_distributions:
            mean_hyper_prior = hyper_prior['mean']
            type, sd_hyper_prior = hyper_prior['sd']

            mean_samples = np.random.normal(loc=mean_hyper_prior['loc'], scale=mean_hyper_prior['scale'], size=(N,))
            if type == 'gamma':
                sd_samples = np.random.gamma(shape = sd_hyper_prior['shape'], scale=sd_hyper_prior['scale'], size=(N,))
            else:
                 sd_samples = halfnorm.rvs(loc = sd_hyper_prior['shape'], scale = sd_hyper_prior['scale'], size=(N,))      

            for i in range(0,N):
                prior_sample = np.random.normal(loc=mean_samples[i], scale=sd_samples[i], size=None)
                prior_samples.append(prior_sample)
        prior_samples = np.array(prior_samples)
                
    else:
        mean_hyper_prior = hyper_prior_distributions['mean']
        type, sd_hyper_prior = hyper_prior_distributions['sd']
        
        mean_samples = np.random.normal(loc=mean_hyper_prior['loc'], scale=mean_hyper_prior['scale'], size=(N,))
        if type == 'gamma':
            sd_samples = np.random.gamma(shape = sd_hyper_prior['shape'], scale=sd_hyper_prior['scale'], size=(N,))
        else:
            sd_samples = halfnorm.rvs(loc = sd_hyper_prior['shape'], scale = sd_hyper_prior['scale'], size=(N,))
        
        prior_samples = []
        for i in range(0,N):
            prior_sample = np.random.normal(loc=mean_samples[i], scale=sd_samples[i], size=None)
            prior_samples.append(prior_sample)
        prior_samples = np.array(prior_samples)
        
    # Prior density of hierarchical effect parameters
    prior_density = stats.gaussian_kde(prior_samples)
    
    return prior_density

In [None]:
def calculate_bayes_factor(fit_df, parameter, prior_kde):
    parameter_samples = fit_df[parameter].to_numpy()

    # Estimate density curves from samples
    parameter_kde = stats.gaussian_kde(parameter_samples)

    # Calculate Bayes Factors 01, evidence for the null hypothesis
    bayes_factor_01 = parameter_kde(0) / prior_kde

    return bayes_factor_01

In [None]:
def get_summary_with_bayes_factor(fit, priors_dict, variables_to_track):
    fit_df = fit.draws_pd()
    bayes_factors = dict()
    
    for parameter_name in priors_dict.keys():
        prior_kde, is_hierarchical = priors_dict[parameter_name]
    
        if is_hierarchical:
            parameters_list = [variable for variable in variables_to_track if parameter_name in variable ]
            
            participants_bf_01 = calculate_bayes_factor_participant_level(
                fit_df,
                parameters_list=parameters_list,
                prior_distribution = prior_kde,
                hierarchical=True
            )
            bayes_factors_hierarchical = dict(zip(parameters_list, participants_bf_01))
            bayes_factors.update(bayes_factors_hierarchical)
        else:
            if isinstance(prior_kde, list):
                prior_kde = np.mean(prior_kde)
                
            bf_01 = calculate_bayes_factor(
                fit_df, 
                parameter=parameter_name,
                prior_kde = prior_kde,
            )
            bayes_factors_population = dict(zip([parameter_name], bf_01))
            bayes_factors.update(bayes_factors_population)
        
    summary_df = fit.summary(percentiles=(2, 98))
    bayes_factors_df = pd.DataFrame.from_dict(bayes_factors, orient='index', columns=['Bayes_factor'])
    result_df = pd.concat([summary_df, bayes_factors_df], axis=1)

    return result_df

Define priors

In [None]:
# columns: prior, is hierarchical || model = drift_boundary
priors_dict = {
    'participants_alpha_cond':[{
        'mean': {'loc':0, "scale":1},
        'sd': ('gamma', {'shape':1, "scale":1}),
    }, True],
    'participants_alpha_ne':[{
        'mean': {'loc':0, "scale":0.2},
        'sd': ('normal', {'shape':0, "scale":.5}),
    }, True],
    'participants_alpha_ne_pre_acc':[{
        'mean': {'loc':0, "scale":0.2},
        'sd': ('normal', {'shape':0, "scale":.5}),
    }, True],
    'participants_alpha_ern':[[{
        'mean': {'loc':0, "scale":0.2},
        'sd': ('normal', {'shape':0, "scale":.5}),
    }, {
        'mean': {'loc':0, "scale":0.2},
        'sd': ('normal', {'shape':0, "scale":.5}),
    }], True],
    'participants_alpha_crn':[[{
        'mean': {'loc':0, "scale":0.2},
        'sd': ('normal', {'shape':0, "scale":.5}),
    }, {
        'mean': {'loc':0, "scale":0.2},
        'sd': ('normal', {'shape':0, "scale":.5}),
    }], True],
    'participants_delta_cond': [{
        'mean': {'loc':0, "scale":2},
        'sd': ('gamma', {'shape':1, "scale":1}),
    }, True],
    'participants_delta_ne': [{
        'mean': {'loc':0, "scale":.5},
        'sd': ('normal', {'shape':0, "scale":.5}),
    }, True],
    'participants_delta_ne_pre_acc':[{
        'mean': {'loc':0, "scale":0.5},
        'sd': ('normal', {'shape':0, "scale":.5}),
    }, True],
    'participants_delta_ern':[[{
        'mean': {'loc':0, "scale":0.2},
        'sd': ('normal', {'shape':0, "scale":.5}),
    }, {
        'mean': {'loc':0, "scale":0.2},
        'sd': ('normal', {'shape':0, "scale":.5}),
    }], True],
    'participants_delta_crn':[[{
        'mean': {'loc':0, "scale":0.2},
        'sd': ('normal', {'shape':0, "scale":.5}),
    }, {
        'mean': {'loc':0, "scale":0.2},
        'sd': ('normal', {'shape':0, "scale":.5}),
    }], True],
    'alpha_cond': [stats.norm.pdf(0, loc=0, scale=1), False],
    'alpha_ne': [stats.norm.pdf(0, loc=0, scale=0.2), False],
    'alpha_ern': [[stats.norm.pdf(0, loc=0, scale=0.2), stats.norm.pdf(0, loc=0, scale=0.2)], False],
    'alpha_crn': [[stats.norm.pdf(0, loc=0, scale=0.2), stats.norm.pdf(0, loc=0, scale=0.2)], False],
    'delta_cond': [stats.norm.pdf(0, loc=0, scale=2), False],
    'delta_ne': [stats.norm.pdf(0, loc=0, scale=.5), False],
    'delta_ern': [[stats.norm.pdf(0, loc=0, scale=0.5), stats.norm.pdf(0, loc=0, scale=0.5)], False],
    'delta_crn': [[stats.norm.pdf(0, loc=0, scale=0.5), stats.norm.pdf(0, loc=0, scale=0.5)], False],
    
    'alpha_pre_acc': [stats.norm.pdf(0, loc=0, scale=0.2), False],
    'alpha_ne_pre_acc': [stats.norm.pdf(0, loc=0, scale=0.2), False],
    'alpha_ne_cond': [stats.norm.pdf(0, loc=0, scale=0.2), False],
    'alpha_pre_acc_cond':[stats.norm.pdf(0, loc=0, scale=0.2), False],
    'alpha_ne_pre_acc_cond':[stats.norm.pdf(0, loc=0, scale=0.2), False],
   
    'delta_pre_acc': [stats.norm.pdf(0, loc=0, scale=.5), False],
    'delta_ne_pre_acc': [stats.norm.pdf(0, loc=0, scale=.5), False],
    'delta_ne_cond': [stats.norm.pdf(0, loc=0, scale=.5), False],
    'delta_pre_acc_cond':[stats.norm.pdf(0, loc=0, scale=.5), False],
    'delta_ne_pre_acc_cond':[stats.norm.pdf(0, loc=0, scale=.5), False],
}


# priors_dict = {
#     'participants_alpha_cond':[{
#         'mean': {'loc':0, "scale":1},
#         'sd': {'shape':1, "scale":1},
#     }, True],
#     'participants_alpha_ne':[{
#         'mean': {'loc':0, "scale":0.2},
#         'sd': {'shape':.3, "scale":1},
#     }, True],
#     'participants_alpha_ne_pre_acc':[{
#         'mean': {'loc':0, "scale":0.2},
#         'sd': {'shape':.3, "scale":1},
#     }, True],
#     'participants_alpha_ern':[[{
#         'mean': {'loc':0, "scale":0.2},
#         'sd': {'shape':.3, "scale":1},
#     }, {
#         'mean': {'loc':0, "scale":0.2},
#         'sd': {'shape':.3, "scale":1},
#     }], True],
#     'participants_alpha_crn':[[{
#         'mean': {'loc':0, "scale":0.2},
#         'sd': {'shape':.3, "scale":1},
#     }, {
#         'mean': {'loc':0, "scale":0.2},
#         'sd': {'shape':.3, "scale":1},
#     }], True],
#     'participants_delta_cond': [{
#         'mean': {'loc':0, "scale":2},
#         'sd': {'shape':1, "scale":1},
#     }, True],
#     'participants_delta_ne': [{
#         'mean': {'loc':0, "scale":.5},
#         'sd': {'shape':.3, "scale":1},
#     }, True],
#     'participants_delta_ne_pre_acc':[{
#         'mean': {'loc':0, "scale":0.5},
#         'sd': {'shape':.3, "scale":1},
#     }, True],
#     'participants_delta_ern':[[{
#         'mean': {'loc':0, "scale":0.2},
#         'sd': {'shape':.3, "scale":1},
#     }, {
#         'mean': {'loc':0, "scale":0.2},
#         'sd': {'shape':.3, "scale":1},
#     }], True],
#     'participants_delta_crn':[[{
#         'mean': {'loc':0, "scale":0.2},
#         'sd': {'shape':.3, "scale":1},
#     }, {
#         'mean': {'loc':0, "scale":0.2},
#         'sd': {'shape':.3, "scale":1},
#     }], True],
#     'alpha_cond': [stats.norm.pdf(0, loc=0, scale=1), False],
#     'alpha_ne': [stats.norm.pdf(0, loc=0, scale=0.2), False],
#     'alpha_ern': [[stats.norm.pdf(0, loc=0, scale=0.2), stats.norm.pdf(0, loc=0, scale=0.2)], False],
#     'alpha_crn': [[stats.norm.pdf(0, loc=0, scale=0.2), stats.norm.pdf(0, loc=0, scale=0.2)], False],
#     'delta_cond': [stats.norm.pdf(0, loc=0, scale=2), False],
#     'delta_ne': [stats.norm.pdf(0, loc=0, scale=.5), False],
#     'delta_ern': [[stats.norm.pdf(0, loc=0, scale=0.5), stats.norm.pdf(0, loc=0, scale=0.5)], False],
#     'delta_crn': [[stats.norm.pdf(0, loc=0, scale=0.5), stats.norm.pdf(0, loc=0, scale=0.5)], False],
    
#     'alpha_pre_acc': [stats.norm.pdf(0, loc=0, scale=0.2), False],
#     'alpha_ne_pre_acc': [stats.norm.pdf(0, loc=0, scale=0.2), False],
#     'alpha_ne_cond': [stats.norm.pdf(0, loc=0, scale=0.2), False],
#     'alpha_pre_acc_cond':[stats.norm.pdf(0, loc=0, scale=0.2), False],
#     'alpha_ne_pre_acc_cond':[stats.norm.pdf(0, loc=0, scale=0.2), False],
   
#     'delta_pre_acc': [stats.norm.pdf(0, loc=0, scale=.5), False],
#     'delta_ne_pre_acc': [stats.norm.pdf(0, loc=0, scale=.5), False],
#     'delta_ne_cond': [stats.norm.pdf(0, loc=0, scale=.5), False],
#     'delta_pre_acc_cond':[stats.norm.pdf(0, loc=0, scale=.5), False],
#     'delta_ne_pre_acc_cond':[stats.norm.pdf(0, loc=0, scale=.5), False],
# }
# priors_dict = {
#     'participants_alpha_ne':[{
#         'mean': {'loc':0, "scale":0.2},
#         'sd': {'shape':.3, "scale":1},
#     }, True],
#     'participants_alpha_ne_pre_acc':[{
#         'mean': {'loc':0, "scale":0.2},
#         'sd': {'shape':.3, "scale":1},
#     }, True],
#     'participants_alpha_ern':[[{
#         'mean': {'loc':0, "scale":0.2},
#         'sd': {'shape':.3, "scale":1},
#     }, {
#         'mean': {'loc':0, "scale":0.2},
#         'sd': {'shape':.3, "scale":1},
#     }], True],
#     'participants_alpha_crn':[[{
#         'mean': {'loc':0, "scale":0.2},
#         'sd': {'shape':.3, "scale":1},
#     }, {
#         'mean': {'loc':0, "scale":0.2},
#         'sd': {'shape':.3, "scale":1},
#     }], True],
#     'participants_delta_ne': [{
#         'mean': {'loc':0, "scale":.5},
#         'sd': {'shape':.3, "scale":1},
#     }, True],
#     'participants_delta_ne_pre_acc':[{
#         'mean': {'loc':0, "scale":0.5},
#         'sd': {'shape':.3, "scale":1},
#     }, True],
#     'participants_delta_ern':[[{
#         'mean': {'loc':0, "scale":0.2},
#         'sd': {'shape':.3, "scale":1},
#     }, {
#         'mean': {'loc':0, "scale":0.2},
#         'sd': {'shape':.3, "scale":1},
#     }], True],
#     'participants_delta_crn':[[{
#         'mean': {'loc':0, "scale":0.2},
#         'sd': {'shape':.3, "scale":1},
#     }, {
#         'mean': {'loc':0, "scale":0.2},
#         'sd': {'shape':.3, "scale":1},
#     }], True],
#     'alpha_ne': [stats.norm.pdf(0, loc=0, scale=0.2), False],
#     'alpha_ern': [[stats.norm.pdf(0, loc=0, scale=0.2), stats.norm.pdf(0, loc=0, scale=0.2)], False],
#     'alpha_crn': [[stats.norm.pdf(0, loc=0, scale=0.2), stats.norm.pdf(0, loc=0, scale=0.2)], False],
#     'delta_ne': [stats.norm.pdf(0, loc=0, scale=.5), False],
#     'delta_ern': [[stats.norm.pdf(0, loc=0, scale=0.5), stats.norm.pdf(0, loc=0, scale=0.5)], False],
#     'delta_crn': [[stats.norm.pdf(0, loc=0, scale=0.5), stats.norm.pdf(0, loc=0, scale=0.5)], False],
    
#     'alpha_pre_acc': [stats.norm.pdf(0, loc=0, scale=0.2), False],
#     'alpha_ne_pre_acc': [stats.norm.pdf(0, loc=0, scale=0.2), False],
   
#     'delta_pre_acc': [stats.norm.pdf(0, loc=0, scale=.5), False],
#     'delta_ne_pre_acc': [stats.norm.pdf(0, loc=0, scale=.5), False],
# }

In [None]:
model_variables = list(fit.method_variables().keys()) + ['chain__', 'iter__', 'draw__']
variables_to_track = fit.draws_pd().columns.to_numpy()
variables_to_track = [variable for variable in variables_to_track if ('log_lik' not in variable and variable not in model_variables)]

try:
    summary = get_summary_with_bayes_factor(fit, priors_dict, variables_to_track)
    summary.to_csv(f'{path}/{model}/results/{model}_summary_with_bf_plgrid_test.csv')
except Exception as e:
    print(e)
    print('Saving summary without bayes factors')
    summary = fit.summary()
    summary.to_csv(f'{path}/{model}/results/{model}_summary.csv')

### Calculate WAIC

In [None]:
def waic(log_likelihood):
    """Calculates the Watanabe-Akaike information criteria.
    Calculates pWAIC1 and pWAIC2
    according to http://www.stat.columbia.edu/~gelman/research/published/waic_understand3.pdf
    Parameters
    ----------
    pointwise : bool, default to False
        By default, gives the averaged waic.
        Set to True is you want additional waic per observation.
    Returns
    -------
    out: dict
        Dictionary containing lppd (log pointwise predictive density),
        p_waic, waic, waic_se (standard error of the waic), and
        pointwise_waic (when `pointwise` is True).
    """

    N = log_likelihood.shape[1]
    likelihood = np.exp(log_likelihood)

    mean_l = np.mean(likelihood, axis=0) # N observations

    pointwise_lppd = np.log(mean_l)
    lppd = np.sum(pointwise_lppd)

    pointwise_var_l = np.var(log_likelihood, axis=0) # N observations
    var_l = np.sum(pointwise_var_l)

    pointwise_waic = - 2*pointwise_lppd +  2*pointwise_var_l
    waic = -2*lppd + 2*var_l
    waic_se = np.sqrt(N * np.var(pointwise_waic))

    model_statistics = {'lppd':lppd,
           'p_waic':var_l,
           'waic':waic,
           'waic_se':waic_se}
    return model_statistics

In [None]:
fit_df = fit.draws_pd()

log_likelihood_columns = [col for col in fit_df.columns if 'log_lik' in col]
log_likelihood = fit_df[log_likelihood_columns].to_numpy()

model_statistics = waic(log_likelihood)
model_statistics_df = pd.DataFrame.from_dict(model_statistics, orient='index', columns=['value'])

display(model_statistics_df)

model_statistics_df.to_csv(f'{path}/{model}/results/{model}_model_statistics_1-chains.csv')

## Plots

In [None]:
model_variables = list(fit.method_variables().keys()) + ['chain__', 'iter__', 'draw__']
variables_to_track = fit.draws_pd().columns.to_numpy()
variables_to_track = [variable for variable in variables_to_track if ('log_lik' not in variable and variable not in model_variables)]
base_variables_to_track = [variable for variable in variables_to_track if 'participant' not in variable] 

# model_variables = list(fit.method_variables().keys()) + ['chain__', 'iter__', 'draw__']
# variables_to_track = fit.draws_pd().columns.to_numpy()
# variables_to_track = [variable for variable in variables_to_track if ('log_lik' not in variable and variable not in model_variables)]
# # base_variables_to_track = [variable for variable in variables_to_track if 'participant' not in variable] + ['chain__', 'iter__'] 
# base_variables_to_track = [variable for variable in variables_to_track if 'participant' not in variable]


# fit_df = fit.draws_pd()
# fit_df_params = fit_df[base_variables_to_track]
# melted_fit_df_params = pd.melt(fit_df_params, id_vars=['chain__', 'iter__'], var_name='parameter', value_name='value')

### Draw chains

In [None]:
# az.style.use("arviz-doc")

# cm = 1/2.54
# dpi = 300

# az.rcParams["plot.max_subplots"] = 500
# plt.rcParams['figure.dpi'] = dpi
# plt.rcParams['ytick.labelsize'] = 20
# plt.rcParams['xtick.labelsize'] = 25
# plt.rcParams['axes.labelsize'] = 20
# plt.rcParams['axes.titlesize'] = 25
# plt.rcParams["font.size"] = 25
# plt.rcParams["axes.edgecolor"] = ".15"
# plt.rcParams["axes.linewidth"]  = 0.5
# plt.rcParams['ytick.major.size'] = 5
# plt.rcParams['ytick.major.width'] = 1
# plt.rcParams['figure.constrained_layout.use'] = True
             
# fig = plt.figure()

# var_names = ['participants_ter']
# az.plot_trace(
#     inference_data, 
#     # var_names=var_names, 
#     divergences=None,
#     # figsize=(14*cm, 3*cm*len(var_names))
# )

# fig.tight_layout()
# # plt.savefig(f'{path}/{model}/results/{model}_traceplots', bbox_inches='tight')

In [None]:
az.style.use("arviz-doc")

cm = 1/2.54
dpi = 300

az.rcParams["plot.max_subplots"] = 500
plt.rcParams['figure.dpi'] = dpi
plt.rcParams['ytick.labelsize'] = 10
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['axes.labelsize'] = 15
plt.rcParams['axes.titlesize'] = 10
plt.rcParams["font.size"] = 10
plt.rcParams["axes.edgecolor"] = ".15"
plt.rcParams["axes.linewidth"]  = 0.5
plt.rcParams['ytick.major.size'] = 5
plt.rcParams['ytick.major.width'] = 1
plt.rcParams['figure.constrained_layout.use'] = True
plt.rcParams['lines.linewidth'] = .5             
plt.rcParams['axes.titleweight'] = 'normal'

fig = plt.figure()

var_names_participants = [
    'participants_ter', 
    'participants_alpha', 
    'participants_delta', 
    'participants_alpha_cond', 
    'participants_delta_cond', 
    'participants_alpha_ne', 
    'participants_delta_ne', 
    'participants_alpha_ne_pre_acc',
    'participants_delta_ne_pre_acc'
]

var_names_main = [
    'participants_alpha_ern', 
    'participants_alpha_crn', 
    'participants_delta_ern', 
    'participants_delta_crn', 
]

var_names_population = base_variables_to_track

for index, parameters in enumerate([var_names_participants, var_names_main, var_names_population]):

    axes = az.plot_trace(
        inference_data, 
        var_names=parameters, 
        divergences=None,
        figsize=(15*cm, 3.5*len(parameters)*cm)
    )
    
    fig.tight_layout()
                         
    plt.savefig(f'{path}/{model}/results/{model}_traceplots_{index}', bbox_inches='tight')

In [None]:
# g = sns.relplot(
#     data=melted_fit_df_params,
#     x="iter__", 
#     y="value",
#     hue="chain__", 
#     col="parameter",
#     kind="line",
#     col_wrap=4,
#     height=5, 
#     aspect=1.1, 
#     facet_kws=dict(sharex=True, sharey=False),
#     palette='colorblind'
# )

# fig = plt.gcf()

# plt.savefig(f'{path}/{model}/results/{model}_chains-6.png', bbox_inches='tight')

In [None]:
# sns.histplot(
#     data = fit_df,
#     x = 'lp__', 
#     kde=True
# )

### Draw base parameters distributions

In [None]:
cm = 1/2.54
dpi = 300

az.rcParams["plot.max_subplots"] = 500
plt.rcParams['figure.dpi'] = dpi
plt.rcParams['ytick.labelsize'] = 5
plt.rcParams['xtick.labelsize'] = 6
plt.rcParams['axes.labelsize'] = 5
plt.rcParams['axes.titlesize'] = 6
plt.rcParams["font.size"] = 5
plt.rcParams["axes.edgecolor"] = ".15"
plt.rcParams["axes.linewidth"]  = 0.5
plt.rcParams['ytick.major.size'] = 5
plt.rcParams['ytick.major.width'] = 1
plt.rcParams['figure.constrained_layout.use'] = True
plt.rcParams['lines.linewidth'] = .7             
plt.rcParams['axes.titleweight'] = 'normal'
plt.rcParams['axes.titlepad'] = .1 


variables_sd = ['ter_sd', 'alpha_sd', 'delta_sd', 'alpha_cond_sd', 'delta_cond_sd', 'alpha_ne_sd', 'delta_ne_sd', 'alpha_ne_pre_acc_sd', 'delta_ne_pre_acc_sd']
variables_main = ['ter', 'alpha', 'delta']
effects = ['alpha_cond', 'alpha_ne', 'alpha_ne_pre_acc', 'alpha_ne_cond', 'alpha_pre_acc_cond', 'alpha_ne_pre_acc_cond', 'delta_cond', 'delta_ne', 'delta_ne_pre_acc', 'delta_ne_cond', 'delta_pre_acc_cond', 'delta_ne_pre_acc_cond',]
main_effects = ['alpha_ern', 'alpha_crn', 'delta_ern', 'delta_crn']

def plot_posteriors(inference_data, variables, ref_val=None, save=False):
    print(save)
    n_cols = 6 if len(variables) > 6 else len(variables)
    n_rows = 1 if len(variables) < 6 else math.ceil(len(variables)/6)
    
    fig = plt.figure(dpi=dpi)

    axes = az.plot_posterior(
        inference_data,
        var_names=variables,
        grid = (n_rows,n_cols),
        hdi_prob = 0.95,
        figsize=(2.5*n_cols*cm, 3*cm*n_rows),
        point_estimate='mean',
        ref_val = ref_val,
        backend_kwargs = {'gridspec_kw': {'wspace':0.45, 'hspace': 0.35}, 'subplot_kw': {'box_aspect':1.2}}
    )
    
    fig = axes.flatten()[0].get_figure()
    fig.tight_layout()
    if save:
        plt.savefig(f'{path}/{model}/results/{model}_posteriors_{save}', bbox_inches='tight')
    plt.plot()


for index, (variables, ref) in enumerate(zip([variables_sd, variables_main, effects, main_effects], [None, None, 0, 0])):
    plot_posteriors(
        inference_data, 
        variables,
        ref,
        save=index+1
    )

In [None]:
# model_variables = list(fit.method_variables().keys()) + ['chain__', 'iter__', 'draw__']
# variables_to_track = fit.draws_pd().columns.to_numpy()
# variables_to_track = [variable for variable in variables_to_track if ('log_lik' not in variable and variable not in model_variables)]
# base_variables_to_track = [variable for variable in variables_to_track if 'participant' not in variable] 

# fit_df = fit.draws_pd()
# fit_df_params = fit_df[base_variables_to_track]
# melted_fit_df_params = pd.melt(fit_df_params, id_vars=['chain__', 'iter__'], var_name='parameter', value_name='value')

In [None]:
# def plot_violin_with_colored_ci(data, **kwargs):
#     subset = data
#     lower = np.percentile(subset['value'], 5)
#     upper = np.percentile(subset['value'], 95)
#     mean = np.mean(subset['value'])
    
#     violin = sns.violinplot(x='value', data=data, fill=False, inner='quart', **kwargs)
#     path_data = violin.get_children()[0].get_paths()
#     x = path_data[0].vertices[:, 0]
#     y = path_data[0].vertices[:, 1]    

#     plt.fill_between(x, y, where=(x >= lower) & (x <= upper), color='blue', alpha=0.5)
    
#     for l in violin.lines:
#         l.set_linestyle('--')
#         l.set_linewidth(1.5)
#         l.set_color('black')
    
#     plt.axvline(mean, color='r', linestyle='-')

# g = sns.FacetGrid(
#     melted_fit_df_params, 
#     col="parameter", 
#     sharex=False, 
#     sharey=False, 
#     col_wrap=4,
#     palette='colorblind',
#     aspect=2,
#     height=2,
# )
# g.set_titles(col_template="{col_name}")
# g.map_dataframe(plot_violin_with_colored_ci)
# g.set_ylabels("")
# g.set_xlabels("")

# g.fig.tight_layout()

# fig = plt.gcf()
# plt.savefig(f'{path}/{model}/results/{model}_distributions_6-chains.png', bbox_inches='tight')

### Pair-plots of base parameters
To check cross-parameter correlations

In [None]:
# base_variables_to_track = [variable for variable in variables_to_track if 'participant' not in variable]
# fit_df = fit.draws_pd()
# fit_df_params = fit_df[base_variables_to_track]

# g = sns.PairGrid(fit_df_params)
# g.map_upper(sns.histplot)
# g.map_lower(sns.kdeplot, fill=True)
# g.map_diag(sns.histplot, kde=True)

# g.fig.tight_layout()
# fig = plt.gcf()
# plt.savefig(f'{path}/{model}/results/{model}_pair_plots_6-chains.png', bbox_inches='tight')

### Bayes factors distributions

In [None]:
summary_copy = summary.copy()
summary_copy['parameter_name'] = summary_copy.index.str.extract(r'([a-zA-Z_]+(?:_[a-zA-Z]+)*)', expand=False)
summary_copy.groupby(['parameter_name']).agg(['min', 'max', 'median']).to_csv(f'{path}/{model}/results/{model}_parameters_aggegation_1-chains.csv')

summary_copy_random_effects = summary_copy[(~summary_copy['Bayes_factor'].isna()) & (summary_copy['parameter_name'].str.contains('participants', case=False, na=False))]
summary_copy_random_effects.reset_index(inplace=True)
summary_copy_random_effects = summary_copy_random_effects.sort_values(by=['parameter_name', 'Bayes_factor'])

In [None]:
sns.set_palette("colorblind")

g = sns.FacetGrid(
    summary_copy_random_effects, 
    col='parameter_name', 
    col_wrap=3, 
    sharey=False,
    sharex=False,
    aspect=1,
    height=6,
)

g.map(
    sns.pointplot, 
    'Bayes_factor', 
    'index', 
)

g.set_yticklabels([])
g.set_axis_labels("Bayes factor", "ID")
g.set_titles(col_template="{col_name}")

# Add vertical lines at x-values 0.1 and 10
def add_vertical_lines(x, color, linestyle):
    plt.axvline(x=x, color=color, linestyle=linestyle)
    
g.map(add_vertical_lines, color='g', linestyle='--', x=0.5)
g.map(add_vertical_lines, color='black', linestyle='--', x=1)
g.map(add_vertical_lines, color='r', linestyle='--', x=10)
g.map(add_vertical_lines, color='r', linestyle='--', x=5)


g.fig.tight_layout()

fig = plt.gcf()
plt.savefig(f'{path}/{model}/results/{model}_random_effects_bfs_1-chains.png', bbox_inches='tight')

### Effects distributions

In [None]:
model_variables = list(fit.method_variables().keys()) + ['chain__', 'iter__', 'draw__']
variables_to_track = fit.draws_pd().columns.to_numpy()
variables_to_track = [variable for variable in variables_to_track if ('log_lik' not in variable and variable not in model_variables)] + ['chain__', 'iter__'] 

fit_df = fit.draws_pd()
fit_df_params = fit_df[variables_to_track]
melted_fit_df_params = pd.melt(fit_df_params, id_vars=['chain__', 'iter__'], var_name='parameter', value_name='value')
melted_fit_df_params['parameter_name'] = melted_fit_df_params['parameter'].str.extract(r'([a-zA-Z_]+(?:_[a-zA-Z]+)*)', expand=False)

random_effects = melted_fit_df_params[(~melted_fit_df_params['value'].isna()) & (melted_fit_df_params['parameter_name'].str.contains('participants', case=False, na=False))]
random_effects.reset_index(inplace=True)
# random_effects = random_effects.sort_values(by=['parameter_name', 'Mean'])

# Group by 'parameter_name' and calculate mean of 'value'
mean_values = random_effects.groupby('parameter')['value'].mean().reset_index()
merged_df = pd.merge(random_effects, mean_values, on='parameter', suffixes=('', '_mean'))

# # Sort by the mean of 'value'
grouped_df_sorted = merged_df.sort_values(by=['parameter_name', 'value_mean'])

In [None]:
summary_copy_random_effects = summary_copy[(~summary_copy['Bayes_factor'].isna()) & (summary_copy['parameter_name'].str.contains('participants', case=False, na=False))]
summary_copy_random_effects.reset_index(inplace=True)
summary_copy_random_effects.head()

In [None]:
fit_df_params.head()

In [None]:
# set plotting parameters
cm = 1/2.54
dpi = 300
plt.rcParams['figure.dpi'] = dpi
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams['ytick.labelsize'] = 20
plt.rcParams['xtick.labelsize'] = 20
plt.rcParams['axes.labelsize'] = 20
plt.rcParams['axes.titlesize'] = 20
plt.rcParams["font.size"] = 20
plt.rcParams["axes.edgecolor"] = ".15"
plt.rcParams["axes.linewidth"]  = 2
plt.rcParams['ytick.major.size'] = 5
plt.rcParams['ytick.major.width'] = 1

sns.set_style("ticks")
palette = sns.color_palette("colorblind")
             
fig = plt.figure(figsize=(15*cm, 7*cm), dpi=dpi)

# Define which colors to use from the palette
blue = palette[0]  
green = palette[2] 
red = palette[3]

# parameter_name_order = [
#     'participants_alpha', 
#     'participants_alpha_cond',  
#     'participants_alpha_ern',
#     'participants_alpha_crn',
#     'participants_alpha_ne',
#     'participants_alpha_ne_pre_acc',
#     'participants_delta', 
#     'participants_delta_cond',
#     'participants_delta_ern',
#     'participants_delta_crn',
#     'participants_delta_ne',
#     'participants_delta_ne_pre_acc', 
# ]

parameter_name_order = [
    'participants_alpha', 
    'participants_alpha_cond',  
    'participants_alpha_ern',
    'participants_alpha_crn',
    'participants_delta', 
    'participants_delta_cond',
    'participants_delta_ern',
    'participants_delta_crn',
    'participants_alpha_ne',
    'participants_alpha_ne_pre_acc',
    'participants_delta_ne',
    'participants_delta_ne_pre_acc', 
]

g = sns.FacetGrid(
    grouped_df_sorted, 
    col='parameter_name', 
    col_wrap=4, 
    sharey=False,
    sharex=False,
    aspect=1.5,
    height=4,
    col_order=parameter_name_order
)

# Define a custom plotting function to color error bars
def plot_effect_sizes(x, y, **kwargs):
    ax = plt.gca()
    data = kwargs.pop('data')
    errorbar = kwargs.pop('errorbar', None)
    
    sns.pointplot(x=x, y=y, data=data, errorbar=errorbar, color='k', alpha=1)

    for line in ax.get_lines():
        x_data, y_data = line.get_data()

        if len(x_data) == 2:
            if np.min(y_data) < 0 < np.max(y_data):
                line.set_color(red)
            else:
                line.set_color(green)
            line.set_alpha(0.7)
            plt.setp(line, zorder=10)  
        else:
            plt.setp(line, zorder=10000)
    
    ax.xaxis.set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.yaxis.set_visible(True)
            
g.map_dataframe(plot_effect_sizes, 'parameter', 'value', errorbar=("pi", 95))


g.set_axis_labels("ID", "Mean effect")
g.set_titles(col_template="{col_name}")

def add_vertical_lines(x, color, linestyle):
    line = plt.axvline(x=x, color=color, linestyle=linestyle)
    plt.setp(line, zorder=1000)

def add_horizontal_lines(y, color, linestyle):
    line = plt.axhline(y=y, color=color, linestyle=linestyle)
    plt.setp(line, zorder=1000)

    
g.map(add_horizontal_lines, color=red, linestyle='-', y=0)

g.fig.tight_layout()
fig = plt.gcf()
plt.savefig(f'{path}/{model}/results/{model}_random_effects_distributions.png', bbox_inches='tight')

----

In [None]:
a = grouped_df_sorted[(grouped_df_sorted['parameter_name'] == 'participants_delta_ern') & (grouped_df_sorted['parameter'] =='participants_delta_ern[110]')]
a

In [None]:
sns.histplot(d)

In [None]:
d = a['value'].to_numpy().flatten()

import scipy.stats as st

mean = np.mean(d)
std_dev = np.std(d, ddof=1)  # ddof=1 for sample standard deviation
n = len(d)

# Calculate critical value (using t-distribution for small sample size)
confidence_level = 0.96
alpha = 1 - confidence_level
dof = n - 1
t_critical = stats.t.ppf(1 - alpha/2, df=dof)

# Calculate margin of error
margin_of_error = t_critical * std_dev / np.sqrt(n)

# Calculate confidence interval
lower_bound = mean - margin_of_error
upper_bound = mean + margin_of_error

# Print the confidence interval
print(f"98% Confidence Interval: ({lower_bound}, {upper_bound})")

In [None]:
# Calculate percentiles
q1 = np.percentile(d, 2)
median = np.percentile(d, 50)  # also known as 50th percentile (median)
q3 = np.percentile(d, 98)

# Print the percentiles
print(f"25th Percentile (Q1): {q1}")
print(f"Median (50th Percentile): {median}")
print(f"75th Percentile (Q3): {q3}")

In [None]:
plt.figure(figsize=(10, 20))

sns.pointplot(data=grouped_df_sorted[grouped_df_sorted['parameter_name'] == 'participants_delta_ern'], y="parameter", x="value", errorbar="ci")

plt.show()

### Test priors

In [None]:
priors_dict = {
    'participants_delta_ne':[{
        'mean': {'loc':0, "scale":0.5},
        'sd': {'shape':0.3, "scale":1},
    }, True],
    'participants_alpha_ne_pre_acc':[{
        'mean': {'loc':0, "scale":0.2},
        'sd': {'shape':1, "scale":1},
    }, True],
    'participants_alpha_ern':[
        [
            {
                'mean': {'loc':0, "scale":0.2},
                'sd': {'shape':.3, "scale":1},
            },
            {
                'mean': {'loc':0, "scale":0.2},
                'sd': {'shape':.3, "scale":1},
            },
        ], 
        True
    ],
}

In [None]:
hyper_prior_distributions = priors_dict['participants_alpha_ern'][0]
N=10000

In [None]:
mean_hyper_prior = hyper_prior_distributions[0]['mean']
sd_hyper_prior = hyper_prior_distributions[0]['sd']

mean_samples = np.random.normal(loc=mean_hyper_prior['loc'], scale=mean_hyper_prior['scale'], size=(N,))
sd_samples = np.random.gamma(shape = sd_hyper_prior['shape'], scale=sd_hyper_prior['scale'], size=(N,))

prior_samples = []

for i in range(0,N):
    prior_sample = np.random.normal(loc=mean_samples[i], scale=sd_samples[i], size=None)
    prior_samples.append(prior_sample)
prior_samples = np.array(prior_samples)

x_vals = np.linspace(-1, 1, 1000)

# Prior density of hierarchical effect parameters
prior_density = stats.gaussian_kde(prior_samples)

parameter_kde = stats.gaussian_kde(fit.draws_pd()['participants_alpha_ern[110]'])

In [None]:
# Generate a range of values over which to evaluate the density

# Evaluate the density
# y_vals = prior_density(x_vals)
# y_sample_vals = parameter_kde(x_vals)
x_vals = np.linspace(-1, 1, 1000)
y_sample2_values = stats.norm.pdf(x_vals, loc=0, scale=0.2)


# Plot the density using matplotlib
plt.figure(figsize=(8, 6))
# plt.plot(x_vals, y_vals, label='Gaussian KDE')
# plt.plot(x_vals, y_sample_vals, label='Gaussian KDE', color='r')
plt.plot(x_vals, y_sample2_values, label='Gaussian KDE')

In [None]:
# Generate a range of values over which to evaluate the density
x_zoom = np.linspace(-0.05, 0.05, 1000)

y_vals = prior_density(x_zoom)
y_sample_vals = parameter_kde(x_zoom)
y_sample2_values = stats.norm.pdf(x_zoom, loc=0, scale=0.5)

# Plot the density using matplotlib
plt.figure(figsize=(8, 6))
plt.plot(x_zoom, y_vals, label='Gaussian KDE')
plt.plot(x_zoom, y_sample_vals, label='Gaussian KDE', color='r')
plt.plot(x_zoom, y_sample2_values, label='Gaussian KDE')

In [None]:
sns.histplot(sd_samples)

In [None]:
# Calculate Bayes Factors 01, evidence for the null hypothesis
bayes_factor_01 = parameter_kde(0) / prior_density(0)
print(bayes_factor_01)

bayes_factor_01 = parameter_kde(0) / stats.norm.pdf(0, loc=0, scale=0.2)
print(bayes_factor_01)

In [None]:
sns.histplot(fit.draws_pd()['participants_delta_ern[209]'])

# up to 10 - it is
# up to 25 - maybe
# above 25 - there is no

In [None]:
N = 100000
mean_samples_1 = np.random.normal(loc=0, scale=0.2, size=(N,))
mean_samples_2 = np.random.normal(loc=0, scale=0.5, size=(N,))

prior_samples =  np.concatenate((mean_samples_1, mean_samples_2))
x_vals = np.linspace(-1, 1, 100000)

# Prior density of hierarchical effect parameters
prior_density = stats.gaussian_kde(prior_samples)

In [None]:
# Evaluate the density
y_vals = prior_density(x_vals)
y_sample2_values = stats.norm.pdf(x_vals, loc=0, scale=0.2)

plt.figure(figsize=(8, 6))
plt.plot(x_vals, y_vals, label='Gaussian KDE')
# plt.plot(x_vals, y_sample2_values, label='Gaussian KDE')

In [None]:
len(prior_samples)

In [None]:
prior_density(0)

In [None]:
prior_kde = [stats.norm.pdf(0, loc=0, scale=0.2), stats.norm.pdf(0, loc=0, scale=0.5)]
prior_kde = np.mean(prior_kde)

In [None]:
prior_kde

In [None]:
priors_dict = {
    'participants_delta_ne':[{
        'mean': {'loc':0, "scale":0.5},
        'sd': {'shape':0.3, "scale":1},
    }, True],
    'participants_alpha_ne_pre_acc':[{
        'mean': {'loc':0, "scale":0.2},
        'sd': {'shape':1, "scale":1},
    }, True],
    'participants_delta_ern':[
        [
            {
                'mean': {'loc':0, "scale":0.2},
                'sd': {'shape':.3, "scale":1},
            },
            {
                'mean': {'loc':0, "scale":0.2},
                'sd': {'shape':.3, "scale":1},
            },
        ], 
        True
    ],
}

In [None]:
hyper_prior_distributions = priors_dict['participants_delta_ern'][0]
N=10000

In [None]:
hyper_prior_distributions

In [None]:
prior_samples = []
for hyper_prior in hyper_prior_distributions:
    mean_hyper_prior = hyper_prior['mean']
    sd_hyper_prior = hyper_prior['sd']

    mean_samples = np.random.normal(loc=mean_hyper_prior['loc'], scale=mean_hyper_prior['scale'], size=(N,))
    sd_samples = np.random.gamma(shape = sd_hyper_prior['shape'], scale=sd_hyper_prior['scale'], size=(N,))

    for i in range(0,N):
        prior_sample = np.random.normal(loc=mean_samples[i], scale=sd_samples[i], size=None)
        prior_samples.append(prior_sample)
prior_samples = np.array(prior_samples)

In [None]:
# Prior density of hierarchical effect parameters
prior_density = stats.gaussian_kde(prior_samples)

parameter_kde = stats.gaussian_kde(fit.draws_pd()['participants_delta_ern[114]'])

In [None]:
# Generate a range of values over which to evaluate the density
x_zoom = np.linspace(-0.05, 0.05, 2000)

y_vals = prior_density(x_zoom)
y_sample_vals = parameter_kde(x_zoom)
y_sample2_values = stats.norm.pdf(x_zoom, loc=0, scale=0.5)

# Plot the density using matplotlib
plt.figure(figsize=(8, 6))
plt.plot(x_zoom, y_vals, label='Gaussian KDE')
plt.plot(x_zoom, y_sample_vals, label='Gaussian KDE', color='r')
# plt.plot(x_zoom, y_sample2_values, label='Gaussian KDE')

In [None]:
sns.histplot(fit.draws_pd()['participants_delta_ern[114]'])