In [1]:
from utils import *
from matplotlib import pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set_context('talk')

import itertools

In [2]:
# task setup
TASK = 'nback'

In [3]:
def get_exp_vars(ss_cond,task=TASK):
  if task ==  'nback':
    set_size_L = [2,3]
    sim_fn = get_nback_idxs
    set_size = set_size_L[ss_cond]
  elif task == 'sternberg':
    set_size_L = [5,10]
    sim_fn = get_sternberg_idxs
    set_size = set_size_L[ss_cond]
  return sim_fn,set_size

# analyses

In [4]:
""" 
compute fns
"""

def compute_roc(dist2target,dist2nontarget):
  # ranges
  lower = np.min([dist2target.min(),dist2nontarget.min()]) 
  upper = np.max([dist2target.max(),dist2nontarget.max()])
  # histograms
  hist_target,_ = np.histogram(dist2target,
                    range=(lower,upper),bins=100)
  hist_nontarget,_ = np.histogram(dist2nontarget,
                    range=(lower,upper),bins=100)
  # roc
  roc_x = np.cumsum(hist_target/hist_target.sum())
  roc_y = np.cumsum(hist_nontarget/hist_nontarget.sum())
  return roc_x,roc_y

def compute_auc(dist2target,dist2nontarget):
  roc_x,roc_y = compute_roc(dist2target,dist2nontarget)
  auc = roc_y.sum()
  return auc

def get_hist(dist,lower=0,upper=3.0,bins=250):
  hist,_ = np.histogram(dist2target,
            range=(lower,upper),bins=100)
  return hist

In [5]:
""" 
plt funs 
"""

def plt_auc_bar(dist2target,dist2nontarget,plt_kwargs=None,ax=None):
  if ax == None: ax = plt_kwargs['ax']; del plt_kwargs['ax']
  auc = compute_auc(dist2target,dist2nontarget)
  x_plt = plt_kwargs['cond']; del plt_kwargs['cond']
  ax.bar(x_plt,auc,**plt_kwargs)
  ax.set_ylim(0,105)
  ax.set_ylabel('auc')
  return None

def plt_roc(dist2target,dist2lure,plt_kwargs=None,ax=None):
  """ 
  confirm that this is an ROC
  assuming that distance is the signal for detecting "match"
  """
  cL = ['r','g']
  rocx,rocy = compute_roc(dist2target,dist2lure)
  if ax == None: ax = plt_kwargs['ax']; del plt_kwargs['ax']
  ax.plot(rocx,rocy,**plt_kwargs)
  ax.plot([0,1],[0,1],c='k',ls='--')
  return None

In [6]:
"""
main wrapper
"""
def grid_search_plt(cmean,cvar,cdim=20,nitr_gs=100):
  """
  give context parameters
  generate all figures
  """

  f,ax = plt.subplots(2,2,figsize=(20,10));ax=ax.reshape(-1)

  for ss_cond in [0,1]:
    # run context simulation 
    sim_fn,set_size = get_exp_vars(ss_cond)
    dist2target,dist2lure = sim_context_dist(
      set_size,sim_fn,
      cvar=cvar,cmean=cmean,cdim=cdim,
      nitr=nitr_gs
    )  
    # plt hist
    ax[ss_cond].hist(dist2target,bins=100,range=(0,3))
    ax[ss_cond].hist(dist2lure,bins=100,range=(0,3))
    # plt roc
    plt_roc_kwargs = {'color':['g','r'][ss_cond],'ax':ax[2]}
    plt_roc(dist2target,dist2lure,plt_roc_kwargs)
    # plt auc
    plt_auc_kwargs = {'cond':ss_cond,'ax':ax[3],'color':['g','r'][ss_cond]}
    plt_auc_bar(dist2target,dist2lure,plt_auc_kwargs)
    
  ## saving
#   plt.legend();ax.legend()
  fig_title = "K{}M{}N{}".format(K,M,N)
  plt.title(fig_title)
  plt.savefig('figures/%s/full_fig-%s.png'%(TASK,fig_title))
  plt.close('all')
  return None



# gridsearch plot

In [7]:
## gridsearch
ML = [0.2,0.3,0.4]
NL = [0.05,0.1,0.2,0.3]
KL = [5,10,20,25] 
for K,M,N in itertools.product(KL,ML,NL):
  print(M,N,K,'\n')
  grid_search_plt(cmean=M,cvar=N,cdim=K)
  

0.2 0.05 5 

0.2 0.1 5 

0.2 0.2 5 

0.2 0.3 5 

0.3 0.05 5 

0.3 0.1 5 

0.3 0.2 5 

0.3 0.3 5 

0.4 0.05 5 

0.4 0.1 5 

0.4 0.2 5 

0.4 0.3 5 

0.2 0.05 10 

0.2 0.1 10 

0.2 0.2 10 

0.2 0.3 10 

0.3 0.05 10 

0.3 0.1 10 

0.3 0.2 10 

0.3 0.3 10 

0.4 0.05 10 

0.4 0.1 10 

0.4 0.2 10 

0.4 0.3 10 

0.2 0.05 20 

0.2 0.1 20 

0.2 0.2 20 

0.2 0.3 20 

0.3 0.05 20 

0.3 0.1 20 

0.3 0.2 20 

0.3 0.3 20 

0.4 0.05 20 

0.4 0.1 20 

0.4 0.2 20 

0.4 0.3 20 

0.2 0.05 25 

0.2 0.1 25 

0.2 0.2 25 

0.2 0.3 25 

0.3 0.05 25 

0.3 0.1 25 

0.3 0.2 25 

0.3 0.3 25 

0.4 0.05 25 

0.4 0.1 25 

0.4 0.2 25 

0.4 0.3 25 

