# Lab 7: Sequence Labelling

This will be our last lab of the unit -- after this we focus on the coursework. You may wish to split this lab over the next two weeks.

In this lab we will implement an HMM for part-of-speech (POS) tagging, then see how to use a library to run an HMM and a CRF for named entity recognition.

### Outcomes
* Be able to implement the main parts of a supervised HMM.
* Understand what the steps of Viterbi are doing.
* Know how to use a CRF and specify the features it uses.

### Overview

The first part of the notebook loads a POS dataset from the NLTK library.
The second part implements and tests an HMM POS tagger.
The third part uses CRFs for the task of named entity recognition (to be covered in week 8).

## Background

### Hidden Markov Model (HMM)

In this lab we will look into Hidden Markov Model (HMM) to model sequential data. HMMs are based on the Markov assumption which states that the present state $z_n$ is sufficient to predict the future $y_{n+1}$ so the past $y_{0:n-1}$ can be forgotten.

Often, the states we are interested in cannot be observed directly -- they are 'hidden'. As we will see in Exercise 2, the part-of-speech (POS) tags are hidden states that we want to predict. We can only observe the words, and have to use them to infer the tags. 
An HMM is specified by the following components:

* A set of $N$ states.

* A transition probability matrix $A$ where each element $a_{ij}$ represents the probability of moving from state $i$ to state $j$, s.t.  $ \sum^{N}_{j=1} a_{ij} = 1$ $ \forall i$

* An emission probability distribution, the probabilities of observations $x_n$ being generated from a state $y_n$

* An initial probability distribution over states. $\pi_n$ is the probability that the Markov chain will start in state $n$. Also, $\sum^{N}_{n=1} \pi_n = 1 $

## Part of Speech (POS) Tagging

Part-of-speech (POS) tagging enables the extraction of meaningful information about words in a sentence and their relation to neighbouring words. Parts of speech are useful features for labeling named entities like people or organisations in information extraction. A word’s part of speech can even play a role in speech recognition or speech synthesis, e.g., the word 'content' is pronounced CONtent when it is a noun and conTENT when it is an adjective.

POS tagging is the process of assigning a POS tag to each token in a text. The input to a tagging algorithm is a sequence of tokenised words and the output is a sequence of tags, one per token. Many words can have different POS tags in difference contexts, so the goal is to find the correct tag for a particular situation. For example, "book" can be a verb ("book that flight") or a noun ("hand me that book"). POS-tagging resolves these ambiguities by choosing the proper tag for the context. POS-tagging is a specific case of the more generic NLP task of sequence labelling.


# 1. Preparing the PoS Tagging Data

For POS tagging, we are going to start with the [Brown corpus](https://www.nltk.org/nltk_data/), which contains many different sources of English text (books, essays, newspaper articles, government documents...) collected and hand-labelled by linguists in 1967.

If you would like to try out POS tagging in another language, just uncomment the code below to switch to the [NLTK Indian corpus](https://www.nltk.org/_modules/nltk/corpus/reader/indian.html), which contains datasets for POS tagging in Bangla, Hindi, Marathi and Telugu.

In [1]:
import nltk

# If you want to try another language, try Hindi:
# nltk.download('indian')  # the dataset
# from nltk.corpus import indian
# nltk_data = list(indian.tagged_sents('hindi.pos'))

nltk.download('brown')  # download Brown corpus
nltk.download('universal_tagset')   # download the POS tags data
from nltk.corpus import brown
nltk_data = list(brown.tagged_sents(tagset='universal'))

[nltk_data] Downloading package brown to
[nltk_data]     C:\Users\griev\AppData\Roaming\nltk_data...
[nltk_data]   Package brown is already up-to-date!
[nltk_data] Downloading package universal_tagset to
[nltk_data]     C:\Users\griev\AppData\Roaming\nltk_data...
[nltk_data]   Package universal_tagset is already up-to-date!


First we need to split the dataset into training and test sets. Run the following cells to achieve that and to prepare the dataset in the correct format.

In [2]:
from sklearn.model_selection import train_test_split
import numpy as np

train_set, test_set = train_test_split(
    nltk_data,
    train_size=0.80,  # use 80% as the training data
    test_size=0.20,
    random_state=101
)
print(f'Number of training sentences: {len(train_set)}')
print(f'Number of test sentences: {len(test_set)}')

# Separate the labels from the text
train_toks = []  # each item in the list is a list of tokens in a document
train_tags = []  # each item in the list is a list of corresponding tags
for tagged_sentence in train_set:
    sentence_toks = []
    sentence_tags = []
    for token, tag in tagged_sentence:
        sentence_toks.append(token)
        sentence_tags.append(tag)

    train_toks.append(sentence_toks)
    train_tags.append(sentence_tags)

test_toks = []
test_tags = []
for tagged_sentence in test_set:
    sentence_toks = []
    sentence_tags = []
    for token, tag in tagged_sentence:
        sentence_toks.append(token)
        sentence_tags.append(tag)
    test_toks.append(sentence_toks)
    test_tags.append(sentence_tags)

print(f'Number of training sentences in train_toks: {len(train_toks)}')
print(f'Number of test sentences in test_toks: {len(test_toks)}')

Number of training sentences: 45872
Number of test sentences: 11468
Number of training sentences in train_toks: 45872
Number of test sentences in test_toks: 11468


Let's see how many words the vocabulary has:

In [3]:
# create list of train and test tagged words
print(f'Number of tagged tokens in the training set: {len([ tok for sent in train_toks for tok in sent ])}')
print(f'Number of tagged tokens in the test set: {len([ tok for sent in test_toks for tok in sent ])}')

Number of tagged tokens in the training set: 927092
Number of tagged tokens in the test set: 234100


Let's expore the different types of tags by running the next cell.

In [4]:
unique_tags = {tag for sent in train_tags for tag in sent}
print(f'Number of possible tags: {len(unique_tags)}')
print(f'Possible tags: {unique_tags}')

Number of possible tags: 12
Possible tags: {'DET', 'ADV', 'ADP', 'X', 'NOUN', '.', 'PRT', 'NUM', 'PRON', 'CONJ', 'ADJ', 'VERB'}


TODO 1.1: Find out what the tags mean at https://github.com/slavpetrov/universal-pos-tags.

The next cell shows an exampes sentence from the dataset.

In [5]:
print('Sentence example: {}'.format(train_set[3]))

Sentence example: [('many', 'ADJ'), ('of', 'ADP'), ('their', 'DET'), ('gifted', 'ADJ'), ('members', 'NOUN'), ('were', 'VERB'), ('prominent', 'ADJ'), ('in', 'ADP'), ('the', 'DET'), ('Vatican', 'NOUN'), ('as', 'ADP'), ('physicians', 'NOUN'), (',', '.'), ('musicians', 'NOUN'), (',', '.'), ('bankers', 'NOUN'), ('.', '.')]


Let's build a vocabulary and convert each token to its index:

In [6]:
from gensim.corpora import Dictionary

# Convert the tokens to IDs in a vocabulary ready for input to our models
dictionary = Dictionary(train_toks + test_toks)

train_toks_encoded = [dictionary.doc2idx(sent) for sent in train_toks]
test_toks_encoded = [dictionary.doc2idx(sent) for sent in test_toks]
print(f'Example sentence: {train_toks_encoded[3]}')

V = len(dictionary.values())  # vocabulary
print(f'Size of vocabulary is {V}')

Example sentence: [41, 28, 46, 40, 42, 47, 45, 23, 31, 37, 38, 44, 11, 43, 11, 39, 0]
Size of vocabulary is 56057


Now, we convert the tags to their indexes:

In [7]:
from sklearn.preprocessing import LabelEncoder

# Convert the tags from their names to numbers
tag_encoder = LabelEncoder()
tag_encoder.fit([tag for sentence in train_tags for tag in sentence])
train_tags_encoded = [tag_encoder.transform(sentence) for sentence in train_tags]
test_tags_encoded = [tag_encoder.transform(sentence) for sentence in test_tags]

num_tags = len(tag_encoder.classes_)

# 2 Implementing the HMM

The two main components of HMMs are the transition model and the observation (or emission) model. The transition model estimates $P(tag_{t+1}|tag_t)$, the probability of the next tag given the current tag. For discrete features, such as tokens, the observation model is the same as the naïve Bayes classifier that we covered in lab 2. It estimates $P(word|tag_t)$, the probability of observing a word given the current tag. 

Let's start by implementing the transition matrix. 

TODO 2.1: Count the state (tag) transitions and starting state (tag) occurrences in the training set and store the counts in the `transitions` and `start_states` matrices below. In `transitions`, rows correspond to states at time t-1, the columns to the following state at time t.

In [11]:
from tqdm import tqdm  # progress bars

transitions = np.zeros((num_tags, num_tags))
start_states = np.zeros(num_tags)

for sentence_tags in tqdm(train_tags_encoded):
    for i, tag in enumerate(sentence_tags):
        if i==0:
            start_states[tag] += 1
            continue
        


  0%|          | 38/45872 [00:00<04:20, 176.26it/s]

[ 5  6 10  3  2  5  6  9  5  6  0]
[ 5  6 10  3  2  5  6  9  5  6  0]
[ 5  6 10  3  2  5  6  9  5  6  0]
[ 5  6 10  3  2  5  6  9  5  6  0]
[ 5  6 10  3  2  5  6  9  5  6  0]
[ 5  6 10  3  2  5  6  9  5  6  0]
[ 5  6 10  3  2  5  6  9  5  6  0]
[ 5  6 10  3  2  5  6  9  5  6  0]
[ 5  6 10  3  2  5  6  9  5  6  0]
[ 5  6 10  3  2  5  6  9  5  6  0]
[ 6  0  3  2  6  4  6  6  0 10  5  1  1  6  3 10  6 10  2  6  2  5  1  6
  4  6  0]
[ 6  0  3  2  6  4  6  6  0 10  5  1  1  6  3 10  6 10  2  6  2  5  1  6
  4  6  0]
[ 6  0  3  2  6  4  6  6  0 10  5  1  1  6  3 10  6 10  2  6  2  5  1  6
  4  6  0]
[ 6  0  3  2  6  4  6  6  0 10  5  1  1  6  3 10  6 10  2  6  2  5  1  6
  4  6  0]
[ 6  0  3  2  6  4  6  6  0 10  5  1  1  6  3 10  6 10  2  6  2  5  1  6
  4  6  0]
[ 6  0  3  2  6  4  6  6  0 10  5  1  1  6  3 10  6 10  2  6  2  5  1  6
  4  6  0]
[ 6  0  3  2  6  4  6  6  0 10  5  1  1  6  3 10  6 10  2  6  2  5  1  6
  4  6  0]
[ 6  0  3  2  6  4  6  6  0 10  5  1  1  6  3 10  6 10  2  6  

  0%|          | 75/45872 [00:00<04:31, 168.93it/s]

[ 3  0  1  6  6  6 10 10 10  1  0  2 10  3  2  5  6  0]
[ 3  0  1  6  6  6 10 10 10  1  0  2 10  3  2  5  6  0]
[ 3  0  1  6  6  6 10 10 10  1  0  2 10  3  2  5  6  0]
[ 3  0  1  6  6  6 10 10 10  1  0  2 10  3  2  5  6  0]
[ 3  0  1  6  6  6 10 10 10  1  0  2 10  3  2  5  6  0]
[ 3  0  1  6  6  6 10 10 10  1  0  2 10  3  2  5  6  0]
[ 3  0  1  6  6  6 10 10 10  1  0  2 10  3  2  5  6  0]
[ 3  0  1  6  6  6 10 10 10  1  0  2 10  3  2  5  6  0]
[ 3  0  1  6  6  6 10 10 10  1  0  2 10  3  2  5  6  0]
[ 3  0  1  6  6  6 10 10 10  1  0  2 10  3  2  5  6  0]
[ 3  0  1  6  6  6 10 10 10  1  0  2 10  3  2  5  6  0]
[ 3  0  1  6  6  6 10 10 10  1  0  2 10  3  2  5  6  0]
[ 3  0  1  6  6  6 10 10 10  1  0  2 10  3  2  5  6  0]
[ 3  0  1  6  6  6 10 10 10  1  0  2 10  3  2  5  6  0]
[ 3  0  1  6  6  6 10 10 10  1  0  2 10  3  2  5  6  0]
[ 3  0  1  6  6  6 10 10 10  1  0  2 10  3  2  5  6  0]
[ 3  0  1  6  6  6 10 10 10  1  0  2 10  3  2  5  6  0]
[0 6 1 0 0]
[0 6 1 0 0]
[0 6 1 0 0]
[0 6 1 0 0]


  0%|          | 110/45872 [00:00<04:36, 165.26it/s]

[ 2  5  1  6  8 10  9  5  6  0]
[ 2  5  1  6  8 10  9  5  6  0]
[ 2  5  1  6  8 10  9  5  6  0]
[ 2  5  1  6  8 10  9  5  6  0]
[ 2  5  1  6  8 10  9  5  6  0]
[ 2  5  1  6  8 10  9  5  6  0]
[ 2  5  1  6  8 10  9  5  6  0]
[ 2  6  0  2  5  6  2  5  6 10  0  5  6 10  0]
[ 2  6  0  2  5  6  2  5  6 10  0  5  6 10  0]
[ 2  6  0  2  5  6  2  5  6 10  0  5  6 10  0]
[ 2  6  0  2  5  6  2  5  6 10  0  5  6 10  0]
[ 2  6  0  2  5  6  2  5  6 10  0  5  6 10  0]
[ 2  6  0  2  5  6  2  5  6 10  0  5  6 10  0]
[ 2  6  0  2  5  6  2  5  6 10  0  5  6 10  0]
[ 2  6  0  2  5  6  2  5  6 10  0  5  6 10  0]
[ 2  6  0  2  5  6  2  5  6 10  0  5  6 10  0]
[ 2  6  0  2  5  6  2  5  6 10  0  5  6 10  0]
[ 2  6  0  2  5  6  2  5  6 10  0  5  6 10  0]
[ 2  6  0  2  5  6  2  5  6 10  0  5  6 10  0]
[ 2  6  0  2  5  6  2  5  6 10  0  5  6 10  0]
[ 2  6  0  2  5  6  2  5  6 10  0  5  6 10  0]
[ 6  6 10  2 10  6  2  6  3  3  2  7  6  6  0  5  1  7  2  8  5  6 10 10
  0  0]
[ 6  6 10  2 10  6  2  6  3  3  2  7 

  0%|          | 142/45872 [00:00<05:55, 128.51it/s]

[10  6  2  5  6  2  9  5  6  2  5  6 10 10  1  2  6  6  0  8 10  2  5  6
  2  5 10  6 10  0  9 10  0  3  4  3  0  3  5 10  1  0  4  5 10  1  0  0]
[ 8 10  2  5  6  6  2  5  6  0]
[ 8 10  2  5  6  6  2  5  6  0]
[ 8 10  2  5  6  6  2  5  6  0]
[ 8 10  2  5  6  6  2  5  6  0]
[ 8 10  2  5  6  6  2  5  6  0]
[ 8 10  2  5  6  6  2  5  6  0]
[ 8 10  2  5  6  6  2  5  6  0]
[ 8 10  2  5  6  6  2  5  6  0]
[ 8 10  2  5  6  6  2  5  6  0]
[ 3 10  2  6  7  0  5  1  6 10 10  2  4  2  5  6  4 10  7  6  5  6  2  1
  6  6  0]
[ 3 10  2  6  7  0  5  1  6 10 10  2  4  2  5  6  4 10  7  6  5  6  2  1
  6  6  0]
[ 3 10  2  6  7  0  5  1  6 10 10  2  4  2  5  6  4 10  7  6  5  6  2  1
  6  6  0]
[ 3 10  2  6  7  0  5  1  6 10 10  2  4  2  5  6  4 10  7  6  5  6  2  1
  6  6  0]
[ 3 10  2  6  7  0  5  1  6 10 10  2  4  2  5  6  4 10  7  6  5  6  2  1
  6  6  0]
[ 3 10  2  6  7  0  5  1  6 10 10  2  4  2  5  6  4 10  7  6  5  6  2  1
  6  6  0]
[ 3 10  2  6  7  0  5  1  6 10 10  2  4  2  5  6  4 10  7  6 

  0%|          | 189/45872 [00:01<04:25, 172.25it/s]

[ 2  5  6 10  5  6  6  0  1  4  3  6  6 10 10  2  5  6  1  6  2  1  6  2
  5  1  6  0]
[ 2  5  6 10  5  6  6  0  1  4  3  6  6 10 10  2  5  6  1  6  2  1  6  2
  5  1  6  0]
[ 2  5  6 10  5  6  6  0  1  4  3  6  6 10 10  2  5  6  1  6  2  1  6  2
  5  1  6  0]
[ 2  5  6 10  5  6  6  0  1  4  3  6  6 10 10  2  5  6  1  6  2  1  6  2
  5  1  6  0]
[ 2  5  6 10  5  6  6  0  1  4  3  6  6 10 10  2  5  6  1  6  2  1  6  2
  5  1  6  0]
[ 2  5  6 10  5  6  6  0  1  4  3  6  6 10 10  2  5  6  1  6  2  1  6  2
  5  1  6  0]
[ 2  5  6 10  5  6  6  0  1  4  3  6  6 10 10  2  5  6  1  6  2  1  6  2
  5  1  6  0]
[ 2  5  6 10  5  6  6  0  1  4  3  6  6 10 10  2  5  6  1  6  2  1  6  2
  5  1  6  0]
[ 2  5  6 10  5  6  6  0  1  4  3  6  6 10 10  2  5  6  1  6  2  1  6  2
  5  1  6  0]
[ 2  5  6 10  5  6  6  0  1  4  3  6  6 10 10  2  5  6  1  6  2  1  6  2
  5  1  6  0]
[ 2  5  6 10  5  6  6  0  1  4  3  6  6 10 10  2  5  6  1  6  2  1  6  2
  5  1  6  0]
[ 2  5  6 10  5  6  6  0  1  4  3  6  6 10 

  0%|          | 227/45872 [00:01<04:47, 158.56it/s]

[ 5  1  6  0 10  2  6  0 10  5  6  6  2  6  4  6  9 10  2  1  1  6  0  3
  1  2  2  6  8  3 10  5  6  5  6 10 10  2  0]
[ 5  1  6  0 10  2  6  0 10  5  6  6  2  6  4  6  9 10  2  1  1  6  0  3
  1  2  2  6  8  3 10  5  6  5  6 10 10  2  0]
[ 5  1  6  0 10  2  6  0 10  5  6  6  2  6  4  6  9 10  2  1  1  6  0  3
  1  2  2  6  8  3 10  5  6  5  6 10 10  2  0]
[ 0 10  8  0  0  8 10  0 10  8  9  0]
[ 0 10  8  0  0  8 10  0 10  8  9  0]
[ 0 10  8  0  0  8 10  0 10  8  9  0]
[ 0 10  8  0  0  8 10  0 10  8  9  0]
[ 0 10  8  0  0  8 10  0 10  8  9  0]
[ 0 10  8  0  0  8 10  0 10  8  9  0]
[ 0 10  8  0  0  8 10  0 10  8  9  0]
[ 0 10  8  0  0  8 10  0 10  8  9  0]
[ 0 10  8  0  0  8 10  0 10  8  9  0]
[ 0 10  8  0  0  8 10  0 10  8  9  0]
[ 0 10  8  0  0  8 10  0 10  8  9  0]
[ 4  1  6 10  5 10  6  3  2  6  0  4  2  5  1  0 10  6  2  5  6  0]
[ 4  1  6 10  5 10  6  3  2  6  0  4  2  5  1  0 10  6  2  5  6  0]
[ 4  1  6 10  5 10  6  3  2  6  0  4  2  5  1  0 10  6  2  5  6  0]
[ 4  1  6 10  5 10

  1%|          | 270/45872 [00:01<04:38, 163.47it/s]

[1 0]
[ 8 10  5  6  2  6  0]
[ 8 10  5  6  2  6  0]
[ 8 10  5  6  2  6  0]
[ 8 10  5  6  2  6  0]
[ 8 10  5  6  2  6  0]
[ 8 10  5  6  2  6  0]
[ 2 10  3  1  8 10  5  6  2  6  0 10  3  2  8  3 10  2  2  5  1  6 10  1
  4  5  6  1  0]
[ 2 10  3  1  8 10  5  6  2  6  0 10  3  2  8  3 10  2  2  5  1  6 10  1
  4  5  6  1  0]
[ 2 10  3  1  8 10  5  6  2  6  0 10  3  2  8  3 10  2  2  5  1  6 10  1
  4  5  6  1  0]
[ 2 10  3  1  8 10  5  6  2  6  0 10  3  2  8  3 10  2  2  5  1  6 10  1
  4  5  6  1  0]
[ 2 10  3  1  8 10  5  6  2  6  0 10  3  2  8  3 10  2  2  5  1  6 10  1
  4  5  6  1  0]
[ 2 10  3  1  8 10  5  6  2  6  0 10  3  2  8  3 10  2  2  5  1  6 10  1
  4  5  6  1  0]
[ 2 10  3  1  8 10  5  6  2  6  0 10  3  2  8  3 10  2  2  5  1  6 10  1
  4  5  6  1  0]
[ 2 10  3  1  8 10  5  6  2  6  0 10  3  2  8  3 10  2  2  5  1  6 10  1
  4  5  6  1  0]
[ 2 10  3  1  8 10  5  6  2  6  0 10  3  2  8  3 10  2  2  5  1  6 10  1
  4  5  6  1  0]
[ 2 10  3  1  8 10  5  6  2  6  0 10  3  2  8 

  1%|          | 315/45872 [00:01<03:58, 190.63it/s]

[ 8 10  3  2  6  4 10  2  5  1  6  0]
[ 8 10  3  2  6  4 10  2  5  1  6  0]
[ 8 10  3  2  6  4 10  2  5  1  6  0]
[ 8 10  3  2  6  4 10  2  5  1  6  0]
[ 8 10  3  2  6  4 10  2  5  1  6  0]
[ 8 10  3  2  6  4 10  2  5  1  6  0]
[ 8 10  3  2  6  4 10  2  5  1  6  0]
[ 8 10  3  2  6  4 10  2  5  1  6  0]
[ 8 10  3  2  6  4 10  2  5  1  6  0]
[ 8 10  3  2  6  4 10  2  5  1  6  0]
[ 8 10  3  2  6  4 10  2  5  1  6  0]
[ 5 10  6  2  6  0]
[ 5 10  6  2  6  0]
[ 5 10  6  2  6  0]
[ 5 10  6  2  6  0]
[ 5 10  6  2  6  0]
[ 1  6  0  9 10  5  1  1  0  0  3  2  9  1  0  3 10  1  6  0]
[ 1  6  0  9 10  5  1  1  0  0  3  2  9  1  0  3 10  1  6  0]
[ 1  6  0  9 10  5  1  1  0  0  3  2  9  1  0  3 10  1  6  0]
[ 1  6  0  9 10  5  1  1  0  0  3  2  9  1  0  3 10  1  6  0]
[ 1  6  0  9 10  5  1  1  0  0  3  2  9  1  0  3 10  1  6  0]
[ 1  6  0  9 10  5  1  1  0  0  3  2  9  1  0  3 10  1  6  0]
[ 1  6  0  9 10  5  1  1  0  0  3  2  9  1  0  3 10  1  6  0]
[ 1  6  0  9 10  5  1  1  0  0  3  2  9  1  0  3

  1%|          | 364/45872 [00:02<03:32, 214.37it/s]

[ 5  6  2 10  5  1  0 10  0  1  1  6  6  2  5  6 10  3  1  2  6  4  6 10
  3 10  2  5  6 10  3  0]
[ 5  6  2 10  5  1  0 10  0  1  1  6  6  2  5  6 10  3  1  2  6  4  6 10
  3 10  2  5  6 10  3  0]
[ 5  6  2 10  5  1  0 10  0  1  1  6  6  2  5  6 10  3  1  2  6  4  6 10
  3 10  2  5  6 10  3  0]
[ 5  6  2 10  5  1  0 10  0  1  1  6  6  2  5  6 10  3  1  2  6  4  6 10
  3 10  2  5  6 10  3  0]
[ 5  6  2 10  5  1  0 10  0  1  1  6  6  2  5  6 10  3  1  2  6  4  6 10
  3 10  2  5  6 10  3  0]
[ 5  6  2 10  5  1  0 10  0  1  1  6  6  2  5  6 10  3  1  2  6  4  6 10
  3 10  2  5  6 10  3  0]
[ 5  6  2 10  5  1  0 10  0  1  1  6  6  2  5  6 10  3  1  2  6  4  6 10
  3 10  2  5  6 10  3  0]
[ 5  6  2 10  5  1  0 10  0  1  1  6  6  2  5  6 10  3  1  2  6  4  6 10
  3 10  2  5  6 10  3  0]
[ 5  6  2 10  5  1  0 10  0  1  1  6  6  2  5  6 10  3  1  2  6  4  6 10
  3 10  2  5  6 10  3  0]
[ 5  6  2 10  5  1  0 10  0  1  1  6  6  2  5  6 10  3  1  2  6  4  6 10
  3 10  2  5  6 10  3  0]
[ 5  6  2 

  1%|          | 386/45872 [00:02<04:36, 164.37it/s]

[ 8 10  1  9 10  3  1  2  5  6  2  6  6  7 10  1  2  5  6  4  2  5  6  2
  1  6  6  0  4  2  5  6  2  6  5 10  2  1  5  1  6  0]
[ 8 10  1  9 10  3  1  2  5  6  2  6  6  7 10  1  2  5  6  4  2  5  6  2
  1  6  6  0  4  2  5  6  2  6  5 10  2  1  5  1  6  0]
[ 8 10  1  9 10  3  1  2  5  6  2  6  6  7 10  1  2  5  6  4  2  5  6  2
  1  6  6  0  4  2  5  6  2  6  5 10  2  1  5  1  6  0]
[ 8 10  1  9 10  3  1  2  5  6  2  6  6  7 10  1  2  5  6  4  2  5  6  2
  1  6  6  0  4  2  5  6  2  6  5 10  2  1  5  1  6  0]
[ 8 10  1  9 10  3  1  2  5  6  2  6  6  7 10  1  2  5  6  4  2  5  6  2
  1  6  6  0  4  2  5  6  2  6  5 10  2  1  5  1  6  0]
[ 8 10  1  9 10  3  1  2  5  6  2  6  6  7 10  1  2  5  6  4  2  5  6  2
  1  6  6  0  4  2  5  6  2  6  5 10  2  1  5  1  6  0]
[ 8 10  1  9 10  3  1  2  5  6  2  6  6  7 10  1  2  5  6  4  2  5  6  2
  1  6  6  0  4  2  5  6  2  6  5 10  2  1  5  1  6  0]
[ 8 10  1  9 10  3  1  2  5  6  2  6  6  7 10  1  2  5  6  4  2  5  6  2
  1  6  6  0  4  2  5  6

  1%|          | 405/45872 [00:02<05:20, 141.65it/s]

[ 8 10  5  6  2  5  5  2  5  6 10  2  5  6 10 10  1  9 10  5  1  6 10  2
  5  6  0]
[ 8 10  5  6  2  5  5  2  5  6 10  2  5  6 10 10  1  9 10  5  1  6 10  2
  5  6  0]
[ 8 10  5  6  2  5  5  2  5  6 10  2  5  6 10 10  1  9 10  5  1  6 10  2
  5  6  0]
[ 8 10  5  6  2  5  5  2  5  6 10  2  5  6 10 10  1  9 10  5  1  6 10  2
  5  6  0]
[ 8 10  5  6  2  5  5  2  5  6 10  2  5  6 10 10  1  9 10  5  1  6 10  2
  5  6  0]
[ 8 10  5  6  2  5  5  2  5  6 10  2  5  6 10 10  1  9 10  5  1  6 10  2
  5  6  0]
[ 8 10  5  6  2  5  5  2  5  6 10  2  5  6 10 10  1  9 10  5  1  6 10  2
  5  6  0]
[ 8 10  5  6  2  5  5  2  5  6 10  2  5  6 10 10  1  9 10  5  1  6 10  2
  5  6  0]
[ 8 10  5  6  2  5  5  2  5  6 10  2  5  6 10 10  1  9 10  5  1  6 10  2
  5  6  0]
[ 8 10  5  6  2  5  5  2  5  6 10  2  5  6 10 10  1  9 10  5  1  6 10  2
  5  6  0]
[ 8 10  5  6  2  5  5  2  5  6 10  2  5  6 10 10  1  9 10  5  1  6 10  2
  5  6  0]
[ 8 10  5  6  2  5  5  2  5  6 10  2  5  6 10 10  1  9 10  5  1  6 10  2
  5

  1%|          | 444/45872 [00:02<04:49, 156.88it/s]

[ 5  1  6  2  8 10  3  3  2 10  2  5  1  6  0  6  0 10  3  2  5 10  5  5
  6  4  5  1  6  8 10  2  0  2  8 10  6  9 10  5  6  2  5  1  6  0]
[ 5  1  6  2  8 10  3  3  2 10  2  5  1  6  0  6  0 10  3  2  5 10  5  5
  6  4  5  1  6  8 10  2  0  2  8 10  6  9 10  5  6  2  5  1  6  0]
[ 5  1  6  2  8 10  3  3  2 10  2  5  1  6  0  6  0 10  3  2  5 10  5  5
  6  4  5  1  6  8 10  2  0  2  8 10  6  9 10  5  6  2  5  1  6  0]
[ 5  1  6  2  8 10  3  3  2 10  2  5  1  6  0  6  0 10  3  2  5 10  5  5
  6  4  5  1  6  8 10  2  0  2  8 10  6  9 10  5  6  2  5  1  6  0]
[ 5  1  6  2  8 10  3  3  2 10  2  5  1  6  0  6  0 10  3  2  5 10  5  5
  6  4  5  1  6  8 10  2  0  2  8 10  6  9 10  5  6  2  5  1  6  0]
[ 5  1  6  2  8 10  3  3  2 10  2  5  1  6  0  6  0 10  3  2  5 10  5  5
  6  4  5  1  6  8 10  2  0  2  8 10  6  9 10  5  6  2  5  1  6  0]
[ 5  1  6  2  8 10  3  3  2 10  2  5  1  6  0  6  0 10  3  2  5 10  5  5
  6  4  5  1  6  8 10  2  0  2  8 10  6  9 10  5  6  2  5  1  6  0]
[ 5  1  6  2 

  1%|          | 480/45872 [00:02<04:52, 155.45it/s]

[ 0  6  3 10  6  5  6  2  6  6  6  2  6 10  6  2  6  6  2  5  1  2  6  4
  5  7  1  1  6  0]
[ 0  6  3 10  6  5  6  2  6  6  6  2  6 10  6  2  6  6  2  5  1  2  6  4
  5  7  1  1  6  0]
[ 0  6  3 10  6  5  6  2  6  6  6  2  6 10  6  2  6  6  2  5  1  2  6  4
  5  7  1  1  6  0]
[ 0  6  3 10  6  5  6  2  6  6  6  2  6 10  6  2  6  6  2  5  1  2  6  4
  5  7  1  1  6  0]
[ 0  6  3 10  6  5  6  2  6  6  6  2  6 10  6  2  6  6  2  5  1  2  6  4
  5  7  1  1  6  0]
[ 0  6  3 10  6  5  6  2  6  6  6  2  6 10  6  2  6  6  2  5  1  2  6  4
  5  7  1  1  6  0]
[ 0  6  3 10  6  5  6  2  6  6  6  2  6 10  6  2  6  6  2  5  1  2  6  4
  5  7  1  1  6  0]
[ 0  6  3 10  6  5  6  2  6  6  6  2  6 10  6  2  6  6  2  5  1  2  6  4
  5  7  1  1  6  0]
[ 0  6  3 10  6  5  6  2  6  6  6  2  6 10  6  2  6  6  2  5  1  2  6  4
  5  7  1  1  6  0]
[ 0  6  3 10  6  5  6  2  6  6  6  2  6 10  6  2  6  6  2  5  1  2  6  4
  5  7  1  1  6  0]
[ 0  6  3 10  6  5  6  2  6  6  6  2  6 10  6  2  6  6  2  5  1  2  6 

  1%|          | 520/45872 [00:03<04:31, 167.05it/s]

[ 8 10  9 10  2  8 10 10  2  8  0]
[ 8 10  9 10  2  8 10 10  2  8  0]
[ 8 10  9 10  2  8 10 10  2  8  0]
[ 8 10  9 10  2  8 10 10  2  8  0]
[ 8 10  9 10  2  8 10 10  2  8  0]
[ 8 10  9 10  2  8 10 10  2  8  0]
[ 8 10  9 10  2  8 10 10  2  8  0]
[ 8 10  9 10  2  8 10 10  2  8  0]
[ 8 10  9 10  2  8 10 10  2  8  0]
[ 3  8  3 10  5  6  2  6  0  8 10  2  6  0  9  8 10 10  5  1  6  0 10  3
 10  2  5  6  0  8 10  3 10  9 10  2  8 10  9 10  0]
[ 3  8  3 10  5  6  2  6  0  8 10  2  6  0  9  8 10 10  5  1  6  0 10  3
 10  2  5  6  0  8 10  3 10  9 10  2  8 10  9 10  0]
[ 3  8  3 10  5  6  2  6  0  8 10  2  6  0  9  8 10 10  5  1  6  0 10  3
 10  2  5  6  0  8 10  3 10  9 10  2  8 10  9 10  0]
[ 3  8  3 10  5  6  2  6  0  8 10  2  6  0  9  8 10 10  5  1  6  0 10  3
 10  2  5  6  0  8 10  3 10  9 10  2  8 10  9 10  0]
[ 3  8  3 10  5  6  2  6  0  8 10  2  6  0  9  8 10 10  5  1  6  0 10  3
 10  2  5  6  0  8 10  3 10  9 10  2  8 10  9 10  0]
[ 3  8  3 10  5  6  2  6  0  8 10  2  6  0  9  8 10 10 

  1%|          | 538/45872 [00:03<05:49, 129.81it/s]

[ 2  7  1  1  6  2  1  6  2  6  6  7  0  0  0  5  6  8 10  2  0  0  0  5
  6 10 10  0  4  0  6  2  5  6  0  0  8 10  5  6  9 10  5  1  6  6  2  0
  5  6  5  6 10  0  0]
[ 2  7  1  1  6  2  1  6  2  6  6  7  0  0  0  5  6  8 10  2  0  0  0  5
  6 10 10  0  4  0  6  2  5  6  0  0  8 10  5  6  9 10  5  1  6  6  2  0
  5  6  5  6 10  0  0]
[ 2  7  1  1  6  2  1  6  2  6  6  7  0  0  0  5  6  8 10  2  0  0  0  5
  6 10 10  0  4  0  6  2  5  6  0  0  8 10  5  6  9 10  5  1  6  6  2  0
  5  6  5  6 10  0  0]
[ 2  7  1  1  6  2  1  6  2  6  6  7  0  0  0  5  6  8 10  2  0  0  0  5
  6 10 10  0  4  0  6  2  5  6  0  0  8 10  5  6  9 10  5  1  6  6  2  0
  5  6  5  6 10  0  0]
[ 2  7  1  1  6  2  1  6  2  6  6  7  0  0  0  5  6  8 10  2  0  0  0  5
  6 10 10  0  4  0  6  2  5  6  0  0  8 10  5  6  9 10  5  1  6  6  2  0
  5  6  5  6 10  0  0]
[ 2  7  1  1  6  2  1  6  2  6  6  7  0  0  0  5  6  8 10  2  0  0  0  5
  6 10 10  0  4  0  6  2  5  6  0  0  8 10  5  6  9 10  5  1  6  6  2  0
  5  6  5

  1%|          | 552/45872 [00:03<04:47, 157.36it/s]


KeyboardInterrupt: 

TODO 2.2: Normalise the transition and start state counts to estimate the conditional probabilities in the transition matrix $A$ and starting state probabilities $\pi$.

In [None]:
### WRITE YOUR CODE HERE


Now let's compute the emission/observation model.

TODO 2.3: Count the number of occurrences of each word type given each tag.

In [None]:
observations = np.zeros((num_tags, V))

for i, sentence_toks in tqdm(enumerate(train_toks_encoded)):
    sentence_tags = train_tags_encoded[i]
    for j, tok in enumerate(sentence_toks):
        tag = sentence_tags[j]
        # WRITE YOUR OWN CODE HERE


TODO 2.4: Normalise the observation counts to obtain the observation probabilities.

In [None]:
#WRITE YOUR OWN CODE HERE


To predict the most likely sequence of tags, we use the Viterbi algorithm, defined in the next cell.

In [None]:
def viterbi(observed_seq, num_tags, start_probs, transition_probs, observation_probs):
    eps = 1e-7

    num_obs = len(observed_seq)

    # Initialise the V and backpointers
    V = np.zeros((num_obs, num_tags))
    backpointer = np.zeros((num_obs, num_tags))

    # For the first data point in the sequence:
    V[0, :] = start_probs * observation_probs[:, observed_seq[0]]

    # Run Viterbi forward for t > 0
    for t in range(1, num_obs):

        for state in range(num_tags):
            # probabilities for all the sequences leading to this state at time t
            seq_prob = V[t-1, :] * transition_probs[:, state]

            # Choose the most likely sequence
            max_seq_prob = np.max(seq_prob)
            best_previous_state = np.argmax(seq_prob)

            # Calculate the probability of the most likely sequence leading to this state at time t, including the current observation.
            # Add eps to help with numerical issues.
            V[t, state] = (max_seq_prob + eps) * (observation_probs[state, observed_seq[t]] + eps)

            backpointer[t, state] = best_previous_state

    t = num_obs - 1

    # Initialise the sequence of predicted states
    state_seq = np.zeros(num_obs, dtype=int)

    # Get the most likely final state:
    state_seq[t] = np.argmax(V[t, :])

    # Backtrack until the first observation
    for t in range(len(observed_seq)-1, 0, -1):
        state_seq[t-1] = backpointer[t, state_seq[t]]

    return state_seq

TODO 2.6: Use the viterbi function to estimate the most likely sequence of states on the test set.

In [None]:
predictions = []
for sentence in tqdm(test_toks_encoded):
    # WRITE YOUR OWN CODE HERE

In [None]:
# Convert the sequence of tag IDs to tag names
predicted_tags = []
for sequence in tqdm(predictions):
    predicted_tags.append(tag_encoder.inverse_transform(sequence))

TODO 2.7: Run the code below to print some example predictions. What kinds of errors does the method make? 

In [None]:
# print some examples:
examples = [2, 334, 4983, 2389]
for eg in examples:
    print(f'Tokens:      {test_toks[eg]}')
    print(f'Gold tag:    {test_tags[eg]}')
    print(f'Predictions: {predicted_tags[eg]}')

In [None]:
# compute accuracy

from sklearn.metrics import accuracy_score

all_predictions = [tag for sentence in predictions for tag in sentence]
all_targets = [tag for sentence in test_tags_encoded for tag in sentence]

acc = accuracy_score(all_targets, all_predictions)
print(f'Accuracy = {acc}')

# 3. Named Entity Recognition (NER) with CRF

Named entity recognition is the task of identifying entities from text, such as people, locations, organisations and times. It is usually modelled as a sequence labelling task. However, many entities span more than one token. To show where these spans start and end, we therefore tag each token as either 'outside' (not part of an entity), 'beginning' or 'inside' (continuation of an entity span). Beginning and inside tags also have the entity type, e.g., "B-Person" or "I-Location". 

We will learn more about NER next week, so you can continue the lab next week if you prefer. Here we will use an NER dataset to learn about CRFs.

Let's load some NER data consisting of English news articles.

In [None]:
from datasets import load_dataset

cache_dir = "./data_cache"

# The data is already divided into training and test sets.
# Load the training set:
train_dataset = load_dataset(
    "conll2003",
    split="train",
    #ignore_verifications=True,
    cache_dir=cache_dir,
)
print(f"Training dataset with {len(train_dataset)} instances loaded")

In [None]:
# Load the test set:
test_dataset = load_dataset(
    "conll2003",
    split="test",
    #ignore_verifications=True,
    cache_dir=cache_dir,
)
print(f"Test dataset with {len(test_dataset)} instances loaded")

Let's take a look at one of the instances in the training set:

In [None]:
train_dataset[0]

As well as NER tags, the dataset includes POS tags and Chunk tags, which identify the grammatical phrases in the sentence.

The tags are all stored by their indexes. The mapping of the POS tags is:

```
{'"': 0, "''": 1, '#': 2, '$': 3, '(': 4, ')': 5, ',': 6, '.': 7, ':': 8, '``': 9, 'CC': 10, 'CD': 11, 'DT': 12,
 'EX': 13, 'FW': 14, 'IN': 15, 'JJ': 16, 'JJR': 17, 'JJS': 18, 'LS': 19, 'MD': 20, 'NN': 21, 'NNP': 22, 'NNPS': 23,
 'NNS': 24, 'NN|SYM': 25, 'PDT': 26, 'POS': 27, 'PRP': 28, 'PRP$': 29, 'RB': 30, 'RBR': 31, 'RBS': 32, 'RP': 33,
 'SYM': 34, 'TO': 35, 'UH': 36, 'VB': 37, 'VBD': 38, 'VBG': 39, 'VBN': 40, 'VBP': 41, 'VBZ': 42, 'WDT': 43,
 'WP': 44, 'WP$': 45, 'WRB': 46}
 ```
 
The mapping from indexes to NER tags is:

```
{'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}
```
 

In [None]:
ner_tag_mapping = {0: 'O', 1:'B-PER', 2:'I-PER', 3:'B-ORG', 4:'I-ORG', 5:'B-LOC', 6:'I-LOC', 7:'B-MISC', 8:'I-MISC'}

Let's put the NER data in the right format for NLTK's CRFTagger class:

In [None]:
train_set = [list(zip(s['tokens'], [ner_tag_mapping[tok] for tok in s['ner_tags']])) for s in train_dataset][:-1]
test_set = [list(zip(s['tokens'], [ner_tag_mapping[tok] for tok in s['ner_tags']])) for s in test_dataset][:-1]
test_tokens = [s['tokens'] for s in test_dataset][:-1]
test_tags = [[ner_tag_mapping[tok] for tok in s['ner_tags']] for s in test_dataset][:-1]

Now, let's train a CRF tagger on our training set. The method you need to use from NLTK is the [train method of the conditional random field (CRF)](https://www.nltk.org/_modules/nltk/tag/crf.html). You need to call the constructor with default arguments, then the train() function.

TODO 3.1: Write a function to train and return a CRF named entity recogniser.

In [None]:
import nltk

# Train a CRF NER tagger
def train_CRF_NER_tagger(train_set):
    ### WRITE YOUR OWN CODE HERE


tagger = train_CRF_NER_tagger(train_set)

Get some predictions from the tagger:

In [None]:
predicted_tags = tagger.tag_sents(test_tokens)

Let's see how well the tagger is performing. In NER, we evaluate performance by finding correctly matched entities, rather than correctly tagged tokens. Only an exact entity match counts as correct. Therefore, we need to compute precision, recall and F1 score by computing true positives, false positives and false negatives by looking for the predicted entity spans and the gold-labelled entity spans in the test set.

The code below contains a function that extract a list of spans from the tagged sentences. The next function calls extract_spans() and computes the precision, recall and f1 scores. However, the function is incomplete.

Run the cal_span_level_F1() function below to compute span-level F1 scores for the predictions. Have a look at the results. Which types of entity are being recognised well and which are very poor?

In [None]:
def extract_spans(tagged_sents):
    """
    Extract a list of tagged spans for each named entity type, 
    where each span is represented by a tuple containing the 
    start token and end token indexes.
    
    returns: a dictionary containing a list of spans for each entity type.
    """
    spans = {}
        
    for sidx, sent in enumerate(tagged_sents):
        start = -1
        entity_type = None
        for i, (tok, lab) in enumerate(sent):
            if 'B-' in lab:
                start = i
                end = i + 1
                entity_type = lab[2:]
            elif 'I-' in lab:
                end = i + 1
            elif lab == 'O' and start >= 0:
                
                if entity_type not in spans:
                    spans[entity_type] = []
                
                spans[entity_type].append((start, end, sidx))
                start = -1
        # Sometimes an I-token is the last token in the sentence, so we still have to add the span to the list
        if start >= 0:    
            if entity_type not in spans:
                spans[entity_type] = []
                
            spans[entity_type].append((start, end, sidx))    
            
    return spans


def cal_span_level_f1(test_sents, test_sents_with_pred):
    # get a list of spans from the test set labels
    gold_spans = extract_spans(test_sents)

    # get a list of spans predicted by our tagger
    pred_spans = extract_spans(test_sents_with_pred)
    
    # compute the metrics for each class:
    f1_per_class = []
    
    ne_types = gold_spans.keys()  # get the list of named entity types (not the tags)
    
    for ne_type in ne_types:
        # compute the confusion matrix
        true_pos = 0
        false_pos = 0
        
        for span in pred_spans[ne_type]:
            if span in gold_spans[ne_type]:
                true_pos += 1
            else:
                false_pos += 1
                
        false_neg = 0
        for span in gold_spans[ne_type]:
            if span not in pred_spans[ne_type]:
                false_neg += 1
                
        if true_pos + false_pos == 0:
            precision = 0
        else:
            precision = true_pos / float(true_pos + false_pos)
            
        if true_pos + false_neg == 0:
            recall = 0
        else:
            recall = true_pos / float(true_pos + false_neg)
        
        if precision + recall == 0:
            f1 = 0
        else:
            f1 = 2 * precision * recall / (precision + recall)
            
        f1_per_class.append(f1)
        print(f'F1 score for class {ne_type} = {f1}')
        
    print(f'Macro-average f1 score = {np.mean(f1_per_class)}')

cal_span_level_f1(test_set, predicted_tags)

We can try to help the CRF tagger by adding some more features. Part-of-speech tags often provide useful information for identifying entites. The code below defines a modified CRFTagger class that overwrites the ```_get_features()``` method, which extracts the features from the tokens. 

TODO 3.2: Add in the previous and next works as features. Be careful with the start and end of the sequence where there is no previous or next word.

In [None]:
import re, unicodedata

class CustomCRFTagger(nltk.tag.CRFTagger):
    _current_tokens = None
    
    def _get_features(self, tokens, idx):
            """
            Extract basic features about this word including
                - Current word
                - is it capitalized?
                - Does it have punctuation?
                - Does it have a number?
                - Suffixes up to length 3

            Note that : we might include feature over previous word, next word etc.

            :return: a list which contains the features
            :rtype: list(str)
            """
            token = tokens[idx]

            feature_list = []

            if not token:
                return feature_list

            # Capitalization
            if token[0].isupper():
                feature_list.append("CAPITALIZATION")

            # Number
            if re.search(self._pattern, token) is not None:
                feature_list.append("HAS_NUM")

            # Punctuation
            punc_cat = {"Pc", "Pd", "Ps", "Pe", "Pi", "Pf", "Po"}
            if all(unicodedata.category(x) in punc_cat for x in token):
                feature_list.append("PUNCTUATION")

            # Suffix up to length 3
            if len(token) > 1:
                feature_list.append("SUF_" + token[-1:])
            if len(token) > 2:
                feature_list.append("SUF_" + token[-2:])
            if len(token) > 3:
                feature_list.append("SUF_" + token[-3:])

                
            # Current word
            feature_list.append("WORD_" + token)
            
            ### WRITE YOUR OWN CODE HERE ###

            
            ####

            return feature_list
                

TODO 3.3: Train your custom CRF tagger, then test it below. How does it compare to the default tagger? Why did adding the new features change the performance in this way? 

The results show how important it is understand your choice of features.

In [None]:
# Train a CRF NER tagger
def train_CustomCRF_NER_tagger(train_set):
    ### WRITE YOUR OWN CODE HERE


tagger = train_CustomCRF_NER_tagger(train_set)

In [None]:
predicted_tags = tagger.tag_sents(test_tokens)
cal_span_level_f1(test_set, predicted_tags)

OPTIONAL TODO 3.4: POS tags can be used as features for tasks like NER. Complete the code below to define another custom CRF tagger that also include POS tags as features. 

In [None]:
# *** Improve the CRF NER tagger using parts of speech (see lab 5) as additional features.
class CRFTaggerWithPOS(CustomCRFTagger):
    _current_tokens = None
    
    def _get_features(self, tokens, index):
        """
        Extract the features for a token and append the POS tag as an additional feature.
        """
        basic_features = super()._get_features(tokens, index)
        
        # Get the pos tags for the current sentence and save it
        if tokens != self._current_tokens:
            self._pos_tagged_tokens = nltk.pos_tag(tokens)
            self._current_tokens = tokens
            
            
        ### WRITE YOUR OWN CODE HERE

        ###
        
        return basic_features

In [None]:
# Train a CRF NER tagger
def train_CRF_NER_tagger_with_POS(train_set):
    ### WRITE YOUR OWN CODE HERE


tagger = train_CRF_NER_tagger_with_POS(train_set)

In [None]:
predicted_tags = tagger.tag_sents(test_tokens)
cal_span_level_f1(test_set, predicted_tags)