In [1]:
import json, re, os, nltk, itertools, sklearn
import numpy as np
from nltk.corpus import treebank
from collections import Counter
import sklearn.metrics as metr

Get the data

In [2]:
corpus = list(treebank.tagged_words())

Filter for uninteresting tags

In [3]:
corpus = [(w, t) for w, t in corpus if t not in ['-NONE-', '-RRB-', '-LRB-']]

In [4]:
tags = [t for w, t in corpus]
words = [w for w, t in corpus]

In [6]:
print('{0} words, tag pairs in the dataset'.format(len(corpus)))

93838 words, tag pairs in the dataset


In [7]:
tag_distribution = Counter([t for w, t in corpus])
print(len(tag_distribution.keys()), 'tags') 

43 tags


In [8]:
tag_distribution.most_common(5)

[('NN', 13166), ('IN', 9857), ('NNP', 9410), ('DT', 8165), ('NNS', 6047)]

# Parameter Estimation

In [9]:
tagset = list(set([t for w, t in corpus]))
print(tagset)

['IN', '.', 'EX', 'RBS', ':', 'VBN', 'DT', 'NNP', 'TO', 'WRB', 'PRP', 'UH', 'JJR', 'CD', 'POS', '#', 'PDT', 'NNS', 'WDT', 'PRP$', 'WP', 'VB', 'CC', 'NNPS', 'VBZ', 'VBG', 'SYM', 'VBP', 'MD', '$', 'JJS', 'RBR', '``', 'WP$', 'JJ', 'VBD', ',', "''", 'RB', 'NN', 'RP', 'LS', 'FW']


In [15]:
index_to_vocabulary = list(set(words))

In [10]:
class Transition(dict):
    def __init__(self, tags, delta=0.5):
        self.delta = delta
        # Add a dummy tag START after each end-of-sentence tag '.'
        t_with_start = [[t] if t != '.' else [t, 'START'] for t in tags]
        self.tags = [t for tag in t_with_start for t in tag]
        # Find the conditional probability (a | b) which is the number of times a follows b. 
        self.bigram = Counter([(a, b) for b, a in zip(self.tags[:-1], self.tags[1:])])
        for key in self.bigram.keys():
            self[key] = self.bigram[key]
    
    def __call__(self, key):
        if key in self:
            return self[key] + self.delta
        else:
            return self.delta

In [11]:
class Emission(dict):
    def __init__(self, words, tags, sigma=0.1):
        self.sigma = sigma
        self.bigram = Counter([(a, b) for a, b in zip(words, tags)])
        for key in self.bigram.keys():
            self[key] = self.bigram[key]
            
    def __call__(self, key):
        if key in self:
            return (self[key] + self.sigma)
        else:
            return self.sigma

In [12]:
transition = Transition(tags, 0.5)

In [13]:
sorted([((t, 'VB'), transition((t, 'VB'))) for t in tagset], key = lambda x: -x[1])[0:10]

[(('DT', 'VB'), 607.5),
 (('IN', 'VB'), 283.5),
 (('VBN', 'VB'), 238.5),
 (('JJ', 'VB'), 219.5),
 (('NN', 'VB'), 171.5),
 (('NNS', 'VB'), 115.5),
 (('PRP$', 'VB'), 101.5),
 (('TO', 'VB'), 94.5),
 (('RB', 'VB'), 93.5),
 (('PRP', 'VB'), 91.5)]

'VB' usually followed by a determiner

In [16]:
emission = Emission(words, tags, 0.1)

In [17]:
sorted([((w, 'NN'), emission((w, 'NN'))) for w in index_to_vocabulary], key = lambda x: -x[1])[0:10]

[(('%', 'NN'), 445.1),
 (('company', 'NN'), 260.1),
 (('year', 'NN'), 212.1),
 (('market', 'NN'), 176.1),
 (('trading', 'NN'), 144.1),
 (('stock', 'NN'), 136.1),
 (('president', 'NN'), 133.1),
 (('program', 'NN'), 127.1),
 (('share', 'NN'), 116.1),
 (('government', 'NN'), 83.1)]

## Train a tagger

In [18]:
start = 0
sequences = []
for i in range(len(corpus)):
    if corpus[i][1] == '.':
        sequences.append(corpus[start:i+1])
        start = i+1

In [41]:
n = len(sequences)
X = [[w for w, tag in sequence] for sequence in sequences]
Y = [[tag for w, tag in sequence] for sequence in sequences]
print(n)

3874


In [32]:
def tagger(sequence, states, transition, emission):
    T = len(sequence)
    N = len(states)
    treillis = np.zeros((N, T))
    max_came_from = {}
    
    # Initialization
    for s in range(N): 
        treillis[s, 0] = transition((states[s], 'START')) * emission((sequence[0], states[s]))
        max_came_from[s, 0] = 'START'
    
    # Recursion
    for t in range(1, T):
        for s in range(N):
            em = emission((sequence[t], states[s]))
            inputs = [transition((states[s], states[k])) * em * treillis[k, t-1] for k in range(N)]
            treillis[s, t] = np.max(inputs)
            max_came_from[s, t] = (np.argmax(inputs), t-1)
    
    # Reconstruct the path
    prev = (np.argmax(treillis[:, -1]), T-1)
    best_sequence = []
    while prev != 'START':
        best_sequence.append(states[prev[0]])
        prev = max_came_from[prev]
    return best_sequence[::-1]

In [25]:
arr = np.asarray([[4, 2, 9], [1, 6, 8], [12, 1, 1]])

In [26]:
arr

array([[ 4,  2,  9],
       [ 1,  6,  8],
       [12,  1,  1]])

In [31]:
np.argmax(arr[:,0])

2

In [37]:
tagger(X[0][:-5], tagset, transition, emission)

['NNP', 'NNP', ',', 'CD', 'NNS', 'JJ', ',', 'MD', 'VB', 'DT', 'NN', 'IN', 'DT']

## Evaluation

In [40]:
folds = sklearn.cross_validation.KFold(n, n_folds=5, shuffle=True)

In [88]:
preds = []
for train_index, test_index in folds:
    X_train = [X[i] for i in train_index]
    y_train = [Y[i] for i in train_index]
    X_test = [X[i] for i in test_index]
    y_test = [Y[i] for i in test_index]
    
    transition = Transition(itertools.chain(*y_train), 0.5)
    emission = Emission(itertools.chain(*X_train), itertools.chain(*y_train), 0.1)
    
    true_pred = [(y, tagger(x, tagset, transition, emission)) for x, y in zip(X_test, y_test)]
    preds.append(true_pred)



In [83]:
def metrics(true_pred):
    rearranged = [(a, b) for true, pred in true_pred for a, b in zip(true, pred)]
    counts = Counter(rearranged)
    prec = {}
    rec = {}
    acc = 0
    for tag1 in tagset:
        TP = counts[(tag1, tag1)]
        FP = sum([counts[(tag2, tag1)] for tag2 in tagset]) - TP
        FN = sum([counts[(tag1, tag2)] for tag2 in tagset]) - TP
        try:
            prec[tag1] = TP / (TP + FP)
        except ZeroDivisionError:
            prec[tag1] = 1
        try:
            rec[tag1] = TP / (TP + FN)
        except ZeroDivisionError:
            rec[tag1] = 1
        acc += TP
    return [acc / len(rearranged), prec, rec]

In [89]:
metr = [metrics(true_pred) for true_pred in preds]

In [91]:
mean_accuracy = np.mean([acc for acc, prec, rec in metr])
print('Mean accuracy: %f' % mean_accuracy)

Mean accuracy: 0.867019


In [98]:
mean_precision = np.mean([np.mean(list(prec.values())) for acc, prec, rec in metr])
print('Mean precision: %f' % mean_precision)

Mean precision: 0.941576


In [99]:
mean_recall = np.mean([np.mean(list(rec.values())) for acc, prec, rec in metr])
print('Mean recall: %f' % mean_recall)

Mean recall: 0.605487


In [100]:
ts = ["NN", "VB", "JJR", "NNP"]

In [101]:
for t in ts:
    print('%s has precision %f, recall %f' % (t, np.mean([prec[t] for acc, prec, rec in metr]), np.mean([rec[t] for acc, prec, rec in metr])))

NN has precision 0.722528, recall 0.957182
VB has precision 0.937887, recall 0.791853
JJR has precision 0.771601, recall 0.205762
NNP has precision 0.866466, recall 0.910429


## Optimization

In [110]:
def metrics2(true_pred):
    acc, prec, rec = metrics(true_pred)
    return acc, np.mean(list(prec.values())), np.mean(list(rec.values()))

In [112]:
sigmas = [0.001, 0.01, 0.1, 1, 10, 100]
deltas = [0.005, 0.05, 0.5, 5, 50, 500]

In [113]:
perfs = []

In [114]:
j = 0
for train_index, test_index in folds:
    X_train = [X[i] for i in train_index]
    y_train = [Y[i] for i in train_index]
    X_test = [X[i] for i in test_index]
    y_test = [Y[i] for i in test_index]
    
    for sigma in sigmas:
        for delta in deltas:
            transition = Transition(itertools.chain(*y_train), delta)
            emission = Emission(itertools.chain(*X_train), itertools.chain(*y_train), sigma)
    
            true_pred = [(y, tagger(x, tagset, transition, emission)) for x, y in zip(X_test, y_test)]
            
            perfs.append((sigma, delta, metrics2(true_pred)))
            
            print('%d done.' % j)
            
            j += 1

0 done.
1 done.
2 done.
3 done.
4 done.
5 done.
6 done.
7 done.
8 done.
9 done.
10 done.
11 done.
12 done.
13 done.
14 done.
15 done.
16 done.
17 done.
18 done.
19 done.
20 done.
21 done.
22 done.
23 done.
24 done.
25 done.
26 done.
27 done.
28 done.
29 done.
30 done.
31 done.
32 done.
33 done.
34 done.
35 done.
36 done.
37 done.
38 done.
39 done.
40 done.
41 done.
42 done.
43 done.
44 done.
45 done.
46 done.
47 done.
48 done.
49 done.
50 done.
51 done.
52 done.
53 done.
54 done.
55 done.
56 done.
57 done.
58 done.
59 done.
60 done.
61 done.
62 done.
63 done.
64 done.
65 done.
66 done.
67 done.
68 done.
69 done.
70 done.
71 done.
72 done.
73 done.
74 done.
75 done.
76 done.
77 done.
78 done.
79 done.
80 done.
81 done.
82 done.
83 done.
84 done.
85 done.
86 done.
87 done.
88 done.
89 done.
90 done.
91 done.
92 done.
93 done.
94 done.
95 done.
96 done.
97 done.
98 done.
99 done.
100 done.
101 done.
102 done.
103 done.
104 done.
105 done.
106 done.
107 done.
108 done.
109 done.
110 done.




In [115]:
perfs

[(0.001, 0.005, (0.9086522997748472, 0.9637471083347966, 0.75421292436970089)),
 (0.001, 0.05, (0.9089739466066259, 0.96361559133511432, 0.75573491245473601)),
 (0.001, 0.5, (0.9092419856331082, 0.96283441661295421, 0.75889626399700849)),
 (0.001, 5, (0.9098316714913691, 0.96119091651371047, 0.77426509084181572)),
 (0.001, 50, (0.9098852792966656, 0.93529529313050686, 0.78804027060899862)),
 (0.001, 500, (0.9040420285193524, 0.91418427340072128, 0.79877464553060162)),
 (0.01, 0.005, (0.9008791680068619, 0.96345147947662935, 0.7046560628907903)),
 (0.01, 0.05, (0.9009327758121582, 0.96346125946927696, 0.70470121981108835)),
 (0.01, 0.5, (0.9014688538651228, 0.96302815547564802, 0.70768825916982836)),
 (0.01, 5, (0.9037203816875737, 0.96282225143620959, 0.732476084212066)),
 (0.01, 50, (0.9082234373324756, 0.93654102923930838, 0.77919796074245729)),
 (0.01, 500, (0.9037739894928701, 0.91414710117111042, 0.79849831912979219)),
 (0.1, 0.005, (0.8709124048461456, 0.94108040670221516, 0.6440

In [120]:
[rec for sigma, delta, (acc, prec, rec) in perfs]

[0.75421292436970089,
 0.75573491245473601,
 0.75889626399700849,
 0.77426509084181572,
 0.78804027060899862,
 0.79877464553060162,
 0.7046560628907903,
 0.70470121981108835,
 0.70768825916982836,
 0.732476084212066,
 0.77919796074245729,
 0.79849831912979219,
 0.64402744077429885,
 0.64408866119664776,
 0.64469160766032008,
 0.65404475211826019,
 0.72966775989789323,
 0.79528788961263475,
 0.51765960233352171,
 0.51773045631214443,
 0.5184423245842501,
 0.52579457160626963,
 0.58833386039422708,
 0.74865889044191236,
 0.32553388377209203,
 0.32551588391609088,
 0.3258187880404847,
 0.32891123853843962,
 0.35793373897646874,
 0.54950798487910602,
 0.20567483803985925,
 0.20568383796785983,
 0.20568383796785983,
 0.20573661998483087,
 0.20779960274117579,
 0.25482624854428609,
 0.68161532537342595,
 0.68201213210797096,
 0.68437143059892347,
 0.71196266030091138,
 0.7388330000692066,
 0.74679185723948072,
 0.62691673157204419,
 0.62692577349037759,
 0.63025702113376625,
 0.6578456034819