# CRF for Entity Extraction on WikiNER (English)
WiNER is a dataset of annotated sentences for Entity Extraction taken from Wikipedia. In this notebook we train and evaluate a CRF model on the english data to recognize entities such as Persons, Locations and Orgnizations from text.

We use the `sklearn-crfsuite` package for implementing our model and `seqeval` for f1-score evaluation.

---

In [1]:
import os
from utils import dataio, modelutils
from pprint import pprint
from sklearn_crfsuite import CRF
from sklearn.model_selection import train_test_split
from seqeval.metrics import classification_report

## Data Preparation

We load the dataset from the `data/` directory. For each token, the datatset reports word, Part of Speech tag and entity tag.

In [2]:
file_path = os.path.join('data', 'wikiner-en-wp3-raw.txt')
sentences, tags, output_labels = dataio.load_wikiner(file_path)

Read 142153 sentences.


In [3]:
print("Labels:", output_labels)

Labels: {'B-PER', 'B-MISC', 'B-LOC', 'I-PER', 'I-ORG', 'I-MISC', 'I-LOC', 'O', 'B-ORG'}


In [4]:
print("Sentence Example:")
pprint(sentences[1])
print("="*30)
print(tags[1])

Sentence Example:
[('In', 'IN'),
 ('the', 'DT'),
 ('end', 'NN'),
 (',', ','),
 ('for', 'IN'),
 ('anarchist', 'JJ'),
 ('historian', 'JJ'),
 ('Daniel', 'NNP'),
 ('Guerin', 'NNP'),
 ('"', 'LQU'),
 ('Some', 'DT'),
 ('anarchists', 'NNS'),
 ('are', 'VBP'),
 ('more', 'RBR'),
 ('individualistic', 'JJ'),
 ('than', 'IN'),
 ('social', 'JJ'),
 (',', ','),
 ('some', 'DT'),
 ('more', 'JJR'),
 ('social', 'JJ'),
 ('than', 'IN'),
 ('individualistic', 'JJ'),
 ('.', '.')]
['O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


---

## Features Engineering

In this section, we build our feature vector for each token. It is composed by:
* The lowercase token string*;
* The token suffix;
* If the token is capitalized*;
* If the token is uppercase*;
* If the token is a number;
* Complete Part-of-Speech tag of the token*;
* More general Part-of-Speech tag of the token*;
* If the token is the first of the sentence;
* If the token is the last of the sentence.

\* also for previous and next tokens, if there are.  

> Note: categorical features are one-hot encoded.

In [5]:
def word_features(sentence, idx):
    """Extract features related to a word and its neighbours"""
    word, pos = sentence[idx]
    
    features = {
        'bias': 1.0,
        'word.lower()': word.lower(),
        'word[-3:]': word[-3:],
        'word[-2:]': word[-2:],
        'word.isupper()': word.isupper(),
        'word.istitle()': word.istitle(),
        'word.isdigit()': word.isdigit(),
        'postag': pos,
        'postag[:2]': pos[:2],
    }
    if idx > 0:
        word1, pos1 = sentence[idx-1]
        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.istitle()': word1.istitle(),
            '-1:word.isupper()': word1.isupper(),
            '-1:postag': pos1,
            '-1:postag[:2]': pos1[:2],
        })
    else:
        features['BOS'] = True
        
    if idx < len(sentence)-1:
        word1, pos1 = sentence[idx+1]
        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.istitle()': word1.istitle(),
            '+1:word.isupper()': word1.isupper(),
            '+1:postag': pos1,
            '+1:postag[:2]': pos1[:2],
        })
    else:
        features['EOS'] = True
                
    return features


def sentence_features(sentence):
    return tuple(word_features(sentence, index) for index in range(len(sentence)))

In [6]:
X = [sentence_features(sentence) for sentence in sentences]

In [7]:
print("Token features example:")
pprint(X[1][1])
print("="*30)
print(tags[1][1])

Token features example:
{'+1:postag': 'NN',
 '+1:postag[:2]': 'NN',
 '+1:word.istitle()': False,
 '+1:word.isupper()': False,
 '+1:word.lower()': 'end',
 '-1:postag': 'IN',
 '-1:postag[:2]': 'IN',
 '-1:word.istitle()': True,
 '-1:word.isupper()': False,
 '-1:word.lower()': 'in',
 'bias': 1.0,
 'postag': 'DT',
 'postag[:2]': 'DT',
 'word.isdigit()': False,
 'word.istitle()': False,
 'word.isupper()': False,
 'word.lower()': 'the',
 'word[-2:]': 'he',
 'word[-3:]': 'the'}
O


## Training

In [8]:
X_train, X_test, y_train, y_test = train_test_split(X, tags, test_size=0.2, 
                                                    random_state=3791)
X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, 
                                                      test_size=0.2, 
                                                      random_state=3791)

In [9]:
%%time

crf = CRF(
    algorithm = 'lbfgs',
    c1 = 0.1,
    c2 = 0.5,
    max_iterations = 800,
    all_possible_transitions = True,
    verbose = False
)

crf.fit(X_train, y_train, X_dev=X_valid, y_dev=y_valid)

Wall time: 22min 2s




CRF(algorithm='lbfgs', all_possible_transitions=True, c1=0.1, c2=0.5,
    keep_tempfiles=None, max_iterations=800)

---

## Evaluation

We evaluate:
* **Memory consumption** using the attribute `crf.size_`;
* **Latency in prediction** using the function `time.process_time()`;
* **F1-score** _on entities_ on the test set using `seqeval`;

In [10]:
print('Model size: {:0.2f}M'.format(crf.size_ / 1000000))

Model size: 9.71M


In [11]:
print(f'Model latency in prediction: {modelutils.compute_prediction_latency(X_test, crf):.3} s')

Model latency in prediction: 0.000278 s


In [12]:
datasets = [('Training Set', X_train, y_train), ('Test Set', X_test, y_test)]

for title, X, Y in datasets:
    Y_pred = crf.predict(X)
    print(title)
    print(classification_report(Y, Y_pred, digits=3))
    print('\n')

Training Set
           precision    recall  f1-score   support

      LOC      0.861     0.898     0.879     54998
     MISC      0.850     0.783     0.815     47222
      PER      0.932     0.950     0.941     61603
      ORG      0.879     0.784     0.829     31794

micro avg      0.885     0.868     0.876    195617
macro avg      0.884     0.868     0.875    195617



Test Set
           precision    recall  f1-score   support

      ORG      0.809     0.704     0.753      9885
      LOC      0.813     0.846     0.829     17267
     MISC      0.767     0.696     0.730     14562
      PER      0.889     0.921     0.905     19345

micro avg      0.828     0.811     0.819     61059
macro avg      0.825     0.811     0.817     61059





---