In [1]:
import warnings
warnings.filterwarnings('ignore')
import os
import pandas as pd

def read_from_file(file_path:str):
    words, poss, chunks, tags = [], [], [], []
    sentences_label = ['Sentence:0']
    sent_id = 1

    for line in open(file_path):
        if line in ['\n', '\r\n']:
            if len(sentences_label) != 1:
                sentences_label.pop()
                sentences_label.append('Sentence:{}'.format(sent_id))
            sent_id += 1

        elif line.startswith('-DOCSTART-'):
            sent_id = sent_id - 1 if len(sentences_label) != 1 else sent_id

        else:
            word, pos, chunk, tag = line.strip().split()
            assert (len(line.split()) == 4)
            words.append(word)
            poss.append(pos)
            chunks.append(chunk)
            tags.append(tag)
            sentences_label.append(None)
    sentences_label.pop()
    assert (len(sentences_label)) == len(tags)

    dataset = {'Sentence #': sentences_label,
               'Word': word,
               'Pos': poss,
               'Chunk': chunks,
               'Tag': tags
               }
    data = pd.DataFrame(dataset, columns=['Sentence #', 'Word', 'Pos', 'Chunk', 'Tag'])
    data = data.fillna(method='ffill')
    print('file {}, tokens {}'.format(file_path, len(data['Word'])))
    return data


data = read_from_file('D:/NER/CRF/datasets/datasets/train.txt')
data

file D:/NER/CRF/datasets/datasets/train.txt, tokens 203621


Unnamed: 0,Sentence #,Word,Pos,Chunk,Tag
0,Sentence:0,2,NNP,B-NP,B-ORG
1,Sentence:0,2,VBZ,B-VP,O
2,Sentence:0,2,JJ,B-NP,B-MISC
3,Sentence:0,2,NN,I-NP,O
4,Sentence:0,2,TO,B-VP,O
...,...,...,...,...,...
203616,Sentence:14040,2,CD,I-NP,O
203617,Sentence:14041,2,NN,B-NP,B-ORG
203618,Sentence:14041,2,CD,I-NP,O
203619,Sentence:14041,2,NNP,I-NP,B-ORG


In [1]:
import warnings
warnings.filterwarnings('ignore')
import os
import pandas as pd
from utils import *
from sklearn_crfsuite import CRF
from sklearn.externals import joblib
from sklearn.model_selection import cross_val_predict
from sklearn_crfsuite.metrics import flat_classification_report
import eli5

ner_dataset_dir = 'datasets/datasets/train.txt'
data = read_from_file(ner_dataset_dir)
test_dir = 'datasets/datasets/test.txt'
test_data = read_from_file(test_dir)

getter = SentenceGetter(data)
sentences = getter.sentences
X = [sent2features(s) for s in sentences]
Y = [sent2labels(s) for s in sentences]

test_getter = SentenceGetter(test_data)
test_sentences = test_getter.sentences
test_X = [sent2features(s) for s in test_sentences]
test_Y = [sent2labels(s) for s in test_sentences]



0                 EU
1            rejects
2             German
3               call
4                 to
5            boycott
6            British
7               lamb
8                  .
9              Peter
10         Blackburn
11          BRUSSELS
12        1996-08-22
13               The
14          European
15        Commission
16              said
17                on
18          Thursday
19                it
20         disagreed
21              with
22            German
23            advice
24                to
25         consumers
26                to
27              shun
28           British
29              lamb
             ...    
203591            77
203592             .
203593        SOCCER
203594             -
203595       ENGLISH
203596        SOCCER
203597       RESULTS
203598             .
203599        LONDON
203600    1996-08-30
203601       Results
203602            of
203603       English
203604        league
203605       matches
203606            on
203607       

In [2]:

# def train():
#     crf = CRF(algorithm='lbfgs', c1=10, c2=0.1, max_iterations=100, all_possible_transitions=True)
#     pred = cross_val_predict(estimator=crf, X=X, y=Y, cv=3)
#     report = flat_classification_report(y_pred=pred, y_true=Y)
#     crf.fit(X, Y)
#
#     print(report)
#     eli5.show_weights(crf, top=5, show=['transition_features'])
#     eli5.show_weights(crf, top=10, feature_re='^word\.is',
#                       horizontal_layout=False, show=['targets'])
#
#     if not os.path.exists('models'):
#         os.makedirs('models')
#     joblib.dump(crf, 'models/crf.pkl')



crf = joblib.load(filename='models/crf.pkl')
pred = crf.predict(test_X)
report = flat_classification_report(y_pred=pred, y_true=test_Y)
print(report)
f_out = 'output/crf.output.txt'
with open(f_out, 'w') as f:
    for s, s_pred in zip(sentences, pred):
        for w, p in zip(s, s_pred):
            f.write('{}\t{}\t{}\n'.format(w[0], w[2], p))

eli5.show_weights(crf, top=5, show=['transition_features'])
eli5.show_weights(crf, top=10, feature_re='^word\.is',
                      horizontal_layout=False, show=['targets'])

              precision    recall  f1-score   support

       B-LOC       0.69      0.68      0.69      1668
      B-MISC       0.76      0.56      0.65       702
       B-ORG       0.66      0.59      0.62      1661
       B-PER       0.70      0.77      0.74      1617
       I-LOC       0.80      0.50      0.62       257
      I-MISC       0.55      0.55      0.55       216
       I-ORG       0.58      0.64      0.61       835
       I-PER       0.77      0.91      0.84      1156
           O       0.98      0.98      0.98     38323

    accuracy                           0.93     46435
   macro avg       0.72      0.69      0.70     46435
weighted avg       0.93      0.93      0.93     46435



From \ To,O,B-LOC,I-LOC,B-MISC,I-MISC,B-ORG,I-ORG,B-PER,I-PER
O,5.741,3.364,-0.134,3.418,0.0,3.733,-0.723,3.82,-0.415
B-LOC,1.045,-0.273,6.137,0.0,0.0,0.0,0.0,-2.22,-0.344
I-LOC,0.08,0.0,5.13,0.0,0.0,0.0,0.0,-1.437,0.0
B-MISC,1.039,-0.839,0.0,-0.449,5.727,0.0,0.0,-0.291,-0.359
I-MISC,0.176,-0.316,0.0,0.0,5.419,0.0,0.0,-0.923,-0.075
B-ORG,1.057,-0.824,0.0,-0.345,0.0,0.0,6.423,-1.543,-0.222
I-ORG,0.144,-1.383,0.0,-0.701,0.0,-0.502,5.206,-2.137,-0.456
B-PER,0.608,-0.526,0.0,-0.54,0.0,-0.777,0.0,-2.058,5.919
I-PER,-0.18,-0.215,0.0,-0.233,0.0,0.0,0.0,-1.441,3.526


In [3]:
eli5.show_weights(crf, top=10, feature_re='^word\.is',
                      horizontal_layout=False, show=['targets'])

Weight?,Feature
-3.372,word.isupper()
-4.779,word.istitle()

Weight?,Feature
1.366,word.isupper()
0.798,word.istitle()

Weight?,Feature
0.099,word.istitle()

Weight?,Feature
1.586,word.isupper()
0.736,word.istitle()

Weight?,Feature
1.542,word.isdigit()
0.126,word.istitle()

Weight?,Feature
2.187,word.isupper()
0.124,word.istitle()

Weight?,Feature
0.381,word.istitle()
0.021,word.isupper()
-0.655,word.isdigit()

Weight?,Feature
1.307,word.istitle()
0.056,word.isupper()

Weight?,Feature
0.452,word.istitle()
-3.589,word.isupper()
