# implement line sweeps of params
- all params fixed expcet one, which is swept over range
- specifically looking at test accuracy as a function of one param


In [1]:
from matplotlib import pyplot as plt
from scipy.special import softmax
from itertools import product
import numpy as np
from utils import *
import time
import seaborn as sns
sns.set_context('talk')

%load_ext autoreload
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
## timestamp and dir for saving
tstamp = time.perf_counter_ns()
import os
os.mkdir('figures/scratch_folders/%i'%tstamp)

In [3]:
## import human data for fitting
import pandas as pd
hdf = pd.read_csv('../human_data.csv')
humanB_acc,humanI_acc = hdf.loc[:,('blocked mean','interleaved mean')].values.T

In [4]:
def get_sm(xth,norm=True):
  """ 
  given x_t_hat from subject
  [trial,layer,node]
  get 2afc normalized softmax for layer 2/3
  return: [layer2/3,trial,node56/78]
  norm=true 
   apply softmax to xth
   when prediction done with multiple schemas
  """
  nodes = {2:(5,6),3:(7,8)} 
  L = [] # layer 2 and 3
  for l,ns in nodes.items():
    y = xth[:,l,ns]
    if norm:
      y=softmax(y,1)
    L.append(y)
  return np.array(L)

def get_acc(data):
  """ 
  returns 2afc softmax of 
  layer 2/3 transitions
  """
  ysm = get_sm(data['xth'])
  L = []
  for i in range(2):
    ysml = ysm[i,:,:]
    yt = data['exp'][:,i+3] 
    pr_yt = ysml[range(len(ysml)),yt - (5+2*i)] # 
    L.append(pr_yt)
  return np.array(L)


In [5]:
### RUN EXP
def run_batch_exp(ns,args):
  """ exp over seeds, 
  single task_condition / param config
  return full data
  """
  dataL = []
  for i in range(ns):
    task = Task()
    sem = SEM(args['sch'])
    exp,curr  = task.generate_experiment(**args['exp'])
    data = sem.run_exp(exp)
    data['exp']=exp
    dataL.append(data)
  return dataL

def run_batch_exp_curr2(ns,args,currL=['blocked','interleaved']):
  """ loop over task conditions, 
  return acc [task_condition,seed,trial]
  """
  accL = []
  dataD = {}
  for curr in currL:
    args['exp']['condition'] = curr
    ## extract other data here
    dataL = run_batch_exp(ns,args)
    dataD[curr] = dataL
    ##
    acc = np.array([get_acc(data) for data in dataL]).mean(1) # mean over layer
    accL.append(acc)
  return np.array(accL),dataD

In [6]:
## default params
expargs = {
  'condition':'blocked',
  'n_train':160,
  'n_test':40
}
schargs = {
    'concentration':1.4,
    'stickiness_wi':5000,
    'stickiness_bt':5, # 100
    'sparsity':0.08,
    'pvar': 2,
    'lrate':0.8,
    'lratep':1,
} 
args = {
    'sch':schargs,
    'exp':expargs
}
param_str = "-".join(["%s_%.3f"%(i,j) for i,j in schargs.items()])
param_str

'concentration_1.400-stickiness_wi_5000.000-stickiness_bt_5.000-sparsity_0.080-pvar_2.000-lrate_0.800-lratep_1.000'

### main

In [7]:
ns = 7
L = []
condL = ['blocked','interleaved','early','middle','late']

model_seed_acc,dataD = run_batch_exp_curr2(ns,args,condL) # [curr,seeds,trials]

model_acc = model_seed_acc.mean(1) # [curr,trials]
model_testacc = model_acc[:,-40:].mean(1)
  


# model state

In [8]:
curr='blocked'
dataD[curr][0].keys()

dict_keys(['zt', 'xth', 'priors', 'likesL2', 'postL2', 'exp'])

## difficulty visualizing priors/likelihoods of a given model:
- how to collapse between schemas?
- each model instance might have different number of schemas

## issue for concatenating between trials:
- when new schema is forked on trial t, list of schemas for that trial is longer. 

In [9]:
dataL = dataD['interleaved']
k = 'postL2' # priors, likesL2, postL2
for seed in range(ns):
  print(np.array(dataL[seed][k]).shape) # [ntrials,nschemas]

(200, 1, 2)
(200, 1, 2)
(200, 1, 2)
(200, 1, 2)
(200, 1)
(200, 1)
(200, 1)


  after removing the cwd from sys.path.


In [10]:
k = 'priors' # priors, likesL2, postL2
seed = -1
dataL[seed][k]

[[5001, 1.4],
 [5005, 1.4],
 [5, 5004, 1.4],
 [5009, 4, 1.4],
 [9, 5008, 1.4],
 [5013, 8, 1.4],
 [13, 5012, 1.4],
 [5017, 12, 1.4],
 [17, 5016, 1.4],
 [5021, 16, 1.4],
 [21, 5020, 1.4],
 [5025, 20, 1.4],
 [25, 5024, 1.4],
 [5029, 24, 1.4],
 [29, 5028, 1.4],
 [5033, 28, 1.4],
 [33, 5032, 1.4],
 [5037, 32, 1.4],
 [37, 5036, 1.4],
 [5041, 36, 1.4],
 [41, 5040, 1.4],
 [5045, 40, 1.4],
 [45, 5044, 1.4],
 [5049, 44, 1.4],
 [49, 5048, 1.4],
 [5053, 48, 1.4],
 [53, 5052, 1.4],
 [5057, 52, 1.4],
 [57, 5056, 1.4],
 [5061, 56, 1.4],
 [61, 5060, 1.4],
 [5065, 60, 1.4],
 [65, 5064, 1.4],
 [5069, 64, 1.4],
 [69, 5068, 1.4],
 [5073, 68, 1.4],
 [73, 5072, 1.4],
 [5077, 72, 1.4],
 [77, 5076, 1.4],
 [5081, 76, 1.4],
 [81, 5080, 1.4],
 [5085, 80, 1.4],
 [85, 5084, 1.4],
 [5089, 84, 1.4],
 [89, 5088, 1.4],
 [5093, 88, 1.4],
 [93, 5092, 1.4],
 [5097, 92, 1.4],
 [97, 5096, 1.4],
 [5101, 96, 1.4],
 [101, 5100, 1.4],
 [5105, 100, 1.4],
 [105, 5104, 1.4],
 [5109, 104, 1.4],
 [109, 5108, 1.4],
 [5113, 108, 1.4]

### plt 

In [11]:
model_acc.shape

(5, 200)