In [336]:
import numpy as np
import pandas as pd

from methods.model_fitting_utilities import softmax_neg_log_likelihood
from scipy import stats
from scipy.optimize import minimize

In [337]:
summary = pd.read_csv('./data/model_fitting_outputs/summary_data.csv', sep=';')
final_judgements = pd.read_csv('./data/model_fitting_outputs/final_judgements.csv')
summary['information_gained'] = (summary['prior_entropy'] - summary['posterior_entropy']) / summary['prior_entropy']

In [338]:

print(f'Final judgements shape: {final_judgements.shape}')
posteriors = {}

for model_name in summary.model_name.unique():
    posteriors[model_name] = pd.read_csv(f'./data/model_fitting_outputs/{model_name}/posteriors.csv')
    print(f'{model_name} posteriors shape: {posteriors[model_name].shape}')

Final judgements shape: (15625, 1092)
normative posteriors shape: (15625, 1092)
LC_discrete posteriors shape: (15625, 1092)
LC_discrete_attention posteriors shape: (15625, 1092)
change_discrete posteriors shape: (15625, 1092)


In [339]:
print('Datapoints per experiment:')
print('Total:', f'N={summary.pid.nunique()}', f'length={summary.shape[0]}')
print('Experiment 2:', f'N={summary[summary.experiment == "experiment_1"].pid.nunique()}', f'length={summary[summary.experiment == "experiment_1"].shape[0]}')
print('Experiment 2:', f'N={summary[summary.experiment == "experiment_2"].pid.nunique()}', f'length={summary[summary.experiment == "experiment_2"].shape[0]}')
print('Experiment 3:', f'N={summary[summary.experiment == "experiment_3"].pid.nunique()}', f'length={summary[summary.experiment == "experiment_3"].shape[0]}')

# Split data into 3 dataset
summary_1 = summary[summary.experiment == "experiment_1"]
summary_2 = summary[summary.experiment == "experiment_2"]
summary_3 = summary[summary.experiment == "experiment_3"]


Datapoints per experiment:
Total: N=302 length=4344
Experiment 2: N=60 length=480
Experiment 2: N=121 length=1936
Experiment 3: N=121 length=1928


In [340]:
summary_3

Unnamed: 0,utid,pid,experiment,difficulty,scenario,model_name,ground_truth,posterior_map,posterior_judgement,prior_judgement,prior_entropy,posterior_entropy,model_specs,information_gained
0,3_566feba6b937e400052d33b2_finance_congruent,566feba6b937e400052d33b2,experiment_3,congruent,finance,normative,[ 1. 1. -0.5 -0.5 1. 1. ],[ 1. 1. -0.5 -0.5 1. 1. ],[ 0. 1. -0.5 -0.5 1. 1. ],[ 1. 1. -0.5 -0.5 1. 1. ],9.390294,7.981602e-36,,1.000000
1,3_566feba6b937e400052d33b2_finance_congruent,566feba6b937e400052d33b2,experiment_3,congruent,finance,LC_discrete,[ 1. 1. -0.5 -0.5 1. 1. ],[1. 0.5 0. 0. 0.5 1. ],[ 0. 1. -0.5 -0.5 1. 1. ],[ 1. 1. -0.5 -0.5 1. 1. ],9.390294,6.192261e-05,,0.999993
2,3_566feba6b937e400052d33b2_finance_congruent,566feba6b937e400052d33b2,experiment_3,congruent,finance,LC_discrete_attention,[ 1. 1. -0.5 -0.5 1. 1. ],[ 1. 1. -0.5 -0.5 1. 1. ],[ 0. 1. -0.5 -0.5 1. 1. ],[ 1. 1. -0.5 -0.5 1. 1. ],9.390294,9.390294e+00,,0.000000
3,3_566feba6b937e400052d33b2_dampened_generic,566feba6b937e400052d33b2,experiment_3,generic,dampened,normative,[-1. 0.5 0. 1. 0. 0. ],[-1. 0.5 0. 1. 0. 0. ],[0.5 0. 0.5 0. 0. 0. ],,13.931326,9.939385e-01,,0.928654
4,3_566feba6b937e400052d33b2_dampened_generic,566feba6b937e400052d33b2,experiment_3,generic,dampened,LC_discrete,[-1. 0.5 0. 1. 0. 0. ],[-1. -1. 0. 1. 0. 0.],[0.5 0. 0.5 0. 0. 0. ],,13.931326,6.574633e-03,,0.999528
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4339,3_61717173748006894b2b54ff_neg_chain_generic_2,61717173748006894b2b54ff,experiment_3,generic_2,neg_chain,change_discrete,[-1 0 0 -1 0 0],[-0.5 -1. 0. -0.5 0.5 -0.5],[-0.5 1. 0. -0.5 0. 0. ],,13.931559,1.301234e+00,,0.906598
4340,3_6176966806de000024ed0ddf_crime_congruent,6176966806de000024ed0ddf,experiment_3,congruent,crime,change_discrete,[-1. 0.5 0.5 1. 0. 1. ],[ 1. 1. 1. 1. 0. -1.],[ 1. 1. 1. 0. 0. -1.],[-1. 0.5 0.5 1. 0. 1. ],13.728335,4.208897e-01,,0.969342
4341,3_6176966806de000024ed0ddf_dampened_generic,6176966806de000024ed0ddf,experiment_3,generic,dampened,change_discrete,[-1. -0.5 0. -1. 0. 0. ],[ 0. -1. 0.5 -1. 1. -1. ],[ 0. 0. -0.5 -0.5 0. 0.5],,13.931559,2.997061e-04,,0.999978
4342,3_6176966806de000024ed0ddf_finance_incongruent,6176966806de000024ed0ddf,experiment_3,incongruent,finance,change_discrete,[ 0. -0.5 0.5 -1. 1. 0. ],[ 0. -1. 1. -0.5 1. 0. ],[ 0.5 -1. -0.5 -0.5 1. 1. ],[ 0.5 1. 0. 0. -0.5 -1. ],13.753367,3.559470e-03,,0.999741


In [341]:
congruent = summary_3[summary_3.difficulty == 'congruent']
incongruent = summary_3[summary_3.difficulty == 'incongruent']

congruent_norm = congruent[congruent.model_name == 'normative'].information_gained
incongruent_norm = incongruent[incongruent.model_name == 'normative'].information_gained

congruent_lc = congruent[congruent.model_name == 'LC_discrete'].information_gained
incongruent_lc = incongruent[incongruent.model_name == 'LC_discrete'].information_gained

congruent_lc_a = congruent[congruent.model_name == 'LC_discrete_attention'].information_gained
incongruent_lc_a = incongruent[incongruent.model_name == 'LC_discrete_attention'].information_gained

congruent_change = congruent[congruent.model_name == 'change_discrete'].information_gained
incongruent_change = incongruent[incongruent.model_name == 'change_discrete'].information_gained

print('Information gained')
print('Normative')
print(f'Congruent: mean={congruent_norm.mean().round(4)}, std={congruent_norm.std().round(4)}')
print(f'Incongruent: mean={incongruent_norm.mean().round(4)}, std={incongruent_norm.std().round(4)}')
print(stats.ttest_ind(congruent_norm, incongruent_norm))

print('LC_discrete')
print(f'Congruent: mean={congruent_lc.mean().round(4)}, std={congruent_lc.std().round(4)}')
print(f'Incongruent: mean={incongruent_lc.mean().round(4)}, std={incongruent_lc.std().round(4)}')
print(stats.ttest_ind(congruent_norm, incongruent_norm))

print('LC_discrete_attention')
print(f'Congruent: mean={congruent_lc_a.mean().round(4)}, std={congruent_lc_a.std().round(4)}')
print(f'Incongruent: mean={incongruent_lc_a.mean().round(4)}, std={incongruent_lc_a.std().round(4)}')
print(stats.ttest_ind(congruent_lc_a, incongruent_lc_a))

print('Discrete Change')
print(f'Congruent: mean={congruent_change.mean().round(4)}, std={congruent_change.std().round(4)}')
print(f'Incongruent: mean={incongruent_change.mean().round(4)}, std={incongruent_change.std().round(4)}')
print(stats.ttest_ind(congruent_change, incongruent_change))

Information gained
Normative
Congruent: mean=0.9963, std=0.0243
Incongruent: mean=0.992, std=0.0316
Ttest_indResult(statistic=1.188732138968508, pvalue=0.2357298236523205)
LC_discrete
Congruent: mean=0.9953, std=0.0215
Incongruent: mean=0.9915, std=0.025
Ttest_indResult(statistic=1.188732138968508, pvalue=0.2357298236523205)
LC_discrete_attention
Congruent: mean=0.8817, std=0.27
Incongruent: mean=0.9117, std=0.2214
Ttest_indResult(statistic=-0.9414486846448392, pvalue=0.347429561901356)
Discrete Change
Congruent: mean=0.8789, std=0.2665
Incongruent: mean=0.911, std=0.211
Ttest_indResult(statistic=-1.0329441239542303, pvalue=0.30267852362438635)


In [342]:
# Select only data not links

judgements_arr = final_judgements[final_judgements.columns[6:]].to_numpy()

# Recover posteriors for each modes
posteriors_arr = {}
for model_name in summary.model_name.unique():
    posteriors_arr[model_name] = posteriors[model_name][posteriors[model_name].columns[6:]].to_numpy()


In [343]:
discretenans = np.argwhere(np.isnan(posteriors_arr['LC_discrete']))
temp=0

gennanidx = np.unique(discretenans[:, 1])

model_names = ['normative', 'LC_discrete', 'LC_discrete_attention', 'change_discrete']
nLL = np.zeros(len(model_names))
optim_results = [None for i in model_names]
for i, model in enumerate(model_names):
    selection = np.delete(judgements_arr, gennanidx, axis=1)
    dataset = np.delete(posteriors_arr[model], gennanidx, axis=1)
    #nLL[i] = softmax_neg_log_likelihood(temp, dataset, judgements_arr)
    optim_results[i] = minimize(softmax_neg_log_likelihood, 
                              0, 
                              args=(dataset, selection))

In [344]:
print('Uniform negative log likelihood (baseline) = 10487.1 \n')
for i, model in enumerate(model_names):
    print(f'Model name: {model}')
    print(f'Negative Log Likelihood = {optim_results[i].fun}')
    print(f'Temperature = {optim_results[i].x[0]}')
    print(f'Optim. result: {optim_results[i].message} \n')

Uniform negative log likelihood (baseline) = 10487.1 

Model name: normative
Negative Log Likelihood = 10446.786038761678
Temperature = 4.493196817730923
Optim. result: Optimization terminated successfully. 

Model name: LC_discrete
Negative Log Likelihood = 10462.94930681456
Temperature = 3.387430856637707
Optim. result: Optimization terminated successfully. 

Model name: LC_discrete_attention
Negative Log Likelihood = 10457.622909812366
Temperature = 4.119688694925813
Optim. result: Optimization terminated successfully. 

Model name: change_discrete
Negative Log Likelihood = 10159.79672505864
Temperature = 7.107513697001946
Optim. result: Optimization terminated successfully. 



In [518]:
# Reintroduce the log likelihoods to the datasets

parameters_per_model = {
    'normative' : 1,
    'LC_discrete': 1,
    'LC_discrete_attention': 1,
    'change_discrete': 1
}




summary = summary.sort_values('utid')

col_wo_nans = [posteriors['normative'].columns[i] for i in range(6, len(posteriors['normative'].columns)) if i not in nanidx]

summary['sample_BIC'] = np.nan
summary['sample_nLL'] = np.nan
summary['sample_temp'] = np.nan

summaries = [None for mpdel in model_names]
trial_lvl_nll = {}
for i, model in enumerate(model_names):
    utid = [posteriors[model].columns[i] for i in range(6, len(posteriors[model].columns)) if i not in nanidx]

    temp = optim_results[i].x[0]
    selection = np.delete(judgements_arr, nanidx, axis=1)
    dataset = np.delete(posteriors_arr[model], nanidx, axis=1)
    
    trial_lvl_nll[model] = softmax_neg_log_likelihood(temp, dataset, selection, return_selection=True)

    df = pd.DataFrame(index=posteriors[model].columns[6:].sort_values())
    df['nLL'] = np.nan
    df.loc[utid, 'nLL'] = trial_lvl_nll[model]



    summary.loc[summary.model_name == model, 'sample_nLL' ] = df['nLL'].to_numpy()
    summary.loc[summary.model_name == model, 'sample_temp' ] = temp

    nLL = summary[summary.model_name == model].sample_nLL
    n = (~nLL.isna()).sum()
    bic = parameters_per_model[model] * np.log(n) + 2 * nLL.sum()

    summary.loc[summary.model_name == model, 'sample_BIC'] = bic





In [519]:
summary.sample_BIC.nunique()

4

In [520]:
summary[['model_name', 'sample_BIC', 'sample_temp']].groupby('model_name').mean().sort_values('sample_BIC')

Unnamed: 0_level_0,sample_BIC,sample_temp
model_name,Unnamed: 1_level_1,Unnamed: 2_level_1
change_discrete,20365.36667,7.107514
LC_discrete,20932.887027,3.387431
normative,20939.211412,4.493197
LC_discrete_attention,20960.878054,4.119689


In [521]:
summary[summary.difficulty.isin(['generic', 'generic_2'])].groupby('model_name').mean().sort_values('sample_nLL')

Unnamed: 0_level_0,prior_entropy,posterior_entropy,model_specs,information_gained,sample_nLL,sample_temp,sample_BIC,part_nLL,part_temp,part_BIC
model_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
change_discrete,13.931559,1.550242,,0.888724,9.383367,7.107514,20365.36667,9.29566,11.718907,71.519525
normative,13.931326,0.035066,,0.997483,9.60629,4.493197,20939.211412,9.227425,1.459443,68.465633
LC_discrete_attention,13.931326,1.494304,,0.892738,9.63194,4.119689,20960.878054,8.73426,9.932897,67.840082
LC_discrete,13.931326,0.024315,,0.998255,9.644349,3.387431,20932.887027,9.268751,1.692588,67.830419


In [522]:
summary[summary.difficulty.isin(['congruent', 'incongruent', 'implausible'])].groupby('model_name').sum().sort_values('sample_nLL')

Unnamed: 0_level_0,prior_entropy,posterior_entropy,model_specs,information_gained,sample_nLL,sample_temp,sample_BIC,part_nLL,part_temp,part_BIC
model_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
change_discrete,8292.794879,766.257207,0.0,545.315767,5647.022114,4285.830759,12280320.0,5582.680269,3112.510327,49702.827927
LC_discrete,6222.163715,24.347932,0.0,598.67384,5814.373045,2042.620807,12622530.0,5592.380031,843.918207,47217.98816
LC_discrete_attention,6222.163715,571.93193,0.0,547.515698,5824.716853,2484.172283,12639410.0,5363.485075,4295.290855,47993.166573
normative,6222.163715,27.763321,0.0,600.363741,5826.272683,2709.397681,12626340.0,5667.547531,736.451699,47693.639092


In [523]:
summary[summary.experiment == 'experiment_1'].groupby('model_name').sum().sort_values('sample_nLL')

Unnamed: 0_level_0,prior_entropy,posterior_entropy,model_specs,information_gained,sample_nLL,sample_temp,sample_BIC,part_nLL,part_temp,part_BIC
model_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
change_discrete,1671.787062,220.40383,0.0,104.179529,1111.567863,852.901644,2443844.0,1092.933401,1002.154045,4809.107093
normative,1671.759094,4.769764,0.0,119.657623,1141.476966,539.183618,2512705.0,1048.354263,265.156554,4412.10379
LC_discrete_attention,1671.759094,224.663982,0.0,103.873467,1150.227935,494.362643,2515305.0,984.1041,1811.991302,4373.789891
LC_discrete,1671.759094,3.513318,0.0,119.747812,1155.619881,406.491703,2511946.0,1096.256654,154.845879,4584.400106


In [524]:
summary[summary.experiment == 'experiment_2'].groupby('model_name').sum().sort_values('sample_nLL')

Unnamed: 0_level_0,prior_entropy,posterior_entropy,model_specs,information_gained,sample_nLL,sample_temp,sample_BIC,part_nLL,part_temp,part_BIC
model_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
change_discrete,6678.365158,593.23124,0.0,437.05411,4553.712038,3440.036629,9856837.0,4539.689472,908.864464,40189.515774
LC_discrete,5437.59093,11.034828,0.0,480.028501,4674.6779,1639.516535,10131520.0,4580.013795,362.486638,38446.230829
LC_discrete_attention,5437.59093,484.434737,0.0,441.116097,4675.306468,1993.929328,10145060.0,4447.572015,2969.530786,39452.576122
normative,5437.59093,19.717606,0.0,482.268062,4676.470906,2174.70726,10134580.0,4547.680465,593.393756,38317.44372


In [525]:
summary[summary.experiment == 'experiment_3'].groupby('model_name').sum().sort_values('sample_nLL')

Unnamed: 0_level_0,prior_entropy,posterior_entropy,model_specs,information_gained,sample_nLL,sample_temp,sample_BIC,part_nLL,part_temp,part_BIC
model_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
change_discrete,6671.585582,698.288502,0.0,431.558588,4513.908306,3425.821602,9816107.0,4439.861124,6861.723767,39248.135591
LC_discrete,5841.644045,21.519739,0.0,480.056261,4632.651526,1632.741673,10089650.0,4392.916113,1144.105882,36949.449371
normative,5841.644045,20.21307,0.0,480.222297,4648.162706,2165.720866,10092700.0,4528.358958,582.81238,38032.992134
LC_discrete_attention,5841.644045,584.582047,0.0,433.718514,4651.409496,1985.689951,10103140.0,4150.456438,4311.357818,36933.560234


In [526]:
summary[summary.model_name == 'normative'].groupby('pid').sum().sort_values('sample_nLL')

Unnamed: 0_level_0,prior_entropy,posterior_entropy,model_specs,information_gained,sample_nLL,sample_temp,sample_BIC,part_nLL,part_temp,part_BIC
pid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
5c5af8dfe7f42600017cde68,27.862652,5.462265e-13,0.0,2.000000,14.830574,8.986394,41878.422825,19.313255,-0.000256,81.253020
5c665e1948985d0001f09675,27.862652,1.690113e-97,0.0,2.000000,14.831340,8.986394,41878.422825,19.313255,-0.000256,81.253020
5c405aee40de0b0001437db0,27.862652,9.926065e-01,0.0,1.928750,14.831342,8.986394,41878.422825,4.130263,33.737373,20.521052
5d2a4351379e510001491ee4,27.862652,4.525247e-05,0.0,1.999997,14.831342,8.986394,41878.422825,19.313255,-0.000256,81.253020
5ed2a11ff7ce322295acbf2e,27.862652,4.293558e-02,0.0,1.996918,19.310773,8.986394,41878.422825,19.313255,-0.000256,81.253020
...,...,...,...,...,...,...,...,...,...,...
60199372bd48490e12a0d92e,47.829902,8.840850e-04,0.0,3.999912,38.649078,17.972787,83756.845649,38.626510,-0.001023,325.012079
60aa4b9356c591511cc09f5f,48.310158,2.955075e-45,0.0,4.000000,38.649078,17.972787,83756.845649,38.626510,-0.001023,325.012079
6021d74808e9e926a4ff44c6,47.970930,4.532577e-22,0.0,4.000000,38.649078,17.972787,83756.845649,38.626510,-0.001023,325.012079
60649bf146ce5a76f4bef1ce,45.124767,2.928355e-29,0.0,4.000000,38.649078,17.972787,83756.845649,38.626510,-0.001023,325.012079


In [527]:
summary.groupby('model_name').mean()

Unnamed: 0_level_0,prior_entropy,posterior_entropy,model_specs,information_gained,sample_nLL,sample_temp,sample_BIC,part_nLL,part_temp,part_BIC
model_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
LC_discrete,11.925409,0.033304,,0.997075,9.652167,3.387431,20932.887027,9.288917,1.532692,73.782362
LC_discrete_attention,11.925409,1.192333,,0.902035,9.647278,4.119689,20960.878054,8.823326,8.372818,74.364573
change_discrete,13.832171,1.397342,,0.899069,9.373101,7.107514,20365.36667,9.274847,8.078032,77.575284
normative,11.925409,0.041161,,0.996453,9.637303,4.493197,20939.211412,9.322646,1.327222,74.366979


In [528]:
posteriors_wo_links = {}
for i, model in enumerate(model_names):
    posteriors_wo_links[model] = posteriors[model][posteriors[model].columns[6:]]
    sorted_cols = posteriors_wo_links[model].columns.sort_values()
    posteriors_wo_links[model] = posteriors_wo_links[model][sorted_cols]
    print(sorted_cols.shape)


(1086,)
(1086,)
(1086,)
(1086,)


In [529]:
# Data per participants

final_judgements_numpy = final_judgements.to_numpy()
utids = summary.utid.unique()

pid_range = 26

pid_posteriors = {}
pid_selections = {}
pid_utids = {}


for i, model in enumerate(model_names):
    pid_posteriors[model] = {}
    pid_selections[model] = {}
    current_id = None
    current_col_idx = []
    length_1 = 0
    length_2 = 0

    all_utids = []
    collected_utids = []
    all_current_idx = []
    for j, utid in enumerate(posteriors_wo_links[model].columns):
        
        all_utids.append(utid)

       
        if not current_id:
            current_id = utid[:pid_range]
                    
        if utid[:pid_range] == current_id:
            current_col_idx.append(utid)
            collected_utids.append(utid)
        else:
            pid_posteriors[model][current_id] = posteriors_wo_links[model][current_col_idx]
            pid_selections[model][current_id] = final_judgements[current_col_idx]
            pid_utids[current_id] = current_col_idx

            current_col_idx = [utid]
            current_id = utid[:pid_range]

        if utid == posteriors_wo_links[model].columns[-1]:
            pid_posteriors[model][current_id] = posteriors_wo_links[model][current_col_idx]
            pid_selections[model][current_id] = final_judgements[current_col_idx]
            pid_utids[current_id] = current_col_idx
            


In [530]:
summary

Unnamed: 0,utid,pid,experiment,difficulty,scenario,model_name,ground_truth,posterior_map,posterior_judgement,prior_judgement,prior_entropy,posterior_entropy,model_specs,information_gained,sample_nLL,sample_temp,sample_BIC,part_nLL,part_temp,part_BIC
2898,1_56da8da8c5b248000ae2adaf_ccause_generic,56da8da8c5b248000ae2adaf,experiment_1,generic,ccause,normative,[1 1 0 0 0 0],[1. 1. 0. 0. 0. 0.],[0.5 0.5 0. 0. 0. 0. ],,13.931326,5.460554e-43,,1.000000,9.662270,4.493197,20939.211412,9.656627,-0.000128,40.62651
3258,1_56da8da8c5b248000ae2adaf_ccause_generic,56da8da8c5b248000ae2adaf,experiment_1,generic,ccause,change_discrete,[1 1 0 0 0 0],[ 0.5 1. -0.5 -1. 1. 0.5],[0.5 0.5 0. 0. 0. 0. ],,13.931559,1.072596e-01,,0.992301,9.668182,7.107514,20365.366670,9.656627,-0.000128,42.62651
2899,1_56da8da8c5b248000ae2adaf_ccause_generic,56da8da8c5b248000ae2adaf,experiment_1,generic,ccause,LC_discrete,[1 1 0 0 0 0],[1. 1. 0. 0.5 0. 0.5],[0.5 0.5 0. 0. 0. 0. ],,13.931326,8.608257e-33,,1.000000,9.658456,3.387431,20932.887027,9.656627,-0.000128,40.62651
2900,1_56da8da8c5b248000ae2adaf_ccause_generic,56da8da8c5b248000ae2adaf,experiment_1,generic,ccause,LC_discrete_attention,[1 1 0 0 0 0],[1. 1. 0. 0. 0. 0.],[0.5 0.5 0. 0. 0. 0. ],,13.931326,2.739902e-18,,1.000000,9.660495,4.119689,20960.878054,9.656627,-0.000128,42.62651
2902,1_56da8da8c5b248000ae2adaf_pos_chain_generic_2,56da8da8c5b248000ae2adaf,experiment_1,generic_2,pos_chain,LC_discrete,[1 0 0 1 0 0],[1. 1. 0. 1. 0. 0.],[1. 0.5 0. 0.5 0. 0.5],,13.931326,6.197995e-33,,1.000000,9.658456,3.387431,20932.887027,9.656627,-0.000128,40.62651
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1441,3_6176966806de000024ed0ddf_finance_incongruent,6176966806de000024ed0ddf,experiment_3,incongruent,finance,LC_discrete,[ 0. -0.5 0.5 -1. 1. 0. ],[ 0. -0.5 0.5 -0.5 0.5 0. ],[ 0.5 -1. -0.5 -0.5 1. 1. ],[ 0.5 1. 0. 0. -0.5 -1. ],10.350805,2.351528e-05,,0.999998,9.658456,3.387431,20932.887027,9.656627,-0.000256,81.25302
1445,3_6176966806de000024ed0ddf_neg_chain_generic_2,6176966806de000024ed0ddf,experiment_3,generic_2,neg_chain,LC_discrete_attention,[-1 0 0 -1 0 0],[-1. 0.5 0. -1. 0. 0. ],[-1. 1. 1. -0.5 0. -1. ],,13.931326,4.645406e+00,,0.666550,9.657437,4.119689,20960.878054,9.656627,-0.000256,85.25302
1444,3_6176966806de000024ed0ddf_neg_chain_generic_2,6176966806de000024ed0ddf,experiment_3,generic_2,neg_chain,LC_discrete,[-1 0 0 -1 0 0],[-1. 1. 0. -1. 0. 0.],[-1. 1. 1. -0.5 0. -1. ],,13.931326,3.380795e-17,,1.000000,9.658456,3.387431,20932.887027,9.656627,-0.000256,81.25302
1443,3_6176966806de000024ed0ddf_neg_chain_generic_2,6176966806de000024ed0ddf,experiment_3,generic_2,neg_chain,normative,[-1 0 0 -1 0 0],[-1. 0. 0. -1. 0. 0.],[-1. 1. 1. -0.5 0. -1. ],,13.931326,2.979968e-29,,1.000000,9.662270,4.493197,20939.211412,9.656627,-0.000256,81.25302


In [531]:
pid_posteriors['LC_discrete_attention']['3_615373f88839471e0ab77399']

Unnamed: 0,3_615373f88839471e0ab77399_crime_congruent,3_615373f88839471e0ab77399_dampened_generic,3_615373f88839471e0ab77399_finance_incongruent,3_615373f88839471e0ab77399_pos_chain_generic_2
0,0.0,0.0,0.0,0.0
1,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,0.0
...,...,...,...,...
15620,0.0,0.0,0.0,0.0
15621,0.0,0.0,0.0,0.0
15622,0.0,0.0,0.0,0.0
15623,0.0,0.0,0.0,0.0


In [532]:
optim_results_part = {}
temp=0

broken_trials = {}

model_names = ['normative', 'LC_discrete', 'LC_discrete_attention', 'change_discrete']
for i, model in enumerate(model_names):
    length_out = 0
    length_in = 0
    optim_results_part[model] = {}
    broken_trials[model] = []

    part_posteriors = pid_posteriors[model]
    part_selection = pid_selections[model]
    for j, part in enumerate(part_posteriors.keys()):

        local_utids = part_posteriors[part].columns

        length_in += part_posteriors[part].shape[1]

        data = part_posteriors[part].to_numpy()
        raw_select = part_selection[part].to_numpy()
        
        discretenans = np.argwhere(np.isnan(data))
        nanidx = np.unique(discretenans[:, 1])

        if nanidx.size > 0:
            broken_trials[model] = broken_trials[model] + [pid_utids[part][idx] for idx in nanidx]

        selection = np.delete(raw_select, nanidx, axis=1).astype(bool)

        dataset = np.delete(data, nanidx, axis=1)
        optim_results_part[model][part] = minimize(softmax_neg_log_likelihood, 
                                                   0, 
                                                   args=(dataset, selection),
                                                   bounds=[(-2000, 2000)])

        length_out += softmax_neg_log_likelihood(optim_results_part[model][part].x[0], dataset, selection, return_selection=True).size
    print(model, length_in, length_out)

  softmax_unnorm = np.exp(dataset * temp)
  softmax = softmax_unnorm / softmax_unnorm.sum(axis=0).reshape((1, dataset.shape[1]))
  log_likelihood = np.log(judgements_likelihood)


normative 1086 1086


  df = fun(x) - f0


LC_discrete 1086 1084
LC_discrete_attention 1086 1086
change_discrete 1086 1086


In [533]:
old_summary = summary.copy()

In [534]:
summary = summary.sort_values('utid')
summary['part_nLL'] = np.nan
summary['part_temp'] = np.nan
summary['part_BIC'] = np.nan

df_models = {}

model_names = ['normative', 'LC_discrete', 'LC_discrete_attention', 'change_discrete']
for model in model_names:

    part_posteriors = pid_posteriors[model]
    part_selection = pid_selections[model]

    trials_nll = []
    part_temps = []
    part_bics = []
    part_utids = []

    utid = [posteriors[model].columns[i] for i in range(6, len(posteriors[model].columns.sort_values())) if posteriors[model].columns[i] not in broken_trials[model]]


    for part in optim_results_part[model].keys():
        if '1_test' in part:
            part_id = 'test_625034503'
        else:
            part_id = part[2:]

        local_utids = [utid_loc for utid_loc in part_posteriors[part].columns.to_list() if utid_loc not in broken_trials[model]]

        temp = optim_results_part[model][part].x[0]

        data = part_posteriors[part].to_numpy()
        raw_select = part_selection[part].to_numpy()
        
        discretenans = np.argwhere(np.isnan(data))
        nanidx = np.unique(discretenans[:, 1])

        selection = np.delete(raw_select, nanidx, axis=1).astype(bool)
        dataset = np.delete(data, nanidx, axis=1)

        part_trial_lvl_nLL = softmax_neg_log_likelihood(temp, dataset, selection, return_selection=True)
        #print(part_trial_lvl_nLL.size)
        trials_nll.append(part_trial_lvl_nLL.flatten())


        part_temps.append(temp*np.ones(part_trial_lvl_nLL.size))
        #summary.loc[summary.pid == part_id & summary.index[summary.model_name == model], 'part_temp'] = temp
        
        bic = parameters_per_model[model] * part_trial_lvl_nLL.size + 2 * np.sum(part_trial_lvl_nLL)
        #summary.loc[summary.pid == part_id & summary.index[summary.model_name == model], 'part_BIC'] = bic
        part_bics.append(bic * np.ones(part_trial_lvl_nLL.size))

        part_utids.append(local_utids)
        
    
    utids = np.concatenate(part_utids).flatten()
    trials_nll_as_flat_array = np.concatenate(trials_nll).flatten()
    part_temps_as_flat_array = np.concatenate(part_temps).flatten()
    part_bics_as_flat_array = np.concatenate(part_bics).flatten()
    df = pd.DataFrame(index=posteriors[model].columns[6:].sort_values())
    df['nLL'] = np.nan
    df['temps'] = np.nan
    df['bics'] = np.nan
    df.loc[utids, 'nLL'] = trials_nll_as_flat_array
    df.loc[utids, 'temps'] = part_temps_as_flat_array
    df.loc[utids, 'bics'] = part_bics_as_flat_array

    df_models[model] = df

    summary.loc[summary.model_name == model, 'part_nLL' ] = df['nLL'].to_numpy()
    print((summary[summary.model_name == model].utid == df.index).sum())
    summary.loc[summary.model_name == model, 'part_temp' ] = df['temps'].to_numpy()
    summary.loc[summary.model_name == model, 'part_BIC' ] = df['bics'].to_numpy()

    


1086
1086
1086
1086


In [535]:
pid_utids['3_615373f88839471e0ab77399']

['3_615373f88839471e0ab77399_crime_congruent',
 '3_615373f88839471e0ab77399_dampened_generic',
 '3_615373f88839471e0ab77399_finance_incongruent',
 '3_615373f88839471e0ab77399_pos_chain_generic_2']

In [536]:
summary[summary.model_name == 'LC_discrete_attention'][['utid', 'part_nLL', 'model_name', 'part_temp', 'part_BIC']].tail(24)

Unnamed: 0,utid,part_nLL,model_name,part_temp,part_BIC
1376,3_615230e2db85b00b50240ab8_crime_congruent,9.656627,LC_discrete_attention,-0.000256,81.25302
1379,3_615230e2db85b00b50240ab8_dampened_generic,9.656627,LC_discrete_attention,-0.000256,81.25302
1382,3_615230e2db85b00b50240ab8_finance_incongruent,9.656627,LC_discrete_attention,-0.000256,81.25302
1385,3_615230e2db85b00b50240ab8_pos_chain_generic_2,9.656627,LC_discrete_attention,-0.000256,81.25302
1388,3_615373f88839471e0ab77399_crime_congruent,9.656627,LC_discrete_attention,-0.000256,81.25302
1391,3_615373f88839471e0ab77399_dampened_generic,9.656627,LC_discrete_attention,-0.000256,81.25302
1394,3_615373f88839471e0ab77399_finance_incongruent,9.656627,LC_discrete_attention,-0.000256,81.25302
1397,3_615373f88839471e0ab77399_pos_chain_generic_2,9.656627,LC_discrete_attention,-0.000256,81.25302
1400,3_615ec387ec57223c894f6fc2_crime_congruent,1.386293,LC_discrete_attention,8.557953,66.438062
1403,3_615ec387ec57223c894f6fc2_dampened_generic,9.944246,LC_discrete_attention,8.557953,66.438062


In [537]:
df_models['LC_discrete_attention'].tail(25)

Unnamed: 0,nLL,temps,bics
3_6151d444ac15808872e32dc8_pos_chain_generic_2,9.652534,2.303739,81.248103
3_615230e2db85b00b50240ab8_crime_congruent,9.656627,-0.000256,81.25302
3_615230e2db85b00b50240ab8_dampened_generic,9.656627,-0.000256,81.25302
3_615230e2db85b00b50240ab8_finance_incongruent,9.656627,-0.000256,81.25302
3_615230e2db85b00b50240ab8_pos_chain_generic_2,9.656627,-0.000256,81.25302
3_615373f88839471e0ab77399_crime_congruent,9.656627,-0.000256,81.25302
3_615373f88839471e0ab77399_dampened_generic,9.656627,-0.000256,81.25302
3_615373f88839471e0ab77399_finance_incongruent,9.656627,-0.000256,81.25302
3_615373f88839471e0ab77399_pos_chain_generic_2,9.656627,-0.000256,81.25302
3_615ec387ec57223c894f6fc2_crime_congruent,1.386293,8.557953,66.438062


In [538]:
summary[summary.difficulty.isin(['generic', 'generic_2'])].groupby('model_name').mean().sort_values('part_BIC')

Unnamed: 0_level_0,prior_entropy,posterior_entropy,model_specs,information_gained,sample_nLL,sample_temp,sample_BIC,part_nLL,part_temp,part_BIC
model_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
LC_discrete_attention,13.931326,1.494304,,0.892738,9.63194,4.119689,20960.878054,8.73426,9.932897,64.349399
LC_discrete,13.931326,0.024315,,0.998255,9.644349,3.387431,20932.887027,9.268751,1.692588,67.830419
change_discrete,13.931559,1.550242,,0.888724,9.383367,7.107514,20365.36667,9.29566,11.718907,68.028842
normative,13.931326,0.035066,,0.997483,9.60629,4.493197,20939.211412,9.227425,1.459443,68.465633


In [539]:
summary[summary.difficulty.isin(['congruent', 'incongruent', 'implausible'])].groupby('model_name').mean().sort_values('part_BIC')

Unnamed: 0_level_0,prior_entropy,posterior_entropy,model_specs,information_gained,sample_nLL,sample_temp,sample_BIC,part_nLL,part_temp,part_BIC
model_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
LC_discrete_attention,10.318679,0.950053,,0.909495,9.659564,4.119689,20960.878054,8.894668,7.123202,75.593974
change_discrete,13.752562,1.27497,,0.907347,9.364879,7.107514,20365.36667,9.258176,5.161709,78.429234
LC_discrete,10.318679,0.040512,,0.99613,9.658427,3.387431,20932.887027,9.305125,1.40419,78.565704
normative,10.318679,0.046042,,0.995628,9.662144,4.493197,20939.211412,9.398918,1.221313,79.093929


In [515]:
summary[summary.difficulty.isin(['implausible'])].groupby('model_name').mean().sort_values('part_BIC')

Unnamed: 0_level_0,prior_entropy,posterior_entropy,model_specs,information_gained,sample_nLL,sample_temp,sample_BIC,part_nLL,part_temp,part_BIC
model_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
normative,10.365001,0.033846,,0.996786,9.662075,4.493197,20939.211412,9.34352,1.22602,79.168272
LC_discrete,10.365001,0.006407,,0.999377,9.65843,3.387431,20932.887027,9.607463,0.748939,79.686178
LC_discrete_attention,10.365001,0.848413,,0.920947,9.659696,4.119689,20967.868311,9.137015,6.135394,81.513587
change_discrete,13.755299,1.172199,,0.914878,9.272491,7.107514,20372.356926,9.48652,1.877819,83.03619


In [516]:
old_summary.head()

Unnamed: 0,utid,pid,experiment,difficulty,scenario,model_name,ground_truth,posterior_map,posterior_judgement,prior_judgement,prior_entropy,posterior_entropy,model_specs,information_gained,sample_nLL,sample_temp,sample_BIC,part_nLL,part_temp,part_BIC
2898,1_56da8da8c5b248000ae2adaf_ccause_generic,56da8da8c5b248000ae2adaf,experiment_1,generic,ccause,normative,[1 1 0 0 0 0],[1. 1. 0. 0. 0. 0.],[0.5 0.5 0. 0. 0. 0. ],,13.931326,5.460554000000001e-43,,1.0,9.66227,4.493197,20939.211412,9.656627,-0.000128,40.62651
3258,1_56da8da8c5b248000ae2adaf_ccause_generic,56da8da8c5b248000ae2adaf,experiment_1,generic,ccause,change_discrete,[1 1 0 0 0 0],[ 0.5 1. -0.5 -1. 1. 0.5],[0.5 0.5 0. 0. 0. 0. ],,13.931559,0.1072596,,0.992301,9.668182,7.107514,20372.356926,9.656627,-0.000128,40.62651
2899,1_56da8da8c5b248000ae2adaf_ccause_generic,56da8da8c5b248000ae2adaf,experiment_1,generic,ccause,LC_discrete,[1 1 0 0 0 0],[1. 1. 0. 0.5 0. 0.5],[0.5 0.5 0. 0. 0. 0. ],,13.931326,8.608257000000001e-33,,1.0,9.658456,3.387431,20932.887027,9.656627,-0.000128,40.62651
2900,1_56da8da8c5b248000ae2adaf_ccause_generic,56da8da8c5b248000ae2adaf,experiment_1,generic,ccause,LC_discrete_attention,[1 1 0 0 0 0],[1. 1. 0. 0. 0. 0.],[0.5 0.5 0. 0. 0. 0. ],,13.931326,2.739902e-18,,1.0,9.660495,4.119689,20967.868311,9.656627,-0.000128,40.62651
2902,1_56da8da8c5b248000ae2adaf_pos_chain_generic_2,56da8da8c5b248000ae2adaf,experiment_1,generic_2,pos_chain,LC_discrete,[1 0 0 1 0 0],[1. 1. 0. 1. 0. 0.],[1. 0.5 0. 0.5 0. 0.5],,13.931326,6.197995e-33,,1.0,9.658456,3.387431,20932.887027,9.656627,-0.000128,40.62651


In [517]:
best_fit = {}

pids = []
for pid in summary.pid.unique():
    pids.append(pid)

pids = pd.Series(pids)
cols_interest = ['utid', 'pid', 'model_name', 'posterior_map', 'posterior_judgement', 'sample_BIC', 'sample_temp', 'part_nLL', 'part_temp', 'part_BIC']
summary[summary.pid == pids[250]].sort_values('model_name')[['utid','model_name', 'part_nLL', 'part_temp', 'part_BIC']]

Unnamed: 0,utid,model_name,part_nLL,part_temp,part_BIC
823,3_602d29200724e36ae7a9def2_crime_congruent,LC_discrete,9.9482,8.573676,66.410926
826,3_602d29200724e36ae7a9def2_dampened_generic,LC_discrete,9.934539,8.573676,66.410926
829,3_602d29200724e36ae7a9def2_finance_incongruent,LC_discrete,1.374524,8.573676,66.410926
832,3_602d29200724e36ae7a9def2_pos_chain_generic_2,LC_discrete,9.9482,8.573676,66.410926
824,3_602d29200724e36ae7a9def2_crime_congruent,LC_discrete_attention,0.693146,9.656565,52.171431
827,3_602d29200724e36ae7a9def2_dampened_generic,LC_discrete_attention,10.349711,9.656565,52.171431
830,3_602d29200724e36ae7a9def2_finance_incongruent,LC_discrete_attention,0.693146,9.656565,52.171431
833,3_602d29200724e36ae7a9def2_pos_chain_generic_2,LC_discrete_attention,10.349711,9.656565,52.171431
4136,3_602d29200724e36ae7a9def2_crime_congruent,change_discrete,1.826643,9.896022,54.720901
4137,3_602d29200724e36ae7a9def2_dampened_generic,change_discrete,0.580626,9.896022,54.720901


In [450]:
summary[summary.pid == '615373f88839471e0ab77399'].part_temp.unique()

array([-2.55795385e-04,  8.55811011e+00])