In [2]:
from itertools import chain

import nltk
import sklearn
import scipy.stats
from sklearn.metrics import make_scorer
from sklearn.cross_validation import cross_val_score
from sklearn.grid_search import RandomizedSearchCV

import sklearn_crfsuite
from sklearn_crfsuite import scorers
from sklearn_crfsuite import metrics

# Named-Entity-Recognition example to Marvin framework

## Data Handler
### Acquisitor and Cleanning Action

In [3]:
import os
nltk.download(info_or_id='conll2002', download_dir=os.environ["MARVIN_DATA_PATH"])
train_sents = list(nltk.corpus.conll2002.iob_sents('esp.train'))
test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb'))

[nltk_data] Downloading package conll2002 to
[nltk_data]     /home/erick/marvin/data...
[nltk_data]   Package conll2002 is already up-to-date!


### Training Preparator Action (features preparation)

In [71]:
def word2features(sent, i):
    word = sent[i][0]
    postag = sent[i][1]
    
    features = {
        'bias': 1.0,
        'word.lower()': word.lower(),
        'word[-3:]': word[-3:],
        'word[-2:]': word[-2:],
        'word.isupper()': word.isupper(),
        'word.istitle()': word.istitle(),
        'word.isdigit()': word.isdigit(),
        'postag': postag,
        'postag[:2]': postag[:2],        
    }
    if i > 0:
        word1 = sent[i-1][0]
        postag1 = sent[i-1][1]
        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.istitle()': word1.istitle(),
            '-1:word.isupper()': word1.isupper(),
            '-1:postag': postag1,
            '-1:postag[:2]': postag1[:2],
        })
    else:
        features['BOS'] = True
        
    if i < len(sent)-1:
        word1 = sent[i+1][0]
        postag1 = sent[i+1][1]
        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.istitle()': word1.istitle(),
            '+1:word.isupper()': word1.isupper(),
            '+1:postag': postag1,
            '+1:postag[:2]': postag1[:2],
        })
    else:
        features['EOS'] = True
                
    return features


def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
    return [label for token, postag, label in sent]

def sent2tokens(sent):
    return [token for token, postag, label in sent]


X_train = [sent2features(s) for s in train_sents]
y_train = [sent2labels(s) for s in train_sents]

X_test = [sent2features(s) for s in test_sents]
y_test = [sent2labels(s) for s in test_sents]

In [75]:
test_sents[8]

[(u'En', u'SP', u'O'),
 (u'declaraciones', u'NC', u'O'),
 (u'a', u'SP', u'O'),
 (u'Efe', u'NC', u'B-ORG'),
 (u',', u'Fc', u'O'),
 (u'el', u'DA', u'O'),
 (u'alcalde', u'NC', u'O'),
 (u'de', u'SP', u'O'),
 (u'Ar\xe9valo', u'NC', u'B-LOC'),
 (u',', u'Fc', u'O'),
 (u'Francisco', u'AQ', u'B-PER'),
 (u'Le\xf3n', u'NC', u'I-PER'),
 (u'(', u'Fpa', u'O'),
 (u'PSOE', u'NP', u'B-ORG'),
 (u')', u'Fpt', u'O'),
 (u',', u'Fc', u'O'),
 (u'lament\xf3', u'VMI', u'O'),
 (u'la', u'DA', u'O'),
 (u'"', u'Fe', u'O'),
 (u'tardanza', u'NC', u'O'),
 (u'"', u'Fe', u'O'),
 (u'de', u'SP', u'O'),
 (u'la', u'DA', u'O'),
 (u'Consejer\xeda', u'NC', u'B-ORG'),
 (u'de', u'SP', u'I-ORG'),
 (u'Fomento', u'NC', u'I-ORG'),
 (u'en', u'SP', u'O'),
 (u'la', u'DA', u'O'),
 (u'entrega', u'NC', u'O'),
 (u'de', u'SP', u'O'),
 (u'las', u'DA', u'O'),
 (u'llaves', u'NC', u'O'),
 (u'de', u'SP', u'O'),
 (u'las', u'DA', u'O'),
 (u'viviendas', u'NC', u'O'),
 (u'a', u'SP', u'O'),
 (u'sus', u'DP', u'O'),
 (u'leg\xedtimos', u'AQ', u'O'),
 (

### Training Action

In [5]:
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs', 
    c1=0.10789964607864502, 
    c2=0.082422264927260847, 
    max_iterations=100, 
    all_possible_transitions=True
)
crf.fit(X_train, y_train)

CRF(algorithm='lbfgs', all_possible_states=None,
  all_possible_transitions=True, averaging=None, c=None, c1=0.107899646079,
  c2=0.0824222649273, calibration_candidates=None, calibration_eta=None,
  calibration_max_trials=None, calibration_rate=None,
  calibration_samples=None, delta=None, epsilon=None, error_sensitive=None,
  gamma=None, keep_tempfiles=None, linesearch=None, max_iterations=100,
  max_linesearch=None, min_freq=None, model_filename=None,
  num_memories=None, pa_type=None, period=None, trainer_cls=None,
  variance=None, verbose=False)

### Evaluator Action

In [6]:
labels = list(crf.classes_)
labels.remove('O')
labels
y_pred = crf.predict(X_test)
metrics.flat_f1_score(y_test, y_pred, 
                      average='weighted', labels=labels)

0.79760762520912232

In [65]:
index = 2
y_pred = crf.predict([X_test[index]])
sentence = []
entities = {}
for i, token in enumerate(X_test[index]):
    word = token["word.lower()"]
    sentence.append(word)
    
    label = y_pred[0][i]
    if label != "O":
        if label in entities:
            entities[label].append(word)
        else:
            entities[label] = [word]

In [70]:
print 'Sentence: ' + ' '.join(sentence)
print
print 'Entities found:'
for k, v in entities.items():
    print k + " -> " + ' '.join(v)

Sentence: las reservas " on line " de billetes aéreos a través de internet aumentaron en españa un 300 por ciento en el primer trimestre de este año con respecto al mismo período de 1999 , aseguró hoy iñigo garcía aranda , responsable de comunicación de savia amadeus .

Entities found:
B-MISC -> internet
I-LOC -> amadeus
B-PER -> iñigo
B-LOC -> españa savia
I-PER -> garcía aranda


In [8]:
sorted_labels = sorted(
    labels, 
    key=lambda name: (name[1:], name[0])
)
print(metrics.flat_classification_report(
    y_test, y_pred, labels=sorted_labels, digits=3
))

             precision    recall  f1-score   support

      B-LOC      0.806     0.784     0.795      1084
      I-LOC      0.697     0.631     0.662       325
     B-MISC      0.749     0.555     0.637       339
     I-MISC      0.743     0.582     0.653       557
      B-ORG      0.807     0.835     0.821      1400
      I-ORG      0.841     0.800     0.820      1104
      B-PER      0.845     0.887     0.865       735
      I-PER      0.894     0.940     0.916       634

avg / total      0.812     0.788     0.798      6178



### Prediction

In [9]:
from collections import Counter

def print_transitions(trans_features):
    for (label_from, label_to), weight in trans_features:
        print("%-6s -> %-7s %0.6f" % (label_from, label_to, weight))

print("Top likely transitions:")
print_transitions(Counter(crf.transition_features_).most_common(20))

print("\nTop unlikely transitions:")
print_transitions(Counter(crf.transition_features_).most_common()[-20:])

Top likely transitions:
B-ORG  -> I-ORG   7.509195
I-ORG  -> I-ORG   7.286281
B-MISC -> I-MISC  6.657995
I-MISC -> I-MISC  6.611249
B-PER  -> I-PER   6.455306
B-LOC  -> I-LOC   5.592603
I-LOC  -> I-LOC   4.883487
I-PER  -> I-PER   4.879621
O      -> O       3.864632
O      -> B-ORG   2.824896
O      -> B-PER   2.406219
O      -> B-LOC   1.918864
O      -> B-MISC  1.825806
B-LOC  -> B-LOC   0.410911
B-ORG  -> O       0.384361
I-PER  -> B-LOC   0.362426
B-ORG  -> B-LOC   0.236403
B-MISC -> B-ORG   0.136150
B-MISC -> O       -0.157301
B-LOC  -> B-PER   -0.219753

Top unlikely transitions:
B-LOC  -> B-MISC  -2.094940
B-ORG  -> B-ORG   -2.117205
B-ORG  -> I-LOC   -2.122777
I-ORG  -> I-PER   -2.205134
I-ORG  -> B-LOC   -2.272456
I-PER  -> I-LOC   -2.403253
I-MISC -> B-LOC   -2.433610
B-PER  -> B-MISC  -2.511643
I-ORG  -> B-MISC  -2.595139
I-ORG  -> I-LOC   -2.655252
B-ORG  -> B-MISC  -2.673500
I-PER  -> B-ORG   -2.686785
I-MISC -> I-LOC   -2.837655
I-PER  -> B-MISC  -2.930608
B-PER  -> B-PER

In [10]:
def print_state_features(state_features):
    for (attr, label), weight in state_features:
        print("%0.6f %-8s %s" % (weight, label, attr))    

print("Top positive:")
print_state_features(Counter(crf.state_features_).most_common(30))

print("\nTop negative:")
print_state_features(Counter(crf.state_features_).most_common()[-30:])

Top positive:
10.050768 B-ORG    word.lower():efe-cantabria
8.978537 B-ORG    word.lower():psoe-progresistas
6.302095 I-ORG    -1:word.lower():l
5.194187 B-ORG    word.lower():petrobras
5.180269 O        BOS
5.077409 B-ORG    word.lower():xfera
4.992312 B-LOC    -1:word.lower():cantabria
4.930629 B-MISC   word.lower():diversia
4.914647 B-ORG    word[-2:]:-e
4.868154 B-ORG    word.lower():coag-extremadura
4.806764 B-ORG    word.lower():telefónica
4.787065 O        word.lower():r.
4.787065 O        word[-3:]:R.
4.695655 B-MISC   word.lower():justicia
4.656129 B-MISC   word.lower():competencia
4.637028 I-LOC    -1:word.lower():calle
4.615657 B-ORG    +1:word.lower():plasencia
4.615615 B-ORG    -1:word.lower():distancia
4.560780 B-ORG    word.lower():terra
4.543435 B-LOC    word.lower():líbano
4.527419 I-ORG    -1:word.lower():rasd
4.516271 B-PER    -1:word.lower():según
4.513310 B-LOC    -1:word.lower():celebrarán
4.506090 B-ORG    word.isupper()
4.440726 B-ORG    word.lower():esquerra
4.