# run simulations 

In [1]:
import pdb, os, pickle,json 
import math 
import pandas as pd
import numpy as np 
from scipy.stats import ttest_rel,ttest_ind, ttest_1samp 
from matplotlib import pyplot as plt 
import seaborn as sns 
from sklearn.metrics import adjusted_rand_score 

from humanUtils import * 
from cswHumanDatabase import load_final_df, load_dfs 
from analysis import * 
from utils import * 

%load_ext autoreload
%autoreload 2

sns.set_context('talk')
%matplotlib inline
plt.rcParams['font.size'] = 22

In [2]:
condL = ['blocked','interleaved','early','middle','late']

# model dfs

## read gridsearch results, rerun and save top model experiment
* run experiments with best params of each gridsearch
* saveout acc, exp and zt of each

## sims info
* version 5: skipt1 = false; no variability across seeds 
* version 1: skipt1 = true; no variability across seeds
* version 2: skipt1 = true; variable concentration across seeds 
* version 9: skipt1 = true starting in 2nd block for blocked, always false for interleaved; variable concentration across seeds

In [3]:
def get_gs_info(gsname):
  # concentration variance
  if (gsname == V9) or (gsname == V2):
    concentration_sd = 0.3
    concentration_lb = 0
    concentration_ub = np.inf
    concentration_mean = args["sch"]["concentration"]
    concentration_info = {
      'concentration_mean': concentration_mean,
      'concentration_sd': concentration_sd,
      'concentration_lb': concentration_lb,
      'concentration_ub': concentration_ub
    }
  else:
    concentration_info = None
  # run fun  
  if gsname == V9:
    run_fn = run_batch_exp_curr_v8and9
  else:
    run_fn = run_batch_exp_curr
  return {"concentration_info":concentration_info,'run_fn':run_fn}
  

In [4]:
V1 = 'gs0317'
V2 = 'gs0317_concentrationAcross'
V9 = 'gs0317_version9_concentrationAcross_skipt1False_except_blocked'
V5 = 'gs0317_version5_skipt1False'
gsnameL = [V5,V1,V2,V9]
# gsnameL = [V9]

In [5]:
recompute = True
ns=50

for gsname in gsnameL:
  # read gridsearch dataframe
  gsdf = pd.read_csv(
    'data/model/%s-summdf.csv'%gsname
  ).astype(float)
  
  # find param of best model
  TOP_K = 1
  gsdf.loc[:,"mse-bi"] = gsdf.b_mse1+gsdf.b_mse2+gsdf.i_mse1+gsdf.i_mse2
  topKgsdf = gsdf.sort_values('mse-bi').iloc[:TOP_K]
  args,paramstr = get_argsD(topKgsdf.iloc[0])
  print(gsname)
  print()
  print(paramstr)
  print()
  print(topKgsdf.loc[:,("b_mse1","b_mse2","i_mse1","i_mse2","mse-bi")].round(4).to_dict())
  
  # run experiment
  if recompute:
    print(gsname)
    # gridsearch params: alpha variance, run fun
    gs_params = get_gs_info(gsname)
    concentration_info = gs_params['concentration_info']
    run_fn = gs_params['run_fn']
    # run
    (exp_batch_data, alpha_per_seed,_,_) = run_fn(ns,args,condL,concentration_info) # [curr],[seeds],{data}
    # unpack and save
    exp = np.stack([exp_batch_data[c][s]['exp'] for c in range(5) for s in range(ns)]).reshape(5,ns,200,5)
    zt = np.stack([exp_batch_data[c][s]['zt'] for c in range(5) for s in range(ns)]).reshape(5,ns,200,5)
    xth = np.stack([exp_batch_data[c][s]['xth'].argmax(-1) for c in range(5) for s in range(ns)]).reshape(5,ns,200,5)
    acc = unpack_acc(exp_batch_data,mean_over_tsteps=True)
    np.save(f'data/model/{gsname}-acc',acc)
    np.save(f'data/model/{gsname}-exp',exp)
    np.save(f'data/model/{gsname}-zt',zt)
    np.save(f'data/model/{gsname}-xth',xth)
    np.save(f'data/model/{gsname}-alpha',np.array(alpha_per_seed))
    

gs0317_version5_skipt1False

concentration_4.775-stickiness_wi_96.792-stickiness_bt_96.792-sparsity_0.051-pvar_0.000-lrate_1.000-lratep_1.000-decay_rate_1.000-beta2_0.000-skipt1_0.000-ppd_allsch_0.000

{'b_mse1': {748: 0.0545}, 'b_mse2': {748: 0.0079}, 'i_mse1': {748: 0.0108}, 'i_mse2': {748: 0.0187}, 'mse-bi': {748: 0.0919}}
gs0317_version5_skipt1False
gs0317

concentration_1.703-stickiness_wi_1.848-stickiness_bt_1.848-sparsity_0.211-pvar_0.000-lrate_1.000-lratep_1.000-decay_rate_1.000-beta2_0.000-skipt1_1.000-ppd_allsch_0.000

{'b_mse1': {45355: 0.012}, 'b_mse2': {45355: 0.0058}, 'i_mse1': {45355: 0.0107}, 'i_mse2': {45355: 0.0121}, 'mse-bi': {45355: 0.0406}}
gs0317
gs0317_concentrationAcross

concentration_3.604-stickiness_wi_5.057-stickiness_bt_5.057-sparsity_0.436-pvar_0.000-lrate_1.000-lratep_1.000-decay_rate_1.000-beta2_0.000-skipt1_1.000-ppd_allsch_0.000

{'b_mse1': {22765: 0.011}, 'b_mse2': {22765: 0.005}, 'i_mse1': {22765: 0.0111}, 'i_mse2': {22765: 0.0099}, 'mse-bi': {22765:

## make accuracy dataframe
* load saved np files from all gridsearches
* extract accuracy and format dataframe

In [6]:
L = []
for gsname in [V5,V1,V2,V9]:
  acc = np.load(f"data/model/{gsname}-acc.npy")
  alpha = np.load(f"data/model/{gsname}-alpha.npy")
  for cix in range(5):
    for six in range(ns):
      for tix in range(200):
        if len(alpha[0]):
          a = alpha[cix,six]  
        else:
          a = []
        L.append({
          'gs':gsname,
          'alpha':a,
          'cond':cix,
          'seed':six,
          'trial':tix,
          'acc':acc[cix,six,tix]
        })
model_acc_df = pd.DataFrame(L)
model_acc_df.loc[:,'test'] = model_acc_df.trial >= 160

In [7]:
# model_acc_df.to_csv(f"data_csv/model/acc_df.csv")

## make states dataframe
* as above

In [8]:
L = []
for gsname in [V5,V1,V2,V9]:
  exp = np.load(f"data/model/{gsname}-exp.npy")
  zt = np.load(f"data/model/{gsname}-zt.npy")
  xth = np.load(f"data/model/{gsname}-xth.npy")
  alpha = np.load(f"data/model/{gsname}-alpha.npy")
  for cix in range(5):
    for six in range(ns):
      for trix in range(200):
        for tstep in range(5):
          if len(alpha[0]):
            a = alpha[cix,six]  
          else:
            a = []
          L.append({
            'gs':gsname,
            'alpha':a,
            'cond':cix,
            'seed':six,
            'trial':trix,
            'tstep':tstep,
            'exp':exp[cix,six,trix,tstep],
            'zt':zt[cix,six,trix,tstep],
            'xth':xth[cix,six,trix,tstep]
          })

model_states_df = pd.DataFrame(L)
