In [1]:
from matplotlib import pyplot as plt
from scipy.special import softmax
from itertools import product
import numpy as np
from utils import *
import time
import seaborn as sns
sns.set_context('talk')

%load_ext autoreload
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
## import human data for fitting
import pandas as pd
hdf = pd.read_csv('../human_data.csv')
hB,hI = hdf.loc[:,('blocked mean','interleaved mean')].values.T

In [3]:
### RUN EXP
def run_batch_exp(ns,args):
  dataL = []
  for i in range(ns):
    task = Task()
    sem = SEM(args['sch'])
    exp,curr  = task.generate_experiment(**args['exp'])
    data = sem.run_exp(exp)
    data['exp']=exp
    dataL.append(data)
  return dataL

In [4]:
def get_sm(xth):
  """ 
  given x_t_hat from subject
  [trial,layer,node]
  get 2afc normalized softmax for layer 2/3
  return: [trial,layer2/3,node56/78]
  """
  nodes = {2:(5,6),3:(7,8)} 
  L = [] # layer 2 and 3
  for l,ns in nodes.items():
    y = xth[:,l,ns]
#     y /= y.sum(1,keepdims=True)
#     print(y/y.sum(1,keepdims=True))
#     y = softmax(xth[:,l,ns],-1)
    L.append(y)
  return np.array(L)

def get_acc(data):
  """ 
  returns 2afc softmax of 
  layer 2/3 transitions
  """
  ysm = get_sm(data['xth'])
  L = []
  for i in range(2):
    ysml = ysm[i]
    yt = data['exp'][:,i+3] 
    pr_yt = ysml[range(len(ysml)),yt - (5+2*i)] # 
    L.append(pr_yt)
  return np.array(L)

In [5]:
alfaL = np.arange(0.85,1.151,0.05)
betaL = np.arange(0.001,0.251,0.05)
lmdaL = np.arange(0.01,0.2,0.025)

param_setsize = len(alfaL)*len(betaL)*len(lmdaL)
print(param_setsize)
print(alfaL,betaL,lmdaL)

expargs = {
  'condition':'blocked',
  'n_train':160,
  'n_test':40
}
ns = 5

280
[0.85 0.9  0.95 1.   1.05 1.1  1.15] [0.001 0.051 0.101 0.151 0.201] [0.01  0.035 0.06  0.085 0.11  0.135 0.16  0.185]


In [None]:
L = []
for idx,(alfa,beta,lmda) in enumerate(product(alfaL,betaL,lmdaL)):
  params = alfa,beta,lmda
  print(idx/param_setsize,alfa,beta,lmda)
  schargs = {
    'concentration':alfa,
    'stickiness_wi':beta,
    'stickiness_bt':beta,
    'sparsity':lmda
  }
  args = {
    'sch':schargs,
    'exp':expargs
  }
  dataL = run_batch_exp(ns,args)
  acc = np.array([get_acc(data) for data in dataL]) # sub,layer,trial
  acc = acc.mean(1) # mean over layers
  test_acc = acc[(2,3),-40:].mean()
  mB = acc.mean(0)
  mse = np.mean((mB - hB)**2)
  ## record
  D = {**schargs,'mse':mse,'testacc':test_acc}
  L.append(D)


0.0 0.85 0.001 0.01
0.0035714285714285713 0.85 0.001 0.035
0.007142857142857143 0.85 0.001 0.060000000000000005
0.010714285714285714 0.85 0.001 0.085
0.014285714285714285 0.85 0.001 0.11
0.017857142857142856 0.85 0.001 0.135
0.02142857142857143 0.85 0.001 0.16000000000000003
0.025 0.85 0.001 0.18500000000000003
0.02857142857142857 0.85 0.051000000000000004 0.01
0.03214285714285714 0.85 0.051000000000000004 0.035
0.03571428571428571 0.85 0.051000000000000004 0.060000000000000005
0.039285714285714285 0.85 0.051000000000000004 0.085
0.04285714285714286 0.85 0.051000000000000004 0.11
0.04642857142857143 0.85 0.051000000000000004 0.135
0.05 0.85 0.051000000000000004 0.16000000000000003
0.05357142857142857 0.85 0.051000000000000004 0.18500000000000003
0.05714285714285714 0.85 0.101 0.01
0.060714285714285714 0.85 0.101 0.035
0.06428571428571428 0.85 0.101 0.060000000000000005
0.06785714285714285 0.85 0.101 0.085
0.07142857142857142 0.85 0.101 0.11
0.075 0.85 0.101 0.135
0.07857142857142857 0.

In [None]:
gsdf = pd.DataFrame(L)
gsdf = gsdf.sort_values('testacc',ascending=False)
gsdf = gsdf.sort_values('mse',ascending=True)
best_param = gsdf.iloc[0]

# best fit

In [None]:
a,bw,bb,l = best_param[:4]

print(best_param)
schargs_prime = {
  'concentration':a,
  'stickiness_wi':bw,
  'stickiness_bt':bb,
  'sparsity':l
}
args = {
  'sch':schargs_prime,
  'exp':expargs
  }
dataL = run_batch_exp(ns,args)
acc = np.array([get_acc(data) for data in dataL]) # sub,layer,trial
acc = acc.mean(1) # mean over layers
mB = acc.mean(0)
## plot
plt.plot(mB,c='b',lw=3)
plt.scatter(range(len(hB)),hB,c='k',s=10,zorder=99)
plt.axhline(0.5,c='k',lw=0.5)
for v in [40,80,160]:
  plt.axvline(v,c='k',lw=0.5)
  
# plt.savefig('figures/scratch/bestfitGS-%i'%tstampbestfitGS-7562171286.png)

# analysis

In [None]:
pL = ['concentration','stickiness_bt','sparsity']
f,axa = plt.subplots(1,3,figsize=(20,4),sharey=True)
for i in range(3):
  ax = axa[i]
  msegroup = gsdf.groupby(pL[i]).mse
  M = msegroup.mean()
  S = msegroup.std()/np.sqrt(msegroup.count())
  ax.plot(M)
#   ax.bar(M,yerr=S)
  ax.fill_between(M.index,M+S,M-S,alpha=0.5)
  ax.set_xlabel(pL[i])
  ax.set_ylabel('mse')
  

# plt.savefig('figures/scratch/mse-by-param-%i'%tstamp)

In [None]:
f,axa = plt.subplots(3,2,figsize=(30,24),sharey=True)
for oi,outvar in enumerate(pL):
  ii=-1
  for invar in pL:
    if outvar==invar: continue
    ii+=1
    ax = axa[oi,ii]
    for oval,d in gsdf.groupby([outvar,invar]).mean().groupby(outvar):
      xplt = d.reset_index().loc[:,invar]
      yplt = d.mse.values
      ax.set_title('%s for different %s'%(invar,outvar))
      ax.plot(xplt,yplt)
      ax.set_ylabel('mse')
      
# plt.savefig('figures/scratch/mse-by-paramXparam-%i'%tstamp)