<a href="https://colab.research.google.com/github/TurkuNLP/intro-to-nlp/blob/master/sequence_labeling_crf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Sequence labeling with CRF

This notebook provides an example of how to perform sequence labeling with a Conditional Random Field (CRF) model.

(This notebook is based on this [notebook for CoNLL 2002 Spanish tagging](https://github.com/TeamHG-Memex/sklearn-crfsuite/blob/master/docs/CoNLL2002.ipynb))

---

## Setup

We'll be using the [sklearn-crfsuite](https://sklearn-crfsuite.readthedocs.io/en/latest/) Python package, which wraps the [CRFsuite](http://www.chokkan.org/software/crfsuite/) CRF implementation.

(Specifically, we'll install a fork that fixes a [versioning incompatibility issue](https://github.com/TeamHG-Memex/sklearn-crfsuite/issues/60) with the integration)

In [1]:
!pip install --quiet git+https://github.com/MeMartijn/updated-sklearn-crfsuite.git#egg=sklearn_crfsuite

  Preparing metadata (setup.py) ... [?25l[?25hdone


To take advantage of the integration of CRFsuite into the [scikit-learn](https://scikit-learn.org/stable/) machine learning library, we'll here be using parts of that library to support data loading and evaluation. (The basic functionality provided by this library is similar to what you could get from the `datasets` and `evaluate` libraries that you should know by now, and there's no need to try to memorize any of the specific details of these libraries.) 

In [2]:
from itertools import chain

import nltk
import sklearn
import scipy.stats
import sklearn.metrics

from nltk.corpus.reader import ConllCorpusReader

from sklearn.metrics import make_scorer
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import RandomizedSearchCV

import sklearn_crfsuite
from sklearn_crfsuite import scorers
from sklearn_crfsuite import metrics
from sklearn_crfsuite.utils import flatten

---

## Load data

We'll use the CoNLL'03 English dataset, which you should be familiar with. We've placed the original data on dl.turkunlp.org, and can download with `wget` as usual: 

In [3]:
!wget -nc --quiet http://dl.turkunlp.org/TKO_7095_2023/eng.train
!wget -nc --quiet http://dl.turkunlp.org/TKO_7095_2023/eng.testa
!wget -nc --quiet http://dl.turkunlp.org/TKO_7095_2023/eng.testb

We'll use the NLTK `ConllCorpusReader` to read in the words, parts of speech and named entity "chunks"

In [4]:
train_data = ConllCorpusReader('', 'eng.train', ['words', 'pos', 'ignore', 'chunk'])
devel_data = ConllCorpusReader('', 'eng.testa', ['words', 'pos', 'ignore', 'chunk'])
test_data = ConllCorpusReader('', 'eng.testa', ['words', 'pos', 'ignore', 'chunk'])

# The comprehension here drops empty sentences used to denote document boundaries 
train_sentences = [s for s in train_data.iob_sents() if s]
devel_sentences = [s for s in devel_data.iob_sents() if s]
test_sentences = [s for s in devel_data.iob_sents() if s]

This reader returns a simple Python list representing the sentences, each of which is represented as a list of tuples of token features (form, POS tag, NE tag).

Note that unlike in `datasets`, the data is here in "human-readable" string form rather than as integer IDs.

In [5]:
train_sentences[0]

[('EU', 'NNP', 'I-ORG'),
 ('rejects', 'VBZ', 'O'),
 ('German', 'JJ', 'I-MISC'),
 ('call', 'NN', 'O'),
 ('to', 'TO', 'O'),
 ('boycott', 'VB', 'O'),
 ('British', 'JJ', 'I-MISC'),
 ('lamb', 'NN', 'O'),
 ('.', '.', 'O')]

---

## Create features

Here a very basic function that returns explicitly defined features for a token

In [6]:
def token_features(tokens, index):
    token = tokens[index][0]
    pos_tag = tokens[index][1]
    
    features = {
        'bias': 1.0,
        'token.text()': token,
        'token.istitle()': token.istitle(),
        'token.isdigit()': token.isdigit(),
        'pos_tag': pos_tag,
    }
               
    return features

Some convenience functions to get features for all tokens in a sentence as well as all labels and all token texts.

In [7]:
def sentence_features(tokens):
    return [token_features(tokens, i) for i in range(len(tokens))]

def sentence_labels(tokens):
    return [label for token, postag, label in tokens]

def sentence_tokens(tokens):
    return [token for token, postag, label in tokens]

Check features for a single token

In [8]:
sentence_features(train_sentences[0])[0]

{'bias': 1.0,
 'token.text()': 'EU',
 'token.istitle()': False,
 'token.isdigit()': False,
 'pos_tag': 'NNP'}

Get features and labels for training and development data

In [9]:
X_train = [sentence_features(s) for s in train_sentences]
y_train = [sentence_labels(s) for s in train_sentences]

X_devel = [sentence_features(s) for s in devel_sentences]
y_devel = [sentence_labels(s) for s in devel_sentences]

---

## Train and evaluate model

Instantiate CRF. For details on the model hyperparameters, see the [CRFsuite documentation](https://www.chokkan.org/software/crfsuite/).

In [10]:
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs', 
    c1=0.1, 
    c2=0.1, 
    max_iterations=100, 
    all_possible_transitions=True
)

Train the model

In [11]:
crf.fit(X_train, y_train)

Get set of POS tags for evaluation, excluding the "Out" tag `O`

In [12]:
labels = list(crf.classes_)
labels.remove('O')
labels

['I-ORG', 'I-MISC', 'I-PER', 'I-LOC', 'B-LOC', 'B-MISC', 'B-ORG']

Predict labels for development set

In [13]:
y_pred = crf.predict(X_devel)

Evaluate predictions against gold standard development set labels

In [14]:
sorted_labels = sorted(labels, key=lambda name: (name[1:], name[0]))

print(sklearn.metrics.classification_report(
    flatten(y_devel),
    flatten(y_pred),
    labels=sorted_labels
))

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

       B-LOC       0.00      0.00      0.00         0
       I-LOC       0.87      0.82      0.85      2094
      B-MISC       0.00      0.00      0.00         4
      I-MISC       0.93      0.76      0.83      1264
       B-ORG       0.00      0.00      0.00         0
       I-ORG       0.81      0.76      0.78      2092
       I-PER       0.87      0.91      0.89      3149

   micro avg       0.86      0.83      0.84      8603
   macro avg       0.50      0.46      0.48      8603
weighted avg       0.86      0.83      0.84      8603



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


---

## Analyze classifier

We can look at `transition_features_` to analyze the transition weights learned by the model:

In [15]:
from collections import Counter

def print_transitions(trans_features):
    for (label_from, label_to), weight in trans_features:
        print("%-6s -> %-7s %0.6f" % (label_from, label_to, weight))

print("Most likely transitions:")
print_transitions(Counter(crf.transition_features_).most_common(10))

print("\nLeast likely transitions:")
print_transitions(Counter(crf.transition_features_).most_common()[-10:])

Most likely transitions:
B-ORG  -> B-ORG   5.427859
I-MISC -> B-MISC  4.697286
I-ORG  -> I-ORG   3.499751
I-MISC -> I-MISC  3.493192
I-PER  -> I-PER   3.042392
B-MISC -> I-MISC  3.006700
I-LOC  -> B-LOC   2.983642
I-LOC  -> I-LOC   1.332836
O      -> I-MISC  1.299373
O      -> O       1.245926

Least likely transitions:
I-MISC -> I-ORG   -2.181237
B-MISC -> I-ORG   -2.267659
B-ORG  -> O       -2.469499
I-MISC -> I-LOC   -2.534070
I-PER  -> I-MISC  -3.184330
I-ORG  -> I-LOC   -3.394631
I-PER  -> I-LOC   -3.760992
I-LOC  -> I-PER   -4.260426
I-LOC  -> I-ORG   -4.329928
I-PER  -> I-ORG   -5.490100


Similarly, we can look at `state_features_` to look at the probabilities of the explicit features we introduced:

In [17]:
def print_state_features(state_features):
    for (attr, label), weight in state_features:
        print("%0.6f %-8s %s" % (weight, label, attr))    

print("Top positive:")
print_state_features(Counter(crf.state_features_).most_common(10))

print("\nTop negative:")
print_state_features(Counter(crf.state_features_).most_common()[-10:])

Top positive:
9.048645 O        token.text():Minister
8.736580 I-MISC   token.text():GMT
8.716987 I-MISC   token.text():DUTCH
8.449610 O        token.text():President
8.298575 I-LOC    token.text():AMSTERDAM
8.130931 I-ORG    token.text():1860
8.032552 I-LOC    token.text():BONN
7.993696 I-ORG    token.text():OSCE
7.771414 I-PER    token.text():Inzamam-ul-Haq
7.749895 I-LOC    token.text():ATHENS

Top negative:
-2.377139 O        token.text():serie
-2.407187 O        token.text():de
-2.409392 O        pos_tag:NNPS
-2.496532 O        token.text():154
-2.781862 O        token.text():95
-2.858090 O        token.text():04
-3.306775 B-MISC   bias
-3.451022 O        token.text():NEW
-3.565962 O        token.text():ST
-4.336524 B-LOC    bias
