In [85]:
import numpy as np
import pandas as pd
import pickle
import random
from scipy.optimize import minimize
from scipy.stats import pearsonr, bernoulli
import matplotlib.pyplot as plt
import seaborn as sns
import math
import ptitprince as pt # rainplot

## Plot

In [88]:
def fig_modelpred_on_behav(ev_per_trial, p_hit_per_trial, all_users_folder):

    # Timepoints
    window_size = 16
    N_trials = 112
    timepoints = ['t'+ str(t+1) for t in range(int(N_trials/window_size))]

    # Load behaviour and compute stats
    hit_perc_per_t = pd.read_pickle(all_users_folder + 'hit_perc_per_t/hit_perc_per_t_w' + str(window_size) + '.pkl');
    hit_perc_per_t.drop('ID', axis=1, inplace=True);
    stats_hits_perc_per_t = hit_perc_per_t.groupby('Code').agg(['mean', 'var']).T.swaplevel(axis=0);

    # Load model predictions and compute stats
    ev_per_trial.drop('ID', axis=1, inplace=True);
    stats_ev_per_trial = ev_per_trial.groupby('Cue').agg(['mean', 'var']).T.swaplevel(axis=0)
    stats_ev_per_trial.columns = 'Cue_'+stats_ev_per_trial.columns

    p_hit_per_trial.drop('ID', axis=1, inplace=True);
    stats_p_hit_per_trial = p_hit_per_trial.groupby('Cue').agg(['mean', 'var']).T.swaplevel(axis=0)
    stats_p_hit_per_trial.columns = 'Cue_'+stats_p_hit_per_trial.columns

    # Plot
    f, axs = plt.subplots(2, 1, figsize=(10, 6), dpi=80)
    plt.subplots_adjust(hspace = 0.4)

    x_mod = np.arange(len(stats_ev_per_trial.loc['mean']))+1
    x_behav = np.arange(len(timepoints))*4+4

    for cue, color, alpha in zip(['Cue_HP', 'Cue_LP', 'Cue_LR', 'Cue_HR'], ['red', 'red', 'green', 'green'], [0.6, 0.3, 0.3, 0.6]):
        # Hit probabilites
        m=axs[0].plot(x_mod, stats_p_hit_per_trial.loc['mean'][cue], '.-', color=color, alpha=alpha);
        # Behaviour: sliding average
        b=axs[0].plot(x_behav, stats_hits_perc_per_t.loc['mean'][cue], '.-', color='gray', alpha=alpha);
        # Expected values
        m=axs[1].plot(x_mod, stats_ev_per_trial.loc['mean'][cue], '.-', color=color, alpha=alpha);

    for ax in axs:
        ax.grid(axis='x', color='0.95');
        ax.set_xticks(x_mod);
    
    axs[0].set_title('Model probability of hit', fontsize=16);
    axs[1].set_title('Model expected values', fontsize=16);
    
    axs[0].set_ylabel("Hit [%]", fontsize=16);
    axs[1].set_ylabel("Values", fontsize=16);
    
    axs[0].set_ylim([0,1]);
    axs[1].set_ylim([-3,3]);
    
    axs[0].legend(b, {'Behaviour'})
    
    plt.show()

In [87]:
def plot_model_parameters(data_mod, param_names):
    
    Nparam = len(param_names)
    
    if Nparam<=3:
        f, axs = plt.subplots(1, Nparam, figsize=(Nparam*6, 5))
    elif Nparam==4:
        f, axs = plt.subplots(2, 2, figsize=(2*6, 2*5))
        axs = axs.reshape(-1)
    else:
        f, axs = plt.subplots(2, 3, figsize=(3*6, 2*5))
        axs = axs.reshape(-1)
        
    plt.subplots_adjust(hspace = 0.3)
        
    pal = sns.color_palette(n_colors=1)
    
    if Nparam == 1:
        ax, param = axs, param_names[0]
        y = data_mod[param].squeeze().tolist()
        pt.half_violinplot(ax=ax, x = y, palette = pal, bw = .2, cut = 0., scale = "area", width = .6, inner = None, orient = 'h')
        sns.stripplot(ax=ax, x = y, palette = pal, edgecolor = "white", size = 5, jitter = 1, zorder = 0, orient = 'h', alpha=.35)
        sns.boxplot(x = y, saturation=1, showfliers=False, width=0.15, boxprops={'zorder': 3, 'facecolor': 'none'}, ax=ax)
        ax.set_title('Parameter: ' + param, fontsize=18)
        ax.set_xlabel('Best fit parameter value', fontsize=16)

    else:
        for ax, param in zip(axs, param_names):
            y = data_mod[param].tolist()
            # plot clouds
            pt.half_violinplot(ax=ax, x = y, palette = pal, bw = .2, cut = 0., scale = "area", width = .6, inner = None, orient = 'h')
            # add rain
            sns.stripplot(ax=ax, x = y, palette = pal, edgecolor = "white", size = 5, jitter = 1, zorder = 0, orient = 'h', alpha=.35)
            sns.boxplot(x = y, saturation=1, showfliers=False, width=0.15, boxprops={'zorder': 3, 'facecolor': 'none'}, ax=ax)
            # Makeup
            ax.set_title('Parameter: ' + param, fontsize=18)
            ax.set_xlabel('Best fit parameter value', fontsize=16)

    plt.show()

In [2]:
def plot_correlation(title, x, y, data, ax, text_pos, xlabel, ylabel, xlim, ylim):
        
    # Initialise figue
    ax.set_title(title, fontsize = 22)

    # Scatter plot
    sns.scatterplot( x = x, y = y, data = data, ax = ax);

    # Linear regression line
    sns.regplot(x = x, y = y, data = data, ax = ax);
    
    # Plot horizontal line
    ax.axhline(y=0, color='k', linestyle=':')

    # Compute correlation stats (r and p values)
    r, p = pearsonr(data[x], data[y])
    
    # Write stats on fig
    if p<0.01:
        ax.text(text_pos[0], text_pos[1], 'r={:.2f} \np<0.01'.format(r, p), transform=ax.transAxes, fontsize=16)
    else:    
        ax.text(text_pos[0], text_pos[1], 'r={:.2f} \np={:.2g}'.format(r, p), transform=ax.transAxes, fontsize=16)
    
    # Properties
    ax.set_ylabel(ylabel, fontsize=15)
    ax.set_ylim(bottom=ylim[0], top=ylim[1])
    
    ax.set_xlabel(xlabel, fontsize=15)
    ax.set_xlim(left=xlim[0], right=xlim[1])

In [43]:
def plot_param_recov_correlations(df):
    
    # get Param names
    param_names = df.columns[2::].tolist()

    Nparam = len(param_names)

    if Nparam<=3:
        #f, axs = plt.subplots(1, Nparam, figsize=(Nparam*6, 5))
        f, axs = plt.subplots(1, Nparam, figsize=(Nparam*4, 3))
    elif Nparam==4:
        f, axs = plt.subplots(2, 2, figsize=(2*6, 2*5))
        axs = axs.reshape(-1)
    else:
        f, axs = plt.subplots(2, 3, figsize=(3*6, 2*5))
        axs = axs.reshape(-1)

    plt.subplots_adjust(hspace = 0.3)

    for i, param in enumerate(param_names):

        lims = [math.floor(df[param].min()), math.ceil(df[param].max())]

        df_tmp = df.pivot(index = 'simID', columns = 'Type', values=param)
        
        ax = axs if Nparam==1 else axs[i]

        plot_correlation(title=param, x='Fit', y='Sim',
                         data=df_tmp, ax=ax, text_pos=[.2,.8], 
                         xlabel='Fitted', ylabel='Simulated', xlim=lims, ylim=lims)
    plt.show()

In [44]:
def plot_param_recov_conf_matrix(df):
    
    #f, axs = plt.subplots(1, 1, figsize=(7, 6))
    f, axs = plt.subplots(1, 1, figsize=(4, 3))

    # get Param names
    param_names = df.columns[2::].tolist()

    # Initialise matrices of nans
    Nparam = len(param_names)
    confusion_mat = np.empty((Nparam, Nparam))
    confusion_mat[:] = np.NaN

    # Store pearson correlations
    for i_sim, param_sim_name in enumerate(param_names):
        for i_fit, param_fit_name in enumerate(param_names):
            param_sim_values = df[df['Type']=='Sim'][param_sim_name]
            param_fit_values = df[df['Type']=='Fit'][param_fit_name]
            corr, _ = pearsonr(param_sim_values, param_fit_values)
            confusion_mat[i_sim, i_fit] = corr
            #print('sim: ' + param_sim_name + ' fit: ' + param_fit_name)
            #print(corr)

    # Plot confusion matrix
    ax = sns.heatmap(confusion_mat.T, vmin=-1, vmax=1, annot=True, annot_kws={"fontsize":12, "weight":"normal"}, cmap='coolwarm')
    ax.set_xticklabels(param_names, fontsize = 12)
    ax.set_yticklabels(param_names, fontsize = 12)
    ax.set_xlabel('Simulated', fontsize = 14)
    ax.set_ylabel('Fitted', fontsize = 14)
    plt.show()
    

## Model class

In [4]:
class Model:
    
    def __init__(self, mod_name, value_fct, dec_fct, param_names):
        self.mod_name = mod_name
        self.value_fct = value_fct
        self.dec_fct = dec_fct
        self.param_names = param_names
        self.param_values = []
        self.fbs_all_cues = []
        self.p_hit_all_cues = []
        self.vt_all_cues = []
        self.pe_all_cues = []
        self.trialNo_all_cues = []
        self.isHit_all_cues = []
        self.Ntrials_per_cue = []
        self.nLLs_all_cues = []
        self.total_nLL_per_cue = []
        self.nLL = []
        
    def __repr__(self):
        return f'Model("{self.mod_name}","{self.param_names}")'
    
    def set_part_data(self, ID, fbs_all_cues, isHit_all_cues, trialNo_all_cues):
        self.ID = ID
        self.fbs_all_cues = fbs_all_cues
        self.isHit_all_cues = isHit_all_cues
        self.trialNo_all_cues = trialNo_all_cues
        self.p_hit_all_cues = self.get_empty_dict(fbs_all_cues)
        self.vt_all_cues = self.get_empty_dict(fbs_all_cues)
        self.pe_all_cues = self.get_empty_dict(fbs_all_cues)
        self.nLLs_all_cues = self.get_empty_dict(fbs_all_cues)
        self.total_nLL_per_cue = dict.fromkeys(fbs_all_cues.keys())
    
    def set_dataset(self, fbs_all_cues, trialNo_all_cues):
        self.fbs_all_cues = fbs_all_cues
        self.trialNo_all_cues = trialNo_all_cues
        self.p_hit_all_cues = self.get_empty_dict(fbs_all_cues)
        self.vt_all_cues = self.get_empty_dict(fbs_all_cues)
        self.pe_all_cues = self.get_empty_dict(fbs_all_cues)
        self.isHit_all_cues = self.get_empty_dict(fbs_all_cues)
        self.nLLs_all_cues = self.get_empty_dict(fbs_all_cues)
        self.total_nLL_per_cue = dict.fromkeys(fbs_all_cues.keys())
        
    def set_param_values(self, param_values):
        self.param_values = param_values
        
    def get_empty_dict(self, fbs_all_cues):
        tmp = dict.fromkeys(fbs_all_cues.keys())
        _, rand_fb = random.choice(list(fbs_all_cues.items()))
        self.Ntrials_per_cue = len(rand_fb)
        for cue in tmp.keys():
            tmp[cue] = np.empty(len(rand_fb), dtype=object)
        return tmp

    def compute_nLL_per_cue(self, p_hit_all_cues, hits_all_cues):
        
        for cue, p_hit in p_hit_all_cues.items():
            
            for ind, f_x in enumerate(p_hit):
                
                y = hits_all_cues[cue][ind]
                
                # logistic reg cost function
                self.nLLs_all_cues[cue][ind] = - y * np.log(f_x) - (1-y) * np.log(1-f_x)
            
            self.total_nLL_per_cue[cue] = sum(self.nLLs_all_cues[cue])
        
    
    def total_nLL(self, param_values):
        
        # set parameter values
        self.set_param_values(param_values)
        
        # run model
        self.run_model(sim_behav=0)
        
        # compute nLL per cue TODO change the sim_hits
        self.compute_nLL_per_cue(self.p_hit_all_cues, self.isHit_all_cues)
        
        # total
        total_nLL = sum(self.total_nLL_per_cue.values())
        
        return total_nLL
    
    
    def fit(self, param_lower_bounds, param_upper_bounds, n_iterations, method='TNC'):
        
        # compute sequence of parameter bounds
        bounds = [[low,up] for low, up in zip(param_lower_bounds, param_upper_bounds)]
                
        # init
        mat_min_nLL=[]
        mat_best_params=[]
        
        for i in range(0, n_iterations):
            
            # define the starting point as a random sample from the domain
            initial_guess = [np.random.uniform(bound[0],bound[1],1) for bound in bounds]
            
            # find the min likelihood 
            result = minimize(self.total_nLL, initial_guess, method = method, bounds = bounds, 
                              options={'xtol': 1e-8, 'disp': False})

            # store min_nLL and parameters
            mat_min_nLL.append(result.fun)
            mat_best_params.append(result.x)

        # Find best params
        ind = np.argmin(mat_min_nLL)
        best_params = mat_best_params[ind]

        # Compute best LL 
        nLL = self.total_nLL(best_params)
        
        # Store
        self.nLL = nLL
        
        return best_params, nLL
            
                    
    def run_model(self, sim_behav):
        
        # set v0 to 0 or get its value if its a free parameter
        v0 = 0 if 'v0' not in self.param_names else self.param_values[self.param_names.index('v0')]
        
        for cue in self.fbs_all_cues.keys():
            
            # t = 0
            self.vt_all_cues[cue][0] = v0

            for t, trial in zip(range(0, self.Ntrials_per_cue), self.trialNo_all_cues[cue]):
                
                # make a decision
                self.p_hit_all_cues[cue][t] = self.dec_fct(self.vt_all_cues[cue][t], trial, self.param_names, self.param_values)
                
                if sim_behav == 1:
                    #self.sim_hits_all_cues[cue][t] = 1 if self.p_hit_all_cues[cue][t] >= 0.5 else 0
                    self.isHit_all_cues[cue][t] = float(bernoulli.rvs(self.p_hit_all_cues[cue][t], size=1))
                    
                # compute PE and new v
                vt, pe = self.value_fct(vt = self.vt_all_cues[cue][t], isHit = self.isHit_all_cues[cue][t], fb = self.fbs_all_cues[cue][t], 
                                        param_names = self.param_names, param_values = self.param_values)
                self.pe_all_cues[cue][t] = pe

                # update v
                if t < (self.Ntrials_per_cue-1):
                    self.vt_all_cues[cue][t+1] = vt

## Other

In [5]:
def fit_model_to_data(uniqueIDs, mod_info, param_lower_bounds, param_upper_bounds, all_users_folder):
    
    # Fit
    all_users = {}
    p_hit_per_trial = pd.DataFrame([])
    ev_per_trial = pd.DataFrame([])
    pe_per_trial = pd.DataFrame([])
    isHit_per_trial = pd.DataFrame([])
    fbs_per_trial = pd.DataFrame([])
    trialNo_per_trial = pd.DataFrame([])

    for n_part,ID in enumerate(uniqueIDs): 

        # Get data
        user_folder = 'data/user_' + ID + '/'
        df2_cf = pd.read_pickle(user_folder + 'df2_cf.pkl')
        isHit_all_cues, fbs_all_cues, trialNo_all_cues = extract_hits_fbs(df2_cf)

        # Create a new Model object
        mod = Model(mod_name = mod_info['name'],
                    value_fct = mod_info['value_fct'], 
                    dec_fct = mod_info['dec_fct'], 
                    param_names = mod_info['param_names'])

        # Input data to model
        mod.set_part_data(ID, fbs_all_cues, isHit_all_cues, trialNo_all_cues)

        # Fit model
        mod.fit(param_lower_bounds, param_upper_bounds, n_iterations=5)

        # Nested dictionnary user data
        all_users[n_part] = {}
        all_users[n_part]['ID'] = mod.ID
        all_users[n_part]['nLL'] = mod.nLL
        all_users[n_part]['Nparams'] = len(mod.param_names)

        for i in range(0,len(mod.param_names)):
            all_users[n_part][mod.param_names[i]] = mod.param_values[i]

        # Concatenated model predictions: phit
        tmp = pd.DataFrame(mod.p_hit_all_cues).transpose()
        tmp.columns = tmp.columns + 1
        tmp = tmp.reset_index().rename(columns = {'index': 'Cue'})
        tmp.insert(0,'ID',ID)
        p_hit_per_trial = pd.concat([tmp, p_hit_per_trial], axis=0)

        # Concatenated model predictions: EVs
        tmp = pd.DataFrame(mod.vt_all_cues).transpose()
        tmp.columns = tmp.columns+1
        tmp = tmp.reset_index().rename(columns = {'index': 'Cue'})
        tmp.insert(0,'ID',ID)
        ev_per_trial = pd.concat([tmp, ev_per_trial], axis = 0)
        
        # Concatenated prediction errors: PEs
        tmp = pd.DataFrame(mod.pe_all_cues).transpose()
        tmp.columns = tmp.columns+1
        tmp = tmp.reset_index().rename(columns = {'index': 'Cue'})
        tmp.insert(0,'ID',ID)
        pe_per_trial = pd.concat([tmp, pe_per_trial], axis = 0)
        
        # Concatenated hits: isHit
        tmp = pd.DataFrame(mod.isHit_all_cues).transpose()
        tmp.columns = tmp.columns+1
        tmp = tmp.reset_index().rename(columns = {'index': 'Cue'})
        tmp.insert(0,'ID',ID)
        isHit_per_trial = pd.concat([tmp, isHit_per_trial], axis = 0)
        
        # Concatenated feedbacks: fbs
        tmp = pd.DataFrame(mod.fbs_all_cues).transpose()
        tmp.columns = tmp.columns+1
        tmp = tmp.reset_index().rename(columns = {'index': 'Cue'})
        tmp.insert(0,'ID',ID)
        fbs_per_trial = pd.concat([tmp, fbs_per_trial], axis = 0)
                
        # Concatenated trials Nb: trialNo_all_cues
        tmp = pd.DataFrame(mod.trialNo_all_cues).transpose()
        tmp.columns = tmp.columns+1
        tmp = tmp.reset_index().rename(columns = {'index': 'Cue'})
        tmp.insert(0,'ID',ID)
        trialNo_per_trial = pd.concat([tmp, trialNo_per_trial], axis = 0)
        
        
    # Save mod LLs and parameter values
    mod_fit = pd.DataFrame(all_users).transpose()
    mod_fit.to_pickle(all_users_folder + 'mod_param_fits.pkl')

    # Save mod predictions
    p_hit_per_trial = p_hit_per_trial.sort_values(by ='ID').reset_index(drop = True)
    p_hit_per_trial.to_pickle(all_users_folder + 'mod_p_hit_per_trial.pkl')
    
    ev_per_trial = ev_per_trial.sort_values(by ='ID').reset_index(drop = True)
    ev_per_trial.to_pickle(all_users_folder + 'mod_ev_per_trial.pkl')
    
    pe_per_trial = pe_per_trial.sort_values(by ='ID').reset_index(drop = True)
    pe_per_trial.to_pickle(all_users_folder + 'mod_pe_per_trial.pkl')
    
    isHit_per_trial = isHit_per_trial.sort_values(by ='ID').reset_index(drop = True)
    isHit_per_trial.to_pickle(all_users_folder + 'mod_isHit_per_trial.pkl')
    
    trialNo_per_trial = trialNo_per_trial.sort_values(by ='ID').reset_index(drop = True)
    trialNo_per_trial.to_pickle(all_users_folder + 'mod_trialNo_per_trial.pkl')
    
    fbs_per_trial = fbs_per_trial.sort_values(by ='ID').reset_index(drop = True)
    fbs_per_trial.to_pickle(all_users_folder + 'mod_fbs_per_trial.pkl')
    

In [6]:
def reformat_param_lists(all_sim_param_values, all_fit_param_values, mod_info):
    
    df_sim = pd.DataFrame(all_sim_param_values, columns = mod_info['param_names'])
    df_sim.insert(0, 'Type', 'Sim')
    df_sim = df_sim.rename_axis('simID').reset_index()

    df_fit = pd.DataFrame(all_fit_param_values, columns = mod_info['param_names'])
    df_fit.insert(0, 'Type', 'Fit')
    df_fit = df_fit.rename_axis('simID').reset_index()
    
    df = pd.concat([df_sim, df_fit])
    df.sort_values(by='simID', ascending=True, inplace = True)
    df.reset_index(inplace = True, drop = True)
    
    return df

In [7]:
def extract_hits_fbs(df):
    
    isHit_all_cues = {}
    fbs_all_cues = {}
    trialNo_all_cues = {}
        
    for cue in ['HR', 'LR', 'HP', 'LP']:

        # Extract cue data
        cue_data_tmp = df[df['Code']==('Cue_' + cue)] 

        # Create a cue trial column
        cue_data_tmp.insert(0, "CueTrial", list(range(1,len(cue_data_tmp)+1)))

        # Store all in dictionnaries
        isHit_all_cues[cue]=cue_data_tmp['isHit'].tolist()
        fbs_all_cues[cue]=cue_data_tmp['FBs'].tolist()
        trialNo_all_cues[cue]=cue_data_tmp['Trial'].tolist()
        
    return isHit_all_cues, fbs_all_cues, trialNo_all_cues

In [8]:
def extract_rand_dataset():
    
    all_users_folder = 'data/all_users/'

    # Load IDs
    with open('uniqueIDs.pkl', 'rb') as f:
        uniqueIDs = pickle.load(f)

    # Extract random data set
    random.shuffle(uniqueIDs)
    ID = uniqueIDs[0]
    user_folder = 'data/user_' + ID + '/'
    df2_cf = pd.read_pickle(user_folder + 'df2_cf.pkl')
    _, fbs_all_cues, trialNo_all_cues = extract_hits_fbs(df2_cf)
    #print('Dataset from ID =',ID)
    
    return fbs_all_cues, trialNo_all_cues

In [66]:
def sim_fit_model(mod_info, param_lower_bounds, param_upper_bounds, Nsim, all_users_folder):

    # Create a new Model object
    mod = Model(mod_name = mod_info['name'], 
                value_fct = mod_info['value_fct'], 
                dec_fct = mod_info['dec_fct'], 
                param_names = mod_info['param_names'])
    
    # print model info
    print(repr(mod))

    # init
    all_sim_param_values = []
    all_fit_param_values = []
    
    # bounds
    bounds = [[low,up] for low, up in zip(param_lower_bounds, param_upper_bounds)]

    for sim_id in range(0, Nsim):

        if sim_id>4 and sim_id%int(round(Nsim/4))==0:
            print(sim_id)

        # extract random dataset and set
        fbs_all_cues, trialNo_all_cues = extract_rand_dataset()
        mod.set_dataset(fbs_all_cues, trialNo_all_cues)

        # random param values for simulation
        generating_param_values = [np.random.uniform(bound[0],bound[1],1) for bound in bounds]

        # set param values
        mod.set_param_values(generating_param_values)

        # simulate behav
        mod.run_model(sim_behav=1)

        # fit
        best_params, _ = mod.fit(param_lower_bounds, param_upper_bounds, n_iterations = 10)

        # store
        all_sim_param_values.append(np.concatenate(generating_param_values, axis=0))
        all_fit_param_values.append(best_params)

    # Reformat
    df = reformat_param_lists(all_sim_param_values, all_fit_param_values, mod_info)

    # Save
    filename = all_users_folder + 'sim_refit/' + mod_info['name'] + '.pkl'
    df.to_pickle(filename)
    print('DONE. Saved to: ', filename.strip())
    
    return mod

## Value function

In [73]:
def rescorla_wagner_2LR_1t(vt, isHit, fb, param_names, param_values):   
    
    # set alpha to 1 or get its value if its a free parameter
    alpha_rew = param_values[param_names.index('alpha_rew')]
    alpha_pun = param_values[param_names.index('alpha_pun')]
    
    pe = fb - vt

    if isHit == 1:
        if fb > 0:
            vt = vt + alpha_rew * pe
        elif fb < 0:
            vt = vt + alpha_pun * pe 
    
    elif isHit == 0:
        pe = np.nan
        
    return vt, pe

In [67]:
def rescorla_wagner_1t(vt, isHit, fb, param_names, param_values):   
    
    # set alpha to 1 or get its value if its a free parameter
    alpha = 0.1 if 'alpha' not in param_names else param_values[param_names.index('alpha')]
    
    if isHit == 1:
        pe = fb - vt
        vt = vt + alpha * pe 
    
    elif isHit == 0:
        pe = np.nan
        
    return vt, pe

## Decision function

In [71]:
def my_softmax_1t(vt, trial, param_names, param_values): 
    
    # set beta to 1 or get its value if its a free parameter
    # same for press bias pi to 0, and for shrinking press bias pi_t
    beta = 1 if 'beta' not in param_names else param_values[param_names.index('beta')]
    pi = 0 if 'pi' not in param_names else param_values[param_names.index('pi')]
    pi_t = 0 if 'pi_t' not in param_names else param_values[param_names.index('pi_t')]
    
    # normalise the trial, start at 1 at the new block
    norm_trial = (trial-56) if trial>56 else trial
    shrink = (56 - norm_trial)/56
        
    v_no_hit = 0
    v_hit = vt + pi + pi_t*shrink
    diff = v_no_hit - v_hit
    
    p_hit = 1 / (1 + np.exp(beta * diff))
            
    return p_hit

In [93]:
def my_softmax_1t_2pit(vt, trial, param_names, param_values): 
    
    # set beta to 1 or get its value if its a free parameter
    # same for press bias pi to 0, and for shrinking press bias pi_t
    beta = 1 if 'beta' not in param_names else param_values[param_names.index('beta')]
    
    pi_t_1 = 0 if 'pi_t_1' not in param_names else param_values[param_names.index('pi_t_1')]
    pi_t_2 = 0 if 'pi_t_2' not in param_names else param_values[param_names.index('pi_t_2')]
    
    # normalise the trial, start at 1 at the new block
    if trial>56:
        norm_trial = (trial-56) 
        pi_t = pi_t_2
    else:
        norm_trial = trial
        pi_t = pi_t_1
    
    shrink = (56 - norm_trial)/56
        
    v_no_hit = 0
    v_hit = vt + pi_t*shrink
    diff = v_no_hit - v_hit
    
    p_hit = 1 / (1 + np.exp(beta * diff))
            
    return p_hit