# Notebook to extract hidden-states and attention heads activations from bert model predictions

In [1]:
import os
import glob
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from model import GPT2Extractor
from tokenizer import tokenize

In [2]:
def check_folder(path):
    """Create adequate folders if necessary."""
    try:
        if not os.path.isdir(path):
            check_folder(os.path.dirname(path))
            os.mkdir(path)
    except:
        pass

Defining variables:

In [3]:
template = '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/text/english/text_english_run*.txt' # path to text input
language = 'english'

Creating iterator for each run:

In [4]:
pretrained_gpt2_models = ['gpt2']
names = ['gpt2']
config_paths = [None]
saving_path_folders = [
    '/Users/alexpsq/Code/Parietal/data/stimuli-representations/{}/gpt2'.format(language)]
#saving_path_folders = [
#    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/gpt2'.format(language)]
prediction_types = ['sentence']


In [5]:
names

['gpt2']

In [6]:
template = '/Users/alexpsq/Code/Parietal/data/text_english_run*.txt' # path to text input


In [7]:
paths = sorted(glob.glob(template))

In [8]:
iterator_list = [tokenize(path, language, train=False) for path in paths]

100%|██████████| 135/135 [00:00<00:00, 569648.93it/s]
100%|██████████| 135/135 [00:00<00:00, 813550.34it/s]
100%|██████████| 176/176 [00:00<00:00, 994875.34it/s]
100%|██████████| 173/173 [00:00<00:00, 418946.07it/s]
100%|██████████| 177/177 [00:00<00:00, 912029.25it/s]
100%|██████████| 216/216 [00:00<00:00, 652245.98it/s]
100%|██████████| 196/196 [00:00<00:00, 690246.50it/s]
100%|██████████| 145/145 [00:00<00:00, 457617.82it/s]
100%|██████████| 207/207 [00:00<00:00, 1087996.15it/s]

Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.





## Activation extraction

In [9]:
output_attentions = False
output_hidden_states = True

In [10]:
for index, gpt2_model in enumerate(pretrained_gpt2_models):
    extractor = GPT2Extractor(gpt2_model, language, names[index], prediction_types[index], output_hidden_states=output_hidden_states, output_attentions=output_attentions, config_path=config_paths[index])
    print(extractor.name, ' - Extracting activations ...')
    for run_index, iterator in tqdm(enumerate(iterator_list)):
        print("############# Run {} #############".format(run_index))
        check_folder(saving_path_folders[index])
        activations  = extractor.extract_activations(iterator, language)
        hidden_states_activations = activations[0]
        attention_heads_activations = activations[1]
        activations = pd.concat([hidden_states_activations, attention_heads_activations], axis=1)
        
        activations.to_csv(os.path.join(saving_path_folders[index], 'activations_run{}.csv'.format(run_index + 1)), index=False)
        

0it [00:00, ?it/s]

gpt2  - Extracting activations ...
############# Run 0 #############


0it [00:05, ?it/s]


TypeError: cannot concatenate object of type '<class 'list'>'; only Series and DataFrame objs are valid

In [11]:
hidden_states_activations

Unnamed: 0,hidden_state-layer-0-1,hidden_state-layer-0-2,hidden_state-layer-0-3,hidden_state-layer-0-4,hidden_state-layer-0-5,hidden_state-layer-0-6,hidden_state-layer-0-7,hidden_state-layer-0-8,hidden_state-layer-0-9,hidden_state-layer-0-10,...,hidden_state-layer-12-759,hidden_state-layer-12-760,hidden_state-layer-12-761,hidden_state-layer-12-762,hidden_state-layer-12-763,hidden_state-layer-12-764,hidden_state-layer-12-765,hidden_state-layer-12-766,hidden_state-layer-12-767,hidden_state-layer-12-768
0,-0.222231,-0.199346,0.102250,0.072660,-0.109572,-0.251374,-0.427017,-0.192565,-0.090373,-0.170221,...,-0.080227,0.069681,-0.214584,-0.155730,0.703937,-0.156755,0.002902,-0.229858,-0.056037,-0.035686
1,-0.046362,-0.132434,0.061459,0.199972,0.023228,0.044123,-0.422747,0.118858,-0.056940,-0.028383,...,0.193726,-0.091036,0.351631,0.080826,0.024973,0.143238,0.126042,0.072917,-0.112462,-0.029984
2,-0.118786,-0.107309,0.112027,0.168924,-0.044425,-0.127225,-0.485738,-0.087842,0.077962,-0.014586,...,-0.471415,0.179074,0.452323,0.054650,0.266642,0.126591,0.120691,0.058829,-0.294023,-0.180088
3,-0.019110,-0.148983,0.221716,0.134502,-0.158677,-0.114787,-0.451582,-0.186877,0.073118,-0.092988,...,-0.230266,0.223334,0.257051,0.072100,0.021015,0.158723,0.245941,-0.148598,0.192446,0.064503
4,-0.176048,-0.148281,0.341300,0.020825,-0.146095,-0.004921,-0.512484,-0.072568,-0.079359,-0.025802,...,-0.089730,0.417009,0.170452,0.164576,0.559334,0.247522,0.036557,0.004062,-0.024033,0.180207
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1888,-0.117341,0.028862,0.136164,0.208377,-0.114606,-0.003114,-0.489909,-0.084179,-0.006406,-0.203908,...,-0.025312,0.264754,0.181911,0.096496,1.213056,0.510090,-0.240733,-0.259916,-0.091030,0.613524
1889,-0.023117,0.019310,0.138129,0.104024,-0.043709,-0.025883,-0.515242,-0.089623,0.032116,-0.067125,...,-0.085490,0.127606,0.355455,-0.004345,1.160623,0.390218,-0.194015,-0.346576,-0.213058,0.623512
1890,-0.005001,-0.147860,0.263442,0.080779,0.019132,0.074155,-0.558621,-0.244804,0.123008,-0.094988,...,-0.145380,0.182243,0.245503,0.058192,2.074650,0.414970,-0.519282,-0.641491,-0.046649,0.676222
1891,-0.145681,0.085838,0.187571,0.059699,0.041959,-0.155613,-0.619032,-0.030259,0.111349,-0.023952,...,-0.152615,0.269339,0.417994,0.230248,1.897090,0.376180,-0.449664,-0.705842,-0.185300,0.792126


### Generate control activations

In [76]:
bert_model = 'bert-base-cased'
language = 'english'
name = 'bert-base-cased_control'
prediction_type = 'sentence'
saving_path_folder = '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}'.format(language)


In [77]:
def randomize_layer(model, layer_nb):
    """Randomize layer weights and put bias to zero.
    The input "layer_nb" goes from 1 to 12 to be coherent with the rest of the analysis.
    It is then transfomed in the function.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.self.query.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.query.weight))
    model.encoder.layer[layer_nb].attention.self.query.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.query.bias))
    model.encoder.layer[layer_nb].attention.self.key.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.key.weight))
    model.encoder.layer[layer_nb].attention.self.key.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.key.bias))
    model.encoder.layer[layer_nb].attention.self.value.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.value.weight))
    model.encoder.layer[layer_nb].attention.self.value.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.value.bias))
    model.encoder.layer[layer_nb].attention.output.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.output.dense.weight))
    model.encoder.layer[layer_nb].attention.output.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.output.dense.bias))
    model.encoder.layer[layer_nb].attention.output.LayerNorm.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.output.LayerNorm.weight))
    model.encoder.layer[layer_nb].attention.output.LayerNorm.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.output.LayerNorm.bias))
    model.encoder.layer[layer_nb].intermediate.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].intermediate.dense.weight))
    model.encoder.layer[layer_nb].intermediate.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].intermediate.dense.bias))
    model.encoder.layer[layer_nb].output.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].output.dense.weight))
    model.encoder.layer[layer_nb].output.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].output.dense.bias))
    model.encoder.layer[layer_nb].output.LayerNorm.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].output.LayerNorm.weight))
    model.encoder.layer[layer_nb].output.LayerNorm.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].output.LayerNorm.bias))
    return model

In [78]:
def randomize_attention_query(model, layer_nb):
    """Randomize attention query weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.self.query.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.query.weight))
    model.encoder.layer[layer_nb].attention.self.query.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.query.bias))
    return model

def randomize_attention_key(model, layer_nb):
    """Randomize attention key weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.self.key.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.key.weight))
    model.encoder.layer[layer_nb].attention.self.key.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.key.bias))
    return model

def randomize_attention_value(model, layer_nb):
    """Randomize attention value weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.self.value.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.value.weight))
    model.encoder.layer[layer_nb].attention.self.value.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.value.bias))
    return model

def randomize_attention_output_dense(model, layer_nb):
    """Randomize attention dense network weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.output.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.output.dense.weight))
    model.encoder.layer[layer_nb].attention.output.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.output.dense.bias))
    return model


def randomize_intermediate_dense(model, layer_nb):
    """Randomize intermediate dense network weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].intermediate.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].intermediate.dense.weight))
    model.encoder.layer[layer_nb].intermediate.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].intermediate.dense.bias))
    return model

def randomize_outptut_dense(model, layer_nb):
    """Randomize output dense network weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].output.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].output.dense.weight))
    model.encoder.layer[layer_nb].output.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].output.dense.bias))
    return model


In [79]:
def randomize_embeddings(model):
    """Randomize embeddings weights and put bias to zero.
    """
    model.embeddings.word_embeddings.weight = torch.nn.parameter.Parameter(torch.rand_like(model.embeddings.word_embeddings.weight))
    model.embeddings.position_embeddings.weight = torch.nn.parameter.Parameter(torch.rand_like(model.embeddings.position_embeddings.weight))
    model.embeddings.token_type_embeddings.weight = torch.nn.parameter.Parameter(torch.rand_like(model.embeddings.token_type_embeddings.weight))
    model.embeddings.LayerNorm.weight = torch.nn.parameter.Parameter(torch.rand_like(model.embeddings.LayerNorm.weight))
    model.embeddings.LayerNorm.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.embeddings.LayerNorm.bias))
    return model

In [80]:
for layer in range(13):
    extractor = BertExtractor(bert_model, language, name, prediction_type, output_hidden_states=True, output_attentions=True, config_path=None)
    if layer==0:
        extractor.model = randomize_embeddings(extractor.model)
    else:
        extractor.model = randomize_layer(extractor.model, layer)
    print(extractor.name, ' - Extracting activations for layer {}...'.format(layer))
    for run_index, iterator in tqdm(enumerate(iterator_list)):
        print("############# Run {} #############".format(run_index))
        check_folder(saving_path_folders[index])
        activations  = extractor.extract_activations(iterator, language)
        hidden_states_activations = activations[0]
        attention_heads_activations = activations[1]
        (cls_hidden_states_activations, cls_attention_activations) = activations[2]
        (sep_hidden_states_activations, sep_attention_activations) = activations[3]
        activations = pd.concat([hidden_states_activations, attention_heads_activations], axis=1)
        cls_activations = pd.concat([cls_hidden_states_activations, cls_attention_activations], axis=1)
        sep_activations = pd.concat([sep_hidden_states_activations, sep_attention_activations], axis=1)

        # activations
        heads = np.arange(1, 13)
        columns_to_retrieve = ['hidden_state-layer-{}-{}'.format(layer, i) for i in range(1, 769)]
        if layer > 0:
            columns_to_retrieve += ['attention-layer-{}-head-{}-{}'.format(layer, head, i) for head in heads for i in range(1, 65)]
        activations = activations[columns_to_retrieve]

        # CLS
        heads = np.arange(1, 13)
        columns_to_retrieve = ['CLS-hidden_state-layer-{}-{}'.format(layer, i) for i in range(1, 769)]
        if layer > 0:
            columns_to_retrieve += ['CLS-attention-layer-{}-head-{}-{}'.format(layer, head, i) for head in heads for i in range(1, 65)]
        cls_activations = cls_activations[columns_to_retrieve]

        # SEP
        heads = np.arange(1, 13)
        columns_to_retrieve = ['SEP-hidden_state-layer-{}-{}'.format(layer, i) for i in range(1, 769)]
        if layer > 0:
            columns_to_retrieve += ['SEP-attention-layer-{}-head-{}-{}'.format(layer, head, i) for head in heads for i in range(1, 65)]
        sep_activations = sep_activations[columns_to_retrieve]
    
        check_folder(os.path.join(saving_path_folder, name + '_layer-{}'.format(layer)))
        activations.to_csv(os.path.join(saving_path_folder, name + '_layer-{}'.format(layer), 'activations_run{}.csv'.format(run_index + 1)), index=False)
        cls_activations.to_csv(os.path.join(saving_path_folder, name + '_layer-{}'.format(layer), 'cls_run{}.csv'.format(run_index + 1)), index=False)
        sep_activations.to_csv(os.path.join(saving_path_folder, name + '_layer-{}'.format(layer), 'sep_run{}.csv'.format(run_index + 1)), index=False)



0it [00:00, ?it/s]

bert-base-cased  - Extracting activations for layer 0...
############# Run 0 #############


1it [00:23, 23.37s/it]

############# Run 1 #############


2it [00:50, 24.54s/it]

############# Run 2 #############


3it [01:23, 26.94s/it]

############# Run 3 #############


3it [01:44, 34.92s/it]


KeyboardInterrupt: 