# Notebook to extract hidden-states and attention heads activations from bert model predictions

In [1]:
import os
import glob
import pandas as pd
from tqdm import tqdm
from model import BertExtractor
from tokenizer import tokenize

Defining variables:

In [2]:
template = '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/text/english/*.txt' # path to text input
language = 'english'
prediction_type = 'sentence'

In [3]:
pretrained_bert_models = [
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/models/BERT/NER_CONLL2003_bert-base-cased/fine-tuned'] # path to the model from which we want to retrieve the activations
infos = [os.path.basename(os.path.dirname(model)).split('_') for model in pretrained_bert_models]
names = ['{}_{}_{}_{}'.format(info[2], info[0], info[1], os.path.basename(model)) for (info, model) in zip(infos, pretrained_bert_models)]
config_paths = [os.path.join(os.path.dirname(model), 'config.yml') for model in pretrained_bert_models]
saving_path_folders = [
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/{}'.format(language, name) for name in names]

In [4]:
## testing notebook

template = '/Users/alexpsq/Code/NeuroSpin/LePetitPrince/data/text/english/text_en.txt' # path to text input
language = 'english'
prediction_type = 'sentence'
pretrained_bert_models = [
    'bert-base-cased'] # path to the model from which we want to retrieve the activations
names = ['bert-base-cased']
config_paths = [None]
saving_path_folders = [
    '/Users/alexpsq/Code/Parietal/']

Creating iterator for each run:

In [5]:
paths = sorted(glob.glob(template))

In [6]:
iterator_list = [tokenize(path, language, train=False) for path in paths]

100%|██████████| 1597/1597 [00:00<00:00, 2121058.74it/s]

Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.





## Activation extraction

In [8]:
for index, bert_model in enumerate(pretrained_bert_models):
    extractor = BertExtractor(bert_model, language, names[index], prediction_type, output_hidden_states=True, output_attentions=True, config_path=config_paths[index])
    print(extractor.name, ' - Extracting activations ...')
    for run_index, iterator in tqdm(enumerate(iterator_list)):
        hidden_states_activations, attention_heads_activations  = extractor.extract_activations(iterator, language)
        hidden_states_activations.to_csv(os.path.join(saving_path_folders[index], 'hidden_states_run{}.csv'.format(run_index)))
        attention_heads_activations.to_csv(os.path.join(saving_path_folders[index], 'attention_heads_run{}.csv'.format(run_index)))

0it [00:00, ?it/s]

bert-base-cased  - Extracting activations ...


1it [12:33, 753.82s/it]
