In [1]:
import os
import pickle
import numpy as np
import pandas as pd

In [2]:
working_dir = "/om2/user/ckauf/.result_caching/neural_nlp.models.wrapper.core.ActivationsExtractorHelper._from_sentences_stored"

## Example of what the structure looks like

In [3]:
file = os.path.join(working_dir, "identifier=distilgpt2,stimuli_identifier=Pereira2018-Original-384sentences7.pkl")
with open(file, 'rb') as f:
    result = pickle.load(f)

data = result['data']
data

<xarray.NeuroidAssembly (presentation: 4, neuroid: 5376)>
array([[-0.045079,  0.1112  ,  0.058343, ...,  0.358238, -0.26599 ,  0.280199],
       [ 0.024145,  0.090933,  0.113967, ..., -0.012236,  0.10889 ,  0.254299],
       [-0.112562, -0.150726,  0.18517 , ...,  0.196315,  0.178203,  0.134934],
       [ 0.028349, -0.096964,  0.173206, ...,  0.366918, -0.123713,  0.491447]],
      dtype=float32)
Coordinates:
  * presentation       (presentation) MultiIndex
  - stimulus_sentence  (presentation) object 'a banana is a long fruit that grows in bunchs with a soft edible inside' ... 'unripe bananas and plantains are staple foods and often cooked like potatoes'
  - sentence_num       (presentation) int64 0 1 2 3
  * neuroid            (neuroid) MultiIndex
  - neuroid_num        (neuroid) int64 0 1 2 3 4 5 6 7 ... 23 24 25 26 27 28 29
  - model              (neuroid) object 'distilgpt2' ... 'distilgpt2'
  - layer              (neuroid) object 'drop' 'drop' 'drop' ... 'drop' 'drop'
  - neuroid

In [39]:
data.shape
# number of sentences in passage * hidden layer activation

(4, 5376)

## Create dictionary
### Structure:

For all models one separate dictionary structured as follows:
* key = sentenceIdentifier _ conditionIdentifier
* value = associated unpickled file (assembly like above)

In [33]:
glove_dic = {}
lm1b_dic = {}
distilgpt2_dic = {}

for ind,filename in enumerate(os.listdir(working_dir)):
    if ind%500 == 0:
        print(ind)
        
    if "identifier=glove" in filename:
        passage = filename.split(",")[-1].split("-")[-1].split(".")[0]
        condition = filename.split(",")[-1].split("-")[-2]
        #key = passage + "_" + condition
        file = os.path.join(working_dir,filename)
        with open(file, 'rb') as f:
            result = pickle.load(f)
        #glove_dic[key] = result
        if not condition in glove_dic:
            glove_dic[condition] = {}
        glove_dic[condition][passage] = result
        
    if "identifier=lm_1b" in filename:
        passage = filename.split(",")[-1].split("-")[-1].split(".")[0]
        condition = filename.split(",")[-1].split("-")[-2]
        #key = passage + "_" + condition
        file = os.path.join(working_dir,filename)
        with open(file, 'rb') as f:
            result = pickle.load(f)
        #lm1b_dic[key] = result
        if not condition in lm1b_dic:
            lm1b_dic[condition] = {}
        lm1b_dic[condition][passage] = result
        
    if "identifier=distilgpt2" in filename:
        passage = filename.split(",")[-1].split("-")[-1].split(".")[0]
        condition = filename.split(",")[-1].split("-")[-2]
        #key = passage + "_" + condition
        #print(key)
        file = os.path.join(working_dir,filename)
        with open(file, 'rb') as f:
            result = pickle.load(f)
        #distilgpt2_dic[key] = result
        if not condition in distilgpt2_dic:
            distilgpt2_dic[condition] = {}
        distilgpt2_dic[condition][passage] = result
print("Done")

0
500
1000
1500
2000
2500
3000
3500
4000
Done


In [37]:
print(glove_dic.keys())
print(len(glove_dic['Original']))

dict_keys(['Original', 'Scr1', 'Scr3', 'Scr5', 'Scr7', 'backward', 'lowPMI', 'random'])
168


In [31]:
print(len(glove_dic)) # old, from different iteration // keeping for sanity check

1344


## Quick number sanity-check
* 8 conditions á 168 passage encodings >> 8 * 168 = 1344
* 
* 168 * 4 = 672 (Do all passages have 4 sentences?)
* 243 + 384 = 627 (No, some must have fewer!)

In [29]:
cond_dict = {}

for ind,filename in enumerate(os.listdir(working_dir)):
    if "identifier=glove" in filename:
        condition = filename.split(",")[-1].split("-")[-2]
        if not condition in cond_dict:
            cond_dict[condition] = 1
        else:
            cond_dict[condition] += 1
print(cond_dict)

{'Original': 168, 'Scr1': 168, 'Scr3': 168, 'Scr5': 168, 'Scr7': 168, 'backward': 168, 'lowPMI': 168, 'random': 168}


__________________________________________________________________________________________________________________

__________________________________________________________________________________________________________________


# Get correlations

### Questions
* size is passage_len * hidden_layer_size >> Should correlation value come from a matrix norm?
* alternatively, we could do independent plots for each layer.

__________________________________________________________________________________________________________________

### Quick change: Better dictionary structure: sentence to condition

In [43]:
glove_dic_senttocond = {}
lm1b_dic_senttocond = {}
distilgpt2_dic_senttocond = {}

for ind,filename in enumerate(os.listdir(working_dir)):
    if ind%500 == 0:
        print(ind)
        
    if "identifier=glove" in filename:
        passage = filename.split(",")[-1].split("-")[-1].split(".")[0]
        condition = filename.split(",")[-1].split("-")[-2]
        #key = passage + "_" + condition
        file = os.path.join(working_dir,filename)
        with open(file, 'rb') as f:
            result = pickle.load(f)
        #glove_dic_senttocond[key] = result
        if not condition in glove_dic_senttocond:
            glove_dic_senttocond[passage] = {}
        glove_dic_senttocond[passage][condition] = result
        
    if "identifier=lm_1b" in filename:
        passage = filename.split(",")[-1].split("-")[-1].split(".")[0]
        condition = filename.split(",")[-1].split("-")[-2]
        #key = passage + "_" + condition
        file = os.path.join(working_dir,filename)
        with open(file, 'rb') as f:
            result = pickle.load(f)
        #lm1b_dic_senttocond[key] = result
        if not condition in lm1b_dic_senttocond:
            lm1b_dic_senttocond[passage] = {}
        lm1b_dic_senttocond[passage][condition] = result
        
    if "identifier=distilgpt2" in filename:
        passage = filename.split(",")[-1].split("-")[-1].split(".")[0]
        condition = filename.split(",")[-1].split("-")[-2]
        #key = passage + "_" + condition
        #print(key)
        file = os.path.join(working_dir,filename)
        with open(file, 'rb') as f:
            result = pickle.load(f)
        #distilgpt2_dic_senttocond[key] = result
        if not condition in distilgpt2_dic_senttocond:
            distilgpt2_dic_senttocond[passage] = {}
        distilgpt2_dic_senttocond[passage][condition] = result
print("Done")

0
500
1000
1500
2000
2500
3000
3500
4000
Done


In [44]:
print(glove_dic_senttocond.keys())

dict_keys(['243sentences1', '243sentences10', '243sentences11', '243sentences12', '243sentences13', '243sentences14', '243sentences15', '243sentences16', '243sentences17', '243sentences18', '243sentences19', '243sentences2', '243sentences20', '243sentences21', '243sentences22', '243sentences23', '243sentences24', '243sentences25', '243sentences26', '243sentences27', '243sentences28', '243sentences29', '243sentences3', '243sentences30', '243sentences31', '243sentences32', '243sentences33', '243sentences34', '243sentences35', '243sentences36', '243sentences37', '243sentences38', '243sentences39', '243sentences4', '243sentences40', '243sentences41', '243sentences42', '243sentences43', '243sentences44', '243sentences45', '243sentences46', '243sentences47', '243sentences48', '243sentences49', '243sentences5', '243sentences50', '243sentences51', '243sentences52', '243sentences53', '243sentences54', '243sentences55', '243sentences56', '243sentences57', '243sentences58', '243sentences59', '243

In [None]:
# old code
#cols = df.corr(method='pearson').nlargest(k,target_col)[target_col].index
#cm = df[cols].corr(method='pearson')
#plt.figure(figsize=(15,8))
#plt.tight_layout()
#sns.heatmap(cm, annot=True, cmap = 'viridis')