# Grid Search

Use KL divergence to find the best-fit parameters for each model.

## Important Note on the Fitting
It's worth noting this fitting is stochasitc -- we don't have a closed form likelihood on which to evaluate the human behavior, and instead, we're matching the distribution of human data to the distribution of model data.  As a consequence, while the fitting procedure as a whole is probably fine, I don't know that I have a ton of confidence in the parameter values themselves.  By this I mean that if we were to run the grid search again with new samples, I think we would fine the same patterns of behavior and the same relative ordering of the models.  (I've actually done this a couple of times with slightly different constraints, which is why I'm confident in this).  That being said, I don't know if we would find the same parameters, and I don't know if a new sample of these paramters would produce consistent behavior. This problem is solvable (just run 10x more batches for the grid-search) but it's not particularly important for our purposes, so I've ignored it.

*The key takeaway from this is that if we want to make changes to the model for some reason, we would have to re-do the grid-search!*

What you don't want is to be in a situation where you're looking at effects that have evaporated for some reason under new simulations.  The fix to that is to re-run the grid search (maybe with a larger sample size) and check again.

## Note on the data
The data for the simulations is not included in the Github.  The uncompress original is close to 10Gb and is several thousand files.  Email me for a copy if you need it.

## Load Libraries

In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import pandas as pd

from glob import glob
import json

import re
import statsmodels.api as sm
import statsmodels.formula.api as smf

import os

sns.set_context('talk')

## Helper functions


### KL divergence

Here, we use the sample distribution of the model and the empirical distribution of the 
human subjects and assume their empirical mean and variances are the mean and variance of a normal approximation to the population distribution.  Then we use the KL divergence to fit the model

Let $p(x) = N(x; \mu_1, \sigma_1)$ be the human data and $q(x) = N(x; \mu_2, \sigma_2)$ be the model we are using to approximate the human data.  We are minimzing the KL with

$$KL(p||q) = -\int_x p(x) \ln \frac{q(x)}{p(x)}$$

which, with Gaussians is equal to 

$$KL(p||q) = \ln \frac{\sigma_2}{\sigma_1} + \frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\sigma^2_2} - \frac{1}{2}$$




In [2]:
def kl_pq(mu_1, mu_2, sd_1, sd_2):
    # p(x) = N(x; mu_1, sd_1)
    # q(x) = N(x; mu_2, sd_2)
    
    return np.log(sd_2) - np.log(sd_1) + \
        (sd_1 ** 2 + (mu_1 - mu_2)**2) / (2 * (sd_2**2) ) - 0.5

# Load data -- Trial by Trial

In [3]:
# human behavior
trial_by_trial = pd.read_csv('./data/BehDataTrialXTrial.csv', index_col=1)
trial_by_trial

Unnamed: 0_level_0,Unnamed: 0,Blocked Mean,Interleaved Mean,Early Mean,Middle Mean,Late Mean,Instructed Mean,Blocked StdDev,Interleaved StdDev,Early StdDev,Middle StdDev,Late StdDev,Instructed StdDev
t,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
0,0,0.450360,0.470288,0.577810,0.450856,0.497531,0.437500,0.501249,0.487302,0.497887,0.502753,0.503535,0.500000
1,1,0.531633,0.532713,0.622421,0.511164,0.598901,0.579710,0.496320,0.494410,0.488719,0.505432,0.495362,0.497222
2,2,0.720819,0.562386,0.641971,0.640693,0.535714,0.555556,0.448816,0.492848,0.480277,0.484374,0.504077,0.500391
3,3,0.735830,0.558931,0.673140,0.534179,0.510692,0.430769,0.440724,0.500549,0.459289,0.503967,0.499688,0.499038
4,4,0.818325,0.628748,0.794226,0.542948,0.530864,0.671053,0.388218,0.473767,0.408293,0.501980,0.504043,0.472953
...,...,...,...,...,...,...,...,...,...,...,...,...,...
195,195,0.920087,0.629630,0.970402,0.561983,0.712500,0.688889,0.274456,0.473969,0.169234,0.503595,0.461168,0.468179
196,196,0.943584,0.624803,0.907396,0.806818,0.687831,0.729167,0.231785,0.487467,0.294442,0.392854,0.468723,0.449093
197,197,0.950606,0.701003,0.904768,0.740498,0.823214,0.750000,0.219428,0.456047,0.294962,0.445260,0.389212,0.436931
198,198,0.910809,0.666543,0.909506,0.783144,0.661184,0.512821,0.288121,0.477685,0.291931,0.401216,0.477308,0.506370


In [4]:
#
path = './data/'


# Load files one at a time
tXt_Loss = []
sim = 0
files = glob(path + 'trial*csv')
print(files)

def get_error(file):

    model_data = pd.read_csv(file)
    model_grouped = model_data.groupby(['Condition', 't'])

    
    MSE = 0
    KL = 0

    for cond in ['Blocked', 'Interleaved', 'Instructed']:
        mu_model = model_grouped.mean().loc[cond]['Accuracy'].values
        sd_model = model_grouped.std().loc[cond]['Accuracy'].values

        mu_beh = trial_by_trial.loc[:, '{} Mean'.format(cond)].values
        sd_beh = trial_by_trial.loc[:, '{} StdDev'.format(cond)].values

        MSE += np.mean((mu_beh - mu_model) ** 2)
        KL  += np.sum(kl_pq(mu_beh, mu_model, sd_beh, sd_model))

    return {
            'KL': MSE,
            'MSE': KL,
        }

def get_model(file):
    if 'nosplit' in file:
        if 'MLP' in file:
            return 'MLP'
        else:
            return 'LSTM'
    else:
        if 'MLP' in file:
            return 'SEM-MLP'
        else:
            return 'SEM'
        
# helper function
def check_number(_str):
    if _str in '0 1 2 3 4 5 6 7 8 9'.split():
        return True
    return False


def pull_from_start(file_string, start_index, end_crit = '_'):
    # pulls all of the charectors before the stop charector "_"
    output_string = ''
    idx = start_index
    while file_string[idx] != end_crit:
        output_string += file_string[idx]
        idx += 1
    return output_string

for file in tqdm(files):
    
    idx = file.find('nhidden') + 7
    n_hidden = pull_from_start(file, idx)
    if n_hidden == 'None':
        n_hidden = 10
    else:
        n_hidden = int(pull_from_start(file, idx))
    
    # learning rate -- this we can't find a reliable end point beyond 
    # "_" charaecter
    idx = file.find('_lr') + 3
    lr = float(pull_from_start(file, idx))

    # n_epochs
    idx = file[idx:].find('_n') + idx + 2
    n = int(pull_from_start(file, idx))

    # dropout
    idx = file[idx:].find('_d') + idx + 2
    d = float(pull_from_start(file, idx))

    # loglmda
    idx = file.find('_loglmda') + 9
    loglamda =float(pull_from_start(file, idx))

    # logalpha
    idx = file.find('_logalfa') + 9
    idx2 = file.find('_loglmda')
    logalpha = float(file[idx:idx2])       
    
    # epsilon
    idx = file.find('_e') + 2
    epsilon = float(pull_from_start(file, idx))

    _tXt = get_error(file)
    _tXt['lr'] = lr
    _tXt['n_epochs'] = n
    _tXt['dropout'] = d
    _tXt['epsilon'] = epsilon
    _tXt['loglamda'] = loglamda
    _tXt['logalpha'] = logalpha
    _tXt['sim'] = sim
    _tXt['n_hidden'] = n_hidden
    _tXt['model'] = get_model(file)

    tXt_Loss.append(_tXt)
    sim += 1
    
#     break
tXt_Loss = pd.DataFrame(tXt_Loss)
tXt_Loss.index = range(len(tXt_Loss))



 33%|███▎      | 1/3 [00:00<00:00,  8.77it/s]

['./data/trial_X_trial_MLP_nhiddenNone_e1e-05_lr0.005_n8_d0.0_logalfa_-208.0_loglmda_208.0__nosplit_online.csv', './data/trial_X_trial_VanillaLSTM_nhiddenNone_e1e-05_lr0.05_n2_d0.0_logalfa_16.0_loglmda_8.0__online.csv', './data/trial_X_trial_VanillaLSTM_nhiddenNone_e1e-05_lr0.005_n2_d0.0_logalfa_-208.0_loglmda_208.0__nosplit_online.csv']


100%|██████████| 3/3 [00:00<00:00, 12.09it/s]


In [5]:
min_vals = tXt_Loss.groupby('model')['KL'].min()
min_vals

model
LSTM    0.052784
MLP     0.099126
SEM     0.027546
Name: KL, dtype: float64

In [6]:
#LSTM
model = 'LSTM'
tXt_Loss.loc[(tXt_Loss.model == model) & (tXt_Loss.KL == min_vals[model])]

Unnamed: 0,KL,MSE,lr,n_epochs,dropout,epsilon,loglamda,logalpha,sim,n_hidden,model
2,0.052784,1282.533262,0.005,2,0.0,1e-05,208.0,-208.0,2,10,LSTM


In [7]:
#SEM
model = 'SEM'
tXt_Loss.loc[(tXt_Loss.model == model) & (tXt_Loss.KL == min_vals[model])]

Unnamed: 0,KL,MSE,lr,n_epochs,dropout,epsilon,loglamda,logalpha,sim,n_hidden,model
1,0.027546,366.362387,0.05,2,0.0,1e-05,8.0,16.0,1,10,SEM


In [8]:
#MLP
model = 'MLP'
tXt_Loss.loc[(tXt_Loss.model == model) & (tXt_Loss.KL == min_vals[model])]

Unnamed: 0,KL,MSE,lr,n_epochs,dropout,epsilon,loglamda,logalpha,sim,n_hidden,model
0,0.099126,2178.732861,0.005,8,0.0,1e-05,208.0,-208.0,0,10,MLP


In [9]:
def get_data(model):
    theta = tXt_Loss.loc[(tXt_Loss.model == model) & (tXt_Loss.KL == min_vals[model])]
    lr, n_epochs, logalpha, loglamda = \
        theta.loc[:, 'lr n_epochs logalpha loglamda'.split()].values[0]
    
    model_name = ['VanillaLSTM', 'MLP'][model == 'MLP']
    no_split_tag = ['_nosplit', ''][model == 'SEM']
    
    args = [lr, int(n_epochs), float(logalpha), float(loglamda),no_split_tag]
    
    file = '{}trial_X_trial_{}_nhiddenNone_e1e-05_lr'.format(path, model_name)
    file += '{}_n{}_d0.0_logalfa_{}_loglmda_{}_{}_online_instructed.csv'.format(*args)
#     return file
    return pd.read_csv(file)


In [10]:
get_data('SEM')
# ii = len(file)  - 30
# print(file[:ii])
# glob(file[:ii] + "*")


FileNotFoundError: [Errno 2] No such file or directory: './data/trial_X_trial_VanillaLSTM_nhiddenNone_e1e-05_lr0.05_n2_d0.0_logalfa_16.0_loglmda_8.0__online_instructed.csv'

In [None]:
def plot_acc_by_time(model, condition):
    df = get_data(model)
    sns.relplot(data=df[df.Condition==condition],
            x='t', y='Accuracy', kind='line', aspect=2, height=3)
    plt.title(condition)
    plt.xlabel('Story')
    plt.ylim(0, 1.0)
    plt.xlim(0, 200)
    return plt.gcf()

def plot_pe_by_time(model, condition):
    df = get_data(model)
    sns.relplot(data=df[df.Condition==condition],
            x='t', y='pe', kind='line', aspect=2, height=3)
    plt.title(condition)
    plt.xlabel('Story')
    plt.ylabel('Prediction Error')
    plt.xlim(0, 200)
    plt.ylim(0, 1)
    return plt.gcf()

In [None]:
df = get_data("SEM")
sns.relplot(data=df, hue='Condition',
        x='t', y='Accuracy', kind='line', aspect=2, height=3)
plt.title('SEM')

In [None]:
df = get_data("LSTM")
sns.relplot(data=df, hue='Condition',
        x='t', y='Accuracy', kind='line', aspect=2, height=3)
plt.title('LSTM')

In [None]:
df = get_data("MLP")
sns.relplot(data=df, hue='Condition',
        x='t', y='Accuracy', kind='line', aspect=2, height=3)
plt.title('MLP')

## Model fitting: Blocked/Interleaved Only

In [11]:
# Load files one at a time
tXt_Loss_BIonly = []
sim = 0
files = glob(path + 'trial*csv')

def get_error(file):

    model_data = pd.read_csv(file)
    model_grouped = model_data.groupby(['Condition', 't'])

    
    MSE = 0
    KL = 0

    for cond in ['Blocked', 'Interleaved']:
        mu_model = model_grouped.mean().loc[cond]['Accuracy'].values
        sd_model = model_grouped.std().loc[cond]['Accuracy'].values

        mu_beh = trial_by_trial.loc[:, '{} Mean'.format(cond)].values
        sd_beh = trial_by_trial.loc[:, '{} StdDev'.format(cond)].values

        MSE += np.mean((mu_beh - mu_model) ** 2)
        KL  += np.sum(kl_pq(mu_beh, mu_model, sd_beh, sd_model))

    return {
            'KL': MSE,
            'MSE': KL,
        }


for file in tqdm(files):
    
    idx = file.find('nhidden') + 7
    n_hidden = pull_from_start(file, idx)
    if n_hidden == 'None':
        n_hidden = 10
    else:
        n_hidden = int(pull_from_start(file, idx))
    
    # learning rate -- this we can't find a reliable end point beyond 
    # "_" charaecter
    idx = file.find('_lr') + 3
    lr = float(pull_from_start(file, idx))

    # n_epochs
    idx = file[idx:].find('_n') + idx + 2
    n = int(pull_from_start(file, idx))

    # dropout
    idx = file[idx:].find('_d') + idx + 2
    d = float(pull_from_start(file, idx))

    # loglmda
    idx = file.find('_loglmda') + 9
    loglamda =float(pull_from_start(file, idx))

    # logalpha
    idx = file.find('_logalfa') + 9
    idx2 = file.find('_loglmda')
    logalpha = float(file[idx:idx2])       
    
    # epsilon
    idx = file.find('_e') + 2
    epsilon = float(pull_from_start(file, idx))

    _tXt = get_error(file)
    _tXt['lr'] = lr
    _tXt['n_epochs'] = n
    _tXt['dropout'] = d
    _tXt['epsilon'] = epsilon
    _tXt['loglamda'] = loglamda
    _tXt['logalpha'] = logalpha
    _tXt['sim'] = sim
    _tXt['n_hidden'] = n_hidden
    _tXt['model'] = get_model(file)

    tXt_Loss_BIonly.append(_tXt)
    sim += 1
    
#     break
tXt_Loss_BIonly = pd.DataFrame(tXt_Loss_BIonly)
tXt_Loss_BIonly.index = range(len(tXt_Loss_BIonly))




100%|██████████| 3/3 [00:00<00:00, 13.89it/s]


In [12]:
min_vals_BIonly = tXt_Loss_BIonly.groupby('model')['MSE'].min()
min_vals_BIonly

model
LSTM     967.359407
MLP     2121.687083
SEM      206.624449
Name: MSE, dtype: float64

In [13]:

def get_data(model, fits_df, metric='KL'):
    
    min_vals = fits_df.groupby('model')[metric].min()
    theta = fits_df.loc[(fits_df.model == model) & (fits_df[metric] == min_vals[model])]
    lr, n_epochs, logalpha, loglamda = \
        theta.loc[:, 'lr n_epochs logalpha loglamda'.split()].values[0]
    print(lr, n_epochs, logalpha, loglamda)
    
    model_name = ['VanillaLSTM', 'MLP'][model == 'MLP']
    no_split_tag = ['_nosplit', ''][model == 'SEM']
    
    args = [lr, int(n_epochs), float(logalpha), float(loglamda),no_split_tag]
    
    file = '{}trial_X_trial_{}_nhiddenNone_e1e-05_lr'.format(path, model_name)
    file += '{}_n{}_d0.0_logalfa_{}_loglmda_{}_{}_online_instructed.csv'.format(*args)

    return pd.read_csv(file)



In [14]:
df = get_data("MLP", tXt_Loss_BIonly)
sns.relplot(data=df, hue='Condition',
        x='t', y='Accuracy', kind='line', aspect=2, height=3)
plt.title('MLP')

0.005 8.0 -208.0 208.0


FileNotFoundError: [Errno 2] No such file or directory: './data/trial_X_trial_MLP_nhiddenNone_e1e-05_lr0.005_n8_d0.0_logalfa_-208.0_loglmda_208.0__nosplit_online_instructed.csv'

In [15]:
df = get_data("SEM", tXt_Loss_BIonly)
sns.relplot(data=df, hue='Condition',
        x='t', y='Accuracy', kind='line', aspect=2, height=3)
plt.title('SEM')

0.05 2.0 16.0 8.0


FileNotFoundError: [Errno 2] No such file or directory: './data/trial_X_trial_VanillaLSTM_nhiddenNone_e1e-05_lr0.05_n2_d0.0_logalfa_16.0_loglmda_8.0__online_instructed.csv'

In [16]:
df = get_data("LSTM", tXt_Loss_BIonly)
sns.relplot(data=df, hue='Condition',
        x='t', y='Accuracy', kind='line', aspect=2, height=3)
plt.title('LSTM')

0.005 2.0 -208.0 208.0


FileNotFoundError: [Errno 2] No such file or directory: './data/trial_X_trial_VanillaLSTM_nhiddenNone_e1e-05_lr0.005_n2_d0.0_logalfa_-208.0_loglmda_208.0__nosplit_online_instructed.csv'