# Notebook to extract hidden-states and attention heads activations from DistilBert model predictions

In [7]:
import os
import glob
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from model import DistilBertExtractor
from sklearn.preprocessing import StandardScaler
from tokenizer import tokenize
from utils import set_seed
from numpy import linalg as la


In [2]:
def check_folder(path):
    """Create adequate folders if necessary."""
    try:
        if not os.path.isdir(path):
            check_folder(os.path.dirname(path))
            os.mkdir(path)
    except:
        pass

In [3]:
def transform(activations, path, name, run_index, n_layers_hidden=13, n_layers_attention=12, hidden_size=768):
    assert activations.values.shape[1] == (n_layers_hidden + n_layers_attention) * hidden_size
    indexes = [[index*hidden_size, (index+1)*hidden_size] for index in range(n_layers_hidden + n_layers_attention)]
    for order in [None]:
        matrices = []
        for index in indexes:
            matrix = activations.values[:, index[0]:index[1]]
            with_std = True if order=='std' else False
            scaler = StandardScaler(with_mean=True, with_std=with_std)
            scaler.fit(matrix)
            matrix = scaler.transform(matrix)
            if order is not None and order != 'std':
                matrix = matrix / np.mean(la.norm(matrix, ord=order, axis=1))
            matrices.append(matrix)
        matrices = np.hstack(matrices)
        new_data = pd.DataFrame(matrices, columns=activations.columns)
        new_path = path + '_norm-' + str(order).replace('np.', '')
        check_folder(new_path)
        new_data.to_csv(os.path.join(new_path, name + '_run{}.csv'.format(run_index + 1)), index=False)


Defining variables:

In [11]:
template = '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/text/english/text_english_run*.txt' # path to text input
language = 'english'

Creating iterator for each run:

In [12]:
template = '/Users/alexpsq/Code/Parietal/data/text_english_run*.txt' # path to text input


In [13]:
paths = sorted(glob.glob(template))

In [14]:
iterator_list = [tokenize(path, language, train=False) for path in paths]

100%|██████████| 135/135 [00:00<00:00, 523318.89it/s]
100%|██████████| 135/135 [00:00<00:00, 675693.37it/s]
100%|██████████| 176/176 [00:00<00:00, 687334.73it/s]
100%|██████████| 173/173 [00:00<00:00, 864856.49it/s]
100%|██████████| 177/177 [00:00<00:00, 548295.28it/s]
100%|██████████| 216/216 [00:00<00:00, 943718.40it/s]
100%|██████████| 196/196 [00:00<00:00, 597444.47it/s]
100%|██████████| 145/145 [00:00<00:00, 568386.99it/s]
100%|██████████| 207/207 [00:00<00:00, 845395.26it/s]


In [62]:
# For 'uncased' models
#iterator_list = [ [sent.lower() for sent in run ] for run in iterator_list ] 

[['Once , when I was six years old , I saw a magnificent picture in a book about the primeval forest called ‘ Real - life Stories . ’',
  'It showed a boa constrictor swallowing a wild animal .',
  'Here is a copy of the drawing .',
  'It said in the book : “ Boa constrictors swallow their prey whole , without chewing .',
  'Then they are not able to move , and they sleep for the six months it takes for digestion . ”',
  'So I thought a lot about the adventures of the jungle and , in turn , I managed , with a coloured pencil , to make my first drawing .',
  'My Drawing Number one .',
  'It looked like this : I showed my masterpiece to the grownups and I asked them if my drawing frightened them .',
  'They answered me : “ Why would anyone be frightened by a hat ? ”',
  'My drawing was not of a hat .',
  'It showed a boa constrictor digesting an elephant .',
  'I then drew the inside of the boa constrictor , so that the grownups could understand .',
  'They always need to have things exp

## Activation extraction

In [13]:
pretrained_distilbert_models = ['distilbert-base-cased'] * 8
names = [
    'distilbert-base-cased_pre-0_1_post-0',
    'distilbert-base-cased_pre-1_1_post-0',
    'distilbert-base-cased_pre-2_1_post-0',
    'distilbert-base-cased_pre-5_1_post-0',
    'distilbert-base-cased_pre-7_1_post-0',
    'distilbert-base-cased_pre-10_1_post-0',
    'distilbert-base-cased_pre-15_1_post-0',
    'distilbert-base-cased_pre-20_1_post-0'
         ]
config_paths = [None] * 8
saving_path_folders = [
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/distilbert-base-cased_pre-0_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/distilbert-base-cased_pre-1_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/distilbert-base-cased_pre-2_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/distilbert-base-cased_pre-5_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/distilbert-base-cased_pre-7_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/distilbert-base-cased_pre-10_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/distilbert-base-cased_pre-15_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/distilbert-base-cased_pre-20_1_post-0'.format(language)
]
prediction_types = ['sentence'] * 8
number_of_sentence_list = [1] * 8
number_of_sentence_before_list = [0, 1, 2, 5, 7, 10, 15, 20] # 
number_of_sentence_after_list = [0] * 8

In [14]:
for index, distilbert_model in enumerate(pretrained_distilbert_models):
    extractor = DistilBertExtractor(distilbert_model, 
                              language, 
                              names[index], 
                              prediction_types[index], 
                              output_hidden_states=True, 
                              output_attentions=False, 
                              config_path=config_paths[index], 
                              max_length=512, 
                              number_of_sentence=number_of_sentence_list[index], 
                              number_of_sentence_before=number_of_sentence_before_list[index], 
                              number_of_sentence_after=number_of_sentence_after_list[index])
    print(extractor.name, ' - Extracting activations ...')
    for run_index, iterator in tqdm(enumerate(iterator_list)):
        print("############# Run {} #############".format(run_index))
        activations  = extractor.extract_activations(iterator, language)
        hidden_states_activations = activations[0]
        #attention_heads_activations = activations[1]
        #(cls_hidden_states_activations, cls_attention_activations) = activations[2]
        #(sep_hidden_states_activations, sep_attention_activations) = activations[3]
        #activations = pd.concat([hidden_states_activations, attention_heads_activations], axis=1)
        #cls_activations = pd.concat([cls_hidden_states_activations, cls_attention_activations], axis=1)
        #sep_activations = pd.concat([sep_hidden_states_activations, sep_attention_activations], axis=1)
        
        transform(
            hidden_states_activations, 
            saving_path_folders[index], 
            'activations', 
            run_index=run_index,
            n_layers_hidden=13,
            n_layers_attention=0, 
            hidden_size=768))
        #transform(activations, saving_path_folders[index], 'activations', run_index=run_index)
        #transform(cls_activations, saving_path_folders[index], 'cls')
        #transform(sep_activations, saving_path_folders[index], 'sep')
        
        #activations.to_csv(os.path.join(saving_path_folders[index], 'activations_run{}.csv'.format(run_index + 1)), index=False)
        #cls_activations.to_csv(os.path.join(saving_path_folders[index], 'cls_run{}.csv'.format(run_index + 1)), index=False)
        #sep_activations.to_csv(os.path.join(saving_path_folders[index], 'sep_run{}.csv'.format(run_index + 1)), index=False)
        del activations
        #del cls_activations
        #del sep_activations
        del hidden_states_activations
        #del attention_heads_activations
        #del cls_hidden_states_activations
        #del cls_attention_activations
        #del sep_hidden_states_activations
        #del sep_attention_activations
        

0it [00:00, ?it/s]

bert-base-cased_pre-15_1_post-0  - Extracting activations ...
############# Run 0 #############


1it [05:01, 301.20s/it]

############# Run 1 #############


2it [10:30, 309.72s/it]

############# Run 2 #############


3it [16:47, 329.75s/it]

############# Run 3 #############


4it [22:50, 339.74s/it]

############# Run 4 #############


5it [28:09, 333.71s/it]

############# Run 5 #############


6it [34:25, 346.16s/it]

############# Run 6 #############


7it [40:47, 357.10s/it]

############# Run 7 #############


8it [46:05, 345.43s/it]

############# Run 8 #############


9it [53:16, 355.17s/it]
0it [00:00, ?it/s]

bert-base-cased_pre-20_1_post-0  - Extracting activations ...
############# Run 0 #############


1it [05:01, 301.66s/it]

############# Run 1 #############


2it [10:32, 310.45s/it]

############# Run 2 #############


3it [16:29, 324.52s/it]

############# Run 3 #############


4it [22:10, 329.46s/it]

############# Run 4 #############


5it [27:36, 328.26s/it]

############# Run 5 #############


6it [34:04, 346.34s/it]

############# Run 6 #############


7it [40:36, 359.79s/it]

############# Run 7 #############


8it [45:59, 348.78s/it]

############# Run 8 #############


9it [53:37, 357.53s/it]
