# Notebook to evaluate model entropy

In [1]:
from transformers import BertTokenizer, GPT2Tokenizer,  GPT2LMHeadModel, GPT2Tokenizer, BertForMaskedLM, RobertaForMaskedLM, RobertaTokenizer
import scipy
import pandas as pd
import numpy as np
import os
import torch
import glob

In [2]:
from GPT2.tokenizer import tokenize
from LSTM.tokenizer import unk_transform
#from LSTM.model import LSTMExtractor
from LSTM.data import Dictionary

In [3]:
from GPT2 import utils as utils_gpt2
from BERT import utils as utils_bert
from ROBERTA import utils as utils_roberta


### Functions

In [4]:
def entropy(pk):
    pk = pk.numpy()
    entropy = -np.sum(pk * np.log2(pk), axis=0)
    return entropy 

In [5]:
def eval_output(out): 
    print(out[0].detach().squeeze(0).shape[0])
    result = np.sum([entropy(scipy.special.softmax(out[0].detach().squeeze(0)[ax])) for ax in range(out[0].detach().squeeze(0).shape[0])]) 
    return result

### Model instanciation

In [14]:
model_base = GPT2LMHeadModel.from_pretrained('gpt2')
t_base = GPT2Tokenizer.from_pretrained('gpt2')

In [7]:
model_medium = GPT2LMHeadModel.from_pretrained('gpt2-medium')
t_medium = GPT2Tokenizer.from_pretrained('gpt2-medium')

In [8]:
model_bert = BertForMaskedLM.from_pretrained('bert-base-cased')
t_bert = BertTokenizer.from_pretrained('bert-base-cased') 

In [4]:
model_roberta = RobertaForMaskedLM.from_pretrained('roberta-base')
t_roberta = RobertaTokenizer.from_pretrained('roberta-base') 

In [None]:
model_lstm = LSTMExtractor(...)

In [None]:
data = pd.read_csv('data/stimuli-representations/english/LSTM_embedding-size_600_nhid_300_nlayers_1_dropout_02_wiki_kristina_english/activations_run1.csv')
lstm_result = data['entropy']

### Data retrieval 

In [5]:
language = 'english'

In [6]:
template = '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/text/english/text_english_run*.txt' # path to text input  
#template = '/USers/alexpsq/Code/Parietal/data/text_english_run*.txt'


In [7]:
paths = sorted(glob.glob(template))

In [8]:
iterator_list = [tokenize(path, language, train=False) for path in paths]

100%|██████████| 135/135 [00:00<00:00, 233112.82it/s]
100%|██████████| 135/135 [00:00<00:00, 202950.19it/s]
100%|██████████| 176/176 [00:00<00:00, 273103.04it/s]
100%|██████████| 173/173 [00:00<00:00, 270348.21it/s]
100%|██████████| 177/177 [00:00<00:00, 220033.14it/s]
100%|██████████| 216/216 [00:00<00:00, 252429.55it/s]
100%|██████████| 196/196 [00:00<00:00, 308242.81it/s]
100%|██████████| 145/145 [00:00<00:00, 197139.09it/s]
100%|██████████| 207/207 [00:00<00:00, 75893.44it/s]

Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.
Tokenizing...
Preprocessing...
Preprocessed.
Tokenized.





In [18]:
res = ' '.join(iterator_list[0])
res = res.split(' ')
print(len(res))

2015


In [None]:
vocab_path = '/neurospin/unicog/protocols/IRMf/LePetitPrince_Pallier_2018/LePetitPrince/data/text/english/lstm_training'
#vocab_path = '/Users/alexpsq/Code/data/'
vocab = Dictionary(vocab_path, language)

In [None]:
iterator_list_lstm = [[unk_transform(word, vocab) for item in iterator_ for word in item.strip().split(' ')] for iterator_ in iterator_list]


In [232]:
def batchity(iterator, context_length, pretrained_bert, max_length=512):
    """Batchify iterator sentence, to get minimum context length 
    when possible.
    Arguments:
        - iterator: sentence iterator
        - context_length: int
    Returns:
        - batch: sequence iterator
        - indexes: tuple of int
    """
    iterator = [item.strip() for item in iterator]
    max_length -= 2 # for special tokens
    tokenizer = BertTokenizer.from_pretrained(pretrained_bert)
    
    batch = []
    indexes = []
    sentence_count = 0
    n = len(iterator)
    
    assert context_length < max_length
    token_count = 0
    while sentence_count < n and token_count < max_length:
        token_count += len(tokenizer.wordpiece_tokenizer.tokenize(iterator[sentence_count]))
        if token_count < max_length:
            sentence_count += 1
    batch.append(' '.join(iterator[:sentence_count]))
    indexes.append((0, len(tokenizer.wordpiece_tokenizer.tokenize(batch[-1]))))
    
    while sentence_count < n:
        token_count = 0
        sentence_index = sentence_count - 1
        tmp = sentence_count
        while token_count < context_length:
            token_count += len(tokenizer.wordpiece_tokenizer.tokenize(iterator[sentence_index]))
            sentence_index -= 1
        while sentence_count < n and token_count < max_length:
            token_count += len(tokenizer.wordpiece_tokenizer.tokenize(iterator[sentence_count]))
            if token_count < max_length:
                sentence_count += 1
        batch.append(' '.join(iterator[sentence_index+1:sentence_count]))
        indexes.append((len(tokenizer.wordpiece_tokenizer.tokenize(' '.join(iterator[sentence_index+1:tmp]))), len(tokenizer.wordpiece_tokenizer.tokenize(batch[-1]))))
    return batch, indexes

In [43]:
t_bert.decode(t_bert.convert_tokens_to_ids(t_bert.wordpiece_tokenizer.tokenize(' Straight ahead ... ”')))

'Straight ahead... ”'

In [40]:
t_base.tokenize(' Straight ahead ... ”')

['ĠStraight', 'Ġahead', 'Ġ...', 'ĠâĢ', 'Ŀ']

In [None]:
result = []                                                                                                                                                                                        

for line in iterator_list[0]:  
    result.append(len([word for word in line.strip().split(' ')])) 

In [None]:
for i in result:
    out_lstm.append(np.sum(lstm_result[index:index+i]))
    index+=i 

### Evaluation

In [9]:
results = []

In [10]:
def eval_output(out):
    result = np.sum([entropy(scipy.special.softmax(out[ax])) for ax in range(out.shape[0])]) 
    return result

def entropy(pk):
    pk = pk
    entropy = -np.sum(pk * np.log2(pk), axis=0)
    return entropy 

In [57]:
batches, indexes = utils_gpt2.batchify_per_sentence_with_context(iterator_list[0], 1, 15, 'gpt2', max_length=512)
for index, batch in enumerate(batches):
    batch = batch.strip()
    tokenized_text = t_base.tokenize(batch, add_prefix_space=True)
    inputs_ids = torch.tensor([t_base.convert_tokens_to_ids(tokenized_text)])
    attention_mask = torch.tensor([[1 for x in tokenized_text]])
    out_gpt2_base = model_base(inputs_ids, attention_mask=attention_mask)
    mapping = utils_gpt2.match_tokenized_to_untokenized(tokenized_text, batch)

    out = out_gpt2_base[0].detach().squeeze(0).numpy()
    key = None
    new_activations = []
    for key_, value in mapping.items(): 
        if value[0] == indexes[index][0]:
            key = key_
    #print(indexes[index][0], mapping)
    for word_index in range(key, len(mapping.keys())):
        word_activation = []
        word_activation.append([out[index, :] for index in mapping[word_index]])
        word_activation = np.vstack(word_activation)
        new_activations.append(np.mean(word_activation, axis=0).reshape(1,-1))
    activations = np.vstack(new_activations)
    print(eval_output(activations))

1479.7416
81.449234
133.16344
84.96507
65.91492
57.852043
67.326385
67.71767
107.38642
57.373596
45.209236
48.960426
194.10556
86.80297
85.67359
138.64224
36.772038
80.569664
76.45224
144.53545
40.44752
176.63167
53.160664
63.024975
121.399635
98.385025
98.946335
35.67877
41.66576
32.041798
35.979294
67.63603
25.65775
44.13664
96.6409
80.69574
94.12051
41.170517
224.02808
66.01508
75.64002
148.27258
115.84273
100.31509
95.7048
25.196476
78.75001
161.28549
213.25603
44.87957
31.980438
146.24057
59.030235
81.71396
9.538782
70.57346
84.76148
49.91638
32.467304
24.687328
24.578554
46.58152
50.78563
27.506622
105.17947
76.638115
25.125643
169.71748
67.78848
196.56427
57.977867
139.05779
68.33978
29.470459
86.83928
38.392086
66.54461
74.31881
10.809681
41.35556
64.11928
54.646378
96.48307
87.59054
240.96822
47.792496
15.6032505
30.054953
44.417683
78.10411
35.53222
33.70642
47.416634
19.873497
19.407288
16.716162
91.17547
61.873688
83.09551
30.125158
124.356224
33.502968
145.25267
68.33378
1

In [81]:
batches, indexes = utils_bert.batchify_per_sentence_with_context(iterator_list[0], 1, 15, 'bert-base-cased', max_length=512)
for index, batch in enumerate(batches):
    batch = '[CLS] ' + batch.strip() + ' [SEP]'
    tokenized_text = t_bert.wordpiece_tokenizer.tokenize(batch)
    inputs_ids = torch.tensor([t_bert.convert_tokens_to_ids(tokenized_text)])
    attention_mask = torch.tensor([[1 for x in tokenized_text]])
    out_bert = model_bert(inputs_ids, attention_mask=attention_mask)
    mapping = utils_bert.match_tokenized_to_untokenized(tokenized_text, batch)

    out = out_bert[0].detach().squeeze(0).numpy()
    new_activations = []
    
    key_start = None
    key_stop = None
    for key_, value in mapping.items(): 
        if (value[0] - 1) == (indexes[index][0]): #because we added [CLS] token at the beginning
            key_start = key_
    for key_, value in mapping.items(): 
        if value[-1] == (indexes[index][1]): #because we added [CLS] token at the beginning
            key_stop = key_
    for word_index in range(key_start, key_stop + 1): # len(mapping.keys()) - 1
        word_activation = []
        word_activation.append([out[index, :] for index in mapping[word_index]])
        word_activation = np.vstack(word_activation)
        new_activations.append(np.mean(word_activation, axis=0).reshape(1,-1))
    activations = np.vstack(new_activations)
    print(eval_output(activations))

51.726067
0.51056916
15.8018265
1.6382804
0.89195096
0.073676035
2.8023162
0.24247593
4.106795
1.2153734
0.16375905
3.4561903
1.9005369
0.33084705
2.488223
12.904313
0.10350615
1.3819168
2.1053207
6.3975186
0.2907766
7.2758303
0.33432707
0.9509853
1.6948833
8.109937
2.5655208
3.175221
1.1034843
1.9702392
2.962284
0.38524824
0.048696566
1.1148791
0.7471783
0.17330718
0.9248373
0.14171925
14.549717
6.0737476
2.1092517
7.7870703
0.2677342
4.1574235
3.08417
2.1414485
7.6469812
8.539354
2.9270263
2.006554
0.90666777
6.918694
3.5429232
7.5012136
0.06419663
5.3476486
12.037139
0.16684225
0.13371281
0.5327476
0.03672949
1.7643049
0.17249109
0.6117816
8.716273
0.18557273
1.754605
7.043815
0.748481
15.958378
0.90590805
5.558101
1.0060807
1.8724229
4.915218
3.2075346
0.7280161
3.3455396
0.3391056
2.4124892
0.25327185
0.08924365
1.2096369
0.16785105
2.4065912
1.503155
0.10804603
0.07097508
0.6218133
0.04384697
1.8683732
0.9556374
5.203377
1.3145803
0.050640065
1.9549698
2.306864
2.104768
2.25039
0

In [11]:
batches, indexes = utils_roberta.batchify_per_sentence_with_pre_and_post_context(iterator_list[0], 1, 0, 0, 'roberta-base', max_length=512)
for index, batch in enumerate(batches):
    batch = '<s> ' + batch.strip() + ' </s>'
    tokenized_text = t_roberta.tokenize(batch, add_prefix_space=True)
    inputs_ids = torch.tensor([t_roberta.convert_tokens_to_ids(tokenized_text)])
    attention_mask = torch.tensor([[1 for x in tokenized_text]])
    out_roberta = model_roberta(inputs_ids, attention_mask=attention_mask)
    mapping = utils_roberta.match_tokenized_to_untokenized(tokenized_text, batch)

    out = out_roberta[0].detach().squeeze(0).numpy()
    new_activations = []
    key_start = None
    key_stop = None
    for key_, value in mapping.items(): 
        if (value[0] - 1) == (indexes[index][0]): #because we added [CLS] token at the beginning
            key_start = key_
    for key_, value in mapping.items(): 
        if value[-1] == (indexes[index][1]): #because we added [CLS] token at the beginning
            key_stop = key_
    for word_index in range(key_start, key_stop + 1): # len(mapping.keys()) - 1
        word_activation = []
        word_activation.append([out[index, :] for index in mapping[word_index]])
        word_activation = np.vstack(word_activation)
        new_activations.append(np.mean(word_activation, axis=0).reshape(1,-1))
    activations = np.vstack(new_activations)
    print(eval_output(activations))

['Once , when I was six years old , I saw a magnificent picture in a book about the primeval forest called ‘ Real - life Stories . ’']
0 0 135
--


IndexError: string index out of range

In [17]:
len(t_roberta.tokenize('', add_prefix_space=True))

IndexError: string index out of range

In [73]:
for line in iterator_list[0]:
    line = '<s> ' + line.strip() + ' </s>'
    print(line)
    print(t_roberta.tokenize(line, add_prefix_space=True))

<s> Once , when I was six years old , I saw a magnificent picture in a book about the primeval forest called ‘ Real - life Stories . ’ </s>
['<s>', 'ĠOnce', 'Ġ,', 'Ġwhen', 'ĠI', 'Ġwas', 'Ġsix', 'Ġyears', 'Ġold', 'Ġ,', 'ĠI', 'Ġsaw', 'Ġa', 'Ġmagnificent', 'Ġpicture', 'Ġin', 'Ġa', 'Ġbook', 'Ġabout', 'Ġthe', 'Ġprime', 'val', 'Ġforest', 'Ġcalled', 'ĠâĢ', 'ĺ', 'ĠReal', 'Ġ-', 'Ġlife', 'ĠStories', 'Ġ.', 'ĠâĢ', 'Ļ', '</s>']
<s> It showed a boa constrictor swallowing a wild animal . </s>
['<s>', 'ĠIt', 'Ġshowed', 'Ġa', 'Ġbo', 'a', 'Ġconst', 'rict', 'or', 'Ġswallowing', 'Ġa', 'Ġwild', 'Ġanimal', 'Ġ.', '</s>']
<s> Here is a copy of the drawing . </s>
['<s>', 'ĠHere', 'Ġis', 'Ġa', 'Ġcopy', 'Ġof', 'Ġthe', 'Ġdrawing', 'Ġ.', '</s>']
<s> It said in the book : “ Boa constrictors swallow their prey whole , without chewing . </s>
['<s>', 'ĠIt', 'Ġsaid', 'Ġin', 'Ġthe', 'Ġbook', 'Ġ:', 'ĠâĢ', 'ľ', 'ĠBo', 'a', 'Ġconst', 'rict', 'ors', 'Ġswallow', 'Ġtheir', 'Ġprey', 'Ġwhole', 'Ġ,', 'Ġwithout', 'Ġchewing', 'Ġ.'

In [34]:
35.969994
0.15918078
8.293572
0.16216268
0.4701601
0.09935355
4.9531755
0.41039103
0.78353435
1.0025072
0.027972687
1.6763369
1.3630848
0.07719013
7.683632

torch.Size([1, 32, 50257])
32
torch.Size([1, 32, 50257])
32
torch.Size([1, 32, 28996])
32
torch.Size([1, 34, 50265])
34
torch.Size([1, 13, 50257])
13
torch.Size([1, 13, 50257])
13
torch.Size([1, 16, 28996])
16
torch.Size([1, 15, 50265])
15


KeyboardInterrupt: 

In [24]:
result = pd.DataFrame(results, columns=['GPT2-base', 'GPT2-medium', 'BERT-base-cased', 'ROBERTA-base'])

In [61]:
result.sum()

GPT2-base          13650.756184
GPT2-medium        12541.113115
BERT-base-cased     2469.644808
ROBERTA-base         466.423326
dtype: float64