In [2]:
# !pip install sklearn_crfsuite eli5

Collecting sklearn_crfsuite
  Downloading https://files.pythonhosted.org/packages/25/74/5b7befa513482e6dee1f3dd68171a6c9dfc14c0eaa00f885ffeba54fe9b0/sklearn_crfsuite-0.3.6-py2.py3-none-any.whl
Collecting eli5
[?25l  Downloading https://files.pythonhosted.org/packages/97/2f/c85c7d8f8548e460829971785347e14e45fa5c6617da374711dec8cb38cc/eli5-0.10.1-py2.py3-none-any.whl (105kB)
[K     |████████████████████████████████| 112kB 672kB/s eta 0:00:01
Collecting python-crfsuite>=0.8.3 (from sklearn_crfsuite)
[?25l  Downloading https://files.pythonhosted.org/packages/da/05/5cd3eb8dbbe3c787e3cf84d5767d95198298f7951bf8e40c46ebd8c80a32/python_crfsuite-0.9.6-cp37-cp37m-manylinux1_x86_64.whl (749kB)
[K     |████████████████████████████████| 757kB 1.4MB/s eta 0:00:01     |███████████████████████▏        | 542kB 1.4MB/s eta 0:00:01
[?25hCollecting tabulate (from sklearn_crfsuite)
[?25l  Downloading https://files.pythonhosted.org/packages/c2/fd/202954b3f0eb896c53b7b6f07390851b1fd2ca84aa95880d7ae4f434

In [7]:
import eli5
import nltk
import scipy.stats
import sklearn
import sklearn_crfsuite

from itertools import chain
from sklearn.metrics import make_scorer
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import RandomizedSearchCV
from sklearn_crfsuite import scorers
from sklearn_crfsuite import metrics

### Загрузим данные:

In [19]:
train_sents = list(nltk.corpus.conll2002.iob_sents('esp.train'))
test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb'))
train_sents[0]

[('Melbourne', 'NP', 'B-LOC'),
 ('(', 'Fpa', 'O'),
 ('Australia', 'NP', 'B-LOC'),
 (')', 'Fpt', 'O'),
 (',', 'Fc', 'O'),
 ('25', 'Z', 'O'),
 ('may', 'NC', 'O'),
 ('(', 'Fpa', 'O'),
 ('EFE', 'NC', 'B-ORG'),
 (')', 'Fpt', 'O'),
 ('.', 'Fp', 'O')]

### Добавим фичи для каждого слова, чтобы обучить CRF (смотри лекцию:)):

In [109]:
def word2features(sent, i):
    word = sent[i][0]
    postag = sent[i][1]

    features = {
        'bias': 1.0,
        'postag': postag,
        'word_len': len(word),
        'is_title': word.istitle(),
        'is_upper': word.isupper(),
        'post_start': False,
        'pre_last': False,
        'is_num': word.isnumeric(),
        'word[-3:]': word[-3:],
        'word[-2:]': word[-2:],
    }
    if i > 1:
        word2 = sent[i-2][0]
        features.update({
            '-2:word.istitle()': word2.istitle(),
            '-2:word.isuper()': word2.isupper(),
        })
    elif i > 0:
        word1 = sent[i-1][0]
        postag1 = sent[i-1][1]
        features['post_start'] = True
        features.update({
            '-1:word.istitle()': word1.istitle(),
            '-1:word.isuper()': word1.isupper(),
            '-1:word_len': len(word1),
            '-1:postag': sent[i-1][1],
            '-1:word[-3:]': word1[-3:],
            '-1:word[-2:]': word1[-2:],
        })
    else:
        features['BOS'] = True
        
    if i < len(sent) - 2:
        word2 = sent[i+2][0]
        features.update({
            '+2:word.istitle()': word2.istitle(),
            '+2:word.isuper()': word2.isupper(),
        })
    elif i < len(sent) - 1:
        word1 = sent[i+1][0]
        postag1 = sent[i+1][1]
        features['pre_last'] = True
        features.update({
            '+1:word.istitle()': word1.istitle(),
            '+1:word.isuper()': word1.isupper(),
            '+1:word[-3:]': word1[-3:],
            '+1:word[-2:]': word1[-2:],
            '+1:word_len': len(word1),
            '+1:postag': sent[i+1][1],
        })
    else:
        features['EOS'] = True

    return features


def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
    return [label for token, postag, label in sent]

def sent2tokens(sent):
    return [token for token, postag, label in sent]

X_train = [sent2features(s) for s in train_sents]
y_train = [sent2labels(s) for s in train_sents]

X_test = [sent2features(s) for s in test_sents]
y_test = [sent2labels(s) for s in test_sents]

### Посмотрим на пример фичей для одного слова:

In [110]:
X_train[0][1]

{'bias': 1.0,
 'postag': 'Fpa',
 'word_len': 1,
 'is_title': False,
 'is_upper': False,
 'post_start': True,
 'pre_last': False,
 'is_num': False,
 'word[-3:]': '(',
 'word[-2:]': '(',
 '-1:word.istitle()': True,
 '-1:word.isuper()': False,
 '-1:word_len': 9,
 '-1:postag': 'NP',
 '-1:word[-3:]': 'rne',
 '-1:word[-2:]': 'ne',
 '+2:word.istitle()': False,
 '+2:word.isuper()': False}

### Обучим CRF:

In [111]:
%%time
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs',
    c1= 0.1,
    c2= 0.1,
    max_iterations=250,
    all_possible_transitions=True
)
crf.fit(X_train, y_train)

CPU times: user 5min 56s, sys: 698 ms, total: 5min 57s
Wall time: 5min 59s


CRF(algorithm='lbfgs', all_possible_states=None, all_possible_transitions=True,
    averaging=None, c=None, c1=0.1, c2=0.1, calibration_candidates=None,
    calibration_eta=None, calibration_max_trials=None, calibration_rate=None,
    calibration_samples=None, delta=None, epsilon=None, error_sensitive=None,
    gamma=None, keep_tempfiles=None, linesearch=None, max_iterations=250,
    max_linesearch=None, min_freq=None, model_filename=None, num_memories=None,
    pa_type=None, period=None, trainer_cls=None, variance=None, verbose=False)

### Посмотрим на веса признаков:

In [112]:
eli5.show_weights(crf, top=30)

From \ To,O,B-LOC,I-LOC,B-MISC,I-MISC,B-ORG,I-ORG,B-PER,I-PER
O,3.85,0.438,-5.95,1.079,-5.768,0.714,-6.057,0.565,-5.915
B-LOC,0.19,-2.198,3.885,-4.355,-3.185,-5.373,-3.824,-3.422,-3.233
I-LOC,-0.146,-4.8,4.291,-4.003,-3.001,-4.77,-3.775,-4.972,-3.39
B-MISC,-1.57,-4.691,-4.138,-5.402,3.259,-4.583,-4.786,-4.653,-3.906
I-MISC,-0.467,-5.18,-3.295,-2.4,4.878,-4.673,-4.119,-3.447,-3.575
B-ORG,-0.046,-3.178,-3.827,-4.952,-3.984,-5.977,3.622,-3.134,-3.883
I-ORG,-0.947,-6.288,-4.063,-4.606,-3.931,-5.012,3.105,-3.706,-4.525
B-PER,-0.846,-4.315,-3.528,-4.839,-3.45,-5.229,-4.171,-5.885,4.717
I-PER,-0.553,-2.8,-3.013,-4.431,-2.961,-5.178,-3.987,-3.802,3.887

Weight?,Feature,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0
Weight?,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
Weight?,Feature,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
Weight?,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3
Weight?,Feature,Unnamed: 2_level_4,Unnamed: 3_level_4,Unnamed: 4_level_4,Unnamed: 5_level_4,Unnamed: 6_level_4,Unnamed: 7_level_4,Unnamed: 8_level_4
Weight?,Feature,Unnamed: 2_level_5,Unnamed: 3_level_5,Unnamed: 4_level_5,Unnamed: 5_level_5,Unnamed: 6_level_5,Unnamed: 7_level_5,Unnamed: 8_level_5
Weight?,Feature,Unnamed: 2_level_6,Unnamed: 3_level_6,Unnamed: 4_level_6,Unnamed: 5_level_6,Unnamed: 6_level_6,Unnamed: 7_level_6,Unnamed: 8_level_6
Weight?,Feature,Unnamed: 2_level_7,Unnamed: 3_level_7,Unnamed: 4_level_7,Unnamed: 5_level_7,Unnamed: 6_level_7,Unnamed: 7_level_7,Unnamed: 8_level_7
Weight?,Feature,Unnamed: 2_level_8,Unnamed: 3_level_8,Unnamed: 4_level_8,Unnamed: 5_level_8,Unnamed: 6_level_8,Unnamed: 7_level_8,Unnamed: 8_level_8
+9.904,BOS,,,,,,,
+7.469,word[-3:]:Día,,,,,,,
+6.199,word[-3:]:B,,,,,,,
+6.199,word[-2:]:B,,,,,,,
+6.058,word[-3:]:R.,,,,,,,
+5.875,word[-3:]:zas,,,,,,,
+5.364,word[-3:]:apa,,,,,,,
+5.333,word[-3:]:Y,,,,,,,
+5.333,word[-2:]:Y,,,,,,,
+4.696,postag:CS,,,,,,,

Weight?,Feature
+9.904,BOS
+7.469,word[-3:]:Día
+6.199,word[-3:]:B
+6.199,word[-2:]:B
+6.058,word[-3:]:R.
+5.875,word[-3:]:zas
+5.364,word[-3:]:apa
+5.333,word[-3:]:Y
+5.333,word[-2:]:Y
+4.696,postag:CS

Weight?,Feature
+4.908,word[-3:]:San
+4.731,word[-3:]:AEL
+4.694,word[-3:]:joz
+4.393,word[-3:]:la.
+4.035,word[-3:]:sil
+3.906,word[-3:]:uta
+3.880,BOS
+3.868,word[-2:]:UU
+3.866,word[-3:]:RFA
+3.832,word[-3:]:gro

Weight?,Feature
+3.967,word[-3:]:ies
+3.853,word[-3:]:neo
+3.738,word[-3:]:ose
+3.661,word[-3:]:PSN
+3.661,word[-2:]:SN
+3.648,word[-3:]:eba
+3.615,word[-3:]:tol
+3.543,word[-3:]:oan
+3.538,word[-3:]:San
+3.481,word[-3:]:Mil

Weight?,Feature
+5.270,word[-3:]:iga
+4.795,word[-3:]:Net
+4.584,word[-3:]:mio
+4.573,word[-3:]:IFE
+4.173,word[-3:]:Pro
+4.162,word[-3:]:Ley
+4.141,word[-3:]:.it
+4.052,word[-3:]:rta
+4.048,word[-3:]:PAs
+3.949,word[-3:]:-PP

Weight?,Feature
+3.621,word[-3:]:Oro
+3.512,word[-3:]:mic
+3.428,postag:AO
+3.382,word[-3:]:San
+3.271,word[-3:]:dre
+3.255,word[-2:]:a.
+3.150,postag:RG
+3.129,word[-3:]:mio
+3.116,word[-3:]:uro
+3.091,word[-3:]:000

Weight?,Feature
+6.839,word[-3:]:bca
+6.464,word[-2:]:-e
+5.641,word[-3:]:CiU
+5.641,word[-2:]:iU
+4.262,word[-2:]:Gs
+4.262,word[-3:]:NGs
+4.233,word[-3:]:lGo
+4.233,word[-2:]:Go
+4.224,word[-2:]:iA
+4.224,word[-3:]:UiA

Weight?,Feature
+5.138,word[-3:]:xto
+3.779,word[-3:]:San
+3.530,word[-2:]:L
+3.530,word[-3:]:L
+3.487,word[-3:]:00
+3.378,word[-3:]:ics
+3.347,word[-3:]:API
+3.138,word[-3:]:ews
+3.126,word[-2:]:PI
+3.094,is_num

Weight?,Feature
+5.509,word[-3:]:man
+4.849,BOS
+4.693,word[-3:]:món
+3.986,word[-3:]:yes
+3.983,word[-3:]:dal
+3.929,word[-2:]:'o
+3.929,word[-3:]:o'o
+3.901,word[-3:]:vic
+3.777,word[-3:]:uan
+3.747,word[-3:]:lza

Weight?,Feature
+3.941,word[-3:]:jía
+3.741,word[-3:]:oña
+3.337,word[-3:]:uco
+2.952,word[-3:]:rdt
+2.952,word[-2:]:dt
+2.898,word[-3:]:os.
+2.877,word[-2:]:ez
+2.873,word[-3:]:xon
+2.731,word[-3:]:jal
+2.696,word[-3:]:ñón


### Посчитаем предсказание на тесте:

In [113]:
labels = list(crf.classes_)
labels.remove('O')
labels

['B-LOC', 'B-ORG', 'B-PER', 'I-PER', 'B-MISC', 'I-ORG', 'I-LOC', 'I-MISC']

In [114]:
y_pred = crf.predict(X_test)
metrics.flat_f1_score(y_test, y_pred, average='weighted', labels=labels)

0.7116576057585307

### А теперь отдельно для каждого тэга:

In [115]:
sorted_labels = sorted(
    labels,
    key=lambda name: (name[1:], name[0])
)
print(metrics.flat_classification_report(
    y_test, y_pred, labels=sorted_labels, digits=3
))

              precision    recall  f1-score   support

       B-LOC      0.694     0.693     0.693      1084
       I-LOC      0.529     0.498     0.513       325
      B-MISC      0.623     0.419     0.501       339
      I-MISC      0.559     0.436     0.490       557
       B-ORG      0.752     0.786     0.769      1400
       I-ORG      0.762     0.778     0.770      1104
       B-PER      0.773     0.750     0.761       735
       I-PER      0.841     0.896     0.868       634

   micro avg      0.726     0.708     0.717      6178
   macro avg      0.692     0.657     0.671      6178
weighted avg      0.719     0.708     0.712      6178



### Посмотрим на наиболее и наименее вероятные переходы модели: 

In [116]:
from collections import Counter

def print_transitions(trans_features):
    for (label_from, label_to), weight in trans_features:
        print("%-6s -> %-7s %0.6f" % (label_from, label_to, weight))

print("Top likely transitions:")
print_transitions(Counter(crf.transition_features_).most_common(20))

print("\nTop unlikely transitions:")
print_transitions(Counter(crf.transition_features_).most_common()[-20:])

Top likely transitions:
I-MISC -> I-MISC  4.877970
B-PER  -> I-PER   4.716772
I-LOC  -> I-LOC   4.291041
I-PER  -> I-PER   3.887080
B-LOC  -> I-LOC   3.885361
O      -> O       3.850426
B-ORG  -> I-ORG   3.622206
B-MISC -> I-MISC  3.258901
I-ORG  -> I-ORG   3.104825
O      -> B-MISC  1.078870
O      -> B-ORG   0.714269
O      -> B-PER   0.565233
O      -> B-LOC   0.437593
B-LOC  -> O       0.189898
B-ORG  -> O       -0.046128
I-LOC  -> O       -0.146324
I-MISC -> O       -0.466654
I-PER  -> O       -0.553293
B-PER  -> O       -0.845622
I-ORG  -> O       -0.946583

Top unlikely transitions:
B-MISC -> B-LOC   -4.691146
I-LOC  -> B-ORG   -4.770356
B-MISC -> I-ORG   -4.786298
I-LOC  -> B-LOC   -4.800120
B-PER  -> B-MISC  -4.838547
B-ORG  -> B-MISC  -4.951894
I-LOC  -> B-PER   -4.972026
I-ORG  -> B-ORG   -5.011607
I-PER  -> B-ORG   -5.177793
I-MISC -> B-LOC   -5.179780
B-PER  -> B-ORG   -5.229351
B-LOC  -> B-ORG   -5.372815
B-MISC -> B-MISC  -5.401639
O      -> I-MISC  -5.768159
B-PER  -> B