In [13]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.optimize import minimize
import seaborn as sns
import ptitprince as pt # rainplot
from scipy.stats import pearsonr, bernoulli

# Data cleaning

In [10]:
def extract_hits_fbs(df):
    
    isHit_all_cues = {}
    fbs_all_cues = {}
    trialNo_all_cues = {}
        
    for cue in ['HR', 'LR', 'HP', 'LP']:

        # Extract cue data
        cue_data_tmp = df[df['Code']==('Cue_' + cue)] 

        # Create a cue trial column
        cue_data_tmp.insert(0, "CueTrial", list(range(1,len(cue_data_tmp)+1)))

        # Store all in dictionnaries
        isHit_all_cues[cue]=cue_data_tmp['isHit'].tolist()
        fbs_all_cues[cue]=cue_data_tmp['FBs'].tolist()
        trialNo_all_cues[cue]=cue_data_tmp['Trial'].tolist()
        
    return isHit_all_cues, fbs_all_cues, trialNo_all_cues

# Value functions

In [8]:
def rescorla_wagner_noV0_1trial(vt_m_1, isHit, fb, param_names, param_values):    
    
    # Free parameters
    alpha = param_values[param_names.index('alpha')]

    pe = np.nan

    # if hit, recieves fb
    if isHit == 1:
        # Compute prediction error
        pe = fb - vt_m_1
        # Compute new vt and fill in 
        vt = vt_m_1 + alpha * pe

    # if no hit, no fb
    elif isHit == 0:
        # vt does not change 
        vt = vt_m_1

    return vt, pe

# Decision functions

In [28]:
def my_softmax_nobeta_1trial(vt, param_names, param_values):
    
    x = vt
    
    p_hit =  np.exp(x)/(np.exp(x)+1)

    return p_hit

In [29]:
def my_softmax_1trial(vt, param_names, param_values):
    
    # Free parameters
    beta = param_values[param_names.index('beta')]
    
    x = beta * vt
    
    p_hit =  np.exp(x)/(np.exp(x)+1)

    return p_hit

# Model

In [45]:
class Model:
    
    def __init__(self, mod_name, value_fct, dec_fct, param_names):
        self.value_fct = value_fct
        self.dec_fct = dec_fct
        self.param_names = param_names
        self.mod_name = mod_name
        self.param_values = []
        self.gen_param_values = []
        self.values = []
        self.p_hit = []
        self.cue_nLLs = []
        self.total_cue_nLL = []
        self.fbs_all_cues = []
        self.isHit_all_cues = []
        self.nLL = []
        self.ID = []  
        self.Ntrials = []  
        self.PEs = []
        self.shrink_pi = []
        self.shrink_alpha = []
    
    def set_param_values(self, param_values):
        self.param_values = param_values
            
    def set_data(self, ID, fbs_all_cues, isHit_all_cues, trialNo_all_cues):
        self.ID = ID
        self.fbs_all_cues = fbs_all_cues
        self.isHit_all_cues = isHit_all_cues
        self.trialNo_all_cues = trialNo_all_cues
            
    def compute_ll_per_cue(self, p_hit_all_cues, isHit_all_cues):
        
        # Initialise empty dicitonnary
        nLLs_all_cues = dict.fromkeys(p_hit_all_cues.keys())
        total_nLL_all_cues = dict.fromkeys(p_hit_all_cues.keys())
        Ntrials_per_cue = dict.fromkeys(p_hit_all_cues.keys())
        
        # Iterate over cues
        for cue, p_hit in p_hit_all_cues.items():
            isHit = isHit_all_cues[cue]
            # compute likelihoods of choices
            ll_choice = []
            almost_zero = np.finfo(float).tiny
            for ind, hit in enumerate(isHit):
                if hit==1:
                    p_ =  p_hit[ind]
                else:
                    p_ =  1-p_hit[ind]
                    
                ll_choice.append(almost_zero if p_<almost_zero else p_)
                             
            nLLs = -np.log(ll_choice)
            nLLs_all_cues[cue] = nLLs
            total_nLL_all_cues[cue] = sum(nLLs)
            Ntrials_per_cue[cue] = len(nLLs)
            
        return total_nLL_all_cues, nLLs_all_cues, Ntrials_per_cue
    
    def fit(self, param_lower_bound, param_upper_bound, n_iterations, method='Powell'):
        
        # compute sequence of parameter bounds
        bounds = []
        for low, up in zip(param_lower_bound, param_upper_bound):
            bounds.append([low,up])
                
        # init
        mat_min_nLL=[]
        mat_best_params=[]
        
        for i in range(0,n_iterations):
            
            # define the starting point as a random sample from the domain
            initial_guess = np.array(param_lower_bound) + np.random.rand(len(param_lower_bound)) * (np.array(param_upper_bound) - np.array(param_lower_bound))

            # find the min likelihood 
            result = minimize(self.compute_nLL, initial_guess, method = method, bounds = bounds, 
                              options={'xtol': 1e-8, 'disp': False})

            # store min_nLL and parameters
            mat_min_nLL.append(result.fun)
            mat_best_params.append(result.x)

        # Find best params
        ind = np.argmin(mat_min_nLL)
        best_params = mat_best_params[ind]

        # Compute best LL and store
        nLL = self.compute_nLL(best_params)   
        
        
    def simulate_behaviour(self, fbs_all_cues, trialNo_all_cues, param_values, param_names):
    
        # Store parameter values used for behaviour generation
        self.gen_param_values = param_values
        
        # Value function 
        # Free parameters
        alpha = param_values[param_names.index('alpha')]
        if 'v0' not in self.param_names:
            v0 = 0
        else:
            v0 = param_values[param_names.index('v0')]
                
        Ntrials = sum(len(lst) for lst in fbs_all_cues.values())

        # Initialise empty dictionary
        vt_all_cues = dict.fromkeys(fbs_all_cues.keys())
        pe_all_cues = dict.fromkeys(fbs_all_cues.keys())
        isHit_all_cues = dict.fromkeys(vt_all_cues.keys())

        # Iterate over cues
        for cue, fbs in zip(fbs_all_cues.keys(), fbs_all_cues.values()):

            # Initialise vectors
            vts = np.empty(len(fbs))
            hits = np.empty(len(fbs))
            pes = np.empty(len(fbs))

            # Fill in
            vts.fill(np.nan)
            hits.fill(np.nan)
            pes.fill(np.nan)

            # Fill in prior
            vts[0] = v0
            hits[0] = 0

            # Iterate to fill in vector
            for t in range(1,len(vts)):
                # compute EVs
                vt, pe = self.value_fct(vts[t-1], hits[t-1], fbs[t-1], self.param_names, self.param_values)
                # compute prob
                p_hit  = self.dec_fct(vt, self.param_names, self.param_values)
                # Behaviour
                isHit = 1 if p_hit>=0.5 else 0
                vts[t] = vt
                pes[t] = pe
                hits[t] = isHit

            # Concat all cues
            isHit_all_cues[cue] = isHit

        # Store 
        self.set_data('', fbs_all_cues, isHit_all_cues, trialNo_all_cues)
                
    
    def compute_nLL(self, param_values):
        
        fbs_all_cues = self.fbs_all_cues
        isHit_all_cues = self.isHit_all_cues
        trialNo_all_cues = self.trialNo_all_cues
        
        # initialise empty dictionnary
        p_hit_all_cues = dict.fromkeys(fbs_all_cues.keys())
        vt_all_cues = dict.fromkeys(fbs_all_cues.keys())
        pe_all_cues = dict.fromkeys(fbs_all_cues.keys())
        
        # set parameter values
        self.set_param_values(param_values)
        
        # Iterate over cues
        for cue, fbs in zip(fbs_all_cues.keys(), fbs_all_cues.values()):

            # Behaviour
            hits = isHit_all_cues[cue]
            
            # Initialise vectors 
            vts = np.empty(len(fbs))
            vts.fill(np.nan)
            pes = np.empty(len(fbs))
            pes.fill(np.nan)
            p_hits = np.empty(len(fbs))
            p_hits.fill(np.nan)

            if 'v0' not in self.param_names:
                vts[0] = 0
            else:
                vts[0] = v0

            # Iterate to fill in vector
            for t in range(1,len(vts)):
                # compute EVs
                vt, pe = self.value_fct(vts[t-1], hits[t-1], fbs[t-1], self.param_names, self.param_values)
                # compute prob
                p_hit  = self.dec_fct(vt, self.param_names, self.param_values)
                # fill in
                vts[t] = vt
                pes[t] = pe
                p_hits[t] = p_hit
                
            p_hit_all_cues[cue] = p_hits
            vt_all_cues[cue] = vts
            pe_all_cues[cue] = pes

        # compute nLL per cue
        total_nLL_all_cues, nLLs_all_cues, Ntrials_per_cue = self.compute_ll_per_cue(p_hit_all_cues, isHit_all_cues)
        
        # compute total neg LL (sum over cues, i.e. multiply likelihoods)
        nLL = sum(total_nLL_all_cues.values())
        
        # save total number of trials
        Ntrials = sum(Ntrials_per_cue.values())
        
        # Set values
        self.v = vt_all_cues
        self.p_hit = p_hit_all_cues
        self.cue_nLLs = nLLs_all_cues
        self.total_cue_nLL = total_nLL_all_cues
        self.nLL = nLL
        self.Ntrials = Ntrials
        self.PEs = pe_all_cues
        
        return nLL