In [1]:
from utils import *
from matplotlib import pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set_context('talk')
from scipy.spatial import distance 
import itertools

In [2]:
def sim_sternberg_context(set_size,nitr=1,metric='cosine',cvar=.1,cmean=.1,cdim=20):  
  trial_size = set_size + 1 # setsize+probe
  exp_size = 2*trial_size
  # previous probes not included in oldL
  old_trial_idxL = np.arange(0,trial_size-1)
  # current probe not included in newL
  new_trial_idxL = np.arange(trial_size,exp_size-1)
  # set current context to probe of trial 2
  t = exp_size-1 
  # confirm indices
#   print('oldL',old_trial_idxL,'newL',new_trial_idxL,'c_t',t)
  # loop vars
  dist2newL = []
  dist2oldL = []
  for itr in range(nitr):
    # sample new context
    C = fast_n_sphere(n_steps=2*trial_size, dim=cdim, var=cvar, mean=cmean)[0]
    # similarity between current context and items in old vs new list
    dist2new_ = distance.cdist([C[t,:]],C[new_trial_idxL,:],metric=metric)[0]
    dist2old_ = distance.cdist([C[t,:]],C[old_trial_idxL,:],metric=metric)[0]
    dist2newL.append(dist2new_)
    dist2oldL.append(dist2old_)

  dist2new = np.concatenate(dist2newL)
  dist2old = np.concatenate(dist2oldL)

  return dist2old,dist2new

In [3]:
def plt_roc(ax,dist_target,dist_nontarget,label,idx=0):
  """ confirm that this is an ROC
  assuming that distance is the signal for detecting "match"
  """
  cL = ['r','g']
  lower = np.min([dist_target.min(),dist_nontarget.min()]) 
  upper = np.max([dist_target.max(),dist_nontarget.max()]) 
  hist_target,_ = np.histogram(dist_target,range=(lower,upper),bins=100)
  hist_nontarget,_ = np.histogram(dist_nontarget,range=(lower,upper),bins=100)
  ax.plot(
      np.cumsum(hist_target/hist_target.sum()),
      np.cumsum(hist_nontarget/hist_nontarget.sum()),
    label=label, color=cL[idx]
  )
  ax.plot([0,1],[0,1],c='k',ls='--')
  return None

In [4]:
def runexp_and_plt(M,N,K,ssL=[4,9],nitr=20000,sim_fn=sim_sternberg_context):
  """ 
  given parametrization of context
  runs 'dprime experiment' and generates plots
  """
  f,axarr = plt.subplots(1,3,figsize=(20,6))
  overlapL = []
  for idx,ss in enumerate(ssL):
    ax = axarr[idx]
    ax.set_title("set-size=%i"%ss)
    # run simulations
    dist_target,dist_nontarget = sim_fn(
      set_size=ss,nitr=nitr,
      cvar=N,cmean=M,cdim=K
    ) 
    # calculate histogtams
    hist_target = ax.hist(dist_target,bins=100,label='target')[0]
    hist_nontarget = ax.hist(dist_nontarget,bins=100,label='nontarget')[0]
    # plt hist
    # NB flipped call fn for stern
    plt_roc(axarr[2],dist_target,dist_nontarget,"set-size=%i"%ss,idx)
  # calculate overlap difference
  plt.legend()
  ax.legend()
  fig_title = "K{}M{}N{}".format(K,M,N)
  plt.savefig('figures/stern/%s.png'%fig_title)
  plt.close('all')
  return None

In [5]:
ML = [0.2,0.3,0.4]
NL = [0.2,0.3,0.4]
KL = [5,10,20,25] 
for K,M,N in itertools.product(KL,ML,NL):
  print(M,N,K)
  runexp_and_plt(M,N,K,nitr=10000)
  print()

0.2 0.2 5

0.2 0.3 5

0.2 0.4 5

0.3 0.2 5

0.3 0.3 5

0.3 0.4 5

0.4 0.2 5

0.4 0.3 5

0.4 0.4 5

0.2 0.2 10

0.2 0.3 10

0.2 0.4 10

0.3 0.2 10

0.3 0.3 10

0.3 0.4 10

0.4 0.2 10

0.4 0.3 10

0.4 0.4 10

0.2 0.2 20

0.2 0.3 20

0.2 0.4 20

0.3 0.2 20

0.3 0.3 20

0.3 0.4 20

0.4 0.2 20

0.4 0.3 20

0.4 0.4 20

0.2 0.2 25

0.2 0.3 25

0.2 0.4 25

0.3 0.2 25

0.3 0.3 25

0.3 0.4 25

0.4 0.2 25

0.4 0.3 25

0.4 0.4 25

