This notebook reproduces the parameter search (n_topics, video window size, recall window size)

In [None]:
from scipy.stats import pearsonr as corr
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import matplotlib as mpl
import hypertools as hyp
import numpy as np
from scipy import optimize
from scipy.signal import resample
from fastdtw import fastdtw
from scipy.spatial.distance import correlation
from scipy.stats import pearsonr
from multiprocessing import Pool

%matplotlib inline
plt.rc('figure', figsize=(12, 8))

## Load data

In [None]:
movie_text = pd.read_excel('../../../data/raw/Sherlock_Segments_1000_NN_2017.xlsx', )
movie_text['Scene Segments'].fillna(method='ffill', inplace=True)

## First try a simple grid search

In [None]:
def get_models(movie, movie_wsize=50, n_components=100, recall_wsize=5, warp=True):

    # create a list of overlapping text samples
    movie_w = []
    for idx, sentence in enumerate(movie):
        movie_w.append(','.join(movie[idx:idx+movie_wsize]))

    # vectorizer parameters
    vectorizer = {
        'model' : 'CountVectorizer', 
        'params' : {
            'stop_words' : 'english'
        }
    }

    # topic model parameters
    semantic = {
        'model' : 'LatentDirichletAllocation', 
        'params' : {
            'n_components' : n_components,
            'learning_method' : 'batch',
            'random_state' : 0
        }
    }

    # create movie model with hypertools
    movie_model = hyp.tools.format_data(movie_w, vectorizer=vectorizer, semantic=semantic, corpus=movie_w)[0]

    # description are by scene, not TR so stretch the model to be in TRs
    ranges =[[d['Start Time (TRs, 1.5s)'],d['End Time (TRs, 1.5s)']] for i, d in movie_text.iterrows()] 
    expanded = []
    for i in range(1976):
        try:
            idx = np.where([i>=r[0] and i<=r[1] for r in ranges])[0][0]
            expanded.append(movie_model[idx, :])
        except:
            expanded.append(movie_model[0, :])
    movie_model = np.array(expanded)
    
    recalls = []
    for sub in range(1, 18):
        # load subject data
        recall = pd.read_csv('../../../data/raw/NN'+str(sub)+' transcript.txt', header=None, sep='.', error_bad_lines=False, encoding='latin-1').values.tolist()[0][:-1]

        rs = []  
        # loop over sentences
        for sentence in recall:
            try:
                s = sentence.encode('utf-8').strip()
                rs.append(sentence)
            except:
                pass # skips over nans


        # create overlapping windows of 5 sentences
        sub_recall = []
        for idx, sentence in enumerate(rs):
            sub_recall.append(','.join(rs[idx:idx+recall_wsize]))
        recalls.append(sub_recall)

    # create recall models
    recall_models = hyp.tools.format_data(recalls, vectorizer=vectorizer, semantic=semantic, corpus=movie_w)

    # resample
    recall_models_rs = list(map(lambda x: resample(x, 1976), recall_models))
    
    # align with dynamic time warping
    if warp:
        movie_models_dtw = []
        recall_models_rs_dtw = []
        for r in recall_models_rs:
            distance, path = fastdtw(movie_model, r, dist=correlation)
            m = movie_model[list(map(lambda x: x[0], path)), :]
            r = r[list(map(lambda x: x[1], path)), :]
            movie_models_dtw.append(m)
            recall_models_rs_dtw.append(r)
        recall_models_rs_dtw = list(map(lambda x: resample(x, 1976), recall_models_rs_dtw))
        movie_models_rs_dtw = list(map(lambda x: resample(x, 1976), movie_models_dtw))
        return movie_models_rs_dtw, recall_models_rs_dtw
    else:
        return movie_model, recall_models_rs

In [None]:
# create a list of text samples from the scene descriptions / details to train the topic model
movie = movie_text.loc[:,'Scene Details - A Level ':'Words on Screen '].apply(lambda x: ', '.join(x.fillna('')), axis=1).values.tolist()
movie_models, recall_models = get_models(movie)

In [None]:
n_topics = [5, 10, 25, 50, 100]
movie_wsizes = [5, 10, 25, 50, 100]
recall_wsizes = [5, 10, 25, 50, 100]
param_grid = [(a, b, c) for a in n_topics for b in movie_wsizes for c in recall_wsizes]
hand_rec = [27, 24, 32, 33, 32, 39, 30, 39, 28, 40, 34, 38, 47, 38, 27, 37, 39]

In [None]:
def grid_search(movie, a, b, c):
    movie_models, recall_models = get_models(movie, n_components=a, movie_wsize=b, recall_wsize=c)
    movie_rec_corr = [pearsonr(m.ravel(),r.ravel())[0] for m, r in zip(movie_models, recall_models)]
    return pearsonr(movie_rec_corr, hand_rec)

In [None]:
corrs = []
for a, b, c in param_grid:
    corr = grid_search(movie, a, b, c)
    corrs.append(corr)
    print(a, b, c, corr)

In [None]:
np.save('../../../data/processed/grid_search_results', corrs)

In [None]:
movie_models, recall_models = get_models(movie, n_components=100, movie_wsize=50, recall_wsize=10)
movie_rec_corr = [pearsonr(m.ravel(),r.ravel())[0] for m, r in zip(movie_models, recall_models)]
pearsonr(movie_rec_corr, hand_rec)
sns.jointplot(np.array(movie_rec_corr), np.array(hand_rec), kind='reg')

In [None]:
sns.set_context('paper')
mpl.rcParams['pdf.fonttype'] = 42
plt.rc('figure', figsize=(15, 3))
# cmap = 'bone'
f, axarr = plt.subplots(1, 5)
for i, arr in enumerate(np.array(list(map(lambda x: x[0], corrs))).reshape(5, 5, 5)):
    ax = sns.heatmap(arr, vmin=0, vmax=.75, xticklabels=[5, 10, 25, 50, 100], yticklabels=[5, 10, 25, 50, 100],
                     ax=axarr[i], cbar_kws={'label': 'Correlation'})
    ax.set_title(r'Number of topics ($K$): %s' % str(n_topics[i]))
    ax.set_xlabel(r'Recall window width ($\rho$)')
    ax.set_ylabel(r'Video window width ($\omega$)')
#     xplt.show()
plt.tight_layout()
plt.savefig('../../../parameter_search.pdf')