## sklearn-crfsuite

In this notebook we train a basic CRF model for Named Entity Recognition on CoNLL2002 data (following https://github.com/TeamHG-Memex/sklearn-crfsuite/blob/master/docs/CoNLL2002.ipynb) and check its weights to see what it learned.

In [1]:
import sys
sys.path.insert(0, '..')

import nltk
import sklearn_crfsuite

from eli5 import explain_weights, format_as_text
from eli5.formatters import fields

Load training data:

In [2]:
%%time
train_sents = list(nltk.corpus.conll2002.iob_sents('esp.train'))
test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb'))

CPU times: user 2.71 s, sys: 97.2 ms, total: 2.8 s
Wall time: 2.81 s


Extract features: word parts, POS tags, lower/title/upper flags, features of nearby words

In [3]:
def word2features(sent, i):
    word = sent[i][0]
    postag = sent[i][1]
    
    features = {
        'bias': 1.0,
        'word.lower()': word.lower(),
        'word[-3:]': word[-3:],
        'word[-2:]': word[-2:],
        'word.isupper()': word.isupper(),
        'word.istitle()': word.istitle(),
        'word.isdigit()': word.isdigit(),
        'postag': postag,
        'postag[:2]': postag[:2],        
    }
    if i > 0:
        word1 = sent[i-1][0]
        postag1 = sent[i-1][1]
        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.istitle()': word1.istitle(),
            '-1:word.isupper()': word1.isupper(),
            '-1:postag': postag1,
            '-1:postag[:2]': postag1[:2],
        })
    else:
        features['BOS'] = True
        
    if i < len(sent)-1:
        word1 = sent[i+1][0]
        postag1 = sent[i+1][1]
        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.istitle()': word1.istitle(),
            '+1:word.isupper()': word1.isupper(),
            '+1:postag': postag1,
            '+1:postag[:2]': postag1[:2],
        })
    else:
        features['EOS'] = True
                
    return features


def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
    return [label for token, postag, label in sent]

def sent2tokens(sent):
    return [token for token, postag, label in sent]

In [4]:
sent2features(train_sents[0])[0]

{'+1:postag': 'Fpa',
 '+1:postag[:2]': 'Fp',
 '+1:word.istitle()': False,
 '+1:word.isupper()': False,
 '+1:word.lower()': '(',
 'BOS': True,
 'bias': 1.0,
 'postag': 'NP',
 'postag[:2]': 'NP',
 'word.isdigit()': False,
 'word.istitle()': True,
 'word.isupper()': False,
 'word.lower()': 'melbourne',
 'word[-2:]': 'ne',
 'word[-3:]': 'rne'}

In [5]:
%%time
X_train = [sent2features(s) for s in train_sents]
y_train = [sent2labels(s) for s in train_sents]

X_test = [sent2features(s) for s in test_sents]
y_test = [sent2labels(s) for s in test_sents]

CPU times: user 1.63 s, sys: 158 ms, total: 1.78 s
Wall time: 1.81 s


Train a CRF model:

In [6]:
%%time
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs',
    c1=0.1, 
    c2=0.1, 
    max_iterations=20, 
    all_possible_transitions=False,
)
crf.fit(X_train, y_train)

CPU times: user 14.3 s, sys: 219 ms, total: 14.5 s
Wall time: 14.6 s


Check CRF weights (transition weights and state weights):

In [7]:
expl = explain_weights(crf, top=20)
expl

From \ To,B-LOC,O,B-ORG,B-PER,I-PER,B-MISC,I-ORG,I-LOC,I-MISC
B-LOC,-0.06,0.03,0.0,-0.221,0.0,0.0,0.0,4.345,0.0
O,2.488,3.704,2.857,2.262,0.0,2.24,0.0,0.0,0.0
B-ORG,-0.217,0.265,-0.955,-0.381,0.0,-0.535,5.029,0.0,0.0
B-PER,-0.832,-0.336,-1.224,-1.169,4.528,0.0,0.0,0.0,0.0
I-PER,-0.277,-0.81,0.0,-0.728,3.611,0.0,0.0,0.0,0.0
B-MISC,-0.326,-0.651,-0.408,-0.354,0.0,0.0,0.0,0.0,5.87
I-ORG,-1.926,-0.42,-1.703,-0.622,0.0,-0.794,5.551,0.0,0.0
I-LOC,-0.608,-0.297,0.0,0.0,0.0,0.0,0.0,3.744,0.0
I-MISC,-0.986,-0.724,-0.929,-0.612,0.0,-0.421,0.0,0.0,5.874

Weight,Feature
+2.308,word.istitle()
+2.215,-1:word.lower():en
+0.968,word[-2:]:id
+0.921,word[-3:]:rid
+0.920,word.lower():madrid
+0.786,word[-2:]:ia
+0.686,word[-2:]:ña
+0.675,word[-2:]:ís
+0.665,word[-3:]:ona
+0.646,word.lower():españa

Weight,Feature
+3.644,postag[:2]:Fp
+3.577,BOS
+3.140,bias
+1.998,postag:CC
+1.998,postag[:2]:CC
+1.838,postag[:2]:Fc
+1.838,"word[-3:]:,"
+1.838,postag:Fc
+1.838,"word.lower():,"
+1.838,"word[-2:]:,"

Weight,Feature
+2.413,word.lower():efe
+2.156,word.isupper()
+1.133,word[-2:]:FE
+1.117,word[-3:]:EFE
+1.103,word.lower():gobierno
+0.960,-1:word.lower():del
+0.882,word.istitle()
+0.850,word[-3:]:rno
+0.778,-1:word.lower():al
+0.741,word[-2:]:PP

Weight,Feature
+2.805,word.istitle()
+0.785,postag[:2]:NP
+0.785,postag:NP
+0.602,+1:postag:VMI
+0.591,-1:word.lower():a
+0.590,+1:postag[:2]:VM
… 4353 more positive …,… 4353 more positive …
… 429 more negative …,… 429 more negative …
-0.564,word[-2:]:ia
-0.578,word.lower():la

Weight,Feature
+1.837,-1:word.istitle()
+1.267,word[-2:]:ez
+1.015,word.istitle()
+0.692,-1:word.lower():josé
+0.536,-1:postag[:2]:AQ
+0.536,-1:postag:AQ
+0.505,-1:postag[:2]:VM
+0.498,-1:word.lower():juan
+0.425,-1:word.lower():maría
+0.353,-1:postag:VMI

Weight,Feature
+1.894,word.isupper()
+0.804,word.istitle()
+0.664,-1:word.lower():la
+0.519,"-1:word.lower():"""
+0.519,-1:postag[:2]:Fe
+0.519,-1:postag:Fe
+0.460,"word.lower():"""
+0.460,postag[:2]:Fe
+0.460,postag:Fe
+0.460,"word[-2:]:"""

Weight,Feature
+1.515,-1:word.istitle()
+0.938,-1:word.lower():de
+0.723,-1:postag[:2]:SP
+0.723,-1:postag:SP
+0.486,word[-2:]:id
+0.481,word[-3:]:rid
+0.471,word.lower():madrid
+0.435,-1:word.lower():real
+0.390,+1:word.lower():(
+0.390,+1:postag:Fpa

Weight,Feature
+1.302,-1:word.istitle()
+0.801,-1:word.lower():de
+0.641,word[-2:]:de
+0.615,word[-3:]:de
+0.538,-1:word.lower():san
+0.403,-1:word.lower():la
+0.364,-1:postag[:2]:SP
+0.364,-1:postag:SP
+0.316,word[-2:]:la
+0.296,word.istitle()

Weight,Feature
+0.729,-1:word.istitle()
+0.562,+1:postag[:2]:Fe
+0.562,"+1:word.lower():"""
+0.562,+1:postag:Fe
+0.424,word[-2:]:es
+0.343,-1:word.lower():liga
+0.341,-1:word.lower():de
+0.313,word[-2:]:el
+0.274,-1:word.lower():copa
+0.252,+1:postag[:2]:Z


It is also possible to format the result as text (could be useful in console):

In [8]:
print(format_as_text(expl))

Explained as: CRF

Transition features:
          B-LOC       O    B-ORG    B-PER    I-PER    B-MISC    I-ORG    I-LOC    I-MISC
------  -------  ------  -------  -------  -------  --------  -------  -------  --------
B-LOC    -0.060   0.030    0.000   -0.221    0.000     0.000    0.000    4.345     0.000
O         2.488   3.704    2.857    2.262    0.000     2.240    0.000    0.000     0.000
B-ORG    -0.217   0.265   -0.955   -0.381    0.000    -0.535    5.029    0.000     0.000
B-PER    -0.832  -0.336   -1.224   -1.169    4.528     0.000    0.000    0.000     0.000
I-PER    -0.277  -0.810    0.000   -0.728    3.611     0.000    0.000    0.000     0.000
B-MISC   -0.326  -0.651   -0.408   -0.354    0.000     0.000    0.000    0.000     5.870
I-ORG    -1.926  -0.420   -1.703   -0.622    0.000    -0.794    5.551    0.000     0.000
I-LOC    -0.608  -0.297    0.000    0.000    0.000     0.000    0.000    3.744     0.000
I-MISC   -0.986  -0.724   -0.929   -0.612    0.000    -0.421    0.000 