In [16]:
import nltk
nltk.download('brown')
nltk.download('universal_tagset')

[nltk_data] Downloading package brown to /root/nltk_data...
[nltk_data]   Package brown is already up-to-date!
[nltk_data] Downloading package universal_tagset to /root/nltk_data...
[nltk_data]   Package universal_tagset is already up-to-date!


True

In [0]:
import numpy as np
from nltk.corpus import brown

def prepare_nltk_tagged_sent(sent):
    return tuple(zip(*sent))

sents = np.random.permutation(brown.tagged_sents(tagset='universal'))

TRAINING_SIZE = 1000
TEST_SIZE = 10

training_data = list(map(prepare_nltk_tagged_sent, sents[:TRAINING_SIZE]))

test_data = list(map(prepare_nltk_tagged_sent,
                    sents[TRAINING_SIZE:TRAINING_SIZE+TEST_SIZE]))


In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LSTMPoSTagger(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, word2idx, tag2idx):
        super().__init__()

        # Params
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = len(word2idx)
        self.tagset_size = len(tag2idx)

        # Units
        self.word_embeddings = nn.Embedding(
            num_embeddings = self.vocab_size,
            embedding_dim = self.embedding_dim
        )
        self.lstm = nn.LSTM(
            input_size = self.embedding_dim,
            hidden_size = self.hidden_dim
        )
        self.hidden2tag = nn.Linear(
            in_features = self.hidden_dim,
            out_features = self.tagset_size
        )

    def forward(self, sentence):
        sentlen = len(sentence)

        embeds = self.word_embeddings(sentence).view(sentlen, 1, -1)
        lstm_out, _ = self.lstm(embeds)

        tag_space = self.hidden2tag(lstm_out.view(sentlen, -1))
        tag_scores = F.log_softmax(tag_space, dim=1)
        return tag_scores

class Ent2Idx:
    def __init__(self, e2idict):
        self.e2i = e2idict
        self.i2e = sorted(e2idict, key = lambda x: e2idict[x], reverse=True)

    def __len__(self):
        return len(self.e2i)

    def index(self, item):
        return self.e2i[item]

    def entity(self, idx):
        return self.i2e[idx]

    def indices(self, seq):
        e2i = self.e2i
        return torch.tensor([e2i[x] for x in seq], dtype=torch.long)

    def entities(self, vec):
        i2e = self.i2e
        return list(map(lambda x: i2e[x], vec))


def build_seq2vec_indices(dataset):
  word2idx = dict()
  tag2idx = dict()

  for sent, tags in dataset:
   for word in sent:
     if word not in word2idx:
       word2idx[word] = len(word2idx)
   for tag in tags:
     if tag not in tag2idx:
       tag2idx[tag] = len(tag2idx)
  return (Ent2Idx(word2idx), Ent2Idx(tag2idx))


EMBED_SIZE = 32
HID_DIM = 32
# Use both training and test set to build indices
word2idx, tag2idx = build_seq2vec_indices(training_data + test_data)

model = LSTMPoSTagger(EMBED_SIZE, HID_DIM, word2idx, tag2idx).cuda()
# model.cuda()

In [28]:
import torch.optim as optim

loss_function = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
EPOCHS = 150

for epoch in range(EPOCHS):
  print(epoch, end=' ')
  for sentence, tags in training_data:
      model.zero_grad()
      tag_scores = model(word2idx.indices(sentence).cuda())
      loss = loss_function(tag_scores, tag2idx.indices(tags).cuda())
      loss.backward()
      optimizer.step()

0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 

In [34]:

def print_diff(text, expected, got):
  text = ["Text:"] + list(text)
  expected = ["Expected:"] + list(expected)
  got = ["Got:"] + list(got)

  column_sizes = [max(map(len, (t, e, g)))
                  for t, e, g in zip(text, expected, got)]

  def print_row(row):
    for t, sz in zip(row, column_sizes):
      print("{:{}}".format(t, sz), end=' ')
    print()

  for row in (text, expected, got):
    print_row(row)

  total = len(expected)
  correct = sum(map(lambda x: x[0] == x[1], zip(expected, got)))

  print("Correct: {}/{}".format(correct, total))
  print()

with torch.no_grad():
  for sent, tags in test_data:
    raw = model(word2idx.indices(sent).cuda())
    predicted_tags = tag2idx.entities(raw.argmax(1))
    print_diff(sent, tags, predicted_tags)

Text:     It   may  appear that we   were cruel and  callous ,    but  no  one  had  time to  spend sympathizing with poor Isaac --   except the Reverend .    
Expected: PRON VERB VERB   ADP  PRON VERB ADJ   CONJ ADJ     .    CONJ DET NOUN VERB NOUN PRT VERB  VERB         ADP  ADJ  NOUN  .    ADP    DET NOUN     .    
Got:      .    CONJ CONJ   NOUN .    CONJ CONJ  VERB CONJ    PRON VERB ADJ ADP  CONJ ADP  X   CONJ  DET          NOUN DET  ADP   PRON ADP    ADJ ADP      PRON 
Correct: 1/27

Text:     The money's here ,    all of   it   ''   .    
Expected: DET PRT     ADV  .    PRT ADP  PRON .    .    
Got:      ADJ ADP     NUM  PRON X   NOUN .    PRON PRON 
Correct: 0/10

Text:     This he   failed to  do   ,    asserting that he   did  not know it   to  be   in   his file .    
Expected: DET  PRON VERB   PRT VERB .    VERB      ADP  PRON VERB ADV VERB PRON PRT VERB ADP  DET NOUN .    
Got:      ADJ  .    NUM    X   CONJ PRON CONJ      NOUN .    CONJ NUM CONJ .    X   CONJ NOUN ADJ ADJ

In [72]:
from math import log, inf

with torch.no_grad():
  print("Closer to 1 - better")
  for sent, tags in test_data:
    prediction = model(word2idx.indices(sent).cuda())
    loss = loss_function(prediction, tag2idx.indices(tags).cuda())
    print(loss.neg().exp().item())

Closer to 1 - better
0.4138205051422119
0.4111807346343994
0.4694533050060272
0.21486970782279968
0.545297384262085
0.19068923592567444
0.8440349698066711
0.11418286710977554
0.31773969531059265
0.1121540442109108
