# Notebook to extract hidden-states and attention heads activations from DistilBert model predictions

In [11]:
import os
import glob
import torch
import gc
import numpy as np
import pandas as pd
from tqdm import tqdm
from model import DistilBertExtractor
from sklearn.preprocessing import StandardScaler
from tokenizer import tokenize
from utils import set_seed
from numpy import linalg as la


In [13]:
def check_folder(path):
    """Create adequate folders if necessary."""
    try:
        if not os.path.isdir(path):
            check_folder(os.path.dirname(path))
            os.mkdir(path)
    except:
        pass

In [14]:
def transform(activations, path, name, run_index, n_layers_hidden=13, n_layers_attention=12, hidden_size=768):
    assert activations.values.shape[1] == (n_layers_hidden + n_layers_attention) * hidden_size
    indexes = [[index*hidden_size, (index+1)*hidden_size] for index in range(n_layers_hidden + n_layers_attention)]
    for order in [None]:
        matrices = []
        for index in indexes:
            matrix = activations.values[:, index[0]:index[1]]
            with_std = True if order=='std' else False
            scaler = StandardScaler(with_mean=True, with_std=with_std)
            scaler.fit(matrix)
            matrix = scaler.transform(matrix)
            if order is not None and order != 'std':
                matrix = matrix / np.mean(la.norm(matrix, ord=order, axis=1))
            matrices.append(matrix)
        matrices = np.hstack(matrices)
        new_data = pd.DataFrame(matrices, columns=activations.columns)
        new_path = path + '_norm-' + str(order).replace('np.', '')
        check_folder(new_path)
        new_data.to_csv(os.path.join(new_path, name + '_run{}.csv'.format(run_index + 1)), index=False)


Defining variables:

In [15]:
template = '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/text/english/text_english_run*.txt' # path to text input
language = 'english'

Creating iterator for each run:

In [16]:
#template = '/Users/alexpsq/Code/Parietal/data/text_english_run*.txt' # path to text input


In [17]:
paths = sorted(glob.glob(template))

In [18]:
iterator_list = [tokenize(path, language, train=False) for path in paths]

100%|██████████| 135/135 [00:00<00:00, 153408.57it/s]
100%|██████████| 135/135 [00:00<00:00, 103591.48it/s]
100%|██████████| 176/176 [00:00<00:00, 194723.69it/s]
100%|██████████| 173/173 [00:00<00:00, 199783.75it/s]
100%|██████████| 177/177 [00:00<00:00, 218607.72it/s]
100%|██████████| 216/216 [00:00<00:00, 249578.42it/s]
100%|██████████| 196/196 [00:00<00:00, 233281.38it/s]
100%|██████████| 145/145 [00:00<00:00, 155463.72it/s]
100%|██████████| 207/207 [00:00<00:00, 189733.59it/s]


In [19]:
# For 'uncased' models
#iterator_list = [ [sent.lower() for sent in run ] for run in iterator_list ] 

## Activation extraction

In [20]:
pretrained_distilbert_models = ['distilbert-base-cased'] * 4
names = [
    #'distilbert-base-cased_pre-0_1_post-0',
    #'distilbert-base-cased_pre-1_1_post-0',
    #'distilbert-base-cased_pre-2_1_post-0',
    #'distilbert-base-cased_pre-5_1_post-0',
    'distilbert-base-cased_pre-7_1_post-0',
    'distilbert-base-cased_pre-10_1_post-0',
    'distilbert-base-cased_pre-15_1_post-0',
    'distilbert-base-cased_pre-20_1_post-0'
         ]
config_paths = [None] * 8
saving_path_folders = [
    #'/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/distilbert-base-cased_pre-0_1_post-0'.format(language),
    #'/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/distilbert-base-cased_pre-1_1_post-0'.format(language),
    #'/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/distilbert-base-cased_pre-2_1_post-0'.format(language),
    #'/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/distilbert-base-cased_pre-5_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/distilbert-base-cased_pre-7_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/distilbert-base-cased_pre-10_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/distilbert-base-cased_pre-15_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/distilbert-base-cased_pre-20_1_post-0'.format(language)
]
prediction_types = ['sentence'] * 8
number_of_sentence_list = [1] * 8
number_of_sentence_before_list = [7, 10, 15, 20] # 0, 1, 2, 5, 
number_of_sentence_after_list = [0] * 8

In [21]:
for index, distilbert_model in enumerate(pretrained_distilbert_models):
    extractor = DistilBertExtractor(distilbert_model, 
                              language, 
                              names[index], 
                              prediction_types[index], 
                              output_hidden_states=True, 
                              output_attentions=False, 
                              config_path=config_paths[index], 
                              max_length=512, 
                              number_of_sentence=number_of_sentence_list[index], 
                              number_of_sentence_before=number_of_sentence_before_list[index], 
                              number_of_sentence_after=number_of_sentence_after_list[index])
    print(extractor.name, ' - Extracting activations ...')
    for run_index, iterator in tqdm(enumerate(iterator_list)):
        gc.collect()
        print("############# Run {} #############".format(run_index))
        activations  = extractor.extract_activations(iterator, language)
        hidden_states_activations = activations[0]
        #attention_heads_activations = activations[1]
        #(cls_hidden_states_activations, cls_attention_activations) = activations[2]
        #(sep_hidden_states_activations, sep_attention_activations) = activations[3]
        #activations = pd.concat([hidden_states_activations, attention_heads_activations], axis=1)
        #cls_activations = pd.concat([cls_hidden_states_activations, cls_attention_activations], axis=1)
        #sep_activations = pd.concat([sep_hidden_states_activations, sep_attention_activations], axis=1)
        
        transform(
            hidden_states_activations, 
            saving_path_folders[index], 
            'activations', 
            run_index=run_index,
            n_layers_hidden=7,
            n_layers_attention=0, 
            hidden_size=768)
        #transform(activations, saving_path_folders[index], 'activations', run_index=run_index)
        #transform(cls_activations, saving_path_folders[index], 'cls')
        #transform(sep_activations, saving_path_folders[index], 'sep')
        
        #activations.to_csv(os.path.join(saving_path_folders[index], 'activations_run{}.csv'.format(run_index + 1)), index=False)
        #cls_activations.to_csv(os.path.join(saving_path_folders[index], 'cls_run{}.csv'.format(run_index + 1)), index=False)
        #sep_activations.to_csv(os.path.join(saving_path_folders[index], 'sep_run{}.csv'.format(run_index + 1)), index=False)
        del activations
        #del cls_activations
        #del sep_activations
        del hidden_states_activations
        #del attention_heads_activations
        #del cls_hidden_states_activations
        #del cls_attention_activations
        #del sep_hidden_states_activations
        #del sep_attention_activations
        

0it [00:00, ?it/s]

distilbert-base-cased_pre-7_1_post-0  - Extracting activations ...
############# Run 0 #############


1it [00:44, 44.43s/it]

############# Run 1 #############


2it [01:22, 42.64s/it]

############# Run 2 #############


3it [01:44, 36.27s/it]

############# Run 3 #############


4it [02:04, 31.35s/it]

############# Run 4 #############


5it [02:27, 28.91s/it]

############# Run 5 #############


6it [04:33, 58.08s/it]

############# Run 6 #############


7it [04:59, 48.50s/it]

############# Run 7 #############


8it [05:50, 49.15s/it]

############# Run 8 #############


9it [06:31, 43.47s/it]
0it [00:00, ?it/s]

distilbert-base-cased_pre-10_1_post-0  - Extracting activations ...
############# Run 0 #############


1it [00:16, 16.88s/it]

############# Run 1 #############


2it [01:40, 36.92s/it]

############# Run 2 #############


3it [02:57, 48.94s/it]

############# Run 3 #############


4it [03:17, 40.38s/it]

############# Run 4 #############


5it [03:37, 34.17s/it]

############# Run 5 #############


6it [04:32, 40.49s/it]

############# Run 6 #############


7it [05:11, 40.00s/it]

############# Run 7 #############


8it [05:29, 33.43s/it]

############# Run 8 #############


9it [05:54, 39.36s/it]
0it [00:00, ?it/s]

distilbert-base-cased_pre-15_1_post-0  - Extracting activations ...
############# Run 0 #############


1it [00:17, 18.00s/it]

############# Run 1 #############


2it [00:37, 18.39s/it]

############# Run 2 #############


3it [00:59, 19.52s/it]

############# Run 3 #############


4it [01:19, 19.73s/it]

############# Run 4 #############


5it [01:40, 19.96s/it]

############# Run 5 #############


6it [02:04, 21.32s/it]

############# Run 6 #############


7it [02:27, 21.90s/it]

############# Run 7 #############


8it [03:10, 28.11s/it]

############# Run 8 #############


9it [04:00, 26.72s/it]
0it [00:00, ?it/s]

distilbert-base-cased_pre-20_1_post-0  - Extracting activations ...
############# Run 0 #############


1it [00:18, 18.59s/it]

############# Run 1 #############


2it [00:39, 19.22s/it]

############# Run 2 #############


3it [01:02, 20.34s/it]

############# Run 3 #############


4it [01:23, 20.69s/it]

############# Run 4 #############


5it [01:43, 20.55s/it]

############# Run 5 #############


6it [02:08, 21.77s/it]

############# Run 6 #############


7it [02:33, 22.78s/it]

############# Run 7 #############


8it [02:53, 21.94s/it]

############# Run 8 #############


9it [03:18, 22.04s/it]
