In [1]:
import os
from utils import load_config
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd

config = load_config()

In [2]:
story_name = 'cable_spool_fort'

In [3]:
# read table
surprisals = []
processed_stimuli_dir = config['directories']['stimuli_dir']
with open(os.path.join(processed_stimuli_dir, f'surprisals_{story_name}.txt'), 'r') as f:
  for i, line in tqdm(enumerate(f.readlines())):
    line = line.split(',')
    surprisals.append([np.nan if surp == '\n' else float(surp) for surp in line])

0it [00:00, ?it/s]

In [4]:
word2tokens = pd.read_csv(os.path.join(processed_stimuli_dir, f'word2tokens_{story_name}.csv'))
word2tokens

Unnamed: 0,start,end,word
0,1,2,the
1,2,3,cable
2,3,5,spool
3,5,6,fort
4,6,7,by
...,...,...,...
2023,2483,2484,it
2024,2484,2485,was
2025,2485,2486,all
2026,2486,2487,right


In [5]:
# align subword token
n_subtokens = len(surprisals)
surprisals_arr = np.full([n_subtokens, n_subtokens], np.nan)
for i, surprisal in tqdm(enumerate(surprisals)):
  n = min(len(surprisal), n_subtokens-i-1)
  surprisals_arr[i, i+1:i+n+1] = surprisal[:n]

0it [00:00, ?it/s]

In [6]:
# sum subtoken surprisals to word surprisals
for i, row in tqdm(word2tokens.iterrows()):
  start, end = row.start, row.end
  surprisals_arr[:, start] = np.sum(surprisals_arr[:, start:end], axis=1)

0it [00:00, ?it/s]

In [7]:
# remove non-terminal subword token surprisals
start_idx = word2tokens['start']
surprisals_arr = surprisals_arr[:, start_idx]

In [8]:
# align context length
for i in tqdm(range(surprisals_arr.shape[-1])):
    j = word2tokens['start'][i]
    # swap first j items and j+1 to end
    surprisals_arr[:, i] = np.concatenate([surprisals_arr[j:, i], surprisals_arr[:j, i]])

surprisals_arr = np.flipud(surprisals_arr)

  0%|          | 0/2028 [00:00<?, ?it/s]

In [9]:
# organize to table
surprisals_arr = surprisals_arr[:1024, :]
surprisals_arr = surprisals_arr.T

In [10]:
surp_columns = ['surp_' + str(i) for i in range(1, 1025)]
surp_table = pd.DataFrame(surprisals_arr, columns=surp_columns)
surp_table['word'] = word2tokens['word']

surp_table = surp_table[['word'] + surp_columns]

In [11]:
surp_table.to_csv(os.path.join(processed_stimuli_dir, f'surprisals_{story_name}.csv'), index=False)
surp_table

Unnamed: 0,word,surp_1,surp_2,surp_3,surp_4,surp_5,surp_6,surp_7,surp_8,surp_9,...,surp_1015,surp_1016,surp_1017,surp_1018,surp_1019,surp_1020,surp_1021,surp_1022,surp_1023,surp_1024
0,the,3.278111,,,,,,,,,...,,,,,,,,,,
1,cable,11.345116,11.064468,,,,,,,,...,,,,,,,,,,
2,spool,13.657824,11.441582,11.902241,,,,,,,...,,,,,,,,,,
3,fort,10.903999,9.666206,10.646751,10.894592,10.396004,,,,,...,,,,,,,,,,
4,by,6.007278,8.172611,7.454346,7.837749,7.468010,7.569530,,,,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2023,it,4.697765,4.365662,4.280426,4.322281,4.339119,4.360588,4.398544,4.375565,4.339180,...,4.249870,4.254539,4.253876,4.246719,4.248413,4.254135,4.226738,4.240074,4.244377,4.128273
2024,was,3.677078,2.158566,2.175110,2.170673,2.246403,2.185013,2.190273,2.214771,2.254070,...,1.833939,1.823006,1.833786,1.827431,1.823845,1.791626,1.821510,1.826088,1.815361,1.676491
2025,all,5.337856,4.310524,4.726067,4.976852,4.978607,5.022606,4.968163,5.008003,5.027847,...,4.429268,4.412842,4.429955,4.421516,4.407539,4.428787,4.416733,4.406685,4.390083,4.546486
2026,right,6.806097,4.488609,4.571327,4.870041,4.971634,4.936180,4.852280,4.880608,4.936584,...,4.976105,4.969650,5.014633,4.891335,4.972343,4.983078,4.899208,4.923599,4.937271,4.799240
