# Notebook to extract hidden-states and attention heads activations from bert model predictions

In [1]:
import os
import glob
import pandas as pd
from tqdm import tqdm
from model import RoBertaExtractor
from tokenizer import tokenize

In [7]:
def check_folder(path):
    """Create adequate folders if necessary."""
    try:
        if not os.path.isdir(path):
            check_folder(os.path.dirname(path))
            os.mkdir(path)
    except:
        pass

Defining variables:

In [16]:
template = '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/text/english/text_english_run*.txt' # path to text input
language = 'english'
prediction_type = 'sentence'# 'sequential'


In [31]:
pretrained_roberta_models = ['roberta-base']
names = ['roberta-base']
config_paths = [None]
saving_path_folders = [
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/roberta-base'.format(language)]
prediction_types = ['sentence']

In [29]:
names

['bert-base-cased']

Creating iterator for each run:

In [4]:
paths = sorted(glob.glob(template))

In [5]:
iterator_list = [tokenize(path, language, train=False) for path in paths]

100%|██████████| 135/135 [00:00<00:00, 237513.02it/s]
100%|██████████| 135/135 [00:00<00:00, 252331.12it/s]
100%|██████████| 176/176 [00:00<00:00, 311607.22it/s]
100%|██████████| 173/173 [00:00<00:00, 319232.11it/s]
100%|██████████| 177/177 [00:00<00:00, 328491.95it/s]
100%|██████████| 216/216 [00:00<00:00, 332343.97it/s]
100%|██████████| 196/196 [00:00<00:00, 245471.36it/s]
100%|██████████| 145/145 [00:00<00:00, 145010.51it/s]
100%|██████████| 207/207 [00:00<00:00, 291153.90it/s]

Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.





## Activation extraction

In [32]:
for index, roberta_model in enumerate(pretrained_roberta_models):
    extractor = RobertaExtractor(roberta_model, language, names[index], prediction_types[index], output_hidden_states=True, output_attentions=True, config_path=config_paths[index])
    print(extractor.name, ' - Extracting activations ...')
    for run_index, iterator in tqdm(enumerate(iterator_list)):
        print("############# Run {} #############".format(run_index))
        check_folder(saving_path_folders[index])
        activations  = extractor.extract_activations(iterator, language)
        hidden_states_activations = activations[0]
        attention_heads_activations = activations[1]
        (cls_hidden_states_activations, cls_attention_activations) = activations[2]
        (sep_hidden_states_activations, sep_attention_activations) = activations[3]
        activations = pd.concat([hidden_states_activations, attention_heads_activations], axis=1)
        cls_activations = pd.concat([cls_hidden_states_activations, cls_attention_activations], axis=1)
        sep_activations = pd.concat([sep_hidden_states_activations, sep_attention_activations], axis=1)
        
        activations.to_csv(os.path.join(saving_path_folders[index], 'activations_run{}.csv'.format(run_index + 1)), index=False)
        cls_activations.to_csv(os.path.join(saving_path_folders[index], 'cls_run{}.csv'.format(run_index + 1)), index=False)
        sep_activations.to_csv(os.path.join(saving_path_folders[index], 'sep_run{}.csv'.format(run_index + 1)), index=False)
        
        

0it [00:00, ?it/s]

bert-base-cased  - Extracting activations ...
############# Run 0 #############


1it [00:59, 59.39s/it]

############# Run 1 #############


2it [02:01, 60.08s/it]

############# Run 2 #############


3it [03:48, 74.41s/it]

############# Run 3 #############


4it [04:55, 72.02s/it]

############# Run 4 #############


5it [06:02, 70.66s/it]

############# Run 5 #############


6it [07:21, 73.09s/it]

############# Run 6 #############


7it [09:14, 85.13s/it]

############# Run 7 #############


7it [10:05, 86.55s/it]


KeyboardInterrupt: 