In [1]:
# On terminal: conda activate python38

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import ipynb.fs.defs.functions as fct
import pickle

import warnings
warnings.filterwarnings("error")

In [3]:
# Load IDs
with open('uniqueIDs.pkl', 'rb') as f:
    uniqueIDs = pickle.load(f)

# Fit mod8 for each participant

In [4]:
# Model settings

# Functions
value_fct = fct.rescorla_wagner_shrinking_alpha
dec_fct = fct.my_softmax_shrinking_press_bias

# Store everything
mod_info = {}
mod_info['name'] = 'model8'
mod_info['value_fct'] = value_fct.__name__
mod_info['dec_fct'] = dec_fct.__name__
mod_info['param_names'] = ['v0', 'alpha_t', 'beta', 'pi_t']
#print(mod['value_fct'].__name__)

# save
all_users_folder = 'data/all_users/mod8/'
file_name = all_users_folder+'mod_parameters.pkl'
with open(file_name, 'wb') as f:
    pickle.dump(mod_info, f)

In [5]:
uniqueIDs = ['001', '003']

In [None]:
run_ = True

# Folder
all_users_folder = 'data/all_users/mod8/'
    
if run_:
    
    # Parameter range for initial guess 
    # order of mod['param_names']:
    # ['v0', 'alpha_t', 'beta', 'pi_t']
    param_lower_bound = [-5, 0, 0, -10]
    param_upper_bound = [5, 1, 15, 10]

    # Fit
    all_users = {}
    p_hit_per_trial = pd.DataFrame([])
    ev_per_trial = pd.DataFrame([])
    trialsNos = pd.DataFrame([])
    hits = pd.DataFrame([])
    fbs = pd.DataFrame([])
    PEs = pd.DataFrame([])
    shrink_PIs = pd.DataFrame([])
    
    for n_part,ID in enumerate(uniqueIDs): 

        # Get data
        user_folder = 'data/user_' + ID + '/'
        df2_cf = pd.read_pickle(user_folder + 'df2_cf.pkl')
        isHit_all_cues, fbs_all_cues, trialNo_all_cues = fct.extract_hits_fbs(df2_cf)

        # Create a new Model object
        mod = fct.Model(mod_name = mod_info['name'],
                     value_fct = value_fct, 
                     dec_fct = dec_fct, 
                     param_names = mod_info['param_names'])

        # Input data to model
        mod.set_data(ID, fbs_all_cues, isHit_all_cues, trialNo_all_cues)

        # Fit model
        mod.fit(param_lower_bound, param_upper_bound, n_iterations=5)

        # Nested dictionnary user data
        all_users[n_part] = {}
        all_users[n_part]['ID']=mod.ID
        all_users[n_part]['nLL']=mod.nLL
        all_users[n_part]['Ntrials']=mod.Ntrials
        all_users[n_part]['Nparams']=len(mod.param_names)
        for i in range(0,len(mod.param_names)):
            all_users[n_part][mod.param_names[i]]=mod.param_values[i]
        
        # Concatenated model predictions: p hit
        tmp = pd.DataFrame(mod.p_hit).transpose()
        tmp.columns = tmp.columns+1
        tmp = tmp.reset_index().rename(columns={'index': 'Cue'})
        tmp.insert(0,'ID',ID)
        p_hit_per_trial = pd.concat([tmp, p_hit_per_trial], axis=0)
        
        # Concatenated model predictions: EVs
        tmp = pd.DataFrame(mod.v).transpose()
        tmp.columns = tmp.columns+1
        tmp = tmp.reset_index().rename(columns={'index': 'Cue'})
        tmp.insert(0,'ID',ID)
        ev_per_trial = pd.concat([tmp, ev_per_trial], axis=0)
        
        # Concatenated trial numbers (of each cue)
        tmp = pd.DataFrame(mod.trialNo_all_cues).transpose()
        tmp.columns = tmp.columns+1
        tmp = tmp.reset_index().rename(columns={'index': 'Cue'})
        tmp.insert(0,'ID',ID)
        trialsNos = pd.concat([tmp, trialsNos], axis=0)
        
        # Concatenated fbs (of each cue)
        tmp = pd.DataFrame(mod.fbs_all_cues).transpose()
        tmp.columns = tmp.columns+1
        tmp = tmp.reset_index().rename(columns={'index': 'Cue'})
        tmp.insert(0,'ID',ID)
        fbs = pd.concat([tmp, fbs], axis=0)
        
        # Concatenated hits (of each cue)
        tmp = pd.DataFrame(mod.isHit_all_cues).transpose()
        tmp.columns = tmp.columns+1
        tmp = tmp.reset_index().rename(columns={'index': 'Cue'})
        tmp.insert(0,'ID',ID)
        hits = pd.concat([tmp, hits], axis=0)
        
        # Concatenated model predictions: PEs
        tmp = pd.DataFrame(mod.PEs).transpose()
        tmp.columns = tmp.columns+1
        tmp = tmp.reset_index().rename(columns={'index': 'Cue'})
        tmp.insert(0,'ID',ID)
        PEs = pd.concat([tmp, PEs], axis=0)
        
        # Concatenated model predictions: shrinking pi
        tmp = pd.DataFrame(mod.shrink_pi).transpose()
        tmp.columns = tmp.columns+1
        tmp = tmp.reset_index().rename(columns={'index': 'Cue'})
        tmp.insert(0,'ID',ID)
        shrink_PIs = pd.concat([tmp, shrink_PIs], axis=0)
        
        # Concatenated model predictions: shrinking alpha
        tmp = pd.DataFrame(mod.shrink_pi).transpose()
        tmp.columns = tmp.columns+1
        tmp = tmp.reset_index().rename(columns={'index': 'Cue'})
        tmp.insert(0,'ID',ID)
        shrink_PIs = pd.concat([tmp, shrink_PIs], axis=0)
    
    
    # Save mod LLs and parameter values
    mod_fit = pd.DataFrame(all_users).transpose()
    mod_fit.to_pickle(all_users_folder+'mod_param_fits.pkl')
    
    # Save mod predictions
    p_hit_per_trial = p_hit_per_trial.sort_values(by='ID').reset_index(drop=True)
    p_hit_per_trial.to_pickle(all_users_folder+'mod_p_hit_per_trial.pkl')
    ev_per_trial = ev_per_trial.sort_values(by='ID').reset_index(drop=True)
    ev_per_trial.to_pickle(all_users_folder+'mod_ev_per_trial.pkl')
    PEs = PEs.sort_values(by='ID').reset_index(drop=True)
    PEs.to_pickle(all_users_folder+'PEs.pkl')
    shrink_PIs = shrink_PIs.sort_values(by='ID').reset_index(drop=True)
    shrink_PIs.to_pickle(all_users_folder+'shrink_PIs.pkl')
    
    # Save behaviour
    trialsNos = trialsNos.sort_values(by='ID').reset_index(drop=True)
    trialsNos.to_pickle(all_users_folder+'trialsNos.pkl')
    fbs = fbs.sort_values(by='ID').reset_index(drop=True)
    fbs.to_pickle(all_users_folder+'fbs.pkl')
    hits = hits.sort_values(by='ID').reset_index(drop=True)
    hits.to_pickle(all_users_folder+'hits.pkl')
    

In [6]:
mod.ID

'001'

ERROR! Session/line number was not unique in database. History logging moved to new session 1114


In [7]:
mod.shrink_pi

[]

In [11]:
mod.shrink_alpha

[]