In [1]:
from utils import *
from matplotlib import pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set_context('talk')
from scipy.spatial import distance 
import itertools

In [2]:
def sim_nback_context(set_size,delta=1,nitr=1,metric='cosine',cvar=.1,cmean=.1,cdim=20,verb=False):  
  """ 
  simulations focus on context distance (similarity = EM)
  delta controls nontarget distribution
  """
  exp_size = 5
  current_t = exp_size - 1
  target_t = current_t - set_size
  nontarget_t = target_t+delta
  if verb:
    print('ss',set_size,'current_t',current_t,
          'target',target_t,'nontarget',nontarget_t)

  # loop vars
  dist2targetL = []
  dist2nontargetL = []
  for itr in range(nitr):
    C = fast_n_sphere(n_steps=exp_size, dim=cdim, var=cvar, mean=cmean)[0]
    # current_t,target_t,non_target_t
    dist2target_ = distance.cdist([C[current_t,:]],[C[target_t,:]],metric=metric)[0]
    dist2nontarget_ = distance.cdist([C[current_t,:]],[C[nontarget_t,:]],metric=metric)[0]
    dist2targetL.append(dist2target_)
    dist2nontargetL.append(dist2nontarget_)
    
  dist2target = np.concatenate(dist2targetL)
  dist2nontarget = np.concatenate(dist2nontargetL)
  return dist2target,dist2nontarget


In [3]:
def score_condition(dist_target,dist_nontarget):
  """ overlap """
  assert len(dist_target) == len(dist_nontarget)
  lower = np.min([dist_target.min(),dist_nontarget.min()]) 
  upper = np.max([dist_target.max(),dist_nontarget.max()]) 
  hist_target = np.histogram(dist_target,range=(lower,upper),bins=100)[0]
  hist_nontarget = np.histogram(dist_nontarget,range=(lower,upper),bins=100)[0]
  # hist overlap
  overlap = np.min([hist_target,hist_nontarget],axis=0).sum() 
  if overlap == 0:
    return overlap
  else:
    return overlap/np.sum(hist_target)

# dist_target,dist_nontarget = sim_nback_context(
#   set_size=3,nitr=1000,verb=True,cvar=.05,cmean=.3,cdim=5)
# score_condition(dist_target,dist_nontarget)

In [4]:
def plt_roc(ax,dist_target,dist_nontarget,label):
  lower = np.min([dist_target.min(),dist_nontarget.min()]) 
  upper = np.max([dist_target.max(),dist_nontarget.max()]) 
  hist_target = np.histogram(dist_target,range=(lower,upper),bins=100)[0]
  hist_nontarget = np.histogram(dist_nontarget,range=(lower,upper),bins=100)[0]
  ax.plot(
    np.cumsum(hist_target/hist_target.sum()),
    np.cumsum(hist_nontarget/hist_nontarget.sum()),
    label=label
             )
  ax.plot([0,1],[0,1],c='k',ls='--')
  return None

In [7]:
def runexp_and_plt(M,N,K,ssL=[2,3],nitr=20000,sim_fn=sim_nback_context):
  """ 
  given parametrization of context
  runs experiments and generates histogram plots
  """
  f,axarr = plt.subplots(1,3,figsize=(20,6))
  scoreL = []
  for idx,ss in enumerate(ssL):
    ax = axarr[idx]
    # run simulations
    dist_target,dist_nontarget = sim_fn(
      set_size=ss,nitr=nitr,
      cvar=N,cmean=M,cdim=K
    ) 
    # plot histogtams
    hist_target = ax.hist(dist_target,bins=100)[0]
    hist_nontarget = ax.hist(dist_nontarget,bins=100)[0]
    # plot roc
    plt_roc(axarr[2],dist_target,dist_nontarget,"set-size=%i"%ss)
    # calculate overlap
    score = score_condition(dist_target,dist_nontarget)
    ax_title = 'ss=',ss,'overlap %.2f'%(score)
    ax.set_title(ax_title)
    scoreL.append(score)
  plt.legend()
  fig_title = "K{}M{}N{}".format(K,M,N)
  plt.savefig('figures/nback/%s.png'%fig_title)
  plt.close('all')
  return None

In [None]:
ML = [0.2,0.3,0.4]
NL = [0.05,0.1,0.2,0.3]
KL = [5,10,20,25] 
for K,M,N in itertools.product(KL,ML,NL):
  print(K,M,N)
  runexp_and_plt(M,N,K,nitr=10000)
  print()

5 0.2 0.05

5 0.2 0.1


In [None]:


# M,N = 0.2,0.2
# ax = plt.gca()
# for ss in [2,3,4]:
#   dist_target,dist_nontarget = sim_nback_context(ss,nitr=1000,cvar=N,cmean=M)
#   plt_roc(ax,dist_target,dist_nontarget,"set-size=%i"%ss)
# plt.legend()