In [1]:
import os
import numpy as np
import torch as tr

from glob import glob as glob
import pandas as pd

from CSWSEM import *
from matplotlib import pyplot as plt

In [2]:
hdf = pd.read_csv('gsdata/humandf.csv')
humanB = hdf.loc[:,'blocked mean']
humanI = hdf.loc[:,'interleaved mean']

In [3]:
gsname = 'gs2'

In [4]:
def make_gsdf(gsname,save=False,debug=False):
  gs_dir = "gsdata/%s/"%gsname
  fpathL = glob(gs_dir+'*')
  df_L = []
  ### initialize gsdf
  if debug: fpathL[:10]
  for fpath in fpathL[:10]:
    batch_df = pd.read_csv(fpath)
    df_L.append(batch_df)
  gsdf = pd.concat(df_L)
  ### gsdf manipulations
  gsdf.index = np.arange(len(gsdf))
  gsdf = gsdf.drop(columns=['Unnamed: 0','like','prior'])
  gsdf.loc[:,'accuracy'] = 1-gsdf.loc[:,'loss']
  gsdf.loc[(gsdf.trial>=160),'phase'] = 'test'
  gsdf.loc[(gsdf.trial<160),'phase'] = 'train'
  gsdf.loc[:,'period'] = gsdf.trial%5
  
  if save:
    gsdf.to_csv('gsdata/%s.csv'%gsname)
    print('saved %s.csv'%gsname)
  return gsdf

In [5]:
gsdf = make_gsdf(gsname,save=1,debug=0)
gsdf = pd.read_csv('gsdata/%s.csv'%gsname)
print('max delta_time',gsdf.delta_time.max())

saved gs2.csv
max delta_time 55.34397864341736


In [6]:
def build_dataD(gsdf,metric='loss',verb=True):
  """ extracts and reformats a column of gsdf
  D[model+param_str][cond] = [seed,time]
  """
  paramL = ['learn_rate','alfa','lmda']
  dataD = {}
  for nosplit,m_df in gsdf.groupby(['nosplit']):  
    if nosplit: model='LSTM'
    else: model='SEM'
    for p,p_df in m_df.groupby(paramL):
      param_str = "-".join([str(p_i) for p_i in p])
      dataD_key = "%s-%s"%(model,param_str)
      dataD[dataD_key] = {}
      for cond,c_df in p_df.groupby('condition'):
        if verb: print(model,param_str,cond)
        sgroup = c_df.groupby('seed')
        seed_arr = -np.ones([len(sgroup),200])
        for s_idx,(s,s_df) in enumerate(sgroup):
          seed_arr[s_idx] = s_df.loc[:,metric]
        dataD[dataD_key][cond] = seed_arr
  return dataD

# loss analysis

In [7]:
# D[model-param_str][cond] = [seed,time]
lossD = build_dataD(gsdf,metric='loss',verb=False)
print('num model conditions',len(lossD))
lossD.keys()

num model conditions 20


dict_keys(['SEM-0.01-0.01-0.1', 'SEM-0.05-1.0-1.0', 'SEM-0.05-10.0-0.1', 'SEM-0.05-100.0-0.01', 'SEM-0.05-10000.0-0.01', 'SEM-0.05-100000.0-100000.0', 'SEM-0.1-0.01-100000.0', 'SEM-0.1-0.1-0.01', 'SEM-0.1-1.0-0.1', 'SEM-0.1-1000.0-0.1', 'LSTM-0.01-0.01-0.1', 'LSTM-0.05-1.0-1.0', 'LSTM-0.05-10.0-0.1', 'LSTM-0.05-100.0-0.01', 'LSTM-0.05-10000.0-0.01', 'LSTM-0.05-100000.0-100000.0', 'LSTM-0.1-0.01-100000.0', 'LSTM-0.1-0.1-0.01', 'LSTM-0.1-1.0-0.1', 'LSTM-0.1-1000.0-0.1'])

In [8]:
def plt_loss_(ax,loss_arr,tag=None):
  Nseeds,_ = loss_arr.shape
  acc_arr = 1-loss_arr
  for acc_seed in acc_arr:
    ax.plot(acc_seed,lw=.05,c='k')
  M = acc_arr.mean(0)
  S = acc_arr.std(0)/np.sqrt(Nseeds)
  ax.fill_between(range(200),M-S,M+S,alpha=.5,color='b')
  ax.plot(M,lw=3,c='b')
  ax.set_ylim(0.2,1)
  return None

def plt_loss(cond_dict,mse_dict,title):
  f,ax = plt.subplots(2,1,figsize=(8,4),sharex=True)
  for idx,(cond,arr) in enumerate(cond_dict.items()):
    plt_loss_(ax[idx],arr)
    ax[idx].plot(hdf.loc[:,"%s mean"%cond],color='red',lw=3)
    ax[idx].set_title("%s mse%f"%(title,mse_dict[cond]))
  total_mse = np.sum([i for i in mse_dict.values()])
  plt.savefig('figures/%s/acc/mse%.4f-acc-%s.png'%(gs_name,total_mse,title))
  plt.close('all')

In [9]:
def calc_mse(cond_dict):
  D = {}
  for cond,loss_arr in cond_dict.items():
    acc_arr = 1-loss_arr
    semM = acc_arr.mean(0)
    humanM = hdf.loc[:,'%s mean'%cond]
    D[cond] = np.mean((semM-humanM)**2)
  return D

In [None]:
L = []
for model_param,cond_dict in lossD.items():
  mse = calc_mse(cond_dict)
  title = model_param
  mse = calc_mse(cond_dict)
  print(model_param,mse)
  L.append({'model':model_param,'mse':mse})
  plt_loss(cond_dict,mse,title)

SEM-0.01-0.01-0.1 {'blocked': 0.012748805862756743, 'interleaved': 0.05016689576544213}


In [None]:
assert False

# schema inference analysis


In [None]:
curr_D = build_dataD(gsdf,metric='curriculum',verb=False)
actsch_D = build_dataD(gsdf,metric='active_schema',verb=False)
k_ = list(actsch_D.keys())[0]

In [None]:
cond = 'blocked'
curr_arr = curr_D[k_][cond]
actsch_arr = actsch_D[k_][cond]

actsch_arr
curr_arr.shape