In [1]:
import os
import numpy as np
import torch as tr

from glob import glob as glob
import pandas as pd


from CSWSEM import *
from matplotlib import pyplot as plt

In [2]:
gs_name = 'absem'
debug = True

In [3]:
hdf = pd.read_csv('gsdata/humandf.csv')
humanB = hdf.loc[:,'blocked mean']
humanI = hdf.loc[:,'interleaved mean']

In [4]:
# from utils_analysis import make_gsdf
# make_gsdf(gsname,save=False)
gsdf = pd.read_csv('gsdata/%s.csv'%gs_name)

In [5]:
def build_dataD(gsdf,metric='loss',verb=True):
  """ D[model][param_str][cond] = [seed,time]
  """
  paramL = ['learn_rate','alfa','lmda']
  dataD = {}
  for nosplit,m_df in gsdf.groupby(['nosplit']):  
    if nosplit: model='LSTM'
    else: model='SEM'
    for p,p_df in m_df.groupby(paramL):
      param_str = "-".join([str(p_i) for p_i in p])
      dataD_key = "%s-%s"%(model,param_str)
      dataD[dataD_key] = {}
      for c,c_df in p_df.groupby('condition'):
        cond=c
        if verb: print(model,param_str,cond)
        sgroup = c_df.groupby('seed')
        seed_arr = -np.ones([len(sgroup),200])
        for s_idx,(s,s_df) in enumerate(sgroup):
          seed_arr[s_idx] = s_df.loc[:,metric]
        dataD[dataD_key][cond] = seed_arr

  return dataD

In [6]:
# D[model-param_str][cond] = [seed,time]
lossD = build_dataD(gsdf,metric='loss',verb=False)
print('num model conditions',len(lossD))
lossD.keys()

num model conditions 41


dict_keys(['SEM-0.005-0.1-0.1', 'SEM-0.005-0.1-1.0', 'SEM-0.005-0.1-10.0', 'SEM-0.005-1.0-0.1', 'SEM-0.005-1.0-1.0', 'SEM-0.005-1.0-10.0', 'SEM-0.005-10.0-0.1', 'SEM-0.005-10.0-1.0', 'SEM-0.005-10.0-10.0', 'SEM-0.01-0.1-0.1', 'SEM-0.01-0.1-1.0', 'SEM-0.01-0.1-10.0', 'SEM-0.01-1.0-0.1', 'SEM-0.01-1.0-1.0', 'SEM-0.01-1.0-10.0', 'SEM-0.01-10.0-0.1', 'SEM-0.01-10.0-1.0', 'SEM-0.01-10.0-10.0', 'SEM-0.05-0.1-0.1', 'SEM-0.05-0.1-1.0', 'SEM-0.05-0.1-10.0', 'SEM-0.05-1.0-0.1', 'SEM-0.05-1.0-1.0', 'SEM-0.05-1.0-10.0', 'SEM-0.05-10.0-0.1', 'SEM-0.05-10.0-1.0', 'SEM-0.05-10.0-10.0', 'SEM-0.1-0.1-0.1', 'SEM-0.1-0.1-1.0', 'SEM-0.1-0.1-10.0', 'SEM-0.1-1.0-0.1', 'SEM-0.1-1.0-1.0', 'SEM-0.1-1.0-10.0', 'SEM-0.1-10.0-0.1', 'SEM-0.1-10.0-1.0', 'SEM-0.1-10.0-10.0', 'LSTM-0.005-0.0-0.0', 'LSTM-0.005-99.0-99.0', 'LSTM-0.01-99.0-99.0', 'LSTM-0.05-99.0-99.0', 'LSTM-0.1-99.0-99.0'])

In [7]:
def plt_loss_(ax,loss_arr,tag=None):
  Nseeds,_ = loss_arr.shape
  acc_arr = 1-loss_arr
  for acc_seed in acc_arr:
    ax.plot(acc_seed,lw=.05,c='k')
  M = acc_arr.mean(0)
  S = acc_arr.std(0)/np.sqrt(Nseeds)
  ax.fill_between(range(200),M-S,M+S,alpha=.5,color='b')
  ax.plot(M,lw=3,c='b')
  ax.set_ylim(0.2,1)
  return None

def plt_loss(cond_dict,mse_dict,title):
  f,ax = plt.subplots(2,1,figsize=(8,4),sharex=True)
  for idx,(cond,arr) in enumerate(cond_dict.items()):
    plt_loss_(ax[idx],arr)
    ax[idx].plot(hdf.loc[:,"%s mean"%cond],color='red',lw=3)
    ax[idx].set_title("%s mse%f"%(title,mse_dict[cond]))
  total_mse = np.sum([i for i in mse_dict.values()])
  plt.savefig('figures/gs-%s/mse%.4f-acc-%s.png'%(gs_name,total_mse,title))
  plt.close('all')

In [8]:

def calc_mse(cond_dict):
  D = {}
  for cond,loss_arr in cond_dict.items():
    acc_arr = 1-loss_arr
    semM = acc_arr.mean(0)
    humanM = hdf.loc[:,'%s mean'%cond]
    D[cond] = np.mean((semM-humanM)**2)
  return D



- what should I call model+param_str? model_inst?

In [9]:
""" cond_dict {
    blocked: [seeds,time], 
    interleaved: [seeds,time],
    } 
"""

L = []
for model_param,cond_dict in lossD.items():
  mse = calc_mse(cond_dict)
  title = model_param
  mse = calc_mse(cond_dict)
  print(model_param,mse)
  L.append({'model':model_param,'mse':mse})
  plt_loss(cond_dict,mse,title)



SEM-0.005-0.1-0.1 {'blocked': 0.04194981667053622, 'interleaved': 0.020167913829712322}
SEM-0.005-0.1-1.0 {'blocked': 0.0419498165709319, 'interleaved': 0.020167913934828963}
SEM-0.005-0.1-10.0 {'blocked': 0.041949816681296076, 'interleaved': 0.020167914054492378}
SEM-0.005-1.0-0.1 {'blocked': 0.04288687175828076, 'interleaved': 0.01935368293900291}
SEM-0.005-1.0-1.0 {'blocked': 0.04288687157526585, 'interleaved': 0.019353682956438183}
SEM-0.005-1.0-10.0 {'blocked': 0.04296016531907916, 'interleaved': 0.019353682853627548}
SEM-0.005-10.0-0.1 {'blocked': 0.04579710595363744, 'interleaved': 0.018019643879956734}
SEM-0.005-10.0-1.0 {'blocked': 0.04579710596622344, 'interleaved': 0.018019643941767537}
SEM-0.005-10.0-10.0 {'blocked': 0.044664300127181464, 'interleaved': 0.017916765507246996}
SEM-0.01-0.1-0.1 {'blocked': 0.013734347109991284, 'interleaved': 0.04342186657763194}
SEM-0.01-0.1-1.0 {'blocked': 0.013734347109991284, 'interleaved': 0.04342186636997123}
SEM-0.01-0.1-10.0 {'blocked'