In [1]:
# On terminal: conda activate python38

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import ipynb.fs.defs.functions as fct
import pickle

import warnings
warnings.filterwarnings("error")

In [2]:
# Load IDs
with open('uniqueIDs.pkl', 'rb') as f:
    uniqueIDs = pickle.load(f)

# Fit mod3 for each participant

In [3]:
# Model settings

# Functions
value_fct = fct.rescorla_wagner_2LR_FB
dec_fct = fct.my_softmax

# Store everything
mod_info = {}
mod_info['name'] = 'model3'
mod_info['value_fct'] = value_fct.__name__
mod_info['dec_fct'] = dec_fct.__name__
mod_info['param_names'] = ['v0', 'alpha_rew', 'alpha_pun', 'beta']
#print(mod['value_fct'].__name__)

# save
all_users_folder = 'data/all_users/mod3/'
file_name = all_users_folder+'mod_parameters.pkl'
with open(file_name, 'wb') as f:
    pickle.dump(mod_info, f)

In [4]:
run_ = True

# Folder
all_users_folder = 'data/all_users/mod3/'
    
if run_:
    
    # Parameter range for initial guess 
    # order of mod['param_names']:
    # ['v0', 'alpha_rew', 'alpha_pun', 'beta']
    param_lower_bound = [-10, 0, 0, 0]
    param_upper_bound = [ 10, 1, 1, 20]

    # Fit
    all_users = {}
    p_hit_per_trial = pd.DataFrame([])
    ev_per_trial = pd.DataFrame([])
    
    for n_part,ID in enumerate(uniqueIDs): 
        
        print(ID)

        # Get data
        user_folder = 'data/user_' + ID + '/'
        df2_cf = pd.read_pickle(user_folder + 'df2_cf.pkl')
        isHit_all_cues, fbs_all_cues, trialNo_all_cues = fct.extract_hits_fbs(df2_cf)

        # Create a new Model object
        mod = fct.Model(mod_name = mod_info['name'],
                     value_fct = value_fct, 
                     dec_fct = dec_fct, 
                     param_names = mod_info['param_names'])

        # Input data to model
        mod.set_data(ID, fbs_all_cues, isHit_all_cues, trialNo_all_cues)

        # Fit model
        mod.fit(param_lower_bound, param_upper_bound, n_iterations=5)

        # Nested dictionnary user data
        all_users[n_part] = {}
        all_users[n_part]['ID']=mod.ID
        all_users[n_part]['nLL']=mod.nLL
        all_users[n_part]['Ntrials']=mod.Ntrials
        all_users[n_part]['Nparams']=len(mod.param_names)
        for i in range(0,len(mod.param_names)):
            all_users[n_part][mod.param_names[i]]=mod.param_values[i]
        
        # Concatenated model predictions: p hit
        tmp = pd.DataFrame(mod.p_hit).transpose()
        tmp.columns = tmp.columns+1
        tmp = tmp.reset_index().rename(columns={'index': 'Cue'})
        tmp.insert(0,'ID',ID)
        p_hit_per_trial = pd.concat([tmp, p_hit_per_trial], axis=0)
        
        # Concatenated model predictions: EVs
        tmp = pd.DataFrame(mod.v).transpose()
        tmp.columns = tmp.columns+1
        tmp = tmp.reset_index().rename(columns={'index': 'Cue'})
        tmp.insert(0,'ID',ID)
        ev_per_trial = pd.concat([tmp, ev_per_trial], axis=0)
    
    # Save mod LLs and parameter values
    mod_fit = pd.DataFrame(all_users).transpose()
    mod_fit.to_pickle(all_users_folder+'mod_param_fits.pkl')
    
    # Save mod predictions
    p_hit_per_trial = p_hit_per_trial.sort_values(by='ID').reset_index(drop=True)
    p_hit_per_trial.to_pickle(all_users_folder+'mod_p_hit_per_trial.pkl')
    ev_per_trial = ev_per_trial.sort_values(by='ID').reset_index(drop=True)
    ev_per_trial.to_pickle(all_users_folder+'mod_ev_per_trial.pkl')
    

001
003
006
007
008
009
010
012
014
023
028
032
033
034
035
036
038
044
045
049
050
054
055
057
058
060
072
073
075
076
077
078
080
084
085
086
088
090
093
096
100
101
103
104
106
107
108
109
112
114
115
118
119
120
121
122
130
132
133
136
137
138
142
143
145
147
149
151
155
158
159
165
170
172
173
175
176
177
178
181
185
186
187
189
199
204
207
208
211
213
217
221
229
230
236
237
239
240
244
251
252
258
259
263
269
271
280
282
283
284
286
287
291
292
293
301
302
303
306
313
314
315
319
321
322
325
326
327
329
335
336
337
339
341
345
349
351
360
361
362
368
375
376
381
384
390
391
393
395
397
400
405
406
412
413
414
421
422
423
427
437
438
440
446
450
453
462
469
470
471


In [5]:
mod.v

{'HR': array([-0.80422084,  4.9532101 ,  4.99962281,  2.97453916,  4.98367204,
         2.96181867,  4.98356949,  3.77180091,  1.02234448,  1.00018013,
         4.96775602,  4.99974007,  3.78469669,  4.99020301,  1.03216645,
        -0.18944379, -0.35359439, -0.35359439, -0.35359439, -0.48450187,
        -0.48450187, -0.48450187, -0.48450187, -0.48450187, -0.48450187,
        -0.48450187, -0.48450187, -0.48450187]),
 'LR': array([-0.80422084, -0.84386926, -0.84386926, -0.84386926, -0.84386926,
        -0.84386926, -0.84386926, -0.84386926, -0.84386926, -0.84386926,
        -0.84386926, -0.84386926, -0.84386926, -0.84386926, -0.84386926,
         0.98513591,  4.96763474,  4.99973909,  1.03224333,  1.00025993,
         0.59517528,  0.99673656, -0.21769856,  0.9901837 ,  0.9901837 ,
         0.58713965, -0.54434555,  4.95530505]),
 'HP': array([-0.80422084, -1.65393327, -1.65393327, -1.65393327, -1.65393327,
        -1.65393327, -1.65393327, -1.65393327, -1.65393327, -1.65393327,
        

In [6]:
mod.PEs

{'HR': array([ 5.80422084,  0.0467899 , -9.99962281,  2.02546084, -9.98367204,
         2.03818133, -5.98356949, -2.77180091, -0.02234448,  3.99981987,
         0.03224398, -5.99974007,  1.21530331, -3.99020301, -6.03216645,
        -0.81055621,         nan,         nan, -0.64640561,         nan,
                nan,         nan,         nan,         nan,         nan,
                nan,         nan,         nan]),
 'LR': array([-0.19577916,         nan,         nan,         nan,         nan,
                nan,         nan,         nan,         nan,         nan,
                nan,         nan,         nan,         nan,  1.84386926,
         4.01486409,  0.03236526, -3.99973909, -0.03224333, -2.00025993,
         0.40482472, -5.99673656,  1.21769856,         nan, -1.9901837 ,
        -5.58713965,  5.54434555,         nan]),
 'HP': array([-4.19577916,         nan,         nan,         nan,         nan,
                nan,         nan,         nan,         nan,         nan,
        