In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('ggplot')

In [11]:
from itertools import chain

import nltk
import sklearn
import scipy.stats
from sklearn.metrics import make_scorer
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import RandomizedSearchCV

import sklearn_crfsuite
from sklearn_crfsuite import scorers
from sklearn_crfsuite import metrics
import eli5

In [12]:
nltk.corpus.conll2002.fileids()

['esp.testa', 'esp.testb', 'esp.train', 'ned.testa', 'ned.testb', 'ned.train']

In [13]:
%%time
train_sents = list(nltk.corpus.conll2002.iob_sents('esp.train'))
test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb'))

Wall time: 1.19 s


In [24]:
test_sents[0]

[('La', 'DA', 'B-LOC'),
 ('Coruña', 'NC', 'I-LOC'),
 (',', 'Fc', 'O'),
 ('23', 'Z', 'O'),
 ('may', 'NC', 'O'),
 ('(', 'Fpa', 'O'),
 ('EFECOM', 'NP', 'B-ORG'),
 (')', 'Fpt', 'O'),
 ('.', 'Fp', 'O')]

In [26]:
train_sents[2]

[('El', 'DA', 'O'),
 ('Abogado', 'NC', 'B-PER'),
 ('General', 'AQ', 'I-PER'),
 ('del', 'SP', 'I-PER'),
 ('Estado', 'NC', 'I-PER'),
 (',', 'Fc', 'O'),
 ('Daryl', 'VMI', 'B-PER'),
 ('Williams', 'NC', 'I-PER'),
 (',', 'Fc', 'O'),
 ('subrayó', 'VMI', 'O'),
 ('hoy', 'RG', 'O'),
 ('la', 'DA', 'O'),
 ('necesidad', 'NC', 'O'),
 ('de', 'SP', 'O'),
 ('tomar', 'VMN', 'O'),
 ('medidas', 'NC', 'O'),
 ('para', 'SP', 'O'),
 ('proteger', 'VMN', 'O'),
 ('al', 'SP', 'O'),
 ('sistema', 'NC', 'O'),
 ('judicial', 'AQ', 'O'),
 ('australiano', 'AQ', 'O'),
 ('frente', 'RG', 'O'),
 ('a', 'SP', 'O'),
 ('una', 'DI', 'O'),
 ('página', 'NC', 'O'),
 ('de', 'SP', 'O'),
 ('internet', 'NC', 'O'),
 ('que', 'PR', 'O'),
 ('imposibilita', 'VMI', 'O'),
 ('el', 'DA', 'O'),
 ('cumplimiento', 'NC', 'O'),
 ('de', 'SP', 'O'),
 ('los', 'DA', 'O'),
 ('principios', 'NC', 'O'),
 ('básicos', 'AQ', 'O'),
 ('de', 'SP', 'O'),
 ('la', 'DA', 'O'),
 ('Ley', 'NC', 'B-MISC'),
 ('.', 'Fp', 'O')]

In [15]:
def word2features(sent, i):
    word = sent[i][0]
    postag = sent[i][1]
    
    features = {
        'bias': 1.0,
        'word.lower()': word.lower(),
        'word[-3:]': word[-3:],
        'word[-2:]': word[-2:],
        'word.isupper()': word.isupper(),
        'word.istitle()': word.istitle(),
        'word.isdigit()': word.isdigit(),
        'postag': postag,
        'postag[:2]': postag[:2],        
    }
    if i > 0:
        word1 = sent[i-1][0]
        postag1 = sent[i-1][1]
        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.istitle()': word1.istitle(),
            '-1:word.isupper()': word1.isupper(),
            '-1:postag': postag1,
            '-1:postag[:2]': postag1[:2],
        })
    else:
        features['BOS'] = True
        
    if i < len(sent)-1:
        word1 = sent[i+1][0]
        postag1 = sent[i+1][1]
        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.istitle()': word1.istitle(),
            '+1:word.isupper()': word1.isupper(),
            '+1:postag': postag1,
            '+1:postag[:2]': postag1[:2],
        })
    else:
        features['EOS'] = True
                
    return features


def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
    return [label for token, postag, label in sent]

def sent2tokens(sent):
    return [token for token, postag, label in sent]

In [16]:
sent2features(train_sents[0])[0]

{'bias': 1.0,
 'word.lower()': 'melbourne',
 'word[-3:]': 'rne',
 'word[-2:]': 'ne',
 'word.isupper()': False,
 'word.istitle()': True,
 'word.isdigit()': False,
 'postag': 'NP',
 'postag[:2]': 'NP',
 'BOS': True,
 '+1:word.lower()': '(',
 '+1:word.istitle()': False,
 '+1:word.isupper()': False,
 '+1:postag': 'Fpa',
 '+1:postag[:2]': 'Fp'}

In [17]:
%%time
X_train = [sent2features(s) for s in train_sents]
y_train = [sent2labels(s) for s in train_sents]

X_test = [sent2features(s) for s in test_sents]
y_test = [sent2labels(s) for s in test_sents]

Wall time: 823 ms


In [18]:
%%time
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs', 
    c1=0.1, 
    c2=0.1, 
    max_iterations=100, 
    all_possible_transitions=True
)
crf.fit(X_train, y_train)

Wall time: 26.2 s




CRF(algorithm='lbfgs', all_possible_transitions=True, c1=0.1, c2=0.1,
    keep_tempfiles=None, max_iterations=100)

In [19]:
eli5.show_weights(crf, top=30)

From \ To,O,B-LOC,I-LOC,B-MISC,I-MISC,B-ORG,I-ORG,B-PER,I-PER
O,3.784,1.846,-5.968,1.805,-5.137,2.755,-5.23,2.549,-5.357
B-LOC,-0.275,0.578,5.696,-1.946,-1.079,-1.694,-0.923,-0.046,-1.668
I-LOC,-0.472,-1.594,4.877,-1.977,-1.233,-1.614,-1.379,-1.853,-1.912
B-MISC,-0.144,-0.175,-1.582,-3.028,6.833,0.299,-1.276,-0.787,-1.335
I-MISC,-0.914,-2.24,-2.557,-1.037,6.753,-1.687,-1.826,-0.785,-2.009
B-ORG,0.325,0.267,-1.967,-2.581,-1.376,-2.108,7.501,-0.549,-1.627
I-ORG,-0.566,-2.2,-2.536,-2.486,-1.639,-0.537,7.206,-0.601,-2.272
B-PER,-0.576,-0.841,-1.434,-2.325,-1.068,-1.504,-0.714,-2.825,6.405
I-PER,-1.864,0.301,-2.455,-2.966,-1.703,-2.512,-1.501,-1.733,4.709

Weight?,Feature,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0
Weight?,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
Weight?,Feature,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
Weight?,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3
Weight?,Feature,Unnamed: 2_level_4,Unnamed: 3_level_4,Unnamed: 4_level_4,Unnamed: 5_level_4,Unnamed: 6_level_4,Unnamed: 7_level_4,Unnamed: 8_level_4
Weight?,Feature,Unnamed: 2_level_5,Unnamed: 3_level_5,Unnamed: 4_level_5,Unnamed: 5_level_5,Unnamed: 6_level_5,Unnamed: 7_level_5,Unnamed: 8_level_5
Weight?,Feature,Unnamed: 2_level_6,Unnamed: 3_level_6,Unnamed: 4_level_6,Unnamed: 5_level_6,Unnamed: 6_level_6,Unnamed: 7_level_6,Unnamed: 8_level_6
Weight?,Feature,Unnamed: 2_level_7,Unnamed: 3_level_7,Unnamed: 4_level_7,Unnamed: 5_level_7,Unnamed: 6_level_7,Unnamed: 7_level_7,Unnamed: 8_level_7
Weight?,Feature,Unnamed: 2_level_8,Unnamed: 3_level_8,Unnamed: 4_level_8,Unnamed: 5_level_8,Unnamed: 6_level_8,Unnamed: 7_level_8,Unnamed: 8_level_8
+4.868,BOS,,,,,,,
+4.582,word[-3:]:R.,,,,,,,
+4.582,word.lower():r.,,,,,,,
+4.156,word[-3:]:B,,,,,,,
+4.156,word.lower():b,,,,,,,
+4.156,word[-2:]:B,,,,,,,
+3.938,bias,,,,,,,
+3.663,word.lower():mayo,,,,,,,
+3.639,word.lower():c,,,,,,,
+3.613,postag[:2]:Fp,,,,,,,

Weight?,Feature
+4.868,BOS
+4.582,word[-3:]:R.
+4.582,word.lower():r.
+4.156,word[-3:]:B
+4.156,word.lower():b
+4.156,word[-2:]:B
+3.938,bias
+3.663,word.lower():mayo
+3.639,word.lower():c
+3.613,postag[:2]:Fp

Weight?,Feature
+4.897,-1:word.lower():cantabria
+4.811,word.lower():líbano
+4.063,-1:word.lower():celebrarán
+3.932,+1:word.lower():finalizaron
+3.927,word.lower():estrecho
+3.917,word.lower():asturias
+3.640,-1:word.lower():nuboso
+3.559,word.lower():retiro
+3.547,word.lower():asunción
+3.505,word.lower():bruselas

Weight?,Feature
+4.229,-1:word.lower():calle
+3.617,-1:word.lower():estadio
+3.268,-1:word.lower():carcedo
+3.116,-1:word.lower():plaza
+3.041,-1:word.lower():santa
+2.932,-1:word.lower():sierra
+2.839,-1:word.lower():salesianos
+2.794,+1:word.lower():deseo
+2.767,-1:word.lower():avenida
+2.564,-1:word.lower():ciudad

Weight?,Feature
+4.724,word.lower():justicia
+4.598,word.lower():competencia
+4.545,word.lower():diversia
+3.945,word.lower():exteriores
+3.792,word.lower():agricultura
+3.672,word.lower():cc2305001730
+3.669,word.lower():internet
+3.653,word.lower():derecho
+3.536,word.lower():feria
+3.510,word.lower():cultura

Weight?,Feature
+3.804,-1:word.lower():1.9
+3.202,+1:word.lower():surrealismo
+2.928,-1:word.lower():xviii
+2.634,-1:word.lower():4444
+2.422,-1:word.lower():around
+2.388,word[-2:]:a.
+2.356,-1:word.lower():carlos.
+2.329,-1:word.lower():ibex
+2.247,+1:word.lower():ojos
+2.146,+1:word.lower():adoptará

Weight?,Feature
+9.811,word.lower():efe-cantabria
+8.587,word.lower():psoe-progresistas
+4.903,word.lower():xfera
+4.760,word.lower():telefónica
+4.675,word[-2:]:-e
+4.409,word.lower():petrobras
+4.278,word.lower():coag-extremadura
+4.223,word.isupper()
+4.190,word.lower():esquerra
+4.151,word.lower():terra

Weight?,Feature
+6.026,-1:word.lower():l
+4.043,-1:word.lower():rasd
+3.749,-1:word.lower():ag
+3.174,-1:word.lower():sports
+2.995,-1:word.lower():antena
+2.760,+1:word.lower():adelantó
+2.691,-1:word.lower():cardinal
+2.613,-1:word.lower():soir
+2.495,-1:word.lower():guerson
+2.458,-1:word.lower():caja

Weight?,Feature
+4.262,-1:word.lower():según
+4.189,word.lower():valedor
+3.922,word.lower():ania
+3.708,word.lower():orduña
+3.583,-1:word.lower():efe
+3.535,word.lower():salva
+3.397,word.lower():mcmanaman
+3.355,word.lower():martina
+3.286,word.lower():reinas
+3.193,word.lower():mas

Weight?,Feature
+3.294,-1:word.lower():juanito
+3.100,-1:word.lower():txon
+3.085,-1:word.lower():juli
+2.854,-1:word.istitle()
+2.810,-1:word.lower():bitxon
+2.757,-1:word.lower():peridis
+2.745,word.lower():gándara
+2.741,+1:word.lower():gándara
+2.612,-1:word.lower():maría
+2.534,-1:word.lower():marqués


In [20]:
labels = list(crf.classes_)
labels.remove('O')
labels

['B-LOC', 'B-ORG', 'B-PER', 'I-PER', 'B-MISC', 'I-ORG', 'I-LOC', 'I-MISC']

In [21]:
y_pred = crf.predict(X_test)
metrics.flat_f1_score(y_test, y_pred,
                      average='weighted', labels=labels)

0.7964686316443963

In [22]:
# group B and I results
sorted_labels = sorted(
    labels,
    key=lambda name: (name[1:], name[0])
)
print(metrics.flat_classification_report(
    y_test, y_pred, labels=sorted_labels, digits=3
))



              precision    recall  f1-score   support

       B-LOC      0.810     0.784     0.797      1084
       I-LOC      0.690     0.637     0.662       325
      B-MISC      0.731     0.569     0.640       339
      I-MISC      0.699     0.589     0.639       557
       B-ORG      0.807     0.832     0.820      1400
       I-ORG      0.852     0.786     0.818      1104
       B-PER      0.850     0.884     0.867       735
       I-PER      0.893     0.943     0.917       634

   micro avg      0.813     0.787     0.799      6178
   macro avg      0.791     0.753     0.770      6178
weighted avg      0.809     0.787     0.796      6178

