# Notebook to extract hidden-states and attention heads activations from bert model predictions

In [1]:
import os
import glob
import torch
import gc
import numpy as np
import pandas as pd
from tqdm import tqdm
from model import BertExtractor
from sklearn.preprocessing import StandardScaler
from tokenizer import tokenize
from utils import set_seed
from numpy import linalg as la
import matplotlib.pyplot as plt

In [2]:
def check_folder(path):
    """Create adequate folders if necessary."""
    try:
        if not os.path.isdir(path):
            check_folder(os.path.dirname(path))
            os.mkdir(path)
    except:
        pass

In [3]:
def transform(activations, path, name, run_index, n_layers_hidden=13, n_layers_attention=12, hidden_size=768):
    assert activations.values.shape[1] == (n_layers_hidden + n_layers_attention) * hidden_size
    indexes = [[index*hidden_size, (index+1)*hidden_size] for index in range(n_layers_hidden + n_layers_attention)]
    for order in [2]:
        matrices = []
        for i, index in enumerate(indexes):
            matrix = activations.values[:, index[0]:index[1]]
            #with_std = True if order=='std' else False
            #scaler = StandardScaler(with_mean=True, with_std=with_std)
            #scaler.fit(matrix)
            #matrix = scaler.transform(matrix)
            if order is not None and order != 'std':
                matrix = matrix / np.mean(la.norm(matrix, ord=order, axis=1))
            matrices.append(matrix)
        matrices = np.hstack(matrices)
        new_data = pd.DataFrame(matrices, columns=activations.columns)
        new_path = path + '_norm-' + str(order).replace('np.', '')
        check_folder(new_path)
        new_data.to_csv(os.path.join(new_path, name + '_run{}.csv'.format(run_index + 1)), index=False)


Defining variables:

In [4]:
template = '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/text/english/text_english_run*.txt' # path to text input
language = 'english'

Creating iterator for each run:

In [5]:
template = '/Users/alexpsq/Code/Parietal/data/text_english_run*.txt' # path to text input


In [6]:
paths = sorted(glob.glob(template))

In [7]:
iterator_list = [tokenize(path, language, train=False) for path in paths]

100%|██████████| 135/135 [00:00<00:00, 420676.85it/s]
100%|██████████| 135/135 [00:00<00:00, 597290.13it/s]
100%|██████████| 176/176 [00:00<00:00, 555453.35it/s]
100%|██████████| 173/173 [00:00<00:00, 735917.44it/s]
100%|██████████| 177/177 [00:00<00:00, 782288.52it/s]
100%|██████████| 216/216 [00:00<00:00, 802453.20it/s]
100%|██████████| 196/196 [00:00<00:00, 882063.93it/s]
100%|██████████| 145/145 [00:00<00:00, 750832.20it/s]
100%|██████████| 207/207 [00:00<00:00, 981040.60it/s]


In [8]:
iterator_list

[['Once , when I was six years old , I saw a magnificent picture in a book about the primeval forest called ‘ Real - life Stories . ’',
  'It showed a boa constrictor swallowing a wild animal .',
  'Here is a copy of the drawing .',
  'It said in the book : “ Boa constrictors swallow their prey whole , without chewing .',
  'Then they are not able to move , and they sleep for the six months it takes for digestion . ”',
  'So I thought a lot about the adventures of the jungle and , in turn , I managed , with a coloured pencil , to make my first drawing .',
  'My Drawing Number one .',
  'It looked like this : I showed my masterpiece to the grownups and I asked them if my drawing frightened them .',
  'They answered me : “ Why would anyone be frightened by a hat ? ”',
  'My drawing was not of a hat .',
  'It showed a boa constrictor digesting an elephant .',
  'I then drew the inside of the boa constrictor , so that the grownups could understand .',
  'They always need to have things exp

In [9]:
iterator_list = [[sent.lower() for sent in text] for text in iterator_list]

In [10]:
iterator_list

[['once , when i was six years old , i saw a magnificent picture in a book about the primeval forest called ‘ real - life stories . ’',
  'it showed a boa constrictor swallowing a wild animal .',
  'here is a copy of the drawing .',
  'it said in the book : “ boa constrictors swallow their prey whole , without chewing .',
  'then they are not able to move , and they sleep for the six months it takes for digestion . ”',
  'so i thought a lot about the adventures of the jungle and , in turn , i managed , with a coloured pencil , to make my first drawing .',
  'my drawing number one .',
  'it looked like this : i showed my masterpiece to the grownups and i asked them if my drawing frightened them .',
  'they answered me : “ why would anyone be frightened by a hat ? ”',
  'my drawing was not of a hat .',
  'it showed a boa constrictor digesting an elephant .',
  'i then drew the inside of the boa constrictor , so that the grownups could understand .',
  'they always need to have things exp

In [11]:

from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
for i in range(9):
    a = len(' '.join(iterator_list[i]).split())
    b = len(tokenizer.wordpiece_tokenizer.tokenize(' '.join(iterator_list[i])))
    print(f'run {i} ->', 100*(a-b)/b)

run 0 -> -4.48814926878467
run 1 -> -3.592814371257485
run 2 -> -2.0051194539249146
run 3 -> -3.3243486073674755
run 4 -> -2.6402640264026402
run 5 -> -3.1426269137792104
run 6 -> -1.4569000404694457
run 7 -> -1.688374336710082
run 8 -> -0.8232065856526852


In [12]:
import random 
sequence = 'once , when i was six years old , i saw a magnificent picture in a book about the primeval forest called ‘ real - life stories . ’'
print(tokenizer.wordpiece_tokenizer.tokenize(sequence))
len(shuffle_words(sequence, start_at=5))

['once', ',', 'when', 'i', 'was', 'six', 'years', 'old', ',', 'i', 'saw', 'a', 'magnificent', 'picture', 'in', 'a', 'book', 'about', 'the', 'prime', '##val', 'forest', 'called', '‘', 'real', '-', 'life', 'stories', '.', '’']


NameError: name 'shuffle_words' is not defined

In [None]:
batch_tmp, index_tmp = shuffle_sentence_context(
    iterator_list[0][:2], 
    context_size=6, 
    pretrained_model='bert-base-uncased',
    seed=1111)
for b in batch_tmp:
    print(b)
    print()

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # to replace with tokenizer of interest

for i, j in enumerate(batch_tmp):
    j  ='[CLS] ' + j + ' [SEP]'
    print(tokenizer.wordpiece_tokenizer.tokenize(j)[index_tmp[i][0]+1:1+index_tmp[i][1]])
#len(tokenizer.wordpiece_tokenizer.tokenize(' '.join(context[i]) + ' ' + ' '.join(words[:i+1])))



In [None]:
import random
iterator_list[0][:3]

In [None]:
batch_tmp, index_tmp = batchify_sentences(
    iterator_list[0][:3],
    'bert-base-uncased',
    1, 
    1, 
    0,
    10,
    transformation='shuffle',
    vocabulary=None,
    dictionary=None,
    seed=111
)
for i, j in enumerate(batch_tmp):
    #j  ='[CLS] ' + j + ' [SEP]'
    #print(index_tmp[i])
    #print(tokenizer.wordpiece_tokenizer.tokenize(j))
    print('##', i, ' - ', j)
    #print('##', i, ' - ', tokenizer.wordpiece_tokenizer.tokenize(j)[index_tmp[i][0]:index_tmp[i][1]])
    #print(tokenizer.wordpiece_tokenizer.tokenize(j)[1+index_tmp[i][0]:1+index_tmp[i][1]])
    #print()
    #print()
#print(batch_tmp[-2])
#print(batch_tmp[-1])

In [None]:
batch_tmp

In [None]:
l = [0,0,0,3, 4, 5]
l[len(l):len(l)]

In [None]:
#tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
#tokenizer.wordpiece_tokenizer.tokenize('he mustn ’ t bite you ...')

In [None]:
#import utils
#import seaborn as sns
#import matplotlib.pyplot as plt
#from transformers import BertTokenizer
#
#tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
#lengths = []
#
#for index in range(9):
#    batches, indexes = utils.batchify_per_sentence_with_pre_and_post_context(
#                iterator_list[index], 
#                1, 
#                12, 
#                0, 
#                'bert-base-uncased', 
#                max_length=512)
#    #lengths.append(np.array(sorted([len(item.split()) for item in batches])))
#    lengths.append(np.array(sorted([len(tokenizer.wordpiece_tokenizer.tokenize(item)) for item in batches])))
#
#    #sns.boxplot(lengths[-1])
#    #plt.show()
#    #print()
#
#print(np.mean(np.array([np.mean(item) for item in lengths])))
#print(np.median(np.array([np.median(item) for item in lengths])))
#print(np.mean(np.array([np.median(item) for item in lengths])))

In [None]:
#import utils
#import seaborn as sns
#import matplotlib.pyplot as plt
#
#for index in range(9):
#    batches, indexes = utils.batchify_per_sentence_with_pre_and_post_context(
#                iterator_list[index], 
#                1, 
#                10, 
#                0, 
#                'bert-base-uncased', 
#                max_length=512)
#    #print(len(batches))
#    #print(sorted([len(item.split()) for item in batches]))
#    #sns.boxplot(sorted([len(item.split()) for item in batches]))
#    #plt.show()
#    #print()
#    indexes_tmp = []
#    for i in range(len(indexes)):
#        if type(indexes[i])==list and type(indexes[i][0])==list:
#            indexes_tmp.append(indexes[i][-1])
#        else:
#            if i > 0:
#                indexes_tmp.append((
#                indexes[i][-1-2][0], 
#                indexes[i][-1-2][1]))
#            else:
#                indexes_tmp.append(None)
#
#    indexes_tmp[0] = (indexes[0][0][0], indexes[0][-1][1])
#    print(indexes_tmp[0])

## Activation extraction

In [11]:
pretrained_bert_models = ['bert-base-uncased']
names = [
    'bert-base-uncased_pre-2_1_post-0_token-8-0'

         ]
config_paths = [None] * 34
saving_path_folders = [
    '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}/bert-base-uncased_pre-2_1_post-0_token-8-0'.format(language),
    
]
prediction_types = ['shuffle'] * 34
number_of_sentence_list = [1] * 34
number_of_sentence_before_list = [1]
number_of_sentence_after_list = [0] * 21
attention_length_before_list = [1] * 34
attention_length_after_list = [0] * 34

stop_attention_at_sent_before_list = [None] * 34
stop_attention_before_sent_list = [0] * 34


In [14]:
for index, bert_model in enumerate(pretrained_bert_models):
    extractor = BertExtractor(bert_model, 
                              language, 
                              names[index], 
                              prediction_types[index], 
                              output_hidden_states=True, 
                              output_attentions=False, 
                              attention_length_before=attention_length_before_list[index],
                              attention_length_after=attention_length_after_list[index],
                              config_path=config_paths[index], 
                              max_length=512, 
                              number_of_sentence=number_of_sentence_list[index], 
                              number_of_sentence_before=number_of_sentence_before_list[index], 
                              number_of_sentence_after=number_of_sentence_after_list[index],
                              stop_attention_at_sent_before=stop_attention_at_sent_before_list[index],
                              stop_attention_before_sent=stop_attention_before_sent_list[index],
                             )
    print(extractor.name, ' - Extracting activations ...')
    for run_index, iterator in tqdm(enumerate(iterator_list)):
        gc.collect()
        print("############# Run {} #############".format(run_index))
        activations  = extractor.extract_activations(iterator, language)
        hidden_states_activations = activations[0]
        attention_heads_activations = activations[1]
        #(cls_hidden_states_activations, cls_attention_activations) = activations[2]
        #(sep_hidden_states_activations, sep_attention_activations) = activations[3]
        #activations = pd.concat([hidden_states_activations, attention_heads_activations], axis=1)
        #cls_activations = pd.concat([cls_hidden_states_activations, cls_attention_activations], axis=1)
        #sep_activations = pd.concat([sep_hidden_states_activations, sep_attention_activations], axis=1)

        #transform(
        #    hidden_states_activations, 
        #    saving_path_folders[index], 
        #    'activations', 
        #    run_index=run_index,
        #    n_layers_hidden=13,
        #    n_layers_attention=0, 
        #    hidden_size=768)

        #transform(cls_activations, saving_path_folders[index], 'cls')
        #transform(sep_activations, saving_path_folders[index], 'sep')
        #check_folder(saving_path_folders[index])
        #hidden_states_activations.to_csv(os.path.join(saving_path_folders[index], 'activations_run{}.csv'.format(run_index + 1)), index=False)
        #cls_activations.to_csv(os.path.join(saving_path_folders[index], 'cls_run{}.csv'.format(run_index + 1)), index=False)
        #sep_activations.to_csv(os.path.join(saving_path_folders[index], 'sep_run{}.csv'.format(run_index + 1)), index=False)
        del activations
        #del cls_activations
        #del sep_activations
        del hidden_states_activations
        #del attention_heads_activations
        #del cls_hidden_states_activations
        #del cls_attention_activations
        #del sep_hidden_states_activations
        #del sep_attention_activations


0it [00:00, ?it/s]

bert-base-uncased_pre-2_1_post-0_token-8-0  - Extracting activations ...
############# Run 0 #############
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful last sentence selected...
Careful l

['six']
[0] [CLS]
[1] when
[2] ,
[3] i
[4] once
[5] was
[6] six
[7] years
[8] old
[9] ,
[10] i
[11] saw
[12] a
[13] magnificent
[14] picture
[15] in
[16] a
[17] book
[18] about
[19] the
[20, 21] prime##val
[22] forest
[23] called
[24] ‘
[25] real
[26] -
[27] life
[28] stories
[29] .
[30] ’
[31] [SEP]
7 7
Batch number:  6  -  [CLS] when , once was i six years old , i saw a magnificent picture in a book about the primeval forest called ‘ real - life stories . ’ [SEP]
['[CLS]', 'when', ',', 'once', 'was', 'i', 'six', 'years', 'old', ',', 'i', 'saw', 'a', 'magnificent', 'picture', 'in', 'a', 'book', 'about', 'the', 'prime', '##val', 'forest', 'called', '‘', 'real', '-', 'life', 'stories', '.', '’', '[SEP]']
indexes: (7, 8) ['years']

['years']
[0] [CLS]
[1] when
[2] ,
[3] once
[4] was
[5] i
[6] six
[7] years
[8] old
[9] ,
[10] i
[11] saw
[12] a
[13] magnificent
[14] picture
[15] in
[16] a
[17] book
[18] about
[19] the
[20, 21] prime##val
[22] forest
[23] called
[24] ‘
[25] real
[26] -
[27]

0it [01:13, ?it/s]

['about']
[0] [CLS]
[1] six
[2] ,
[3] i
[4] was
[5] magnificent
[6] saw
[7] picture
[8] years
[9] ,
[10] when
[11] i
[12] once
[13] a
[14] old
[15] in
[16] a
[17] book
[18] about
[19] the
[20, 21] prime##val
[22] forest
[23] called
[24] ‘
[25] real
[26] -
[27] life
[28] stories
[29] .
[30] ’
[31] [SEP]
19 19
Batch number:  18  -  [CLS] a , years old i six was saw , once when magnificent i a in picture book about the primeval forest called ‘ real - life stories . ’ [SEP]
['[CLS]', 'a', ',', 'years', 'old', 'i', 'six', 'was', 'saw', ',', 'once', 'when', 'magnificent', 'i', 'a', 'in', 'picture', 'book', 'about', 'the', 'prime', '##val', 'forest', 'called', '‘', 'real', '-', 'life', 'stories', '.', '’', '[SEP]']
indexes: (19, 20) ['the']

['the']
[0] [CLS]
[1] a
[2] ,
[3] years
[4] old
[5] i
[6] six
[7] was
[8] saw
[9] ,
[10] once
[11] when
[12] magnificent
[13] i
[14] a
[15] in
[16] picture
[17] book
[18] about
[19] the
[20, 21] prime##val
[22] forest
[23] called
[24] ‘
[25] real
[26] -
[




TypeError: unsupported operand type(s) for +: 'NoneType' and 'int'

In [None]:
iterator_list[1]

### Generate control activations

In [None]:
bert_model = 'bert-base-cased'
language = 'english'
name = 'bert-base-cased_control_'
prediction_type = 'sentence'
saving_path_folder = '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/stimuli-representations/{}'.format(language)
seeds = [24, 213, 1111, 61, 183]

In [None]:
def randomize_layer(model, layer_nb):
    """Randomize layer weights and put bias to zero.
    The input "layer_nb" goes from 1 to 12 to be coherent with the rest of the analysis.
    It is then transfomed in the function.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.self.query.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.query.weight))
    model.encoder.layer[layer_nb].attention.self.query.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.query.bias))
    model.encoder.layer[layer_nb].attention.self.key.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.key.weight))
    model.encoder.layer[layer_nb].attention.self.key.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.key.bias))
    model.encoder.layer[layer_nb].attention.self.value.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.value.weight))
    model.encoder.layer[layer_nb].attention.self.value.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.value.bias))
    model.encoder.layer[layer_nb].attention.output.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.output.dense.weight))
    model.encoder.layer[layer_nb].attention.output.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.output.dense.bias))
    model.encoder.layer[layer_nb].attention.output.LayerNorm.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.output.LayerNorm.weight))
    model.encoder.layer[layer_nb].attention.output.LayerNorm.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.output.LayerNorm.bias))
    model.encoder.layer[layer_nb].intermediate.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].intermediate.dense.weight))
    model.encoder.layer[layer_nb].intermediate.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].intermediate.dense.bias))
    model.encoder.layer[layer_nb].output.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].output.dense.weight))
    model.encoder.layer[layer_nb].output.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].output.dense.bias))
    model.encoder.layer[layer_nb].output.LayerNorm.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].output.LayerNorm.weight))
    model.encoder.layer[layer_nb].output.LayerNorm.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].output.LayerNorm.bias))
    return model

In [None]:
def randomize_attention_query(model, layer_nb):
    """Randomize attention query weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.self.query.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.query.weight))
    model.encoder.layer[layer_nb].attention.self.query.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.query.bias))
    return model

def randomize_attention_key(model, layer_nb):
    """Randomize attention key weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.self.key.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.key.weight))
    model.encoder.layer[layer_nb].attention.self.key.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.key.bias))
    return model

def randomize_attention_value(model, layer_nb):
    """Randomize attention value weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.self.value.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.self.value.weight))
    model.encoder.layer[layer_nb].attention.self.value.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.self.value.bias))
    return model

def randomize_attention_output_dense(model, layer_nb):
    """Randomize attention dense network weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].attention.output.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].attention.output.dense.weight))
    model.encoder.layer[layer_nb].attention.output.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].attention.output.dense.bias))
    return model


def randomize_intermediate_dense(model, layer_nb):
    """Randomize intermediate dense network weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].intermediate.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].intermediate.dense.weight))
    model.encoder.layer[layer_nb].intermediate.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].intermediate.dense.bias))
    return model

def randomize_outptut_dense(model, layer_nb):
    """Randomize output dense network weights of a given layer and put bias to zero.
    """
    layer_nb = layer_nb - 1
    model.encoder.layer[layer_nb].output.dense.weight = torch.nn.parameter.Parameter(torch.rand_like(model.encoder.layer[layer_nb].output.dense.weight))
    model.encoder.layer[layer_nb].output.dense.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.encoder.layer[layer_nb].output.dense.bias))
    return model


In [None]:
def randomize_embeddings(model):
    """Randomize embeddings weights and put bias to zero.
    """
    model.embeddings.word_embeddings.weight = torch.nn.parameter.Parameter(torch.rand_like(model.embeddings.word_embeddings.weight))
    model.embeddings.position_embeddings.weight = torch.nn.parameter.Parameter(torch.rand_like(model.embeddings.position_embeddings.weight))
    model.embeddings.token_type_embeddings.weight = torch.nn.parameter.Parameter(torch.rand_like(model.embeddings.token_type_embeddings.weight))
    model.embeddings.LayerNorm.weight = torch.nn.parameter.Parameter(torch.rand_like(model.embeddings.LayerNorm.weight))
    model.embeddings.LayerNorm.bias = torch.nn.parameter.Parameter(torch.zeros_like(model.embeddings.LayerNorm.bias))
    return model

In [None]:
for seed in seeds:
    set_seed(seed)
    for layer in range(13):
        extractor = BertExtractor(bert_model, language, name, prediction_type, output_hidden_states=True, output_attentions=True, config_path=None)
        if layer==0:
            extractor.model = randomize_embeddings(extractor.model)
        else:
            extractor.model = randomize_layer(extractor.model, layer)
        print(extractor.name + str(seed), ' - Extracting activations for layer {}...'.format(layer))
        for run_index, iterator in tqdm(enumerate(iterator_list)):
            print("############# Run {} #############".format(run_index))
            activations  = extractor.extract_activations(iterator, language)
            hidden_states_activations = activations[0]
            attention_heads_activations = activations[1]
            (cls_hidden_states_activations, cls_attention_activations) = activations[2]
            (sep_hidden_states_activations, sep_attention_activations) = activations[3]
            activations = pd.concat([hidden_states_activations, attention_heads_activations], axis=1)
            cls_activations = pd.concat([cls_hidden_states_activations, cls_attention_activations], axis=1)
            sep_activations = pd.concat([sep_hidden_states_activations, sep_attention_activations], axis=1)

            # activations
            heads = np.arange(1, 13)
            columns_to_retrieve = ['hidden_state-layer-{}-{}'.format(layer, i) for i in range(1, 769)]
            if layer > 0:
                columns_to_retrieve += ['attention-layer-{}-head-{}-{}'.format(layer, head, i) for head in heads for i in range(1, 65)]
            activations = activations[columns_to_retrieve]

            # CLS
            heads = np.arange(1, 13)
            columns_to_retrieve = ['CLS-hidden_state-layer-{}-{}'.format(layer, i) for i in range(1, 769)]
            if layer > 0:
                columns_to_retrieve += ['CLS-attention-layer-{}-head-{}-{}'.format(layer, head, i) for head in heads for i in range(1, 65)]
            cls_activations = cls_activations[columns_to_retrieve]

            # SEP
            heads = np.arange(1, 13)
            columns_to_retrieve = ['SEP-hidden_state-layer-{}-{}'.format(layer, i) for i in range(1, 769)]
            if layer > 0:
                columns_to_retrieve += ['SEP-attention-layer-{}-head-{}-{}'.format(layer, head, i) for head in heads for i in range(1, 65)]
            sep_activations = sep_activations[columns_to_retrieve]

            save_path = os.path.join(saving_path_folder, name + str(seed) + '_layer-{}'.format(layer))
            check_folder(save_path)
            print('\tSaving in {}.'.format(save_path))
            activations.to_csv(os.path.join(save_path, 'activations_run{}.csv'.format(run_index + 1)), index=False)
            cls_activations.to_csv(os.path.join(save_path, 'cls_run{}.csv'.format(run_index + 1)), index=False)
            sep_activations.to_csv(os.path.join(save_path, 'sep_run{}.csv'.format(run_index + 1)), index=False)


# Test activation extraction 

In [15]:
import utils 
import random
from transformers import BertTokenizer

config = {
    'number_of_sentence': 1, 
    'number_of_sentence_before': 3, 
    'number_of_sentence_after': 0, 
    'attention_length_before': 3, 
    'attention_length_after': 0,
}



tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
extractor_full = BertExtractor('bert-base-uncased', 
                              'english', 
                              'test', 
                              'sentence', 
                              output_hidden_states=True, 
                              output_attentions=False, 
                              attention_length_before=config['attention_length_before'],
                              attention_length_after=0,
                              config_path=None, 
                              number_of_sentence=config['number_of_sentence'], 
                              number_of_sentence_before=config['number_of_sentence_before'], 
                              number_of_sentence_after=config['number_of_sentence_after'], 
                             )
extractor_masked = BertExtractor('bert-base-uncased', 
                              'english', 
                              'test', 
                              'control-context', 
                              output_hidden_states=True, 
                              output_attentions=False, 
                              attention_length_before=config['attention_length_before'],
                              attention_length_after=0,
                              config_path=None, 
                              number_of_sentence=config['number_of_sentence'], 
                              number_of_sentence_before=config['number_of_sentence_before'], 
                              number_of_sentence_after=config['number_of_sentence_after'], 
                             )
extractor_shuffle = BertExtractor('bert-base-uncased', 
                              'english', 
                              'test', 
                              'shuffle', 
                              output_hidden_states=True, 
                              output_attentions=False, 
                              attention_length_before=config['attention_length_before'],
                              attention_length_after=0,
                              config_path=None, 
                              number_of_sentence=config['number_of_sentence'], 
                              number_of_sentence_before=config['number_of_sentence_before'], 
                              number_of_sentence_after=config['number_of_sentence_after'], 
                             )

In [12]:
# Full sentences
batches_full, indexes_full = utils.batchify_per_sentence_with_pre_and_post_context(
            iterator_list[0], 
            config['number_of_sentence'], 
            config['number_of_sentence_before'], 
            config['number_of_sentence_after'], 
            'bert-base-uncased',
        )

# Tokens are masked
batches_masked, indexes_masked = utils.batchify_per_sentence_with_pre_and_post_context(
            iterator_list[0], 
            config['number_of_sentence'], 
            config['number_of_sentence_before'], 
            config['number_of_sentence_after'], 
            'bert-base-uncased', 
            )

# Shuffling
batches_shuffle, indexes_shuffle = utils.batchify_sentences(
            iterator_list[0], 
            config['number_of_sentence'], 
            config['number_of_sentence_before'], 
            config['number_of_sentence_after'], 
            pretrained_model='bert-base-uncased', 
            past_context_size=config['attention_length_before'],
            future_context_size=config['attention_length_after'],
            transformation='shuffle',
            vocabulary=None,
            dictionary=None,
            seed=1111,
            )

In [13]:
# Preprocessing full
indexes_full_tmp = []
for i in range(len(indexes_full)):
    if type(indexes_full[i])==list and type(indexes_full[i][0])==list:
        indexes_full_tmp.append(indexes_full[i][-1])
    else:
        if i > 0:
            indexes_full_tmp.append((
            indexes_full[i][-config['number_of_sentence']-config['number_of_sentence_after']][0], 
            indexes_full[i][-config['number_of_sentence']-config['number_of_sentence_after']][1]))
        else:
            indexes_full_tmp.append(None)


            
# Preprocessing masked
indexes_masked_tmp = []
# If beginning and end indexes of each sentences are recorded, we only keep the sentence(s) of interest
for i in range(len(indexes_masked)):
    if type(indexes_masked[i])==list and type(indexes_masked[i][0])==list:
        indexes_masked_tmp.append(indexes_masked[i][-1])
    else:
        if i > 0:
            indexes_masked_tmp.append((
            indexes_masked[i][-config['number_of_sentence']-config['number_of_sentence_after']][0], 
            indexes_masked[i][-config['number_of_sentence']-config['number_of_sentence_after']][1]))
        else:
            indexes_masked_tmp.append(None)

#if config['number_of_sentence_before']==0:
#    indexes_masked_tmp[0] = (indexes_masked[0][0][0][0], indexes_masked[0][-1][1])
#else:
#    indexes_masked_tmp[0] = (indexes_masked[0][0][0], indexes_masked[0][-1][1])


# Preprocessing shuffle




In [21]:
# activation generation full
output = []
for index, batch in enumerate(batches_full):
    batch = batch.strip() # Remove trailing character

    batch = '[CLS] ' + batch + ' [SEP]'
    tokenized_text = tokenizer.wordpiece_tokenizer.tokenize(batch)
    inputs_ids = torch.tensor([tokenizer.convert_tokens_to_ids(tokenized_text)])
    mapping = utils.match_tokenized_to_untokenized(tokenized_text, batch)
    #print(mapping)

    attention_mask = torch.tensor([[1 for x in tokenized_text]])
    
    #print('input shape: ', inputs_ids.shape)
    #print(batch)
    #print(tokenized_text)
    
    with torch.no_grad():
        encoded_layers = extractor_full.model(inputs_ids, attention_mask=attention_mask) # last_hidden_state, pooler_output, hidden_states, attentions

        hidden_states_activations_ = np.vstack(encoded_layers[2]) # retrieve all the hidden states (dimension = layer_count * len(tokenized_text) * feature_count)

        #print(len(encoded_layers[2]))
        #print('output shape:', hidden_states_activations_.shape)
        
        new_activations = []
        key_start = None
        key_stop = None
        
        #print('Mapping:')
        #for key in mapping.keys():
        #    print(batch.split()[key], ''.join([tokenized_text[i] for i in mapping[key]]))
        #print('A priori Token of interest:', tokenized_text[indexes_full_tmp[index][0]:indexes_full_tmp[index][1]])
            
        for key_, value in mapping.items(): 
            if (value[0] - 1) == (indexes_full_tmp[index][0]): #because we added [CLS] token at the beginning
                key_start = key_
        for key_, value in mapping.items(): 
            if value[-1] == (indexes_full_tmp[index][1]): #because we added [CLS] token at the beginning
                key_stop = key_
                
        #print(key_start, key_stop)
        #print('Extracting sentence:')
        print(' '.join([tokenizer.decode(tokenizer.convert_tokens_to_ids([tokenized_text[word] for word in mapping[index]])) for index in range(key_start, key_stop + 1)]))
        #print('dimension match:', len(tokenized_text)==hidden_states_activations_.shape[1])
        output.append(' '.join([tokenizer.decode(tokenizer.convert_tokens_to_ids([tokenized_text[word] for word in mapping[index]])) for index in range(key_start, key_stop + 1)]))
        
        for word_index in range(key_start, key_stop + 1): # len(mapping.keys()) - 1
            word_activation = []
            word_activation.append([hidden_states_activations_[:,index, :] for index in mapping[word_index]])
            word_activation = np.vstack(word_activation)
            new_activations.append(np.mean(word_activation, axis=0).reshape(1,-1))
        
        #print(np.vstack(new_activations).shape)
        #if input()!='':
        #    break


assert ' '.join(output) == ' '.join(iterator_list[0])


once , when i was six years old , i saw a magnificent picture in a book about the primeval forest called ‘ real - life stories . ’
it showed a boa constrictor swallowing a wild animal .
here is a copy of the drawing .
it said in the book : “ boa constrictors swallow their prey whole , without chewing .
then they are not able to move , and they sleep for the six months it takes for digestion . ”
so i thought a lot about the adventures of the jungle and , in turn , i managed , with a coloured pencil , to make my first drawing .
my drawing number one .
it looked like this : i showed my masterpiece to the grownups and i asked them if my drawing frightened them .
they answered me : “ why would anyone be frightened by a hat ? ”
my drawing was not of a hat .
it showed a boa constrictor digesting an elephant .
i then drew the inside of the boa constrictor , so that the grownups could understand .
they always need to have things explained .
my drawing number two looked like this : the grownups 

and perhaps with a hint of sadness , he added : “ straight ahead you can ' t go far ... ”


In [22]:
# activation generation masked
output = []
for index_batch, batch in enumerate(batches_masked):
    batch = batch.strip() # Remove trailing character

    batch = '[CLS] ' + batch + ' [SEP]'
    tokenized_text = tokenizer.wordpiece_tokenizer.tokenize(batch)
    inputs_ids = torch.tensor([tokenizer.convert_tokens_to_ids(tokenized_text)])
    inputs_ids = torch.cat(inputs_ids.size(1) * [inputs_ids])
    attention_mask =  torch.diag_embed(torch.tensor([[0 for x in tokenized_text]]))

    for i in range(min(len(tokenized_text), config['attention_length_before'])):
        attention_mask = torch.add(attention_mask, torch.diag_embed(torch.tensor([[1 for x in range(len(tokenized_text) - i)]]), offset=-i))
    for i in range(1, min(len(tokenized_text), config['attention_length_after'] + 1)):
        attention_mask = torch.add(attention_mask, torch.diag_embed(torch.tensor([[1 for x in range(len(tokenized_text) - i)]]), offset=i))
    mapping = utils.match_tokenized_to_untokenized(tokenized_text, batch)
    #print(mapping)

    attention_mask = attention_mask.squeeze(0)

    beg = indexes_masked_tmp[index_batch][0] + 1 # because of the special token at the beginning
    end = indexes_masked_tmp[index_batch][1] + 1 # because of special token

    inputs_ids = inputs_ids[beg:end, :]
    attention_mask = attention_mask[beg:end, :]
    
    #print('input shape: ', inputs_ids.shape)
    #print('attention_mask shape: ', attention_mask.shape)
    #print(batch)
    #print(tokenized_text)
    #print('Mapping:')
    #for key in mapping.keys():
    #    print(batch.split()[key], ''.join([tokenized_text[i] for i in mapping[key]]))
    #print('A priori Token of interest:', tokenized_text[beg:end])

    with torch.no_grad():
        encoded_layers = extractor_masked.model(inputs_ids, attention_mask=attention_mask) # last_hidden_state, pooler_output, hidden_states, attentions

        #print('output shape at each layer:', encoded_layers[2][0].shape)
        hidden_states_activations_ = np.vstack([torch.cat([encoded_layers[2][layer][i,len(tokenized_text) - encoded_layers[2][layer].size(0) + i - 1,:].unsqueeze(0) for i in range(encoded_layers[2][layer].size(0))], dim=0).unsqueeze(0).detach().numpy() for layer in range(len(encoded_layers[2]))]) # retrieve all the hidden states (dimension = layer_count * len(tokenized_text) * feature_count)
        #print('output shape after concat:', hidden_states_activations_.shape)
        hidden_states_activations_ = np.concatenate([np.zeros((hidden_states_activations_.shape[0], indexes_masked_tmp[index_batch][0] + 1 , hidden_states_activations_.shape[-1])), hidden_states_activations_, np.zeros((hidden_states_activations_.shape[0], len(tokenized_text) - indexes_masked_tmp[index_batch][1] - 1, hidden_states_activations_.shape[-1]))], axis=1)

        #print(len(encoded_layers[2]))
        #print('output shape after filling:', hidden_states_activations_.shape)
        
        new_activations = []
        key_start = None
        key_stop = None
        
        assert indexes_masked_tmp[index_batch][0]==indexes_full_tmp[index_batch][0]
        assert indexes_masked_tmp[index_batch][1]==indexes_full_tmp[index_batch][1]
        
        for key_, value in mapping.items(): 
            if (value[0] - 1) == (indexes_masked_tmp[index_batch][0]): #because we added [CLS] token at the beginning
                key_start = key_
        for key_, value in mapping.items(): 
            if value[-1] == (indexes_masked_tmp[index_batch][1]): #because we added [CLS] token at the beginning
                key_stop = key_
                
        #print(key_start, key_stop)
        #print('Extracting sentence:')
        print(' '.join([tokenizer.decode(tokenizer.convert_tokens_to_ids([tokenized_text[word] for word in mapping[index]])) for index in range(key_start, key_stop + 1)]))
        #print('dimension match:', len(tokenized_text)==hidden_states_activations_.shape[1])
        output.append(' '.join([tokenizer.decode(tokenizer.convert_tokens_to_ids([tokenized_text[word] for word in mapping[index]])) for index in range(key_start, key_stop + 1)]))
        
        for word_index in range(key_start, key_stop + 1): # len(mapping.keys()) - 1
            word_activation = []
            word_activation.append([hidden_states_activations_[:,index, :] for index in mapping[word_index]])
            word_activation = np.vstack(word_activation)
            new_activations.append(np.mean(word_activation, axis=0).reshape(1,-1))
        
        #print(np.vstack(new_activations).shape)
        #if input()!='':
        #    break

assert ' '.join(output) == ' '.join(iterator_list[0])


once , when i was six years old , i saw a magnificent picture in a book about the primeval forest called ‘ real - life stories . ’
it showed a boa constrictor swallowing a wild animal .
here is a copy of the drawing .
it said in the book : “ boa constrictors swallow their prey whole , without chewing .
then they are not able to move , and they sleep for the six months it takes for digestion . ”
so i thought a lot about the adventures of the jungle and , in turn , i managed , with a coloured pencil , to make my first drawing .
my drawing number one .
it looked like this : i showed my masterpiece to the grownups and i asked them if my drawing frightened them .
they answered me : “ why would anyone be frightened by a hat ? ”
my drawing was not of a hat .
it showed a boa constrictor digesting an elephant .
i then drew the inside of the boa constrictor , so that the grownups could understand .
they always need to have things explained .
my drawing number two looked like this : the grownups 

and perhaps with a hint of sadness , he added : “ straight ahead you can ' t go far ... ”


In [16]:
# activation generation shuffle
output = []
for index_batch, batch in enumerate(batches_shuffle):
    batch = batch.strip() # Remove trailing character

    batch = '[CLS] ' + batch + ' [SEP]'
    tokenized_text = tokenizer.wordpiece_tokenizer.tokenize(batch)
    inputs_ids = torch.tensor([tokenizer.convert_tokens_to_ids(tokenized_text)])

    #print('input shape: ', inputs_ids.shape)
    #print(batch)
    #print(tokenized_text)
    
    mapping = utils.match_tokenized_to_untokenized(tokenized_text, batch)

    #print('Mapping:')
    #for key in mapping.keys():
    #    print(batch.split()[key], ''.join([tokenized_text[i] for i in mapping[key]]))
    #print('A priori Token of interest:', tokenized_text[indexes_shuffle[index_batch][0]:indexes_shuffle[index_batch][1]])
         
    with torch.no_grad():
        encoded_layers = extractor_shuffle.model(inputs_ids) # last_hidden_state, pooler_output, hidden_states, attentions

        hidden_states_activations_ = np.vstack(encoded_layers[2]) # retrieve all the hidden states (dimension = layer_count * len(tokenized_text) * feature_count)

        #print('nb of layer:', len(encoded_layers[2]))
        #print('output shape:', hidden_states_activations_.shape)
        
        new_activations = []
        key_start = None
        key_stop = None
        
        for key_, value in mapping.items(): 
            if (value[0] - 1) == (indexes_shuffle[index_batch][0]): #because we added [CLS] token at the beginning
                key_start = key_
        for key_, value in mapping.items(): 
            if value[-1] == (indexes_shuffle[index_batch][1]): #because we added [CLS] token at the beginning
                key_stop = key_
                
        #print(key_start, key_stop)
        #print('Extracting sentence:')
        print(' '.join([tokenizer.decode(tokenizer.convert_tokens_to_ids([tokenized_text[word] for word in mapping[index]])) for index in range(key_start, key_stop + 1)]))
        #print('dimension match:', len(tokenized_text)==hidden_states_activations_.shape[1])
        output.append(' '.join([tokenizer.decode(tokenizer.convert_tokens_to_ids([tokenized_text[word] for word in mapping[index]])) for index in range(key_start, key_stop + 1)]))
        
        for word_index in range(key_start, key_stop + 1): # len(mapping.keys()) - 1
            word_activation = []
            word_activation.append([hidden_states_activations_[:,index, :] for index in mapping[word_index]])
            word_activation = np.vstack(word_activation)
            new_activations.append(np.mean(word_activation, axis=0).reshape(1,-1))
        
        #print(np.vstack(new_activations).shape)
        #if input()!='':
        #    break


assert ' '.join(output) == ' '.join(iterator_list[0])


once
,
when
i
was
six
years
old
,
i
saw
a
magnificent
picture
in
a
book
about
the
primeval
forest
called
‘
real
-
life
stories
.
’
it
showed
a
boa
constrictor
swallowing
a
wild
animal
.
here
is
a
copy
of
the
drawing
.
it
said
in
the
book
:
“
boa
constrictors
swallow
their
prey
whole
,
without
chewing
.
then
they
are
not
able
to
move
,
and
they
sleep
for
the
six
months
it
takes
for
digestion
.
”
so
i
thought
a
lot
about
the
adventures
of
the
jungle
and
,
in
turn
,
i
managed
,
with
a
coloured
pencil
,
to
make
my
first
drawing
.
my
drawing
number
one
.
it
looked
like
this
:
i
showed
my
masterpiece
to
the
grownups
and
i
asked
them
if
my
drawing
frightened
them
.
they
answered
me
:
“
why
would
anyone
be
frightened
by
a
hat
?
”
my
drawing
was
not
of
a
hat
.
it
showed
a
boa
constrictor
digesting
an
elephant
.
i
then
drew
the
inside
of
the
boa
constrictor
,
so
that
the
grownups
could
understand
.
they
always
need
to
have
things
explained
.
my
drawing
number
two
looked
like
this
:
the
grownups


gravely
:
“
that
doesn
’
t
matter
;
where
i
live
,
everything
is
so
small
!
”
and
perhaps
with
a
hint
of
sadness
,
he
added
:
“
straight
ahead
you
can
'
t
go
far
...
”


In [24]:
def transform_sentence_and_context(
    iterator, 
    past_context_size, 
    future_context_size, 
    pretrained_model,
    transformation='shuffle',
    vocabulary=None,
    dictionary=None,
    select=None,
    seed=1111):
    """ DEF...
    """
    random.seed(seed)
    punctuation = ['.', '!', '?', '...', '\'', ',', ';', ':', '/', '-', '"', '‘', '’', '(', ')', '{', '}', '[', ']', '`', '“', '”', '—']
    if select is None:
        words = ' '.join(iterator).split()
    else:
        words = iterator[select].split()
        
    all_words = ' '.join(iterator).split()
    words_before = [] if select is None else ' '.join(iterator[:select]).split()
    supp_before = [len([word for word in all_words[max(j+len(words_before)+1-past_context_size, 0):j+len(words_before)+1] if word in punctuation]) for j in range(len(words))] # we do not count punctuation in the number of words to shuffle
    supp_after = [len([word for word in all_words[j+len(words_before)+1:min(j+len(words_before)+1+future_context_size, len(all_words))] if word in punctuation]) for j in range(len(words))] # we do not count punctuation in the number of words to shuffle

    # For each word, we compute the index of the other words to transform
    # We transform past context. Change conditions "i<j" and ... to something else if needed
    index_words_list_before = [[i for i, item in enumerate(all_words) if item not in punctuation if ((i!=(j+len(words_before))) and  (i <= j+len(words_before)-past_context_size-supp_before[j]))] for j in range(len(words))] # '<=' because context_size of 1 is the current word
    index_words_list_after = [[i for i, item in enumerate(all_words) if item not in punctuation if ((i!=(j+len(words_before))) and (i>j+len(words_before)+future_context_size+supp_after[j]))] for j in range(len(words))] # '<=' because context_size of 1 is the current word

    # Create the new array of sentences with original words 
    new_words = np.tile(np.array(all_words.copy()), (len(words), 1))

    for i in range(len(new_words)):
        if len(index_words_list_before[i])>0: # if there are words to change...
            if transformation=='shuffle':
                # Replace words that need to be shuffled by the random sampling (except fix point and punctuation)
                new_order = random.sample(index_words_list_before[i], len(index_words_list_before[i]))
                if len(index_words_list_before[i])>1:
                    while new_order==index_words_list_before[i]:
                        new_order = random.sample(index_words_list_before[i], len(index_words_list_before[i]))
                new_words[i, index_words_list_before[i]] = new_words[i, new_order]
            elif transformation=='pos_replacement':
                # Replace words that need to be replaced by words with same POS (except fix point and punctuation)
                new_words[i, index_words_list_before[i]] = pick_pos_word(new_words[i, index_words_list_before[i]], dictionary)
            elif transformation=='random_replacement':
                # Replace words that need to be replaced by random words (except fix point and punctuation)
                new_words[i, index_words_list_before[i]] = pick_random_word(new_words[i, index_words_list_before[i]], vocabulary)
        if len(index_words_list_after[i])>0: # if there are words to change...
            if transformation=='shuffle':
                new_order = random.sample(index_words_list_after[i], len(index_words_list_after[i]))
                if len(index_words_list_after[i])>1:
                    while new_order==index_words_list_after[i]:
                        new_order = random.sample(index_words_list_after[i], len(index_words_list_after[i]))
                new_words[i, index_words_list_after[i]] = new_words[i, new_order]
            elif transformation=='pos_replacement':
                new_words[i, index_words_list_after[i]] = pick_pos_word(new_words[i, index_words_list_after[i]], dictionary)
            elif transformation=='random_replacement':
                new_words[i, index_words_list_after[i]] = pick_random_word(new_words[i, index_words_list_after[i]], vocabulary)

    # Convert array to list
    new_words = list(new_words)
    new_words = [list(item) for item in new_words]
    batch_tmp = []
    index_tmp = []
    tokenizer = BertTokenizer.from_pretrained(pretrained_model) # to replace with tokenizer of interest
    # adding transformed context to each sentence
    for i, sentence in enumerate(new_words):
        batch_tmp.append(' '.join(sentence).strip())
        # Determining associated indexes
        tmp1 = ' '.join(sentence[:i+len(words_before)])
        tmp2 = ' '.join(sentence[:i+len(words_before)+1])
        index_tmp.append((len(tokenizer.wordpiece_tokenizer.tokenize(tmp1.strip())), 
                     len(tokenizer.wordpiece_tokenizer.tokenize(tmp2.strip()))
                    )) # to replace with tokenizer of interest and arguments
    print(batch_tmp)
    return batch_tmp, index_tmp

In [25]:
iterator= iterator_list[0]
number_of_sentence=config['number_of_sentence']
number_sentence_before=config['number_of_sentence_before']
number_sentence_after=config['number_of_sentence_after']
pretrained_model='bert-base-uncased'
past_context_size=config['attention_length_before']
future_context_size=config['attention_length_after']
transformation='shuffle'
vocabulary=None
dictionary=None
seed=1111
max_length=512


iterator = [item.strip() for item in iterator]
max_length -= 2 # for special tokens
assert number_of_sentence > 0
tokenizer = BertTokenizer.from_pretrained(pretrained_model) # to replace with tokenizer of interest

batch = []
indexes = []
sentence_count = 0
n = len(iterator)


print('entering while loop...')
# rest of the iterator + context 
while sentence_count < n:
    start = max(sentence_count - number_sentence_before, 0)
    stop = min(sentence_count + number_of_sentence, n)
    stop_post_context = min(stop + number_sentence_after, n)
    token_count = len(tokenizer.wordpiece_tokenizer.tokenize(' '.join(iterator[start:stop_post_context]))) # to replace with tokenizer of interest and arguments
    if token_count > max_length:
        raise ValueError('Cannot fit context with additional sentence. You should reduce context length.')
    # computing batch and indexes
    print(len(iterator[start:stop_post_context]))
    print(iterator[start:stop_post_context])
    
    batch_tmp, index_tmp = transform_sentence_and_context(
        iterator[start:stop_post_context], 
        past_context_size=past_context_size,
        future_context_size=future_context_size, 
        pretrained_model=pretrained_model,
        transformation=transformation,
        vocabulary=vocabulary,
        dictionary=dictionary,
        select=stop-start-1,
        seed=seed
    )        
    batch += batch_tmp
    indexes += index_tmp
    sentence_count = stop
    print(stop-start-1)
    if input()!='':
        break

#for b in batch:
#    print(b)
#print(indexes)

entering while loop...
1
['once , when i was six years old , i saw a magnificent picture in a book about the primeval forest called ‘ real - life stories . ’']
['once , i stories in years real about , was forest a the i magnificent old primeval saw book picture called life ‘ six - when a . ’', 'once , the magnificent in life i years , about old forest i saw called was when real picture a stories six ‘ a - book primeval . ’', 'once , when life saw book in i , a picture the called magnificent primeval years six old was about forest a ‘ stories - real i . ’', 'when , once i about magnificent i life , stories real in years old forest a six saw primeval a the was ‘ picture - called book . ’', 'when , i once was years a in , real old i primeval called a forest the saw picture magnificent six stories ‘ book - about life . ’', 'i , when once was six in stories , old i magnificent picture about a real forest primeval called book life saw ‘ a - years the . ’', 'i , six was when once years book ,

In [17]:
batches_shuffle

['once , i stories in years real about , was forest a the i magnificent old primeval saw book picture called life ‘ six - when a . ’',
 'once , the magnificent in life i years , about old forest i saw called was when real picture a stories six ‘ a - book primeval . ’',
 'once , when life saw book in i , a picture the called magnificent primeval years six old was about forest a ‘ stories - real i . ’',
 'when , once i about magnificent i life , stories real in years old forest a six saw primeval a the was ‘ picture - called book . ’',
 'when , i once was years a in , real old i primeval called a forest the saw picture magnificent six stories ‘ book - about life . ’',
 'i , when once was six in stories , old i magnificent picture about a real forest primeval called book life saw ‘ a - years the . ’',
 'i , six was when once years book , a real stories i called picture the life forest in saw old a ‘ primeval - magnificent about . ’',
 'was , once six i years when old , picture the about l

In [27]:
iterator_list[0]

['once , when i was six years old , i saw a magnificent picture in a book about the primeval forest called ‘ real - life stories . ’',
 'it showed a boa constrictor swallowing a wild animal .',
 'here is a copy of the drawing .',
 'it said in the book : “ boa constrictors swallow their prey whole , without chewing .',
 'then they are not able to move , and they sleep for the six months it takes for digestion . ”',
 'so i thought a lot about the adventures of the jungle and , in turn , i managed , with a coloured pencil , to make my first drawing .',
 'my drawing number one .',
 'it looked like this : i showed my masterpiece to the grownups and i asked them if my drawing frightened them .',
 'they answered me : “ why would anyone be frightened by a hat ? ”',
 'my drawing was not of a hat .',
 'it showed a boa constrictor digesting an elephant .',
 'i then drew the inside of the boa constrictor , so that the grownups could understand .',
 'they always need to have things explained .',
 '