# Notebook to extract hidden-states and attention heads activations from roberta model predictions

In [17]:
import os
import glob
import gc
import numpy as np
import pandas as pd
from tqdm import tqdm
from model import RobertaExtractor
from tokenizer import tokenize


import torch
from sklearn.preprocessing import StandardScaler
from utils import set_seed
from numpy import linalg as la


In [18]:
def check_folder(path):
    """Create adequate folders if necessary."""
    try:
        if not os.path.isdir(path):
            check_folder(os.path.dirname(path))
            os.mkdir(path)
    except:
        pass

In [19]:
def transform(activations, path, name, run_index, n_layers_hidden=13, n_layers_attention=12, hidden_size=768):
    assert activations.values.shape[1] == (n_layers_hidden + n_layers_attention) * hidden_size
    indexes = [[index*hidden_size, (index+1)*hidden_size] for index in range(n_layers_hidden + n_layers_attention)]
    for order in [None]:
        matrices = []
        for index in indexes:
            matrix = activations.values[:, index[0]:index[1]]
            with_std = True if order=='std' else False
            scaler = StandardScaler(with_mean=True, with_std=with_std)
            scaler.fit(matrix)
            matrix = scaler.transform(matrix)
            if order is not None and order != 'std':
                matrix = matrix / np.mean(la.norm(matrix, ord=order, axis=1))
            matrices.append(matrix)
        matrices = np.hstack(matrices)
        new_data = pd.DataFrame(matrices, columns=activations.columns)
        new_path = path + '_norm-' + str(order).replace('np.', '')
        check_folder(new_path)
        new_data.to_csv(os.path.join(new_path, name + '_run{}.csv'.format(run_index + 1)), index=False)


Defining variables:

In [20]:
template = '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/text/english/text_english_run*.txt' # path to text input
language = 'english'
prediction_type = 'sentence'# 'sequential'


Creating iterator for each run:

In [25]:
paths = sorted(glob.glob(template))

In [26]:
iterator_list = [tokenize(path, language, train=False) for path in paths]

100%|██████████| 135/135 [00:00<00:00, 172263.78it/s]
100%|██████████| 135/135 [00:00<00:00, 208479.76it/s]
100%|██████████| 176/176 [00:00<00:00, 218014.62it/s]
100%|██████████| 173/173 [00:00<00:00, 243331.52it/s]
100%|██████████| 177/177 [00:00<00:00, 186436.92it/s]
100%|██████████| 216/216 [00:00<00:00, 205761.90it/s]
100%|██████████| 196/196 [00:00<00:00, 199292.99it/s]
100%|██████████| 145/145 [00:00<00:00, 183461.26it/s]
100%|██████████| 207/207 [00:00<00:00, 259480.25it/s]


## Activation extraction

In [27]:
pretrained_roberta_models = ['roberta-base'] * 8
names = ['roberta-base_pre-0_1_post-0',
         'roberta-base_pre-1_1_post-0',
         'roberta-base_pre-2_1_post-0',
         'roberta-base_pre-5_1_post-0',
         'roberta-base_pre-7_1_post-0',
         'roberta-base_pre-10_1_post-0',
         'roberta-base_pre-15_1_post-0',
         'roberta-base_pre-20_1_post-0'
         ]
config_paths = [None] * 8
saving_path_folders = [
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/roberta_pre-0_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/roberta_pre-1_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/roberta_pre-2_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/roberta_pre-5_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/roberta_pre-7_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/roberta_pre-10_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/roberta_pre-15_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/roberta_pre-20_1_post-0'.format(language)
]
prediction_types = ['sentence'] * 8
number_of_sentence_list = [1] * 8
number_of_sentence_before_list = [0, 1, 2, 5, 7, 10, 15, 20] #
number_of_sentence_after_list = [0] * 8

In [28]:
for index, roberta_model in enumerate(pretrained_roberta_models):
    extractor = RobertaExtractor(roberta_model, 
                                 language, 
                                 names[index], 
                                 prediction_types[index], 
                                 output_hidden_states=True, 
                                 output_attentions=True, 
                                 config_path=config_paths[index],
                                 max_length=512,
                                 number_of_sentence=number_of_sentence_list[index], 
                                 number_of_sentence_before=number_of_sentence_before_list[index], 
                                 number_of_sentence_after=number_of_sentence_after_list[index])
    print(extractor.name, ' - Extracting activations ...')
    for run_index, iterator in tqdm(enumerate(iterator_list)):
        gc.collect()
        print("############# Run {} #############".format(run_index))
        activations  = extractor.extract_activations(iterator, language)
        hidden_states_activations = activations[0]
        attention_heads_activations = activations[1]
        (cls_hidden_states_activations, cls_attention_activations) = activations[2]
        (sep_hidden_states_activations, sep_attention_activations) = activations[3]
        activations = pd.concat([hidden_states_activations, attention_heads_activations], axis=1)
        cls_activations = pd.concat([cls_hidden_states_activations, cls_attention_activations], axis=1)
        sep_activations = pd.concat([sep_hidden_states_activations, sep_attention_activations], axis=1)
        
        transform(
            activations, 
            saving_path_folders[index], 
            'activations', 
            run_index=run_index,
            n_layers_hidden=13,
            n_layers_attention=12, 
            hidden_size=768)
                
        #transform(cls_activations, saving_path_folders[index], 'cls')
        #transform(sep_activations, saving_path_folders[index], 'sep')
        
        #activations.to_csv(os.path.join(saving_path_folders[index], 'activations_run{}.csv'.format(run_index + 1)), index=False)
        #cls_activations.to_csv(os.path.join(saving_path_folders[index], 'cls_run{}.csv'.format(run_index + 1)), index=False)
        #sep_activations.to_csv(os.path.join(saving_path_folders[index], 'sep_run{}.csv'.format(run_index + 1)), index=False)
        del activations
        del cls_activations
        del sep_activations
        del hidden_states_activations
        del attention_heads_activations
        del cls_hidden_states_activations
        del cls_attention_activations
        del sep_hidden_states_activations
        del sep_attention_activations
        

        

0it [00:00, ?it/s]

distilroberta_pre-0_1_post-0  - Extracting activations ...
############# Run 0 #############


1it [00:37, 37.36s/it]

############# Run 1 #############


2it [01:06, 34.89s/it]

############# Run 2 #############


3it [01:38, 34.11s/it]

############# Run 3 #############


4it [02:09, 32.95s/it]

############# Run 4 #############


5it [02:38, 31.98s/it]

############# Run 5 #############


6it [03:16, 33.66s/it]

############# Run 6 #############


7it [04:32, 46.33s/it]

############# Run 7 #############


8it [05:10, 43.85s/it]

############# Run 8 #############


9it [05:51, 39.08s/it]
0it [00:00, ?it/s]

distilroberta_pre-1_1_post-0  - Extracting activations ...
############# Run 0 #############


1it [00:28, 28.23s/it]

############# Run 1 #############


2it [00:58, 28.74s/it]

############# Run 2 #############


3it [01:31, 30.02s/it]

############# Run 3 #############


4it [02:02, 30.32s/it]

############# Run 4 #############


5it [02:32, 30.27s/it]

############# Run 5 #############


6it [03:08, 31.90s/it]

############# Run 6 #############


7it [03:58, 37.46s/it]

############# Run 7 #############


8it [04:38, 38.12s/it]

############# Run 8 #############


9it [05:14, 34.97s/it]
0it [00:00, ?it/s]

distilroberta_pre-2_1_post-0  - Extracting activations ...
############# Run 0 #############


1it [00:31, 31.13s/it]

############# Run 1 #############


2it [01:00, 30.75s/it]

############# Run 2 #############


3it [01:34, 31.58s/it]

############# Run 3 #############


4it [02:05, 31.43s/it]

############# Run 4 #############


5it [02:35, 30.95s/it]

############# Run 5 #############


6it [03:09, 31.92s/it]

############# Run 6 #############


7it [04:10, 40.51s/it]

############# Run 7 #############


8it [04:39, 37.24s/it]

############# Run 8 #############


9it [05:15, 35.10s/it]
0it [00:00, ?it/s]

distilroberta_pre-5_1_post-0  - Extracting activations ...
############# Run 0 #############


1it [00:29, 29.54s/it]

############# Run 1 #############


2it [01:02, 30.56s/it]

############# Run 2 #############


3it [01:37, 31.89s/it]

############# Run 3 #############


4it [02:10, 32.34s/it]

############# Run 4 #############


5it [02:43, 32.29s/it]

############# Run 5 #############


6it [03:51, 43.07s/it]

############# Run 6 #############


7it [04:28, 41.45s/it]

############# Run 7 #############


8it [05:00, 38.61s/it]

############# Run 8 #############


9it [05:39, 37.71s/it]
0it [00:00, ?it/s]

distilroberta_pre-7_1_post-0  - Extracting activations ...
############# Run 0 #############


1it [00:30, 30.12s/it]

############# Run 1 #############


2it [01:02, 30.81s/it]

############# Run 2 #############


3it [01:38, 32.36s/it]

############# Run 3 #############


4it [02:43, 42.21s/it]

############# Run 4 #############


5it [03:30, 43.62s/it]

############# Run 5 #############


6it [04:09, 42.09s/it]

############# Run 6 #############


7it [04:47, 41.05s/it]

############# Run 7 #############


8it [05:20, 38.50s/it]

############# Run 8 #############


9it [05:59, 39.98s/it]
0it [00:00, ?it/s]

distilroberta_pre-10_1_post-0  - Extracting activations ...
############# Run 0 #############


1it [00:31, 31.18s/it]

############# Run 1 #############


2it [01:26, 38.37s/it]

############# Run 2 #############


3it [02:23, 44.16s/it]

############# Run 3 #############


4it [02:58, 41.38s/it]

############# Run 4 #############


5it [03:32, 39.13s/it]

############# Run 5 #############


6it [04:12, 39.41s/it]

############# Run 6 #############


7it [04:52, 39.53s/it]

############# Run 7 #############


8it [05:25, 37.57s/it]

############# Run 8 #############


9it [06:24, 42.69s/it]
0it [00:00, ?it/s]

distilroberta_pre-15_1_post-0  - Extracting activations ...
############# Run 0 #############


1it [00:46, 46.31s/it]

############# Run 1 #############


2it [01:20, 42.65s/it]

############# Run 2 #############


3it [01:58, 41.14s/it]

############# Run 3 #############


4it [02:33, 39.44s/it]

############# Run 4 #############


5it [03:08, 38.12s/it]

############# Run 5 #############


6it [03:49, 38.84s/it]

############# Run 6 #############


7it [04:29, 39.29s/it]

############# Run 7 #############


8it [05:25, 44.32s/it]

############# Run 8 #############


9it [06:09, 41.11s/it]
0it [00:00, ?it/s]

distilroberta_pre-20_1_post-0  - Extracting activations ...
############# Run 0 #############


1it [00:33, 33.20s/it]

############# Run 1 #############


2it [01:09, 34.22s/it]

############# Run 2 #############


3it [01:49, 35.90s/it]

############# Run 3 #############


4it [02:27, 36.36s/it]

############# Run 4 #############


5it [03:02, 36.16s/it]

############# Run 5 #############


6it [04:24, 49.76s/it]

############# Run 6 #############


7it [05:06, 47.56s/it]

############# Run 7 #############


8it [05:41, 43.88s/it]

############# Run 8 #############


9it [06:24, 42.77s/it]
