# Notebook to extract hidden-states and attention heads activations from bert model predictions

In [1]:
import os
import glob
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from model import BertExtractor
from tokenizer import tokenize
from utils import set_seed

In [2]:
def check_folder(path):
    """Create adequate folders if necessary."""
    try:
        if not os.path.isdir(path):
            check_folder(os.path.dirname(path))
            os.mkdir(path)
    except:
        pass

Defining variables:

In [3]:
template = '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/text/english/text_english_run*.txt' # path to text input
language = 'english'

In [4]:
pretrained_bert_models = [
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/models/BERT/SENTENCE-CLASSIFICATION_SST-2_bert_base_cased/fine_tuned',
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/models/BERT/SENTENCE-CLASSIFICATION_COLA_bert_base_cased/fine_tuned',
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/models/BERT/POS_CONLL2003_bert_base_cased/fine_tuned',
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/models/BERT/NER_CONLL2003_bert_base_cased/fine_tuned'
] # path to the model from which we want to retrieve the activations
infos = [os.path.basename(os.path.dirname(model)).split('_') for model in pretrained_bert_models]
names = ['{}_{}_{}_{}'.format(info[2], info[0], info[1], os.path.basename(model)) for (info, model) in zip(infos, pretrained_bert_models)]
config_paths = [os.path.join(os.path.dirname(model), 'config.yml') for model in pretrained_bert_models]
saving_path_folders = [
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/{}'.format(language, name) for name in names]
prediction_types = ['sentence' for i in pretrained_bert_models]

Creating iterator for each run:

In [5]:
pretrained_bert_models += ['bert-base-cased', 'bert-base-cased']
names += ['bert-base-cased', 'bert-base-cased']
config_paths += [None, None]
saving_path_folders += [
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-cased'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-cased_sequential'.format(language)]
prediction_types += ['sentence', 'sequential']

In [6]:
pretrained_bert_models = ['bert-base-cased']
names = ['bert-base-cased']
config_paths = [None]
saving_path_folders = [
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-cased_duplicata'.format(language)]
prediction_types = ['sentence']

In [7]:
names

['bert-base-cased']

In [8]:
#template = '/Users/alexpsq/Code/Parietal/data/text_english_run*.txt' # path to text input


In [9]:
paths = sorted(glob.glob(template))

In [10]:
iterator_list = [tokenize(path, language, train=False) for path in paths]

100%|██████████| 135/135 [00:00<00:00, 238815.28it/s]
100%|██████████| 135/135 [00:00<00:00, 251434.74it/s]
100%|██████████| 176/176 [00:00<00:00, 280150.86it/s]
100%|██████████| 173/173 [00:00<00:00, 274023.64it/s]
100%|██████████| 177/177 [00:00<00:00, 296127.57it/s]
100%|██████████| 216/216 [00:00<00:00, 330645.86it/s]
100%|██████████| 196/196 [00:00<00:00, 178713.82it/s]
100%|██████████| 145/145 [00:00<00:00, 234725.62it/s]
100%|██████████| 207/207 [00:00<00:00, 332651.70it/s]

Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.





## Activation extraction

In [11]:
pretrained_bert_models = ['bert-base-cased'] * 6
names = ['bert-base-cased_seq_pre-0_1_post-0',
         'bert-base-cased_seq_pre-1_1_post-0',
         'bert-base-cased_seq_pre-5_1_post-0',
         'bert-base-cased_seq_pre-10_1_post-0',
         'bert-base-cased_seq_pre-15_1_post-0',
         'bert-base-cased_seq_pre-20_1_post-0'
         ]
config_paths = [None] * 6
saving_path_folders = [
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-cased_seq_pre-0_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-cased_seq_pre-1_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-cased_seq_pre-5_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-cased_seq_pre-10_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-cased_seq_pre-15_1_post-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-cased_seq_pre-20_1_post-0'.format(language)
]
prediction_types = ['sequential'] * 6
number_of_sentence_list = [1] * 6
number_of_sentence_before_list = [0, 1, 5, 10, 15, 20]
number_of_sentence_after_list = [0] * 6

In [12]:
for index, bert_model in enumerate(pretrained_bert_models):
    extractor = BertExtractor(bert_model, 
                              language, 
                              names[index], 
                              prediction_types[index], 
                              output_hidden_states=True, 
                              output_attentions=True, 
                              config_path=config_paths[index], 
                              max_length=512, 
                              number_of_sentence=number_of_sentence_list[index], 
                              number_of_sentence_before=number_of_sentence_before_list[index], 
                              number_of_sentence_after=number_of_sentence_after_list[index])
    print(extractor.name, ' - Extracting activations ...')
    for run_index, iterator in tqdm(enumerate(iterator_list)):
        print("############# Run {} #############".format(run_index))
        check_folder(saving_path_folders[index])
        activations  = extractor.extract_activations(iterator, language)
        hidden_states_activations = activations[0]
        attention_heads_activations = activations[1]
        (cls_hidden_states_activations, cls_attention_activations) = activations[2]
        (sep_hidden_states_activations, sep_attention_activations) = activations[3]
        activations = pd.concat([hidden_states_activations, attention_heads_activations], axis=1)
        cls_activations = pd.concat([cls_hidden_states_activations, cls_attention_activations], axis=1)
        sep_activations = pd.concat([sep_hidden_states_activations, sep_attention_activations], axis=1)
        
        activations.to_csv(os.path.join(saving_path_folders[index], 'activations_run{}.csv'.format(run_index + 1)), index=False)
        cls_activations.to_csv(os.path.join(saving_path_folders[index], 'cls_run{}.csv'.format(run_index + 1)), index=False)
        sep_activations.to_csv(os.path.join(saving_path_folders[index], 'sep_run{}.csv'.format(run_index + 1)), index=False)
        del activations
        del cls_activations
        del sep_activations
        del hidden_states_activations
        del attention_heads_activations
        del cls_hidden_states_activations
        del cls_attention_activations
        del sep_hidden_states_activations
        del sep_attention_activations
        

0it [00:00, ?it/s]

bert-base-cased_seq_pre-0_1_post-0  - Extracting activations ...
############# Run 0 #############


1it [03:35, 215.96s/it]

############# Run 1 #############


2it [08:28, 238.87s/it]

############# Run 2 #############


3it [13:23, 255.90s/it]

############# Run 3 #############


4it [18:22, 268.72s/it]

############# Run 4 #############


5it [23:14, 275.57s/it]

############# Run 5 #############


6it [28:33, 288.59s/it]

############# Run 6 #############


7it [34:06, 302.02s/it]

############# Run 7 #############


8it [38:30, 290.64s/it]

############# Run 8 #############


9it [44:13, 294.85s/it]
0it [00:00, ?it/s]

bert-base-cased_seq_pre-1_1_post-0  - Extracting activations ...
############# Run 0 #############


1it [03:54, 234.94s/it]

############# Run 1 #############


2it [08:53, 253.98s/it]

############# Run 2 #############


3it [14:00, 269.86s/it]

############# Run 3 #############


4it [18:41, 273.29s/it]

############# Run 4 #############


5it [23:14, 273.08s/it]

############# Run 5 #############


6it [27:47, 273.24s/it]

############# Run 6 #############


7it [32:16, 271.94s/it]

############# Run 7 #############


8it [35:59, 257.20s/it]

############# Run 8 #############


9it [41:09, 274.34s/it]
0it [00:00, ?it/s]

bert-base-cased_seq_pre-5_1_post-0  - Extracting activations ...
############# Run 0 #############


1it [04:40, 280.65s/it]

############# Run 1 #############


2it [10:01, 292.72s/it]

############# Run 2 #############


3it [15:44, 307.85s/it]

############# Run 3 #############


4it [20:37, 303.47s/it]

############# Run 4 #############


5it [25:24, 298.35s/it]

############# Run 5 #############


6it [31:02, 310.20s/it]

############# Run 6 #############


7it [37:30, 333.73s/it]

############# Run 7 #############


8it [43:02, 333.21s/it]

############# Run 8 #############


9it [49:30, 330.05s/it]
0it [00:00, ?it/s]

bert-base-cased_seq_pre-10_1_post-0  - Extracting activations ...
############# Run 0 #############


1it [05:01, 301.35s/it]

############# Run 1 #############


2it [10:34, 310.92s/it]

############# Run 2 #############


3it [17:17, 338.52s/it]

############# Run 3 #############


4it [23:43, 352.79s/it]

############# Run 4 #############


5it [29:45, 355.42s/it]

############# Run 5 #############


6it [36:01, 361.56s/it]

############# Run 6 #############


7it [43:05, 380.37s/it]

############# Run 7 #############


8it [49:13, 376.68s/it]

############# Run 8 #############


9it [56:14, 374.99s/it]
0it [00:00, ?it/s]

bert-base-cased_seq_pre-15_1_post-0  - Extracting activations ...
############# Run 0 #############


1it [06:04, 364.15s/it]

############# Run 1 #############


2it [12:55, 378.43s/it]

############# Run 2 #############


3it [19:30, 383.33s/it]

############# Run 3 #############


4it [26:05, 386.91s/it]

############# Run 4 #############


5it [32:25, 384.58s/it]

############# Run 5 #############


6it [39:28, 396.30s/it]

############# Run 6 #############


7it [46:42, 407.56s/it]

############# Run 7 #############


8it [53:12, 402.15s/it]

############# Run 8 #############


9it [1:00:14, 401.57s/it]
0it [00:00, ?it/s]

bert-base-cased_seq_pre-20_1_post-0  - Extracting activations ...
############# Run 0 #############


1it [06:38, 398.47s/it]

############# Run 1 #############


2it [14:12, 415.22s/it]

############# Run 2 #############


3it [21:50, 427.86s/it]

############# Run 3 #############


4it [28:48, 425.10s/it]

############# Run 4 #############


5it [34:20, 397.16s/it]

############# Run 5 #############


6it [42:02, 416.64s/it]

############# Run 6 #############


7it [49:48, 431.43s/it]

############# Run 7 #############


8it [56:47, 427.67s/it]

############# Run 8 #############


9it [1:04:46, 431.84s/it]


### Generate control activations

In [76]:
bert_model = 'bert-base-cased'
language = 'english'
name = 'bert-base-cased_control_'
prediction_type = 'sentence'
saving_path_folder = '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}'.format(language)
seeds = [24, 213, 1111, 61, 183]

In [77]:
def randomize_layer(model, layer_nb):
    """Randomize layer weights and put bias to zero.
    The input "layer_nb" goes from 1 to 12 to be coherent with the rest of the analysis.
    It is then transfomed in the function.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.self.query.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.query.weight))
    model.encoder.layer[layer_nb].attention.self.query.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.query.bias))
    model.encoder.layer[layer_nb].attention.self.key.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.key.weight))
    model.encoder.layer[layer_nb].attention.self.key.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.key.bias))
    model.encoder.layer[layer_nb].attention.self.value.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.value.weight))
    model.encoder.layer[layer_nb].attention.self.value.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.value.bias))
    model.encoder.layer[layer_nb].attention.output.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.output.dense.weight))
    model.encoder.layer[layer_nb].attention.output.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.output.dense.bias))
    model.encoder.layer[layer_nb].attention.output.LayerNorm.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.output.LayerNorm.weight))
    model.encoder.layer[layer_nb].attention.output.LayerNorm.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.output.LayerNorm.bias))
    model.encoder.layer[layer_nb].intermediate.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].intermediate.dense.weight))
    model.encoder.layer[layer_nb].intermediate.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].intermediate.dense.bias))
    model.encoder.layer[layer_nb].output.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].output.dense.weight))
    model.encoder.layer[layer_nb].output.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].output.dense.bias))
    model.encoder.layer[layer_nb].output.LayerNorm.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].output.LayerNorm.weight))
    model.encoder.layer[layer_nb].output.LayerNorm.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].output.LayerNorm.bias))
    return model

In [78]:
def randomize_attention_query(model, layer_nb):
    """Randomize attention query weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.self.query.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.query.weight))
    model.encoder.layer[layer_nb].attention.self.query.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.query.bias))
    return model

def randomize_attention_key(model, layer_nb):
    """Randomize attention key weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.self.key.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.key.weight))
    model.encoder.layer[layer_nb].attention.self.key.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.key.bias))
    return model

def randomize_attention_value(model, layer_nb):
    """Randomize attention value weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.self.value.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.value.weight))
    model.encoder.layer[layer_nb].attention.self.value.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.value.bias))
    return model

def randomize_attention_output_dense(model, layer_nb):
    """Randomize attention dense network weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.output.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.output.dense.weight))
    model.encoder.layer[layer_nb].attention.output.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.output.dense.bias))
    return model


def randomize_intermediate_dense(model, layer_nb):
    """Randomize intermediate dense network weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].intermediate.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].intermediate.dense.weight))
    model.encoder.layer[layer_nb].intermediate.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].intermediate.dense.bias))
    return model

def randomize_outptut_dense(model, layer_nb):
    """Randomize output dense network weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].output.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].output.dense.weight))
    model.encoder.layer[layer_nb].output.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].output.dense.bias))
    return model


In [79]:
def randomize_embeddings(model):
    """Randomize embeddings weights and put bias to zero.
    """
    model.embeddings.word_embeddings.weight = torch.nn.parameter.Parameter(torch.rand_like(model.embeddings.word_embeddings.weight))
    model.embeddings.position_embeddings.weight = torch.nn.parameter.Parameter(torch.rand_like(model.embeddings.position_embeddings.weight))
    model.embeddings.token_type_embeddings.weight = torch.nn.parameter.Parameter(torch.rand_like(model.embeddings.token_type_embeddings.weight))
    model.embeddings.LayerNorm.weight = torch.nn.parameter.Parameter(torch.rand_like(model.embeddings.LayerNorm.weight))
    model.embeddings.LayerNorm.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.embeddings.LayerNorm.bias))
    return model

In [80]:
for seed in seeds:
    set_seed(seed)
    for layer in range(13):
        extractor = BertExtractor(bert_model, language, name, prediction_type, output_hidden_states=True, output_attentions=True, config_path=None)
        if layer==0:
            extractor.model = randomize_embeddings(extractor.model)
        else:
            extractor.model = randomize_layer(extractor.model, layer)
        print(extractor.name + str(seed), ' - Extracting activations for layer {}...'.format(layer))
        for run_index, iterator in tqdm(enumerate(iterator_list)):
            print("############# Run {} #############".format(run_index))
            activations  = extractor.extract_activations(iterator, language)
            hidden_states_activations = activations[0]
            attention_heads_activations = activations[1]
            (cls_hidden_states_activations, cls_attention_activations) = activations[2]
            (sep_hidden_states_activations, sep_attention_activations) = activations[3]
            activations = pd.concat([hidden_states_activations, attention_heads_activations], axis=1)
            cls_activations = pd.concat([cls_hidden_states_activations, cls_attention_activations], axis=1)
            sep_activations = pd.concat([sep_hidden_states_activations, sep_attention_activations], axis=1)

            # activations
            heads = np.arange(1, 13)
            columns_to_retrieve = ['hidden_state-layer-{}-{}'.format(layer, i) for i in range(1, 769)]
            if layer > 0:
                columns_to_retrieve += ['attention-layer-{}-head-{}-{}'.format(layer, head, i) for head in heads for i in range(1, 65)]
            activations = activations[columns_to_retrieve]

            # CLS
            heads = np.arange(1, 13)
            columns_to_retrieve = ['CLS-hidden_state-layer-{}-{}'.format(layer, i) for i in range(1, 769)]
            if layer > 0:
                columns_to_retrieve += ['CLS-attention-layer-{}-head-{}-{}'.format(layer, head, i) for head in heads for i in range(1, 65)]
            cls_activations = cls_activations[columns_to_retrieve]

            # SEP
            heads = np.arange(1, 13)
            columns_to_retrieve = ['SEP-hidden_state-layer-{}-{}'.format(layer, i) for i in range(1, 769)]
            if layer > 0:
                columns_to_retrieve += ['SEP-attention-layer-{}-head-{}-{}'.format(layer, head, i) for head in heads for i in range(1, 65)]
            sep_activations = sep_activations[columns_to_retrieve]

            save_path = os.path.join(saving_path_folder, name + str(seed) + '_layer-{}'.format(layer))
            check_folder(save_path)
            print('\tSaving in {}.'.format(save_path))
            activations.to_csv(os.path.join(save_path, 'activations_run{}.csv'.format(run_index + 1)), index=False)
            cls_activations.to_csv(os.path.join(save_path, 'cls_run{}.csv'.format(run_index + 1)), index=False)
            sep_activations.to_csv(os.path.join(save_path, 'sep_run{}.csv'.format(run_index + 1)), index=False)


0it [00:00, ?it/s]

bert-base-cased  - Extracting activations for layer 0...
############# Run 0 #############


1it [00:23, 23.37s/it]

############# Run 1 #############


2it [00:50, 24.54s/it]

############# Run 2 #############


3it [01:23, 26.94s/it]

############# Run 3 #############


3it [01:44, 34.92s/it]


KeyboardInterrupt: 