# Notebook to extract hidden-states and attention heads activations from bert model predictions

In [66]:
import os
import glob
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from model import BertExtractor
from tokenizer import tokenize

In [41]:
def check_folder(path):
    """Create adequate folders if necessary."""
    try:
        if not os.path.isdir(path):
            check_folder(os.path.dirname(path))
            os.mkdir(path)
    except:
        pass

Defining variables:

In [42]:
template = '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/text/english/text_english_run*.txt' # path to text input
language = 'english'

In [43]:
pretrained_bert_models = [
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/models/BERT/SENTENCE-CLASSIFICATION_SST-2_bert_base_cased/fine_tuned',
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/models/BERT/SENTENCE-CLASSIFICATION_COLA_bert_base_cased/fine_tuned',
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/models/BERT/POS_CONLL2003_bert_base_cased/fine_tuned',
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/models/BERT/NER_CONLL2003_bert_base_cased/fine_tuned'
] # path to the model from which we want to retrieve the activations
infos = [os.path.basename(os.path.dirname(model)).split('_') for model in pretrained_bert_models]
names = ['{}_{}_{}_{}'.format(info[2], info[0], info[1], os.path.basename(model)) for (info, model) in zip(infos, pretrained_bert_models)]
config_paths = [os.path.join(os.path.dirname(model), 'config.yml') for model in pretrained_bert_models]
saving_path_folders = [
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/{}'.format(language, name) for name in names]
prediction_types = ['sentence' for i in pretrained_bert_models]

Creating iterator for each run:

In [44]:
pretrained_bert_models += ['bert-base-cased', 'bert-base-cased']
names += ['bert-base-cased', 'bert-base-cased']
config_paths += [None, None]
saving_path_folders += [
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-cased'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-cased_sequential'.format(language)]
prediction_types += ['sentence', 'sequential']

In [45]:
#pretrained_bert_models = ['bert-base-cased']
#names = ['bert-base-cased']
#config_paths = [None]
#saving_path_folders = [
#    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-cased_duplicata'.format(language)]
#prediction_types = ['sentence']

In [46]:
names

['bert_SENTENCE-CLASSIFICATION_SST-2_fine_tuned',
 'bert_SENTENCE-CLASSIFICATION_COLA_fine_tuned',
 'bert_POS_CONLL2003_fine_tuned',
 'bert_NER_CONLL2003_fine_tuned',
 'bert-base-cased',
 'bert-base-cased']

In [49]:
#template = '/Users/alexpsq/Code/Parietal/data/text_english_run*.txt' # path to text input


In [50]:
paths = sorted(glob.glob(template))

In [51]:
iterator_list = [tokenize(path, language, train=False) for path in paths]

100%|██████████| 135/135 [00:00<00:00, 69128.44it/s]
100%|██████████| 135/135 [00:00<00:00, 168973.75it/s]
100%|██████████| 176/176 [00:00<00:00, 206604.40it/s]
100%|██████████| 173/173 [00:00<00:00, 207912.49it/s]
100%|██████████| 177/177 [00:00<00:00, 98330.04it/s]
100%|██████████| 216/216 [00:00<00:00, 225927.60it/s]
100%|██████████| 196/196 [00:00<00:00, 197189.63it/s]
100%|██████████| 145/145 [00:00<00:00, 247930.73it/s]
100%|██████████| 207/207 [00:00<00:00, 661249.75it/s]

Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.





## Activation extraction

In [11]:
for index, bert_model in enumerate(pretrained_bert_models):
    extractor = BertExtractor(bert_model, language, names[index], prediction_types[index], output_hidden_states=True, output_attentions=True, config_path=config_paths[index])
    print(extractor.name, ' - Extracting activations ...')
    for run_index, iterator in tqdm(enumerate(iterator_list)):
        print("############# Run {} #############".format(run_index))
        check_folder(saving_path_folders[index])
        activations  = extractor.extract_activations(iterator, language)
        hidden_states_activations = activations[0]
        attention_heads_activations = activations[1]
        (cls_hidden_states_activations, cls_attention_activations) = activations[2]
        (sep_hidden_states_activations, sep_attention_activations) = activations[3]
        activations = pd.concat([hidden_states_activations, attention_heads_activations], axis=1)
        cls_activations = pd.concat([cls_hidden_states_activations, cls_attention_activations], axis=1)
        sep_activations = pd.concat([sep_hidden_states_activations, sep_attention_activations], axis=1)
        
        activations.to_csv(os.path.join(saving_path_folders[index], 'activations_run{}.csv'.format(run_index + 1)), index=False)
        cls_activations.to_csv(os.path.join(saving_path_folders[index], 'cls_run{}.csv'.format(run_index + 1)), index=False)
        sep_activations.to_csv(os.path.join(saving_path_folders[index], 'sep_run{}.csv'.format(run_index + 1)), index=False)
        
        

0it [00:00, ?it/s]

bert_SENTENCE-CLASSIFICATION_SST-2_fine_tuned  - Extracting activations ...
############# Run 0 #############


1it [01:09, 69.20s/it]

############# Run 1 #############


2it [02:22, 70.45s/it]

############# Run 2 #############


3it [04:24, 86.02s/it]

############# Run 3 #############


4it [05:45, 84.24s/it]

############# Run 4 #############


5it [07:00, 81.70s/it]

############# Run 5 #############


6it [09:17, 98.26s/it]

############# Run 6 #############


7it [10:48, 96.06s/it]

############# Run 7 #############


8it [12:00, 88.75s/it]

############# Run 8 #############


9it [14:17, 95.24s/it] 
0it [00:00, ?it/s]

bert_SENTENCE-CLASSIFICATION_COLA_fine_tuned  - Extracting activations ...
############# Run 0 #############


1it [01:10, 70.00s/it]

############# Run 1 #############


2it [02:24, 71.22s/it]

############# Run 2 #############


3it [03:48, 75.06s/it]

############# Run 3 #############


4it [05:21, 80.63s/it]

############# Run 4 #############


5it [06:41, 80.33s/it]

############# Run 5 #############


6it [08:16, 84.65s/it]

############# Run 6 #############


7it [10:14, 94.86s/it]

############# Run 7 #############


8it [11:26, 87.83s/it]

############# Run 8 #############


9it [13:02, 86.95s/it]
0it [00:00, ?it/s]

bert_POS_CONLL2003_fine_tuned  - Extracting activations ...
############# Run 0 #############


1it [01:44, 104.64s/it]

############# Run 1 #############


2it [02:57, 95.14s/it] 

############# Run 2 #############


3it [04:26, 93.30s/it]

############# Run 3 #############


4it [05:50, 90.42s/it]

############# Run 4 #############


5it [07:10, 87.44s/it]

############# Run 5 #############


6it [08:44, 89.40s/it]

############# Run 6 #############


7it [10:12, 88.96s/it]

############# Run 7 #############


8it [11:52, 92.18s/it]

############# Run 8 #############


9it [13:29, 89.89s/it]
0it [00:00, ?it/s]

bert_NER_CONLL2003_fine_tuned  - Extracting activations ...
############# Run 0 #############


1it [01:07, 67.39s/it]

############# Run 1 #############


2it [02:56, 79.76s/it]

############# Run 2 #############


3it [04:18, 80.46s/it]

############# Run 3 #############


4it [05:39, 80.74s/it]

############# Run 4 #############


5it [06:59, 80.64s/it]

############# Run 5 #############


6it [08:56, 91.42s/it]

############# Run 6 #############


7it [10:28, 91.55s/it]

############# Run 7 #############


8it [11:41, 85.91s/it]

############# Run 8 #############


9it [13:41, 91.31s/it]


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=435779157.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…

0it [00:00, ?it/s]


bert-base-cased  - Extracting activations ...
############# Run 0 #############


1it [01:10, 70.46s/it]

############# Run 1 #############


2it [02:24, 71.63s/it]

############# Run 2 #############


3it [04:29, 87.57s/it]

############# Run 3 #############


4it [05:50, 85.48s/it]

############# Run 4 #############


5it [07:09, 83.57s/it]

############# Run 5 #############


6it [09:13, 95.90s/it]

############# Run 6 #############


7it [10:46, 94.93s/it]

############# Run 7 #############


8it [11:59, 88.27s/it]

############# Run 8 #############


9it [13:30, 90.11s/it]
0it [00:00, ?it/s]

bert-base-cased  - Extracting activations ...
############# Run 0 #############


1it [05:57, 357.72s/it]

############# Run 1 #############


2it [13:10, 380.18s/it]

############# Run 2 #############


3it [21:04, 408.49s/it]

############# Run 3 #############


4it [28:31, 419.96s/it]

############# Run 4 #############


5it [35:44, 423.80s/it]

############# Run 5 #############


6it [43:58, 444.81s/it]

############# Run 6 #############


7it [53:02, 474.81s/it]

############# Run 7 #############


8it [1:00:08, 460.12s/it]

############# Run 8 #############


9it [1:08:50, 458.95s/it]


### Generate control activations

In [76]:
bert_model = 'bert-base-cased'
language = 'english'
name = 'bert-base-cased_control_'
prediction_type = 'sentence'
saving_path_folder = '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}'.format(language)
seeds = [24, 213, 1111, 61, 183]

In [77]:
def randomize_layer(model, layer_nb):
    """Randomize layer weights and put bias to zero.
    The input "layer_nb" goes from 1 to 12 to be coherent with the rest of the analysis.
    It is then transfomed in the function.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.self.query.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.query.weight))
    model.encoder.layer[layer_nb].attention.self.query.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.query.bias))
    model.encoder.layer[layer_nb].attention.self.key.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.key.weight))
    model.encoder.layer[layer_nb].attention.self.key.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.key.bias))
    model.encoder.layer[layer_nb].attention.self.value.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.value.weight))
    model.encoder.layer[layer_nb].attention.self.value.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.value.bias))
    model.encoder.layer[layer_nb].attention.output.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.output.dense.weight))
    model.encoder.layer[layer_nb].attention.output.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.output.dense.bias))
    model.encoder.layer[layer_nb].attention.output.LayerNorm.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.output.LayerNorm.weight))
    model.encoder.layer[layer_nb].attention.output.LayerNorm.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.output.LayerNorm.bias))
    model.encoder.layer[layer_nb].intermediate.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].intermediate.dense.weight))
    model.encoder.layer[layer_nb].intermediate.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].intermediate.dense.bias))
    model.encoder.layer[layer_nb].output.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].output.dense.weight))
    model.encoder.layer[layer_nb].output.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].output.dense.bias))
    model.encoder.layer[layer_nb].output.LayerNorm.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].output.LayerNorm.weight))
    model.encoder.layer[layer_nb].output.LayerNorm.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].output.LayerNorm.bias))
    return model

In [78]:
def randomize_attention_query(model, layer_nb):
    """Randomize attention query weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.self.query.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.query.weight))
    model.encoder.layer[layer_nb].attention.self.query.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.query.bias))
    return model

def randomize_attention_key(model, layer_nb):
    """Randomize attention key weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.self.key.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.key.weight))
    model.encoder.layer[layer_nb].attention.self.key.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.key.bias))
    return model

def randomize_attention_value(model, layer_nb):
    """Randomize attention value weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.self.value.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.value.weight))
    model.encoder.layer[layer_nb].attention.self.value.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.value.bias))
    return model

def randomize_attention_output_dense(model, layer_nb):
    """Randomize attention dense network weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.output.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.output.dense.weight))
    model.encoder.layer[layer_nb].attention.output.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.output.dense.bias))
    return model


def randomize_intermediate_dense(model, layer_nb):
    """Randomize intermediate dense network weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].intermediate.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].intermediate.dense.weight))
    model.encoder.layer[layer_nb].intermediate.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].intermediate.dense.bias))
    return model

def randomize_outptut_dense(model, layer_nb):
    """Randomize output dense network weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].output.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].output.dense.weight))
    model.encoder.layer[layer_nb].output.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].output.dense.bias))
    return model


In [79]:
def randomize_embeddings(model):
    """Randomize embeddings weights and put bias to zero.
    """
    model.embeddings.word_embeddings.weight = torch.nn.parameter.Parameter(torch.rand_like(model.embeddings.word_embeddings.weight))
    model.embeddings.position_embeddings.weight = torch.nn.parameter.Parameter(torch.rand_like(model.embeddings.position_embeddings.weight))
    model.embeddings.token_type_embeddings.weight = torch.nn.parameter.Parameter(torch.rand_like(model.embeddings.token_type_embeddings.weight))
    model.embeddings.LayerNorm.weight = torch.nn.parameter.Parameter(torch.rand_like(model.embeddings.LayerNorm.weight))
    model.embeddings.LayerNorm.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.embeddings.LayerNorm.bias))
    return model

In [80]:
for seed in seeds:
    set_seed(seed)
    for layer in range(13):
        extractor = BertExtractor(bert_model, language, name, prediction_type, output_hidden_states=True, output_attentions=True, config_path=None)
        if layer==0:
            extractor.model = randomize_embeddings(extractor.model)
        else:
            extractor.model = randomize_layer(extractor.model, layer)
        print(extractor.name + str(seed), ' - Extracting activations for layer {}...'.format(layer))
        for run_index, iterator in tqdm(enumerate(iterator_list)):
            print("############# Run {} #############".format(run_index))
            activations  = extractor.extract_activations(iterator, language)
            hidden_states_activations = activations[0]
            attention_heads_activations = activations[1]
            (cls_hidden_states_activations, cls_attention_activations) = activations[2]
            (sep_hidden_states_activations, sep_attention_activations) = activations[3]
            activations = pd.concat([hidden_states_activations, attention_heads_activations], axis=1)
            cls_activations = pd.concat([cls_hidden_states_activations, cls_attention_activations], axis=1)
            sep_activations = pd.concat([sep_hidden_states_activations, sep_attention_activations], axis=1)

            # activations
            heads = np.arange(1, 13)
            columns_to_retrieve = ['hidden_state-layer-{}-{}'.format(layer, i) for i in range(1, 769)]
            if layer > 0:
                columns_to_retrieve += ['attention-layer-{}-head-{}-{}'.format(layer, head, i) for head in heads for i in range(1, 65)]
            activations = activations[columns_to_retrieve]

            # CLS
            heads = np.arange(1, 13)
            columns_to_retrieve = ['CLS-hidden_state-layer-{}-{}'.format(layer, i) for i in range(1, 769)]
            if layer > 0:
                columns_to_retrieve += ['CLS-attention-layer-{}-head-{}-{}'.format(layer, head, i) for head in heads for i in range(1, 65)]
            cls_activations = cls_activations[columns_to_retrieve]

            # SEP
            heads = np.arange(1, 13)
            columns_to_retrieve = ['SEP-hidden_state-layer-{}-{}'.format(layer, i) for i in range(1, 769)]
            if layer > 0:
                columns_to_retrieve += ['SEP-attention-layer-{}-head-{}-{}'.format(layer, head, i) for head in heads for i in range(1, 65)]
            sep_activations = sep_activations[columns_to_retrieve]

            save_path = os.path.join(saving_path_folder, name + str(seed) + '_layer-{}'.format(layer))
            check_folder(save_path)
            print('\tSaving in {}.'.format(save_path))
            activations.to_csv(os.path.join(save_path, 'activations_run{}.csv'.format(run_index + 1)), index=False)
            cls_activations.to_csv(os.path.join(save_path, 'cls_run{}.csv'.format(run_index + 1)), index=False)
            sep_activations.to_csv(os.path.join(save_path, 'sep_run{}.csv'.format(run_index + 1)), index=False)


0it [00:00, ?it/s]

bert-base-cased  - Extracting activations for layer 0...
############# Run 0 #############


1it [00:23, 23.37s/it]

############# Run 1 #############


2it [00:50, 24.54s/it]

############# Run 2 #############


3it [01:23, 26.94s/it]

############# Run 3 #############


3it [01:44, 34.92s/it]


KeyboardInterrupt: 