In [17]:
import pandas as pd
dataframe = pd.read_csv("../../data/ner.csv")
dataframe.head

<bound method NDFrame.head of             Sentence #                                           Sentence  \
0          Sentence: 1  Thousands of demonstrators have marched throug...   
1          Sentence: 2  Families of soldiers killed in the conflict jo...   
2          Sentence: 3  They marched from the Houses of Parliament to ...   
3          Sentence: 4  Police put the number of marchers at 10,000 wh...   
4          Sentence: 5  The protest comes on the eve of the annual con...   
...                ...                                                ...   
47954  Sentence: 47955  Indian border security forces are accusing the...   
47955  Sentence: 47956  Indian officials said no one was injured in Sa...   
47956  Sentence: 47957  Two more landed in fields belonging to a nearb...   
47957  Sentence: 47958  They say not all of the rockets exploded upon ...   
47958  Sentence: 47959    Indian forces said they responded to the attack   

                                             

In [18]:
from collections import Counter


unique_labels = set()
label_counter = Counter()
for list_of_tags in dataframe["Tag"]:
  for tag in eval(list_of_tags):
    unique_labels.add(tag)
    label_counter[tag] += 1

In [19]:
from pprint import pprint
pprint(label_counter)

Counter({'O': 887908,
         'B-geo': 37644,
         'B-tim': 20333,
         'B-org': 20143,
         'I-per': 17251,
         'B-per': 16990,
         'I-org': 16784,
         'B-gpe': 15870,
         'I-geo': 7414,
         'I-tim': 6528,
         'B-art': 402,
         'B-eve': 308,
         'I-art': 297,
         'I-eve': 253,
         'B-nat': 201,
         'I-gpe': 198,
         'I-nat': 51})


In [20]:
labels_to_ids = {k: v for v, k in enumerate(sorted(unique_labels))}
ids_to_labels = {v: k for v, k in enumerate(sorted(unique_labels))}
print(labels_to_ids)
import pickle
with open("../../data/ids_to_labels.pickle", "wb") as f:
    pickle.dump(ids_to_labels, f)
with open("../../data/labels_to_ids.pickle", "wb") as f:
    pickle.dump(labels_to_ids, f)
with open("../../data/unique_labels.pickle", "wb") as f:
    pickle.dump(unique_labels, f)

{'B-art': 0, 'B-eve': 1, 'B-geo': 2, 'B-gpe': 3, 'B-nat': 4, 'B-org': 5, 'B-per': 6, 'B-tim': 7, 'I-art': 8, 'I-eve': 9, 'I-geo': 10, 'I-gpe': 11, 'I-nat': 12, 'I-org': 13, 'I-per': 14, 'I-tim': 15, 'O': 16}


In [5]:
import transformers
from transformers import BertTokenizerFast

tokenizer = BertTokenizerFast.from_pretrained('dslim/bert-base-NER')

def ids_to_tokens(input):
    return tokenizer.convert_ids_to_tokens(input)

In [6]:
def align_label(texts, labels):
    tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True)

    word_ids = tokenized_inputs.word_ids()

    previous_word_idx = None
    label_ids = []

    for word_idx in word_ids:
        if word_idx is None:
            label_ids.append(-100)
        elif word_idx != previous_word_idx:
            try:
                label_ids.append(labels_to_ids[labels[word_idx]])
            except:
                label_ids.append(-100)
        else:
            try:
                label_ids.append(labels_to_ids[labels[word_idx]] if True else -100)
            except:
                label_ids.append(-100)
        previous_word_idx = word_idx

    return label_ids

In [7]:
import torch

class DataSequence(torch.utils.data.Dataset):
    def __init__(self, df):
        lb = df['Tag'].values.tolist()
        txt = df['Sentence'].values.tolist()
        self.texts = [tokenizer(str(i),
                               padding='max_length', max_length = 512, truncation=True, return_tensors="pt") for i in txt]
        self.labels = [align_label(i,j) for i,j in zip(txt, lb)]

    def __len__(self):
        return len(self.labels)

    def get_batch_data(self, idx):
        return self.texts[idx]

    def get_batch_labels(self, idx):
        return torch.LongTensor(self.labels[idx])

    def __getitem__(self, idx):
        batch_data = self.get_batch_data(idx)
        batch_labels = self.get_batch_labels(idx)

        return batch_data, batch_labels

In [8]:
import numpy as np

df_train, df_val, df_test = np.split(dataframe[:100].sample(frac=1, random_state=42),
                            [int(.8 * len(dataframe[:100])), int(.9 * len(dataframe[:100]))])

In [9]:
from transformers import BertForTokenClassification

class BertModel(torch.nn.Module):
    def __init__(self):
        super(BertModel, self).__init__()
        self.bert = BertForTokenClassification.from_pretrained('dslim/bert-base-NER', num_labels=len(unique_labels), ignore_mismatched_sizes=True)

    def forward(self,input_ids, label=None):
        output = self.bert(labels=label, input_ids = input_ids, return_dict=False)
        return output

In [13]:
from transformers import pipeline
model = BertModel()

model.load_state_dict(torch.load('bert_trainedNEREnglish',map_location=torch.device('cpu')))

Some weights of BertForTokenClassification were not initialized from the model checkpoint at dslim/bert-base-NER and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([9, 768]) in the checkpoint and torch.Size([17, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([9]) in the checkpoint and torch.Size([17]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

In [11]:
df_test

Unnamed: 0,Sentence #,Sentence,POS,Tag
91,Sentence: 92,Gunmen have shot and killed a Roman Catholic n...,"['NNS', 'VBP', 'VBN', 'CC', 'VBN', 'DT', 'NNP'...","['O', 'O', 'O', 'O', 'O', 'O', 'B-org', 'I-org..."
74,Sentence: 75,"Harcourt , of the Australian Trade Commission ...","['NNP', ',', 'IN', 'DT', 'NNP', 'NNP', 'NNP', ...","['B-per', 'O', 'O', 'O', 'B-org', 'I-org', 'I-..."
86,Sentence: 87,Indonesian police have arrested three men in c...,"['JJ', 'NNS', 'VBP', 'VBN', 'CD', 'NNS', 'IN',...","['B-gpe', 'O', 'O', 'O', 'O', 'O', 'O', 'O', '..."
82,Sentence: 83,"In recent weeks , AU officials say Sudanese tr...","['IN', 'JJ', 'NNS', ',', 'NNP', 'NNS', 'VBP', ...","['O', 'O', 'O', 'O', 'B-org', 'O', 'O', 'B-gpe..."
20,Sentence: 21,Local news reports said at least five mortar s...,"['JJ', 'NN', 'NNS', 'VBD', 'IN', 'JJS', 'CD', ...","['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', ..."
60,Sentence: 61,They say the company has produced some shoddy ...,"['PRP', 'VBP', 'DT', 'NN', 'VBZ', 'VBN', 'DT',...","['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', ..."
71,Sentence: 72,It was made a permanent body in 1995 to provid...,"['PRP', 'VBD', 'VBN', 'DT', 'JJ', 'NN', 'IN', ...","['O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-tim', '..."
14,Sentence: 15,An official with the German firm Bilfinger Ber...,"['DT', 'NN', 'IN', 'DT', 'JJ', 'NN', 'NNP', 'N...","['O', 'O', 'O', 'O', 'B-gpe', 'O', 'B-org', 'I..."
92,Sentence: 93,Some witnesses to the Sunday shooting said the...,"['DT', 'NNS', 'TO', 'DT', 'NNP', 'NN', 'VBD', ...","['O', 'O', 'O', 'O', 'B-tim', 'O', 'O', 'O', '..."
51,Sentence: 52,A senior Pakistani military official says Paki...,"['DT', 'JJ', 'JJ', 'JJ', 'NN', 'VBZ', 'NNP', '...","['O', 'O', 'B-gpe', 'O', 'O', 'O', 'B-gpe', 'O..."


In [14]:
output = []
j = 0
for sentence, answer in zip(df_test["Sentence"], df_test["Tag"]):
    B = np.asarray([tokenizer(sentence.replace("-", ""))["input_ids"]]).reshape(1,1,-1)
    logits = model(torch.as_tensor(np.array(B))[0])[0]
    print(logits.shape)
    for i in range(logits.shape[0]):
        #print(logits[i])
        logits_clean = logits[i].argmax(dim=1)
        words = sentence.replace("-", "").split()
        print(logits_clean)
        #for i in range(len(eval(answer))-1):
        #    print(words[i], eval(answer)[i])
        i = 0
        tokenized_sentence = ids_to_tokens(tokenizer(sentence.replace("-", ""))["input_ids"])
        #print([ids_to_labels[x.item()] for x in logits_clean])
        k = 0
        for el in logits_clean[1:-1]:
            if i == len(words):
                break
            elem = logits_clean[k]
            if i+1 <= len(logits_clean) and tokenized_sentence[i][:2] == "##":
                if elem.item() == "O":
                    label = logits_clean[i]
                else:
                    label = elem.item()
                output.append((words[i], ids_to_labels[label], eval(answer)[i]))
                k+=2
            else:
                output.append((words[i], ids_to_labels[elem.item()], eval(answer)[i]))
                k+=1
            i += 1

    print(output)
    j+=1
    if j == 2:
        break

torch.Size([1, 34, 17])
tensor([16, 16, 16, 16, 16, 16, 16, 16,  3,  3, 16, 16, 16, 16, 16, 16, 16, 16,
        16, 16, 16,  2,  2,  2,  2,  2, 10, 10, 10, 10, 16,  2, 16, 16])
[('Gunmen', 'O', 'O'), ('have', 'O', 'O'), ('shot', 'O', 'O'), ('and', 'O', 'O'), ('killed', 'O', 'O'), ('a', 'O', 'O'), ('Roman', 'O', 'B-org'), ('Catholic', 'B-gpe', 'I-org'), ('nun', 'B-gpe', 'O'), ('and', 'O', 'O'), ('her', 'O', 'O'), ('bodyguard', 'O', 'O'), ('at', 'O', 'O'), ('the', 'O', 'O'), ('hospital', 'O', 'O'), ('where', 'O', 'O'), ('she', 'O', 'O'), ('worked', 'O', 'O'), ('in', 'O', 'O'), ('Islamistcontrolled', 'O', 'O'), ('Mogadishu', 'B-geo', 'B-geo'), (',', 'B-geo', 'O'), ('Somalia', 'B-geo', 'B-geo'), ('.', 'B-geo', 'O')]
torch.Size([1, 24, 17])
tensor([16,  6, 16, 16, 16,  5, 13, 13, 16, 16, 16,  5,  5, 16, 16, 16, 16, 16,
        16, 16, 16, 16, 16, 16])
[('Gunmen', 'O', 'O'), ('have', 'O', 'O'), ('shot', 'O', 'O'), ('and', 'O', 'O'), ('killed', 'O', 'O'), ('a', 'O', 'O'), ('Roman', 'O', 'B-or