# Gridsearch analysis pipeline
* prior to running this script:
    * run gridsearch
    * make (gsdf) a dataframe with the gridsearch data 
* this script
    * further processes the gridsearch data
        * e.g. sum MSE of blocked and interleaved
    * takes a slice of the gridsearch data (e.g. top 40 models by MSE)
    * runs an analysis pipeline on those models:
        * get model param and re-run simulations locally to generate full simulation dataset
        * plot model response accuracy and latent cause inference metrics
        
* the goal of this notebook is to find the "top 5 MSE models" for futher analysis

In [1]:
from matplotlib import pyplot as plt
from scipy.special import softmax
from itertools import product
import numpy as np
##
from modelUtils import *
from model import *
from analysis import *
##
import time
import seaborn as sns
from glob import glob as glob
sns.set_context('talk')

%load_ext autoreload
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
gsname = 'gs0317'

In [3]:
gsdf = pd.read_csv('data/gridsearch/%s-summdf.csv'%gsname).astype(float)
gsdf.loc[:,"mse-bi"] = gsdf.b_mse1+gsdf.b_mse2+gsdf.i_mse1+gsdf.i_mse2

In [4]:
""" gsdf contains, for each model param, summary metrics computed during gridsearch. 
in this case, the metrics I used for last gridsearch were MSE on blocked and interleaved
for timesteps 1 and 2 separately. to change the metrics computed on gridsearch data, 
see make_gsdf.py file.
"""
gsdf

Unnamed: 0.1,Unnamed: 0,concentration,stickiness_wi,stickiness_bt,sparsity,pvar,lrate,lratep,decay_rate,skipt1,b_mse1,b_mse2,i_mse1,i_mse2,mse-bi
0,0.0,3.188511,55.476881,55.476881,0.076389,0.0,1.0,1.0,1.0,1.0,0.017370,0.008111,0.018926,0.035066,0.079473
1,0.0,5.382954,31.265678,31.265678,0.079461,0.0,1.0,1.0,1.0,1.0,0.017390,0.008007,0.157696,0.111348,0.294442
2,0.0,7.066637,86.027263,86.027263,0.031159,0.0,1.0,1.0,1.0,1.0,0.017483,0.008079,0.036447,0.039223,0.101231
3,0.0,8.963971,79.949012,79.949012,0.139976,0.0,1.0,1.0,1.0,1.0,0.017090,0.007819,0.021896,0.033605,0.080410
4,0.0,11.006649,80.848997,80.848997,0.171395,0.0,1.0,1.0,1.0,1.0,0.017116,0.007527,0.020142,0.033693,0.078478
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
229441,0.0,82.195917,81.953818,81.953818,0.255394,0.0,1.0,1.0,1.0,1.0,0.016507,0.007600,0.016687,0.027510,0.068304
229442,0.0,82.759646,89.498883,89.498883,0.442347,0.0,1.0,1.0,1.0,1.0,0.015715,0.006876,0.016384,0.027793,0.066768
229443,0.0,84.616677,95.672076,95.672076,0.263176,0.0,1.0,1.0,1.0,1.0,0.016648,0.007419,0.016326,0.027662,0.068055
229444,0.0,87.982409,97.648769,97.648769,0.381902,0.0,1.0,1.0,1.0,1.0,0.015968,0.007162,0.016574,0.027802,0.067506


### plt single condition

In [5]:
figures_dir = "figures/top40/"
def save_and_close(param_str,rank,mse):
  plt.title(param_str[:120])
  plt.savefig(figures_dir+"%s-%i-mse%.3f-%s.png"%(gsname,rank,mse,param_str))
  plt.close('all')
  
def pipeline(row,ns=50):
  """ 
  each step of the pipeline involves 
    running a csw experiment on a param contained in row
    calculating a metric on exp_batch_data
    plotting and saving
  """
  mse = dict(row)['mse-bi'].round(3)
  ## RUN EXP
  args,paramstr = get_argsD(row)
  exp_batch_data = run_batch_exp_curr(ns,args,condL) # [curr],[seeds],{data}
  ## ACCURACY
  plt_acc(exp_batch_data)
  save_and_close("acc-%s"%paramstr,rank,mse)
  ## LC COUNTS VIOLIN
  plt_LC_violins(exp_batch_data)
  save_and_close("lc-%s"%paramstr,rank,mse)
  ## ajdusted rand
  plt_arscores(exp_batch_data)
  save_and_close("ar-%s"%paramstr,rank,mse)

# top 40 MSE analyses

In [6]:
TOP_K = 40
## select top40 MSE
topKgsdf = gsdf.sort_values('mse-bi').iloc[:TOP_K]
topKgsdf

Unnamed: 0.1,Unnamed: 0,concentration,stickiness_wi,stickiness_bt,sparsity,pvar,lrate,lratep,decay_rate,skipt1,b_mse1,b_mse2,i_mse1,i_mse2,mse-bi
132051,0.0,1.524765,1.083961,1.083961,0.152714,0.0,1.0,1.0,1.0,1.0,0.011904,0.005975,0.015239,0.013535,0.046652
11273,0.0,1.500351,1.3754,1.3754,0.112368,0.0,1.0,1.0,1.0,1.0,0.017242,0.008,0.012483,0.013572,0.051297
182117,0.0,0.4928,1.287879,1.287879,0.039337,0.0,1.0,1.0,1.0,1.0,0.017515,0.00812,0.015717,0.01154,0.052892
162102,0.0,0.745254,1.226039,1.226039,0.005049,0.0,1.0,1.0,1.0,1.0,0.017633,0.008259,0.014934,0.012492,0.053319
42124,0.0,2.07062,1.277997,1.277997,0.004873,0.0,1.0,1.0,1.0,1.0,0.017686,0.007982,0.017241,0.010732,0.053641
23167,0.0,2.080948,1.229211,1.229211,0.035416,0.0,1.0,1.0,1.0,1.0,0.01749,0.00807,0.01568,0.012436,0.053676
142388,0.0,1.435624,1.592164,1.592164,0.005674,0.0,1.0,1.0,1.0,1.0,0.017565,0.008233,0.016835,0.013342,0.055975
66734,0.0,1.277693,1.644488,1.644488,0.114348,0.0,1.0,1.0,1.0,1.0,0.017188,0.00796,0.016675,0.014957,0.05678
20383,0.0,2.284206,1.346571,1.346571,0.166261,0.0,1.0,1.0,1.0,1.0,0.0153,0.006448,0.022764,0.012369,0.056882
186200,0.0,1.665763,1.182903,1.182903,0.031025,0.0,1.0,1.0,1.0,1.0,0.017502,0.008014,0.0164,0.015097,0.057013


In [7]:
## LOOP
for rank,(idx,row) in enumerate(topKgsdf.iterrows()):
  print(rank)
  pipeline(row,ns=50)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
