In [75]:
import torch
from transformers import BertForTokenClassification
from transformers import BertTokenizerFast
import pandas as pd
import re
from tqdm import tqdm

In [76]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
label_all_tokens = False

In [77]:
df = pd.read_csv("ner.csv")

labels = [i.split() for i in df['labels'].values.tolist()]

unique_labels = set()

for lb in labels:
  [unique_labels.add(i) for i in lb if i not in unique_labels]
 
labels_to_ids = {k: v for v, k in enumerate(sorted(unique_labels))}
ids_to_labels = {v: k for v, k in enumerate(sorted(unique_labels))}

In [78]:
from transformers import BertForTokenClassification

class BertModel(torch.nn.Module):

    def __init__(self):

        super(BertModel, self).__init__()

        self.bert = BertForTokenClassification.from_pretrained('bert-base-cased', num_labels=len(unique_labels))

    def forward(self, input_id, mask, label):

        output = self.bert(input_ids=input_id, attention_mask=mask, labels=label, return_dict=False)

        return output

In [101]:
def align_word_ids(texts):
  
    tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True)

    word_ids = tokenized_inputs.word_ids()

    previous_word_idx = None
    label_ids = []

    for word_idx in word_ids:

        if word_idx is None:
            label_ids.append(-100)

        elif word_idx != previous_word_idx:
            try:
                label_ids.append(1)
            except:
                label_ids.append(-100)
        else:
            try:
                label_ids.append(1 if label_all_tokens else -100)
            except:
                label_ids.append(-100)
        previous_word_idx = word_idx

    return label_ids


def evaluate_one_text(model, sentence):


    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    if use_cuda:
        model = model.cuda()

    text = tokenizer(sentence, padding='max_length', max_length = 512, truncation=True, return_tensors="pt")

    mask = text['attention_mask'].to(device)
    input_id = text['input_ids'].to(device)
    label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device)

    logits = model(input_id, mask, None)
    logits_clean = logits[0][label_ids != -100]

    predictions = logits_clean.argmax(dim=1).tolist()
    prediction_label = [ids_to_labels[i] for i in predictions]
    return [sentence.split(' '), prediction_label]

In [102]:
saved_model = torch.load('/Volumes/Drive/GitHub/DaVinciCodeTheTrackOfRobertLangdon/models/small-bert-fine-tuned', map_location=torch.device('cpu'))
res = evaluate_one_text(saved_model, 'Bill Gates is the founder of Microsoft That is all trewtewiui i know')

In [97]:
with open('TheDaVinciCode.txt', 'r') as f:
    data = f.read()
    data = data.replace("\n", " ")
    data = data.replace("\t", " ")
    data = data.replace("\r", " ")
    data = data.replace("  ", " ")
    data = data.replace("   ", " ")
    data = data.replace("    ", " ")
    data = data.replace("     ", " ")

f.close()

In [82]:
# Get every chapter in the data with case sensitivity.
paragraphs = data.split("CHAPTER")

# Merge the paragraphs if the first element is not a number.
tmp = []
for i in range(1, len(paragraphs)-1):
    if not paragraphs[i][1].isdigit():
        tmp.append(paragraphs[i-1] + paragraphs[i])
        paragraphs.remove(paragraphs[i])

sentences = []
for paragraph in paragraphs:
    sentence = re.split("(?<!\w\.\w.)(?<![A-Z]\.)(?<![A-Z][a-z]\.)(?<=\.|\?)|!", paragraph)
    sentences.extend(sentence)

print('number of sentences: ', len(sentences))

number of sentences:  13712


In [83]:
sentences = sentences[:3000]

In [106]:
# go three sentences by three in for loop through sentences
full = []
for i in tqdm(range(0, len(sentences)-8, 8)):
    full.append(evaluate_one_text(saved_model, sentences[i] + sentences[i+1] + sentences[i+2]))

100%|██████████| 374/374 [01:42<00:00,  3.65it/s]


In [108]:
result

("The Da Vinci Code  Dan Brown Prologue  Louvre Museum, Paris 10:46 P.M.  Renowned curator Jacques Sauniere staggered through the vaulted archway of the museum's Grand Gallery. He lunged for the nearest painting he could see, a Caravaggio. Grabbing the gilded frame, the seventy- six-year-old man heaved the masterpiece toward himself until it tore from the wall and Sauniere collapsed backward in a heap beneath the canvas.",
 ['O',
  'I-per',
  'I-per',
  'B-per',
  'B-per',
  'I-per',
  'O',
  'B-geo',
  'I-org',
  'O',
  'B-geo',
  'O',
  'O',
  'B-tim',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'B-per',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
 

In [107]:
# go through results two by two
for i in range(0,len(full)-1,2):
    

0
2
4
6
8
10
12
14
16
18
20
22
24
26
28
30
32
34
36
38
40
42
44
46
48
50
52
54
56
58
60
62
64
66
68
70
72
74
76
78
80
82
84
86
88
90
92
94
96
98
100
102
104
106
108
110
112
114
116
118
120
122
124
126
128
130
132
134
136
138
140
142
144
146
148
150
152
154
156
158
160
162
164
166
168
170
172
174
176
178
180
182
184
186
188
190
192
194
196
198
200
202
204
206
208
210
212
214
216
218
220
222
224
226
228
230
232
234
236
238
240
242
244
246
248
250
252
254
256
258
260
262
264
266
268
270
272
274
276
278
280
282
284
286
288
290
292
294
296
298
300
302
304
306
308
310
312
314
316
318
320
322
324
326
328
330
332
334
336
338
340
342
344
346
348
350
352
354
356
358
360
362
364
366
368
370
372
