# Notebook to extract hidden-states and attention heads activations from bert model predictions

In [1]:
import os
import glob
import torch
import gc
import numpy as np
import pandas as pd
from tqdm import tqdm
from model import BertExtractor
from sklearn.preprocessing import StandardScaler
from tokenizer import tokenize
from utils import set_seed
from numpy import linalg as la

In [2]:
def check_folder(path):
    """Create adequate folders if necessary."""
    try:
        if not os.path.isdir(path):
            check_folder(os.path.dirname(path))
            os.mkdir(path)
    except:
        pass

In [3]:
def transform(activations, path, name, run_index, n_layers_hidden=13, n_layers_attention=12, hidden_size=768):
    assert activations.values.shape[1] == (n_layers_hidden + n_layers_attention) * hidden_size
    indexes = [[index*hidden_size, (index+1)*hidden_size] for index in range(n_layers_hidden + n_layers_attention)]
    for order in [None]:
        matrices = []
        for index in indexes:
            matrix = activations.values[:, index[0]:index[1]]
            with_std = True if order=='std' else False
            scaler = StandardScaler(with_mean=True, with_std=with_std)
            scaler.fit(matrix)
            matrix = scaler.transform(matrix)
            if order is not None and order != 'std':
                matrix = matrix / np.mean(la.norm(matrix, ord=order, axis=1))
            matrices.append(matrix)
        matrices = np.hstack(matrices)
        new_data = pd.DataFrame(matrices, columns=activations.columns)
        new_path = path + '_norm-' + str(order).replace('np.', '')
        check_folder(new_path)
        new_data.to_csv(os.path.join(new_path, name + '_run{}.csv'.format(run_index + 1)), index=False)


Defining variables:

In [4]:
template = '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/text/english/text_english_run*.txt' # path to text input
language = 'english'

Creating iterator for each run:

In [5]:
#template = '/Users/alexpsq/Code/Parietal/data/text_english_run*.txt' # path to text input


In [6]:
paths = sorted(glob.glob(template))

In [7]:
iterator_list = [tokenize(path, language, train=False) for path in paths]

100%|██████████| 135/135 [00:00<00:00, 187060.14it/s]
100%|██████████| 135/135 [00:00<00:00, 207867.49it/s]
100%|██████████| 176/176 [00:00<00:00, 223155.23it/s]
100%|██████████| 173/173 [00:00<00:00, 227608.09it/s]
100%|██████████| 177/177 [00:00<00:00, 244368.60it/s]
100%|██████████| 216/216 [00:00<00:00, 248074.94it/s]
100%|██████████| 196/196 [00:00<00:00, 178055.79it/s]
100%|██████████| 145/145 [00:00<00:00, 197075.20it/s]
100%|██████████| 207/207 [00:00<00:00, 145114.65it/s]


In [8]:
iterator_list

[['Once , when I was six years old , I saw a magnificent picture in a book about the primeval forest called ‘ Real - life Stories . ’',
  'It showed a boa constrictor swallowing a wild animal .',
  'Here is a copy of the drawing .',
  'It said in the book : “ Boa constrictors swallow their prey whole , without chewing .',
  'Then they are not able to move , and they sleep for the six months it takes for digestion . ”',
  'So I thought a lot about the adventures of the jungle and , in turn , I managed , with a coloured pencil , to make my first drawing .',
  'My Drawing Number one .',
  'It looked like this : I showed my masterpiece to the grownups and I asked them if my drawing frightened them .',
  'They answered me : “ Why would anyone be frightened by a hat ? ”',
  'My drawing was not of a hat .',
  'It showed a boa constrictor digesting an elephant .',
  'I then drew the inside of the boa constrictor , so that the grownups could understand .',
  'They always need to have things exp

In [9]:
iterator_list = [[sent.lower() for sent in text] for text in iterator_list]

In [10]:
iterator_list

[['once , when i was six years old , i saw a magnificent picture in a book about the primeval forest called ‘ real - life stories . ’',
  'it showed a boa constrictor swallowing a wild animal .',
  'here is a copy of the drawing .',
  'it said in the book : “ boa constrictors swallow their prey whole , without chewing .',
  'then they are not able to move , and they sleep for the six months it takes for digestion . ”',
  'so i thought a lot about the adventures of the jungle and , in turn , i managed , with a coloured pencil , to make my first drawing .',
  'my drawing number one .',
  'it looked like this : i showed my masterpiece to the grownups and i asked them if my drawing frightened them .',
  'they answered me : “ why would anyone be frightened by a hat ? ”',
  'my drawing was not of a hat .',
  'it showed a boa constrictor digesting an elephant .',
  'i then drew the inside of the boa constrictor , so that the grownups could understand .',
  'they always need to have things exp

In [11]:
#import utils
#import seaborn as sns
#import matplotlib.pyplot as plt
#from transformers import BertTokenizer
#
#tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
#lengths = []
#
#for index in range(9):
#    batches, indexes = utils.batchify_per_sentence_with_pre_and_post_context(
#                iterator_list[index], 
#                1, 
#                20, 
#                0, 
#                'bert-base-uncased', 
#                max_length=512)
#    #lengths.append(np.array(sorted([len(item.split()) for item in batches])))
#    lengths.append(np.array(sorted([len(tokenizer.wordpiece_tokenizer.tokenize(item)) for item in batches])))
#
#    sns.boxplot(lengths[-1])
#    plt.show()
#    print()
#
#print(np.mean(np.array([np.mean(item) for item in lengths])))
#print(np.median(np.array([np.median(item) for item in lengths])))
#print(np.mean(np.array([np.median(item) for item in lengths])))

In [12]:
#import utils
#import seaborn as sns
#import matplotlib.pyplot as plt
#
#for index in range(9):
#    batches, indexes = utils.batchify_per_sentence_with_pre_and_post_context(
#                iterator_list[index], 
#                1, 
#                20, 
#                0, 
#                'bert-base-uncased', 
#                max_length=512)
#    print(len(batches))
#    print(sorted([len(item.split()) for item in batches]))
#    sns.boxplot(sorted([len(item.split()) for item in batches]))
#    plt.show()
#    print()

## Activation extraction

In [13]:
pretrained_bert_models = ['bert-base-uncased'] * 28
names = [
    'bert-base-uncased_pre-4_1_post-3_token-1-0',
    'bert-base-uncased_pre-4_1_post-3_token-2-0',
    'bert-base-uncased_pre-4_1_post-3_token-5-0',
    'bert-base-uncased_pre-4_1_post-3_token-10-0',
    'bert-base-uncased_pre-4_1_post-3_token-50-0',
    'bert-base-uncased_pre-4_1_post-3_token-100-0',
    'bert-base-uncased_pre-4_1_post-3_token-200-0',
    'bert-base-uncased_pre-4_1_post-3_token-1000-0',
    'bert-base-uncased_pre-4_1_post-3_token-1000-1',
    'bert-base-uncased_pre-4_1_post-3_token-1000-2',
    'bert-base-uncased_pre-4_1_post-3_token-1000-3',
    'bert-base-uncased_pre-4_1_post-3_token-1000-4',
    'bert-base-uncased_pre-4_1_post-3_token-1000-5',
    'bert-base-uncased_pre-4_1_post-3_token-1000-7',
    'bert-base-uncased_pre-4_1_post-3_token-1000-10',
    'bert-base-uncased_pre-4_1_post-3_token-1000-12',
    'bert-base-uncased_pre-4_1_post-3_token-1000-15',
    'bert-base-uncased_pre-4_1_post-3_token-1000-17',
    'bert-base-uncased_pre-4_1_post-3_token-1000-20',
    'bert-base-uncased_pre-4_1_post-3_token-1000-25',
    'bert-base-uncased_pre-4_1_post-3_token-1000-30',
    'bert-base-uncased_pre-4_1_post-3_token-1000-35',
    'bert-base-uncased_pre-4_1_post-3_token-1000-40',
    'bert-base-uncased_pre-4_1_post-3_token-1000-50',
    'bert-base-uncased_pre-4_1_post-3_token-1000-60',
    'bert-base-uncased_pre-4_1_post-3_token-1000-70',
    'bert-base-uncased_pre-4_1_post-3_token-1000-90',
    'bert-base-uncased_pre-4_1_post-3_token-1000-1000',
         ]
config_paths = [None] * 28
saving_path_folders = [
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-1-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-2-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-5-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-10-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-50-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-100-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-200-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-1000-0'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-1000-1'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-1000-2'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-1000-3'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-1000-4'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-1000-5'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-1000-7'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-1000-10'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-1000-12'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-1000-15'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-1000-17'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-1000-20'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-1000-25'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-1000-30'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-1000-35'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-1000-40'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-1000-50'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-1000-60'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-1000-70'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-1000-90'.format(language),
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-4_1_post-3_token-1000-1000'.format(language),
]
prediction_types = ['token-level'] * 28
number_of_sentence_list = [1] * 28
number_of_sentence_before_list = [4] * 28
number_of_sentence_after_list = [3] * 28
attention_length_before_list = [1, 2, 5, 10, 50, 100, 200] + [1000] * 21
attention_length_after_list = [0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 7, 10, 12, 15, 17, 20, 25, 30, 35, 40, 50, 60, 70, 90, 1000]


In [None]:
for index, bert_model in enumerate(pretrained_bert_models):
    extractor = BertExtractor(bert_model, 
                              language, 
                              names[index], 
                              prediction_types[index], 
                              output_hidden_states=True, 
                              output_attentions=False, 
                              attention_length_before=attention_length_before_list[index],
                              attention_length_after=attention_length_after_list[index],
                              config_path=config_paths[index], 
                              max_length=512, 
                              number_of_sentence=number_of_sentence_list[index], 
                              number_of_sentence_before=number_of_sentence_before_list[index], 
                              number_of_sentence_after=number_of_sentence_after_list[index],
                             )
    print(extractor.name, ' - Extracting activations ...')
    for run_index, iterator in tqdm(enumerate(iterator_list)):
        gc.collect()
        print("############# Run {} #############".format(run_index))
        activations  = extractor.extract_activations(iterator, language)
        hidden_states_activations = activations[0]
        attention_heads_activations = activations[1]
        #(cls_hidden_states_activations, cls_attention_activations) = activations[2]
        #(sep_hidden_states_activations, sep_attention_activations) = activations[3]
        #activations = pd.concat([hidden_states_activations, attention_heads_activations], axis=1)
        #cls_activations = pd.concat([cls_hidden_states_activations, cls_attention_activations], axis=1)
        #sep_activations = pd.concat([sep_hidden_states_activations, sep_attention_activations], axis=1)
        
        transform(
            hidden_states_activations, 
            saving_path_folders[index], 
            'activations', 
            run_index=run_index,
            n_layers_hidden=13,
            n_layers_attention=0, 
            hidden_size=768)
        #transform(cls_activations, saving_path_folders[index], 'cls')
        #transform(sep_activations, saving_path_folders[index], 'sep')
        
        #activations.to_csv(os.path.join(saving_path_folders[index], 'activations_run{}.csv'.format(run_index + 1)), index=False)
        #cls_activations.to_csv(os.path.join(saving_path_folders[index], 'cls_run{}.csv'.format(run_index + 1)), index=False)
        #sep_activations.to_csv(os.path.join(saving_path_folders[index], 'sep_run{}.csv'.format(run_index + 1)), index=False)
        del activations
        #del cls_activations
        #del sep_activations
        del hidden_states_activations
        #del attention_heads_activations
        #del cls_hidden_states_activations
        #del cls_attention_activations
        #del sep_hidden_states_activations
        #del sep_attention_activations
        

0it [00:00, ?it/s]

bert-base-uncased_pre-4_1_post-3_token-1-0  - Extracting activations ...
############# Run 0 #############


1it [02:25, 145.39s/it]

############# Run 1 #############


### Generate control activations

In [76]:
bert_model = 'bert-base-cased'
language = 'english'
name = 'bert-base-cased_control_'
prediction_type = 'sentence'
saving_path_folder = '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}'.format(language)
seeds = [24, 213, 1111, 61, 183]

In [77]:
def randomize_layer(model, layer_nb):
    """Randomize layer weights and put bias to zero.
    The input "layer_nb" goes from 1 to 12 to be coherent with the rest of the analysis.
    It is then transfomed in the function.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.self.query.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.query.weight))
    model.encoder.layer[layer_nb].attention.self.query.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.query.bias))
    model.encoder.layer[layer_nb].attention.self.key.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.key.weight))
    model.encoder.layer[layer_nb].attention.self.key.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.key.bias))
    model.encoder.layer[layer_nb].attention.self.value.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.value.weight))
    model.encoder.layer[layer_nb].attention.self.value.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.value.bias))
    model.encoder.layer[layer_nb].attention.output.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.output.dense.weight))
    model.encoder.layer[layer_nb].attention.output.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.output.dense.bias))
    model.encoder.layer[layer_nb].attention.output.LayerNorm.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.output.LayerNorm.weight))
    model.encoder.layer[layer_nb].attention.output.LayerNorm.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.output.LayerNorm.bias))
    model.encoder.layer[layer_nb].intermediate.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].intermediate.dense.weight))
    model.encoder.layer[layer_nb].intermediate.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].intermediate.dense.bias))
    model.encoder.layer[layer_nb].output.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].output.dense.weight))
    model.encoder.layer[layer_nb].output.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].output.dense.bias))
    model.encoder.layer[layer_nb].output.LayerNorm.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].output.LayerNorm.weight))
    model.encoder.layer[layer_nb].output.LayerNorm.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].output.LayerNorm.bias))
    return model

In [78]:
def randomize_attention_query(model, layer_nb):
    """Randomize attention query weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.self.query.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.query.weight))
    model.encoder.layer[layer_nb].attention.self.query.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.query.bias))
    return model

def randomize_attention_key(model, layer_nb):
    """Randomize attention key weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.self.key.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.key.weight))
    model.encoder.layer[layer_nb].attention.self.key.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.key.bias))
    return model

def randomize_attention_value(model, layer_nb):
    """Randomize attention value weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.self.value.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.value.weight))
    model.encoder.layer[layer_nb].attention.self.value.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.value.bias))
    return model

def randomize_attention_output_dense(model, layer_nb):
    """Randomize attention dense network weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.output.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.output.dense.weight))
    model.encoder.layer[layer_nb].attention.output.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.output.dense.bias))
    return model


def randomize_intermediate_dense(model, layer_nb):
    """Randomize intermediate dense network weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].intermediate.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].intermediate.dense.weight))
    model.encoder.layer[layer_nb].intermediate.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].intermediate.dense.bias))
    return model

def randomize_outptut_dense(model, layer_nb):
    """Randomize output dense network weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].output.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].output.dense.weight))
    model.encoder.layer[layer_nb].output.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].output.dense.bias))
    return model


In [79]:
def randomize_embeddings(model):
    """Randomize embeddings weights and put bias to zero.
    """
    model.embeddings.word_embeddings.weight = torch.nn.parameter.Parameter(torch.rand_like(model.embeddings.word_embeddings.weight))
    model.embeddings.position_embeddings.weight = torch.nn.parameter.Parameter(torch.rand_like(model.embeddings.position_embeddings.weight))
    model.embeddings.token_type_embeddings.weight = torch.nn.parameter.Parameter(torch.rand_like(model.embeddings.token_type_embeddings.weight))
    model.embeddings.LayerNorm.weight = torch.nn.parameter.Parameter(torch.rand_like(model.embeddings.LayerNorm.weight))
    model.embeddings.LayerNorm.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.embeddings.LayerNorm.bias))
    return model

In [80]:
for seed in seeds:
    set_seed(seed)
    for layer in range(13):
        extractor = BertExtractor(bert_model, language, name, prediction_type, output_hidden_states=True, output_attentions=True, config_path=None)
        if layer==0:
            extractor.model = randomize_embeddings(extractor.model)
        else:
            extractor.model = randomize_layer(extractor.model, layer)
        print(extractor.name + str(seed), ' - Extracting activations for layer {}...'.format(layer))
        for run_index, iterator in tqdm(enumerate(iterator_list)):
            print("############# Run {} #############".format(run_index))
            activations  = extractor.extract_activations(iterator, language)
            hidden_states_activations = activations[0]
            attention_heads_activations = activations[1]
            (cls_hidden_states_activations, cls_attention_activations) = activations[2]
            (sep_hidden_states_activations, sep_attention_activations) = activations[3]
            activations = pd.concat([hidden_states_activations, attention_heads_activations], axis=1)
            cls_activations = pd.concat([cls_hidden_states_activations, cls_attention_activations], axis=1)
            sep_activations = pd.concat([sep_hidden_states_activations, sep_attention_activations], axis=1)

            # activations
            heads = np.arange(1, 13)
            columns_to_retrieve = ['hidden_state-layer-{}-{}'.format(layer, i) for i in range(1, 769)]
            if layer > 0:
                columns_to_retrieve += ['attention-layer-{}-head-{}-{}'.format(layer, head, i) for head in heads for i in range(1, 65)]
            activations = activations[columns_to_retrieve]

            # CLS
            heads = np.arange(1, 13)
            columns_to_retrieve = ['CLS-hidden_state-layer-{}-{}'.format(layer, i) for i in range(1, 769)]
            if layer > 0:
                columns_to_retrieve += ['CLS-attention-layer-{}-head-{}-{}'.format(layer, head, i) for head in heads for i in range(1, 65)]
            cls_activations = cls_activations[columns_to_retrieve]

            # SEP
            heads = np.arange(1, 13)
            columns_to_retrieve = ['SEP-hidden_state-layer-{}-{}'.format(layer, i) for i in range(1, 769)]
            if layer > 0:
                columns_to_retrieve += ['SEP-attention-layer-{}-head-{}-{}'.format(layer, head, i) for head in heads for i in range(1, 65)]
            sep_activations = sep_activations[columns_to_retrieve]

            save_path = os.path.join(saving_path_folder, name + str(seed) + '_layer-{}'.format(layer))
            check_folder(save_path)
            print('\tSaving in {}.'.format(save_path))
            activations.to_csv(os.path.join(save_path, 'activations_run{}.csv'.format(run_index + 1)), index=False)
            cls_activations.to_csv(os.path.join(save_path, 'cls_run{}.csv'.format(run_index + 1)), index=False)
            sep_activations.to_csv(os.path.join(save_path, 'sep_run{}.csv'.format(run_index + 1)), index=False)


0it [00:00, ?it/s]

bert-base-cased  - Extracting activations for layer 0...
############# Run 0 #############


1it [00:23, 23.37s/it]

############# Run 1 #############


2it [00:50, 24.54s/it]

############# Run 2 #############


3it [01:23, 26.94s/it]

############# Run 3 #############


3it [01:44, 34.92s/it]


KeyboardInterrupt: 