In [1]:
import os
import numpy as np
import torch as tr

from glob import glob as glob
import pandas as pd


from CSWSEM import *
from matplotlib import pyplot as plt

In [2]:
gs_name = 'absem'
debug = True

In [3]:
hdf = pd.read_csv('gsdata/humandf.csv')
humanB = hdf.loc[:,'blocked mean']
humanI = hdf.loc[:,'interleaved mean']

In [4]:
def make_gsdf(gs_name):
  gs_dir = "gsdata/%s/"%gs_name
  fpathL = glob(gs_dir+'*')
  seed_df_L = []
#   if debug: fpathL=fpathL[:300]
  for fpath in fpathL:
    condition = fpath.split('/')[-1].split('__')[1].split('_')[1]
    seed_df = pd.read_csv(fpath)
    seed_df.loc[:,'model'] = ['LSTM','SEM'][sum(seed_df.loc[:,'nosplit']==1)>0]
    seed_df.loc[:,'condition'] = condition
    seed_df_L.append(seed_df)
  gsdf = pd.concat(seed_df_L)
  gsdf.index = np.arange(len(gsdf))
  gsdf.drop(columns=['Unnamed: 0','like','prior'])
  return gsdf

In [5]:
gsdf = make_gsdf(gs_name)

In [6]:
def build_datadict(gsdf,metric='loss',verb=True):
  """ D[model][param_str][cond] = [seed,time]
  """
  dataD = {
    'LSTM':{},
    'SEM':{},
       }
  for nosplit,m_df in gsdf.groupby(['nosplit']):  
    if nosplit: model='LSTM'
    else: model='SEM'
    for p,p_df in m_df.groupby(['learn_rate','alfa','lmda']):
      param_str = "-".join([str(p_i) for p_i in p])
      dataD[model][param_str] = {}
      for c,c_df in p_df.groupby('condition'):
        cond=c
        if verb: print(model,param_str,cond)
        sgroup = c_df.groupby('seed')
        seed_arr = -np.ones([len(sgroup),200])
        for s_idx,(s,s_df) in enumerate(sgroup):
          seed_arr[s_idx] = s_df.loc[:,metric]
        dataD[model][param_str][cond] = seed_arr

  return dataD



In [7]:
# D[model][param_str][cond] = [seed,time]
lossD = build_datadict(gsdf,metric='loss')
lossD['SEM'].keys()

SEM 0.005-0.1-0.1 blocked
SEM 0.005-0.1-0.1 interleaved
SEM 0.005-0.1-1.0 blocked
SEM 0.005-0.1-1.0 interleaved
SEM 0.005-0.1-10.0 blocked
SEM 0.005-0.1-10.0 interleaved
SEM 0.005-1.0-0.1 blocked
SEM 0.005-1.0-0.1 interleaved
SEM 0.005-1.0-1.0 blocked
SEM 0.005-1.0-1.0 interleaved
SEM 0.005-1.0-10.0 blocked
SEM 0.005-1.0-10.0 interleaved
SEM 0.005-10.0-0.1 blocked
SEM 0.005-10.0-0.1 interleaved
SEM 0.005-10.0-1.0 blocked
SEM 0.005-10.0-1.0 interleaved
SEM 0.005-10.0-10.0 blocked
SEM 0.005-10.0-10.0 interleaved
SEM 0.01-0.1-0.1 blocked
SEM 0.01-0.1-0.1 interleaved
SEM 0.01-0.1-1.0 blocked
SEM 0.01-0.1-1.0 interleaved
SEM 0.01-0.1-10.0 blocked
SEM 0.01-0.1-10.0 interleaved
SEM 0.01-1.0-0.1 blocked
SEM 0.01-1.0-0.1 interleaved
SEM 0.01-1.0-1.0 blocked
SEM 0.01-1.0-1.0 interleaved
SEM 0.01-1.0-10.0 blocked
SEM 0.01-1.0-10.0 interleaved
SEM 0.01-10.0-0.1 blocked
SEM 0.01-10.0-0.1 interleaved
SEM 0.01-10.0-1.0 blocked
SEM 0.01-10.0-1.0 interleaved
SEM 0.01-10.0-10.0 blocked
SEM 0.01-10.0-10.

dict_keys(['0.005-0.1-0.1', '0.005-0.1-1.0', '0.005-0.1-10.0', '0.005-1.0-0.1', '0.005-1.0-1.0', '0.005-1.0-10.0', '0.005-10.0-0.1', '0.005-10.0-1.0', '0.005-10.0-10.0', '0.01-0.1-0.1', '0.01-0.1-1.0', '0.01-0.1-10.0', '0.01-1.0-0.1', '0.01-1.0-1.0', '0.01-1.0-10.0', '0.01-10.0-0.1', '0.01-10.0-1.0', '0.01-10.0-10.0', '0.05-0.1-0.1', '0.05-0.1-1.0', '0.05-0.1-10.0', '0.05-1.0-0.1', '0.05-1.0-1.0', '0.05-1.0-10.0', '0.05-10.0-0.1', '0.05-10.0-1.0', '0.05-10.0-10.0', '0.1-0.1-0.1', '0.1-0.1-1.0', '0.1-0.1-10.0', '0.1-1.0-0.1', '0.1-1.0-1.0', '0.1-1.0-10.0', '0.1-10.0-0.1', '0.1-10.0-1.0', '0.1-10.0-10.0'])

In [15]:
def plt_loss_(ax,loss_arr,tag=None):
  Nseeds,_ = loss_arr.shape
  acc_arr = 1-loss_arr
  for acc_seed in acc_arr:
    ax.plot(acc_seed,lw=.05,c='k')
  M = acc_arr.mean(0)
  S = acc_arr.std(0)/np.sqrt(Nseeds)
  ax.fill_between(range(200),M-S,M+S,alpha=.5,color='b')
  ax.plot(M,lw=3,c='b')
  ax.set_ylim(0.2,1)
  return None

def plt_loss(cond_dict,mse,title):
  f,ax = plt.subplots(2,1,figsize=(8,4),sharex=True)
  ax[0].set_title("%s mse%f"%(title,mse))
  for idx,(cond,arr) in enumerate(cond_dict.items()):
    plt_loss_(ax[idx],arr)
    ax[idx].plot(hdf.loc[:,"%s mean"%cond],color='red',lw=3)
  plt.savefig('figures/gs-%s/mse%.2f-acc-%s.png'%(gs_name,mse,title))
  plt.close('all')

In [16]:
def calc_mse(cond_dict):
  """ cond_dict {
    blocked: [seeds,time], 
    interleaved: [seeds,time],
    } 
  """
  SE = 0
  for cond,arr in cond_dict.items():
    cond
    semM = arr.mean(0)
    humanM = hdf.loc[:,'%s mean'%cond]
    SE += (semM-humanM)**2
  MSE = SE.mean()
  return MSE

In [17]:
L = []
for model,param_dict in lossD.items():
  for param_str,cond_dict in param_dict.items():
    mse = calc_mse(cond_dict)
    title = "%s-%s"%(model,param_str)
    mse = calc_mse(cond_dict)
    print(model,param_str,mse)
    L.append({'model':"%s-%s"%(model,param_str),'mse':mse})
    plt_loss(cond_dict,mse,title)

  

LSTM 0.005-0.0-0.0 0.49285185376425344
LSTM 0.005-99.0-99.0 0.6391365750578102
LSTM 0.01-99.0-99.0 0.7250341531595785
LSTM 0.05-99.0-99.0 0.7363769193153294
LSTM 0.1-99.0-99.0 0.5917241100906452
SEM 0.005-0.1-0.1 0.5779811795149611
SEM 0.005-0.1-1.0 0.5779811803475874
SEM 0.005-0.1-10.0 0.5779811800884102
SEM 0.005-1.0-0.1 0.5728305050278659
SEM 0.005-1.0-1.0 0.5728305056928334
SEM 0.005-1.0-10.0 0.5724554439827115
SEM 0.005-10.0-0.1 0.5597487156971898
SEM 0.005-10.0-1.0 0.559748715805421
SEM 0.005-10.0-10.0 0.5631627462769173
SEM 0.01-0.1-0.1 0.7873259106637198
SEM 0.01-0.1-1.0 0.7873259101238321
SEM 0.01-0.1-10.0 0.7870725771734496
SEM 0.01-1.0-0.1 0.7830009061601015
SEM 0.01-1.0-1.0 0.7826497370948207
SEM 0.01-1.0-10.0 0.782772108054179
SEM 0.01-10.0-0.1 0.7629264128783803
SEM 0.01-10.0-1.0 0.762978151334186
SEM 0.01-10.0-10.0 0.763144493468082
SEM 0.05-0.1-0.1 0.8220804540590416
SEM 0.05-0.1-1.0 0.8246639522624001
SEM 0.05-0.1-10.0 0.8240511515062338
SEM 0.05-1.0-0.1 0.799415692859

In [11]:
msedf = pd.DataFrame(L)
msedf.sort_values('mse')

Unnamed: 0,model,mse
38,SEM-0.1-10.0-0.1,0.130224
39,SEM-0.1-10.0-1.0,0.130224
40,SEM-0.1-10.0-10.0,0.130224
35,SEM-0.1-1.0-0.1,0.131928
37,SEM-0.1-1.0-10.0,0.131928
36,SEM-0.1-1.0-1.0,0.131928
34,SEM-0.1-0.1-10.0,0.132135
33,SEM-0.1-0.1-1.0,0.132135
32,SEM-0.1-0.1-0.1,0.132135
0,LSTM-0.005-0.0-0.0,0.492852
