In [1]:
from scipy.optimize import minimize
from load_data import load_clean_data
import numpy as np
import matplotlib.pyplot as plt

In [68]:
def optimize(fname, 
             bounds, 
             Data, 
             niter,
             toplot=False,
            ):
             


    outcomes = np.full([niter, len(bounds)+1], np.nan)
    optimcurve = np.full(niter, np.nan)
    for i in range(niter):
        
        # random starting point based on maximum bounds
        params0 = np.array([bound[1] * np.random.rand() for bound in bounds])
            
        # compute the function value at the starting point
        llh0 = fname(params0, Data)
        
        # run the optimizer with constraints    
        result = minimize(fun=fname, x0=params0, args=Data, tol=1e-4, bounds=bounds)
        x = result.x
        bestllh = fname(x, Data)
        outcomes[i, :] = [bestllh] + [xi for xi in x]    
        optimcurve[i] = min(outcomes[:(i+1), 0])
        print(f'{(i/niter)*100}%', end='\r')

    # find the global minimum out of all outcomes
    i = np.argwhere(outcomes[:, 0] == np.min(outcomes[:, 0]))
    bestparameters = outcomes[i[0], 1:].flatten()
    bestllh = -outcomes[i[0], 0].flatten()[0]
    
    # plot the best llh found by the optimizer as a function of iteration number.
    if toplot:
        plt.figure()
        plt.plot(range(niter), np.round(optimcurve, 6), 'o-')
        plt.xlabel('iteration')
        plt.ylabel('best minimum')
        #plt.title(str(nblocks) + ' blocks')
    
    return(bestparameters, bestllh)

In [69]:
""" 

Args:

Alpha (float): learning rate bounded from 0 to 1
K (int): working memory capacity
prior (): working memory prior weight
beta (int): inverse temperature average value is 5 . Fixed to 100
epsilon (float): noise bounded from 0 to 1
phi (float): decay bounded from 0 to 1
pers (float): neglect of negative values bounded from 0 to 1
T (int): number of trials per block
actions (array): all actions chosen per block
reward (array): all rewards recieved per block

"""

def likelihood_RLWM(data, K, alpha, prior, epsilon, phi, pers):
    """This function calculate the probability 
    """
    choiceProb = []
    for block in data:
        actions, rewards, stimulus, set_size = block
        set_size = set_size[0]
        beta = 100
        action_options = 3
        T = len(actions)
        Wwm = prior*(np.min([1, K/set_size]))
        neg_alpha = (1-pers)*alpha
        Q = np.ones((set_size, action_options))/action_options
        W = np.ones((set_size, action_options))/action_options
        W0 = W.copy()
   

        for a, r, s in zip(actions, rewards, stimulus):


            W = W + phi*(W0-W)



            Prl = np.exp(beta * Q[s, :])
            Prl = Prl/np.sum(Prl)

            Pwm = np.exp(beta * W[s, :])
            Pwm = Pwm/np.sum(Pwm)

            Pall = Wwm*Pwm + (1-Wwm)*Prl
            Pall = (1-epsilon)*Pall + epsilon*(1/action_options)

            choiceProb.append(Pall[a])

            if r==0:
                Q[s,a] = Q[s,a] + neg_alpha*(r-Q[s,a])

            else:
                Q[s, a] = Q[s, a] + alpha*(r-Q[s, a])


            W[s,a] = r
    NegLL = -np.sum(np.log(choiceProb))
    return NegLL


In [70]:
data = load_clean_data()


10.21 percent trials removed


In [91]:
all_subj = dict()
for subj in np.unique(data['subj']):
    print(subj)
    subj_idx = np.where(data['subj']==subj)[0]
    test_blocks = np.unique([data['block'][i] for i in subj_idx])
    alldata = []
    for k in test_blocks:
        block_idx = np.where(data['block']==k)[0]
        rew = np.array([i for j, i in enumerate(data['correct']) if j in subj_idx and j in block_idx])
        act = np.array([i for j, i in enumerate(data['action']) if j in subj_idx and j in block_idx])
        stim = np.array([i-1 for j, i in enumerate(data['stimulus']) if j in subj_idx and j in block_idx])
        setsize = np.array([i for j, i in enumerate(data['set_size']) if j in subj_idx and j in block_idx])
        indivdata = np.array([act, rew, stim, setsize])
        alldata.append(indivdata)

    set_size_capac=dict()
    for K in [2, 3, 4, 5, 6]:
        fun = lambda x, Data : likelihood_RLWM(Data, K, x[0], x[1], x[2], x[3], x[4])
        bnds = ((0, 1), (0, 1), (0,1), (0, 1), (0,1))
        test= optimize(fun, bnds, alldata, niter=40, toplot=False)
        set_size_capac[K] = test
    all_subj[subj]= min(set_size_capac.items(), key=lambda x: x[1][-1]) 

1
27.5%999999999999%%
37.5%999999999999%%
47.5%999999999999%%
57.5%999999999999%%
67.5%999999999999%%
77.5%999999999999%%
87.5%999999999999%%
97.5%999999999999%%
10.5%999999999999%%
11.5%999999999999%%
12.5%999999999999%%
13.5%999999999999%%
14.5%999999999999%%
15.5%999999999999%%
16.5%999999999999%%
17.5%999999999999%%
18.5%999999999999%%
19.5%999999999999%%
20.5%999999999999%%
21.5%999999999999%%
22.5%999999999999%%
23.5%999999999999%%
24.5%999999999999%%
25.5%999999999999%%
26.5%999999999999%%
27.5%999999999999%%
28.5%999999999999%%
29.5%999999999999%%
30.5%999999999999%%
31.5%999999999999%%
32.5%999999999999%%
33.5%999999999999%%
34.5%999999999999%%
35.5%999999999999%%
36.5%999999999999%%
37.5%999999999999%%
38.5%999999999999%%
39.5%999999999999%%
40.5%999999999999%%
41.5%999999999999%%
42.5%999999999999%%
43.5%999999999999%%
44.5%999999999999%%
45.5%999999999999%%
46.5%999999999999%%
47.5%999999999999%%
48.5%999999999999%%
49.5%999999999999%%
50.5%999999999999%%
51.5%999999999999%

In [92]:
import os
os.system('say "your program has finished"')

0

In [93]:
all_subj

{1: (3,
  (array([0.59467532, 0.77790786, 0.        , 0.17304316, 0.46342024]),
   -5.493061443340547)),
 2: (6,
  (array([0.04133561, 0.17077175, 0.09496387, 0.40270581, 0.61583219]),
   -298.3505971597374)),
 3: (5,
  (array([0.01853035, 0.50136906, 0.01055701, 0.25991959, 1.        ]),
   -219.18323046076154)),
 4: (6,
  (array([0.01043966, 0.27978401, 0.01236253, 0.35656031, 1.        ]),
   -380.84435011575107)),
 5: (6,
  (array([0.02378643, 0.11890255, 0.10801104, 0.07936705, 0.8196913 ]),
   -323.0503584726911)),
 6: (6,
  (array([0.01563116, 0.39247313, 0.03389419, 0.35821868, 1.        ]),
   -308.48582469549376)),
 7: (4,
  (array([0.01703591, 0.24011961, 0.03844884, 0.62937204, 1.        ]),
   -380.78822461057393)),
 8: (6,
  (array([0.00885617, 0.23020278, 0.01034611, 0.09403739, 0.99738388]),
   -373.0912211341481)),
 9: (2,
  (array([0.00721199, 0.76971808, 0.02469827, 0.29880052, 1.        ]),
   -371.04980183086)),
 10: (6,
  (array([0.00880776, 0.19683101, 0.19253825

In [96]:
import pickle

In [97]:
# with open('subj_fit.p', 'wb') as fp:
#     pickle.dump(all_subj, fp, protocol=pickle.HIGHEST_PROTOCOL)

In [98]:
with open('subj_fit.p', 'rb') as fp:
    test = pickle.load(fp)


In [99]:
test

{1: (3,
  (array([0.59467532, 0.77790786, 0.        , 0.17304316, 0.46342024]),
   -5.493061443340547)),
 2: (6,
  (array([0.04133561, 0.17077175, 0.09496387, 0.40270581, 0.61583219]),
   -298.3505971597374)),
 3: (5,
  (array([0.01853035, 0.50136906, 0.01055701, 0.25991959, 1.        ]),
   -219.18323046076154)),
 4: (6,
  (array([0.01043966, 0.27978401, 0.01236253, 0.35656031, 1.        ]),
   -380.84435011575107)),
 5: (6,
  (array([0.02378643, 0.11890255, 0.10801104, 0.07936705, 0.8196913 ]),
   -323.0503584726911)),
 6: (6,
  (array([0.01563116, 0.39247313, 0.03389419, 0.35821868, 1.        ]),
   -308.48582469549376)),
 7: (4,
  (array([0.01703591, 0.24011961, 0.03844884, 0.62937204, 1.        ]),
   -380.78822461057393)),
 8: (6,
  (array([0.00885617, 0.23020278, 0.01034611, 0.09403739, 0.99738388]),
   -373.0912211341481)),
 9: (2,
  (array([0.00721199, 0.76971808, 0.02469827, 0.29880052, 1.        ]),
   -371.04980183086)),
 10: (6,
  (array([0.00880776, 0.19683101, 0.19253825