In [1]:
import sys, os
sys.path.append("C:/users/kristijan/documents/projects/jhu1/")
sys.path.append("C:/users/kristijan/documents/projects/jhu1/data")
os.chdir("C:/users/kristijan/documents/projects/jhu1/")
import pandas as pd
import numpy as np
import plotly.express as px
import seaborn as sns
from stimuli import prefixes, prompts_word_prediction
import json

# Word prediction experiment

Can we study the WM of a pretrained language model by looking at its outputs directly?
- How many items does GPT-2 hold in memory before it starts forgetting
- (What kind a prompt is needed for this?)
- (Does the semantic structure of list composition affect its WM performance?)

## Dataset

In [2]:
# load the word lists
with open("./data/toronto.json") as f:
    stim = json.load(f)

### The Toronto noun pool

I randomly sample the Toronto word pool (comes with freq and concretness estimates).  
The original data look like this:  

In [3]:
trt = pd.read_csv("./data/toronto_freq.txt", sep="\t")
trt

Unnamed: 0,word,imagery,k-f freq,concreteness
0,ABSENCE,3.0,53,3.7
1,ACCORD,1.8,9,2.6
2,ACCOUNT,2.6,117,4.0
3,ACID,4.4,13,6.1
4,ACRE,4.6,9,5.7
...,...,...,...,...
473,WORKER,5.4,30,6.3
474,WORSHIP,4.7,36,2.8
475,WRINKLE,6.0,2,5.2
476,WRITER,4.8,73,6.2


For now, I create word lists of length 3, 5 and 10 words.  
I sample, randomly, 20 items per length, giving me 60 word lists in total.  
Some example word lists:

In [4]:
[l for l in stim[0::10]]

['series, singer, affair.',
 'sentence, teacher, array.',
 'theater, manner, sorrow, angel, illness.',
 'pepper, unit, standard, beggar, meantime.',
 'major, hunter, princess, journal, jersey, marble, fortune, limit, spirit, mother.',
 'frontier, widow, amount, darkness, tiger, mayor, echo, mountain, olive, silver.']

### Prefixes and prompts

The prefixes -- known from before -- I consider using currently:  
(simple, syntax-no-meaning, explicit, random, complex)

In [5]:
prefixes  

{'spl': 'This is a list:',
 'syn': 'This is a mist:',
 'exp': 'Memorize this list:',
 'rnd': 'Penguin was that were:',
 'cpl': 'Before going to the store, Mary composed this shopping list:'}

And some prompts:  
(simple, item-specific, random, complex)  

In [6]:
prompts_word_prediction

{'cnt': 'The list contains:',
 'end': 'This list started with:',
 'rnd': 'At it goals with was:',
 'cpl': 'The first item on her shopping list was the following:'}

### Current ouput

In [7]:
dat = pd.read_csv("./output/word_prediction.txt")

I run all combinations of prefixes and prompts.
I consider two sampling options:  
- only grab the top candiate  
- grab the top 5 candidates  

For example for:   
$\texttt{This is a list:} \space L \space \texttt{The list contains}$:  
The top 1 candidate seems to be the line break

In [8]:
dat.head(10).loc[:, ["Unnamed: 0", "pref", "prompt", "prediction", "prob", "top-k"]]

Unnamed: 0.1,Unnamed: 0,pref,prompt,prediction,prob,top-k
0,0,spl,cnt,\n,0.5311000347137451,1
1,1,spl,cnt,\n,0.5332574248313904,1
2,2,spl,cnt,\n,0.4934692978858948,1
3,3,spl,cnt,\n,0.5984312295913696,1
4,4,spl,cnt,\n,0.4882455170154571,1
5,5,spl,cnt,\n,0.6243464946746826,1
6,6,spl,cnt,\n,0.6836918592453003,1
7,7,spl,cnt,\n,0.5130649209022522,1
8,8,spl,cnt,\n,0.6700116395950317,1
9,9,spl,cnt,\n,0.5511378049850464,1


Or consider:  
$\texttt{Before going to the store, Mary composed this shopping list:} \space L \space \texttt{The first item on her shopping list was the following}$:

And let's look at the `top-1` scenario:

In [9]:
tmp = dat.loc[(dat["top-k"] == 1)]
tmp.tail(10).loc[:, ["Unnamed: 0", "pref", "prompt", "prediction", "prob", "top-k"]]

Unnamed: 0.1,Unnamed: 0,pref,prompt,prediction,prob,top-k
1190,1190,cpl,cpl,"""",0.3913321495056152,1
1191,1191,cpl,cpl,"""",0.3763377666473388,1
1192,1192,cpl,cpl,\n,0.3462581038475036,1
1193,1193,cpl,cpl,\n,0.3721664249897003,1
1194,1194,cpl,cpl,\n,0.3956116139888763,1
1195,1195,cpl,cpl,"""",0.4603332579135895,1
1196,1196,cpl,cpl,\n,0.3521767258644104,1
1197,1197,cpl,cpl,"""",0.4250360429286957,1
1198,1198,cpl,cpl,"""",0.4197111129760742,1
1199,1199,cpl,cpl,\n,0.5021361708641052,1


And the "noise" is also in the `top-5 scenario`:

In [10]:
tmp = dat.loc[(dat["top-k"] == 5)]
tmp.tail(10).loc[:, ["Unnamed: 0", "pref", "prompt", "prediction", "prob", "top-k"]]

Unnamed: 0.1,Unnamed: 0,pref,prompt,prediction,prob,top-k
2390,2390,cpl,cpl,"""\n a the $","[0.39133214950561523, 0.3469286262989044, 0.04...",5
2391,2391,cpl,cpl,"""\n a the sunset","[0.37633776664733887, 0.2485126554965973, 0.08...",5
2392,2392,cpl,cpl,"\n "" a $ diamond","[0.34625810384750366, 0.33191266655921936, 0.0...",5
2393,2393,cpl,cpl,"\n "" a the I","[0.3721664249897003, 0.34648990631103516, 0.07...",5
2394,2394,cpl,cpl,"\n "" a $ A","[0.39561161398887634, 0.32124248147010803, 0.0...",5
2395,2395,cpl,cpl,"""\n a the '","[0.4603332579135895, 0.28773531317710876, 0.05...",5
2396,2396,cpl,cpl,"\n "" a $ the","[0.3521767258644104, 0.2822273075580597, 0.082...",5
2397,2397,cpl,cpl,"""\n a the The","[0.4250360429286957, 0.3691312372684479, 0.039...",5
2398,2398,cpl,cpl,"""\n a\n\n the","[0.4197111129760742, 0.412942498922348, 0.0481...",5
2399,2399,cpl,cpl,"\n "" a the\n\n","[0.5021361708641052, 0.2106030136346817, 0.104...",5
