## Task
Neural network models for named entity recognition on the CoNLL-2003 corpus in IOB format.

Chosen model: model with at least two convolutional layers and dense layers (or CRF). 

### Download data

In [1]:
!wget https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.train -P ./data/
!wget https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testb -P ./data/
!wget https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testa -P ./data/

--2020-05-19 16:08:29--  https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.train
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.244.133
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.244.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3281528 (3,1M) [text/plain]
Saving to: ‘./data/eng.train.2’


2020-05-19 16:08:37 (425 KB/s) - ‘./data/eng.train.2’ saved [3281528/3281528]

--2020-05-19 16:08:37--  https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testb
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.244.133
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.244.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 748096 (731K) [text/plain]
Saving to: ‘./data/eng.testb.2’


2020-05-19 16:08:38 (1,18 MB/s) - ‘./data/eng.testb.2’ saved [748096/748096]

--2020-05-19 16:08:38--  htt

In [38]:
# !pip install allennlp

Collecting allennlp
[?25l  Downloading https://files.pythonhosted.org/packages/bb/bb/041115d8bad1447080e5d1e30097c95e4b66e36074277afce8620a61cee3/allennlp-0.9.0-py3-none-any.whl (7.6MB)
[K     |████████████████████████████████| 7.6MB 13.4MB/s 
Collecting unidecode
[?25l  Downloading https://files.pythonhosted.org/packages/d0/42/d9edfed04228bacea2d824904cae367ee9efd05e6cce7ceaaedd0b0ad964/Unidecode-1.1.1-py2.py3-none-any.whl (238kB)
[K     |████████████████████████████████| 245kB 58.0MB/s 
[?25hCollecting tensorboardX>=1.2
[?25l  Downloading https://files.pythonhosted.org/packages/35/f1/5843425495765c8c2dd0784a851a93ef204d314fc87bcc2bbb9f662a3ad1/tensorboardX-2.0-py2.py3-none-any.whl (195kB)
[K     |████████████████████████████████| 204kB 52.5MB/s 
Collecting numpydoc>=0.8.0
  Downloading https://files.pythonhosted.org/packages/b0/70/4d8c3f9f6783a57ac9cc7a076e5610c0cc4a96af543cafc9247ac307fbfe/numpydoc-0.9.2.tar.gz
Collecting spacy<2.2,>=2.1.0
[?25l  Downloading https://files.py

Read data in Conll-2003 format

In [2]:
from allennlp.data.dataset_readers.conll2003 import Conll2003DatasetReader

In [3]:
reader = Conll2003DatasetReader()
train_instances = reader.read('data/eng.train')
dev_instances = reader.read('data/eng.testa')
test_instances = reader.read('data/eng.testb')


14041it [00:01, 11247.61it/s]
3250it [00:00, 8317.74it/s]
3453it [00:00, 8746.67it/s] 


In [4]:
len(train_instances), len(dev_instances), len(test_instances)

(14041, 3250, 3453)

In [5]:
from allennlp.models import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.training.metrics import SpanBasedF1Measure
from allennlp.training.trainer import Trainer
from allennlp.nn import util as nn_util 

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_sequence

In [7]:
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from typing import Dict, Optional, List

In [8]:
from allennlp.data.vocabulary import Vocabulary   
from allennlp.data.iterators import BucketIterator

In [9]:
torch.manual_seed(11)

<torch._C.Generator at 0x7ff32b083a70>

In [10]:
USE_GPU = torch.cuda.is_available()

### Model
The model uses pretrained Glove word embeddings, char-CNN with 3 filter sizes for char-level embeddings and 3-layer CNN for contextual word embeddings (based on the model proposed in https://arxiv.org/pdf/1707.05928.pdf).

In [11]:
class CharCNNEmbedder(nn.Module):
    def __init__(self,
             char_map: Dict[str, int],
             char_embedding_dim: int,
             char_n_filters: int,
             char_filter_sizes: List[int],
             char_out_dim: int,
             dropout: float) -> None:
        
        super().__init__()

        self.char_map = char_map
        self.char_embedder = nn.Embedding(len(char_map), char_embedding_dim, padding_idx=char_map['<pad>'])
        self.conv_0 = nn.utils.weight_norm(nn.Conv1d(char_embedding_dim, char_n_filters, char_filter_sizes[0]))
        self.conv_1 = nn.utils.weight_norm(nn.Conv1d(char_embedding_dim, char_n_filters, char_filter_sizes[1]))
        self.conv_2 = nn.utils.weight_norm(nn.Conv1d(char_embedding_dim, char_n_filters, char_filter_sizes[2]))
        self.linear = nn.Linear(in_features=len(char_filter_sizes) * char_n_filters, out_features = char_out_dim)
        self.dropout = nn.Dropout(dropout)

    def _tokens_to_char_indices(self, token, max_token_len=20):
        token = str(token)
        if len(token) <  max_token_len:
            token_char_indices = [self.char_map[ch] if ch in self.char_map.keys() else self.char_map['<unk>'] for ch in token] + \
                                [self.char_map['<pad>']]*(max_token_len - len(token))
        else:
            token_char_indices = [self.char_map[ch] if ch in self.char_map.keys() else self.char_map['<unk>'] for ch in token[:max_token_len]]
        return token_char_indices

    def forward(self, tokens_sequences):
        tokens_tensor, tokens_lengths = pad_packed_sequence(pack_sequence([torch.LongTensor([self._tokens_to_char_indices(token) for token in item['words']]) 
                                                      for item in tokens_sequences], enforce_sorted=False))
        tokens_tensor = nn_util.move_to_device(tokens_tensor.permute(1,0,2), 0 if USE_GPU else -1)
        embedded = self.char_embedder(tokens_tensor.reshape(-1,tokens_tensor.size(2)))

        conved_0 = F.relu(self.conv_0(embedded.permute(0,2,1)))
        conved_1 = F.relu(self.conv_1(embedded.permute(0,2,1)))
        conved_2 = F.relu(self.conv_2(embedded.permute(0,2,1)))

        pooled_0 = F.max_pool1d(conved_0, conved_0.size(2))
        pooled_1 = F.max_pool1d(conved_1, conved_1.size(2))
        pooled_2 = F.max_pool1d(conved_2, conved_2.size(2))

        linear_out = self.linear(self.dropout(torch.cat((pooled_0, pooled_1, pooled_2), dim=1).permute(0,2,1))).squeeze()
        return linear_out


In [12]:
class NerCNNModel(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 embedder: TextFieldEmbedder,
                 embedding_dim: int, 
                 n_filters: int, 
                 filter_sizes: List[int], 
                 dropout: float,
                 char_map: Dict[str, int],
                 char_embedding_dim: int,
                 char_n_filters: int,
                 char_filter_sizes: List[int],
                 char_out_dim: int) -> None:
        super().__init__(vocab)

        self._embedder = embedder
        self._char_embedder = CharCNNEmbedder(char_map, char_embedding_dim, char_n_filters, char_filter_sizes, char_out_dim, dropout)
        self._conv_0 = nn.utils.weight_norm(nn.Conv1d(embedding_dim + char_embedding_dim, n_filters, filter_sizes[0], padding = 1))
        self._conv_1 = nn.utils.weight_norm(nn.Conv1d(n_filters, n_filters, filter_sizes[1], padding = 1))
        self._conv_2 = nn.utils.weight_norm(nn.Conv1d(n_filters, n_filters, filter_sizes[2], padding = 1))
        self._classifier = nn.Linear(in_features=n_filters + embedding_dim + char_embedding_dim, out_features=vocab.get_vocab_size('labels')) # emb = concat(word_conv,word_emb,char_emb)
        self._dropout = nn.Dropout(dropout)

        self._f1 = SpanBasedF1Measure(vocab, 'labels')

    def forward(self,
                tokens: Dict[str, torch.Tensor],
                tags: Optional[torch.Tensor] = None,
                metadata = List[Dict[str, List[str]]]) -> Dict[str, torch.Tensor]:
        mask = get_text_field_mask(tokens)

        embedded = self._embedder(tokens)
        embedded = embedded.permute(0, 2, 1)

        char_embedded = self._char_embedder(metadata)
        char_embedded = char_embedded.reshape(embedded.size(0), embedded.size(2), -1).permute(0,2,1)

        conved_0 = self._dropout(F.relu(self._conv_0(torch.cat((char_embedded, embedded), dim=1))))
        conved_1 = self._dropout(F.relu(self._conv_1(conved_0)))
        conved_2 = self._dropout(F.relu(self._conv_2(conved_1)))

        classified = self._classifier(torch.cat((conved_2, char_embedded, embedded), dim=1).permute(0,2,1))

        self._f1(classified, tags, mask)

        output: Dict[str, torch.Tensor] = {}
        output['logits'] = classified

        if tags is not None:
            output["loss"] = sequence_cross_entropy_with_logits(classified, tags, mask)

        return output

    def get_metrics(self, reset: bool = True) -> Dict[str, float]:
        return self._f1.get_metric(reset)

In [13]:
from allennlp.data.vocabulary import Vocabulary   
from allennlp.data.iterators import BucketIterator

BATCH_SIZE = 64
MAX_VOCAB_SIZE = 100000

vocab = Vocabulary.from_instances(train_instances, max_vocab_size=MAX_VOCAB_SIZE)

iterator = BucketIterator(batch_size=BATCH_SIZE, 
                          sorting_keys=[("tokens", "num_tokens")])
iterator.index_with(vocab)

EMBEDDING_DIM = 50 # dimension of pretrained word embeddings

token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                            pretrained_file = "http://nlp.stanford.edu/data/glove.6B.zip",
                            embedding_dim=EMBEDDING_DIM)
word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})

100%|██████████| 14041/14041 [00:00<00:00, 54162.23it/s]


In [14]:
# vocabulary for char embeddings
CHARS = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~'
char_to_id = dict([(char,i+1) for i, char in enumerate(list(CHARS))])
char_to_id['<unk>'] = len(list(CHARS))
char_to_id['<pad>'] = 0

In [15]:
N_FILTERS = 200 # number of filters for word CNN
FILTER_SIZES = [3,3,3] # filter sizes for word CNN
DROPOUT = 0.1 
CHAR_EMBEDDING_DIM = 50 # dimension of char embeddings
CHAR_N_FILTERS = 50 # number of filters for char CNN
CHAR_FILTER_SIZES = [2,3,4] # filter sizes for char CNN
CHAR_OUT_DIM = 50 # dimension of output char-CNN embeddings

In [16]:
model = NerCNNModel(
    vocab,
    word_embeddings, 
    EMBEDDING_DIM,
    N_FILTERS,
    FILTER_SIZES,
    DROPOUT,
    char_to_id,
    CHAR_EMBEDDING_DIM,
    CHAR_N_FILTERS,
    CHAR_FILTER_SIZES,
    CHAR_OUT_DIM
)

In [18]:
if USE_GPU: model.cuda()
else: model

In [19]:
model

NerCNNModel(
  (_embedder): BasicTextFieldEmbedder(
    (token_embedder_tokens): Embedding()
  )
  (_char_embedder): CharCNNEmbedder(
    (char_embedder): Embedding(96, 50, padding_idx=0)
    (conv_0): Conv1d(50, 50, kernel_size=(2,), stride=(1,))
    (conv_1): Conv1d(50, 50, kernel_size=(3,), stride=(1,))
    (conv_2): Conv1d(50, 50, kernel_size=(4,), stride=(1,))
    (linear): Linear(in_features=150, out_features=50, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (_conv_0): Conv1d(100, 200, kernel_size=(3,), stride=(1,), padding=(1,))
  (_conv_1): Conv1d(200, 200, kernel_size=(3,), stride=(1,), padding=(1,))
  (_conv_2): Conv1d(200, 200, kernel_size=(3,), stride=(1,), padding=(1,))
  (_classifier): Linear(in_features=300, out_features=8, bias=True)
  (_dropout): Dropout(p=0.1, inplace=False)
)

### Training

In [20]:
optimizer = optim.SGD(model.parameters(), lr=10e-3)

In [21]:
NUM_EPOCHS = 100
PATIENCE = 10

trainer = Trainer(model=model,
                  optimizer=optimizer,
                  iterator=iterator,
                  train_dataset=train_instances,
                  validation_dataset=dev_instances,
                  patience=PATIENCE,
                  num_epochs=NUM_EPOCHS,
                  cuda_device=0 if USE_GPU else -1)

In [22]:
metrics = trainer.train()

precision-LOC: 0.0039, recall-LOC: 0.0003, f1-measure-LOC: 0.0005, precision-PER: 0.0000, recall-PER: 0.0000, f1-measure-PER: 0.0000, precision-ORG: 0.0239, recall-ORG: 0.0040, f1-measure-ORG: 0.0068, precision-MISC: 0.0142, recall-MISC: 0.0041, f1-measure-MISC: 0.0063, precision-overall: 0.0161, recall-overall: 0.0017, f1-measure-overall: 0.0031, loss: 1.1429 ||: 100%|██████████| 220/220 [01:11<00:00,  3.07it/s]
precision-LOC: 0.0000, recall-LOC: 0.0000, f1-measure-LOC: 0.0000, precision-PER: 0.0000, recall-PER: 0.0000, f1-measure-PER: 0.0000, precision-ORG: 0.0000, recall-ORG: 0.0000, f1-measure-ORG: 0.0000, precision-MISC: 0.0000, recall-MISC: 0.0000, f1-measure-MISC: 0.0000, precision-overall: 0.0000, recall-overall: 0.0000, f1-measure-overall: 0.0000, loss: 0.8563 ||: 100%|██████████| 51/51 [00:04<00:00, 10.74it/s]
precision-PER: 0.0000, recall-PER: 0.0000, f1-measure-PER: 0.0000, precision-LOC: 0.3105, recall-LOC: 0.0527, f1-measure-LOC: 0.0900, precision-ORG: 0.3664, recall-ORG:

precision-PER: 0.1752, recall-PER: 0.1688, f1-measure-PER: 0.1720, precision-LOC: 0.4596, recall-LOC: 0.4948, f1-measure-LOC: 0.4765, precision-ORG: 0.2576, recall-ORG: 0.1767, f1-measure-ORG: 0.2096, precision-MISC: 0.0000, recall-MISC: 0.0000, f1-measure-MISC: 0.0000, precision-overall: 0.3118, recall-overall: 0.2452, f1-measure-overall: 0.2745, loss: 0.4222 ||: 100%|██████████| 51/51 [00:03<00:00, 14.21it/s]
precision-PER: 0.1581, recall-PER: 0.1776, f1-measure-PER: 0.1672, precision-LOC: 0.4328, recall-LOC: 0.4144, f1-measure-LOC: 0.4234, precision-ORG: 0.1898, recall-ORG: 0.1713, f1-measure-ORG: 0.1801, precision-MISC: 0.0000, recall-MISC: 0.0000, f1-measure-MISC: 0.0000, precision-overall: 0.2613, recall-overall: 0.2219, f1-measure-overall: 0.2400, loss: 0.4359 ||: 100%|██████████| 220/220 [00:58<00:00,  3.78it/s]
precision-PER: 0.2147, recall-PER: 0.2199, f1-measure-PER: 0.2173, precision-LOC: 0.4934, recall-LOC: 0.4682, f1-measure-LOC: 0.4804, precision-ORG: 0.2079, recall-ORG:

precision-PER: 0.3117, recall-PER: 0.4427, f1-measure-PER: 0.3658, precision-LOC: 0.5472, recall-LOC: 0.5528, f1-measure-LOC: 0.5500, precision-ORG: 0.2516, recall-ORG: 0.2754, f1-measure-ORG: 0.2630, precision-MISC: 0.1006, recall-MISC: 0.0140, f1-measure-MISC: 0.0245, precision-overall: 0.3610, recall-overall: 0.3684, f1-measure-overall: 0.3647, loss: 0.3387 ||: 100%|██████████| 220/220 [00:58<00:00,  3.77it/s]
precision-PER: 0.3916, recall-PER: 0.5011, f1-measure-PER: 0.4396, precision-ORG: 0.2714, recall-ORG: 0.3565, f1-measure-ORG: 0.3082, precision-MISC: 0.1312, recall-MISC: 0.0228, f1-measure-MISC: 0.0388, precision-LOC: 0.6195, recall-LOC: 0.5389, f1-measure-LOC: 0.5764, precision-overall: 0.4105, recall-overall: 0.4059, f1-measure-overall: 0.4082, loss: 0.3099 ||: 100%|██████████| 51/51 [00:03<00:00, 16.12it/s]
precision-PER: 0.3179, recall-PER: 0.4476, f1-measure-PER: 0.3718, precision-LOC: 0.5533, recall-LOC: 0.5555, f1-measure-LOC: 0.5544, precision-ORG: 0.2625, recall-ORG:

precision-PER: 0.4671, recall-PER: 0.5282, f1-measure-PER: 0.4958, precision-ORG: 0.2752, recall-ORG: 0.4609, f1-measure-ORG: 0.3446, precision-MISC: 0.3385, recall-MISC: 0.1432, f1-measure-MISC: 0.2012, precision-LOC: 0.6361, recall-LOC: 0.5623, f1-measure-LOC: 0.5969, precision-overall: 0.4345, recall-overall: 0.4638, f1-measure-overall: 0.4487, loss: 0.2695 ||: 100%|██████████| 51/51 [00:02<00:00, 17.68it/s]
precision-PER: 0.3928, recall-PER: 0.5358, f1-measure-PER: 0.4533, precision-LOC: 0.5821, recall-LOC: 0.5994, f1-measure-LOC: 0.5906, precision-ORG: 0.2941, recall-ORG: 0.3387, f1-measure-ORG: 0.3148, precision-MISC: 0.2723, recall-MISC: 0.1056, f1-measure-MISC: 0.1522, precision-overall: 0.4133, recall-overall: 0.4392, f1-measure-overall: 0.4259, loss: 0.2882 ||: 100%|██████████| 220/220 [00:54<00:00,  4.07it/s]
precision-PER: 0.4682, recall-PER: 0.5478, f1-measure-PER: 0.5049, precision-ORG: 0.2652, recall-ORG: 0.4832, f1-measure-ORG: 0.3425, precision-MISC: 0.3251, recall-MIS

precision-PER: 0.4694, recall-PER: 0.5982, f1-measure-PER: 0.5260, precision-ORG: 0.3453, recall-ORG: 0.3961, f1-measure-ORG: 0.3690, precision-MISC: 0.3859, recall-MISC: 0.2184, f1-measure-MISC: 0.2790, precision-LOC: 0.5984, recall-LOC: 0.6382, f1-measure-LOC: 0.6177, precision-overall: 0.4662, recall-overall: 0.5004, f1-measure-overall: 0.4827, loss: 0.2534 ||: 100%|██████████| 220/220 [00:53<00:00,  4.15it/s]
precision-PER: 0.5440, recall-PER: 0.5505, f1-measure-PER: 0.5472, precision-ORG: 0.2854, recall-ORG: 0.5444, f1-measure-ORG: 0.3745, precision-MISC: 0.4563, recall-MISC: 0.3113, f1-measure-MISC: 0.3701, precision-LOC: 0.7037, recall-LOC: 0.5482, f1-measure-LOC: 0.6163, precision-overall: 0.4687, recall-overall: 0.5113, f1-measure-overall: 0.4891, loss: 0.2379 ||: 100%|██████████| 51/51 [00:02<00:00, 17.67it/s]
precision-PER: 0.4790, recall-PER: 0.5952, f1-measure-PER: 0.5308, precision-ORG: 0.3442, recall-ORG: 0.4102, f1-measure-ORG: 0.3743, precision-MISC: 0.3851, recall-MIS

precision-PER: 0.6075, recall-PER: 0.6368, f1-measure-PER: 0.6218, precision-ORG: 0.3337, recall-ORG: 0.5213, f1-measure-ORG: 0.4069, precision-MISC: 0.4993, recall-MISC: 0.3839, f1-measure-MISC: 0.4341, precision-LOC: 0.7033, recall-LOC: 0.6130, f1-measure-LOC: 0.6550, precision-overall: 0.5290, recall-overall: 0.5641, f1-measure-overall: 0.5460, loss: 0.2113 ||: 100%|██████████| 51/51 [00:02<00:00, 17.67it/s]
precision-PER: 0.5419, recall-PER: 0.6383, f1-measure-PER: 0.5862, precision-ORG: 0.3821, recall-ORG: 0.4495, f1-measure-ORG: 0.4130, precision-MISC: 0.4443, recall-MISC: 0.3211, f1-measure-MISC: 0.3728, precision-LOC: 0.6245, recall-LOC: 0.6623, f1-measure-LOC: 0.6428, precision-overall: 0.5100, recall-overall: 0.5484, f1-measure-overall: 0.5285, loss: 0.2237 ||: 100%|██████████| 220/220 [00:54<00:00,  4.03it/s]
precision-PER: 0.6346, recall-PER: 0.6129, f1-measure-PER: 0.6236, precision-ORG: 0.3554, recall-ORG: 0.5004, f1-measure-ORG: 0.4156, precision-MISC: 0.4479, recall-MIS

precision-PER: 0.5924, recall-PER: 0.6705, f1-measure-PER: 0.6290, precision-ORG: 0.4255, recall-ORG: 0.4797, f1-measure-ORG: 0.4510, precision-MISC: 0.4820, recall-MISC: 0.4017, f1-measure-MISC: 0.4382, precision-LOC: 0.6474, recall-LOC: 0.6866, f1-measure-LOC: 0.6664, precision-overall: 0.5489, recall-overall: 0.5847, f1-measure-overall: 0.5662, loss: 0.2037 ||: 100%|██████████| 220/220 [00:53<00:00,  4.13it/s]
precision-PER: 0.6918, recall-PER: 0.6178, f1-measure-PER: 0.6527, precision-ORG: 0.3726, recall-ORG: 0.5541, f1-measure-ORG: 0.4456, precision-MISC: 0.5249, recall-MISC: 0.5033, f1-measure-MISC: 0.5138, precision-LOC: 0.6904, recall-LOC: 0.6761, f1-measure-LOC: 0.6832, precision-overall: 0.5674, recall-overall: 0.6037, f1-measure-overall: 0.5850, loss: 0.1907 ||: 100%|██████████| 51/51 [00:02<00:00, 18.25it/s]
precision-PER: 0.5948, recall-PER: 0.6641, f1-measure-PER: 0.6275, precision-LOC: 0.6548, recall-LOC: 0.6933, f1-measure-LOC: 0.6735, precision-ORG: 0.4272, recall-ORG:

precision-PER: 0.6794, recall-PER: 0.7329, f1-measure-PER: 0.7051, precision-ORG: 0.4344, recall-ORG: 0.5183, f1-measure-ORG: 0.4726, precision-MISC: 0.5573, recall-MISC: 0.5434, f1-measure-MISC: 0.5502, precision-LOC: 0.7178, recall-LOC: 0.6908, f1-measure-LOC: 0.7040, precision-overall: 0.6100, recall-overall: 0.6420, f1-measure-overall: 0.6256, loss: 0.1730 ||: 100%|██████████| 51/51 [00:03<00:00, 14.58it/s]
precision-PER: 0.6335, recall-PER: 0.7085, f1-measure-PER: 0.6689, precision-LOC: 0.6734, recall-LOC: 0.7126, f1-measure-LOC: 0.6924, precision-ORG: 0.4717, recall-ORG: 0.5257, f1-measure-ORG: 0.4973, precision-MISC: 0.5181, recall-MISC: 0.4654, f1-measure-MISC: 0.4903, precision-overall: 0.5859, recall-overall: 0.6250, f1-measure-overall: 0.6048, loss: 0.1831 ||: 100%|██████████| 220/220 [01:00<00:00,  3.66it/s]
precision-PER: 0.7026, recall-PER: 0.6976, f1-measure-PER: 0.7001, precision-ORG: 0.4332, recall-ORG: 0.5004, f1-measure-ORG: 0.4644, precision-MISC: 0.5383, recall-MIS

precision-PER: 0.6632, recall-PER: 0.7256, f1-measure-PER: 0.6930, precision-ORG: 0.4903, recall-ORG: 0.5463, f1-measure-ORG: 0.5168, precision-MISC: 0.5497, recall-MISC: 0.5166, f1-measure-MISC: 0.5326, precision-LOC: 0.6933, recall-LOC: 0.7301, f1-measure-LOC: 0.7112, precision-overall: 0.6089, recall-overall: 0.6482, f1-measure-overall: 0.6279, loss: 0.1702 ||: 100%|██████████| 220/220 [00:55<00:00,  3.97it/s]
precision-PER: 0.7222, recall-PER: 0.7253, f1-measure-PER: 0.7237, precision-LOC: 0.7125, recall-LOC: 0.7246, f1-measure-LOC: 0.7185, precision-ORG: 0.4441, recall-ORG: 0.5541, f1-measure-ORG: 0.4930, precision-MISC: 0.5875, recall-MISC: 0.5716, f1-measure-MISC: 0.5794, precision-overall: 0.6261, recall-overall: 0.6626, f1-measure-overall: 0.6438, loss: 0.1615 ||: 100%|██████████| 51/51 [00:03<00:00, 16.95it/s]
precision-PER: 0.6698, recall-PER: 0.7208, f1-measure-PER: 0.6944, precision-LOC: 0.6899, recall-LOC: 0.7361, f1-measure-LOC: 0.7122, precision-ORG: 0.4977, recall-ORG:

precision-PER: 0.7276, recall-PER: 0.7497, f1-measure-PER: 0.7385, precision-ORG: 0.4574, recall-ORG: 0.5921, f1-measure-ORG: 0.5161, precision-MISC: 0.6386, recall-MISC: 0.5846, f1-measure-MISC: 0.6104, precision-LOC: 0.7473, recall-LOC: 0.7197, f1-measure-LOC: 0.7332, precision-overall: 0.6461, recall-overall: 0.6792, f1-measure-overall: 0.6622, loss: 0.1537 ||: 100%|██████████| 51/51 [00:03<00:00, 16.64it/s]
precision-PER: 0.6869, recall-PER: 0.7398, f1-measure-PER: 0.7124, precision-ORG: 0.5203, recall-ORG: 0.5697, f1-measure-ORG: 0.5439, precision-MISC: 0.5750, recall-MISC: 0.5462, f1-measure-MISC: 0.5603, precision-LOC: 0.7064, recall-LOC: 0.7521, f1-measure-LOC: 0.7285, precision-overall: 0.6319, recall-overall: 0.6695, f1-measure-overall: 0.6501, loss: 0.1581 ||: 100%|██████████| 220/220 [00:59<00:00,  3.72it/s]
precision-PER: 0.7383, recall-PER: 0.7383, f1-measure-PER: 0.7383, precision-ORG: 0.4699, recall-ORG: 0.5764, f1-measure-ORG: 0.5177, precision-MISC: 0.5764, recall-MIS

precision-PER: 0.7024, recall-PER: 0.7530, f1-measure-PER: 0.7268, precision-LOC: 0.7181, recall-LOC: 0.7632, f1-measure-LOC: 0.7400, precision-ORG: 0.5414, recall-ORG: 0.5901, f1-measure-ORG: 0.5647, precision-MISC: 0.5967, recall-MISC: 0.5852, f1-measure-MISC: 0.5909, precision-overall: 0.6484, recall-overall: 0.6877, f1-measure-overall: 0.6675, loss: 0.1488 ||: 100%|██████████| 220/220 [01:00<00:00,  3.65it/s]
precision-PER: 0.7365, recall-PER: 0.7633, f1-measure-PER: 0.7497, precision-LOC: 0.7451, recall-LOC: 0.7686, f1-measure-LOC: 0.7567, precision-ORG: 0.5118, recall-ORG: 0.5802, f1-measure-ORG: 0.5439, precision-MISC: 0.6350, recall-MISC: 0.6171, f1-measure-MISC: 0.6260, precision-overall: 0.6696, recall-overall: 0.7009, f1-measure-overall: 0.6849, loss: 0.1433 ||: 100%|██████████| 51/51 [00:04<00:00, 11.73it/s]
precision-PER: 0.7038, recall-PER: 0.7523, f1-measure-PER: 0.7272, precision-LOC: 0.7164, recall-LOC: 0.7633, f1-measure-LOC: 0.7391, precision-ORG: 0.5367, recall-ORG:

In [23]:
metrics

{'best_epoch': 98,
 'peak_cpu_memory_MB': 3085.74,
 'training_duration': '1:41:26.789174',
 'training_start_epoch': 0,
 'training_epochs': 99,
 'epoch': 99,
 'training_precision-LOC': 0.727428949107733,
 'training_recall-LOC': 0.7707282913165266,
 'training_f1-measure-LOC': 0.7484529071743806,
 'training_precision-PER': 0.7054023635340462,
 'training_recall-PER': 0.7596969696969696,
 'training_f1-measure-PER': 0.7315436241610239,
 'training_precision-ORG': 0.544151708164447,
 'training_recall-ORG': 0.5946843853820598,
 'training_f1-measure-ORG': 0.5682969234257571,
 'training_precision-MISC': 0.5944014294222751,
 'training_recall-MISC': 0.5805700988947062,
 'training_f1-measure-MISC': 0.5874043555031871,
 'training_precision-overall': 0.6524720317574882,
 'training_recall-overall': 0.6924549980850249,
 'training_f1-measure-overall': 0.6718691936082741,
 'training_loss': 0.14487948392263866,
 'training_cpu_memory_MB': 3085.74,
 'validation_precision-LOC': 0.773037542662116,
 'validation

In [26]:
with open("./tmp/model.th", 'wb') as f:
    torch.save(model.state_dict(), f)
vocab.save_to_files("./tmp/vocabulary")

### Predictions

In [17]:
# vocab = Vocabulary.from_files("./tmp/vocabulary")

with open("/tmp/model.th", 'rb') as f:
    model.load_state_dict(torch.load(f))
if USE_GPU: 
    model.cuda()

In [18]:
from allennlp.predictors.predictor import Predictor
from allennlp.common.util import JsonDict, sanitize
from allennlp.data import Instance

In [21]:
class CoNLL03Predictor(Predictor):
    def predict_instance(self, instance: Instance) -> JsonDict:
        outputs = self._model.forward_on_instance(instance)
        label_vocab = self._model.vocab.get_index_to_token_vocabulary('labels')

        outputs['tokens'] = [str(token) for token in instance.fields['tokens'].tokens]
        outputs['predicted'] = [label_vocab[l] for l in outputs['logits'].argmax(1)]
        outputs['labels'] = instance.fields['tags'].labels

        return sanitize(outputs)

In [28]:
from allennlp.data.iterators import DataIterator
from tqdm import tqdm
import numpy as np
from typing import Iterable

class TagsPredictor:
    def __init__(self, model: Model, iterator: DataIterator,
                 cuda_device: int=-1) -> None:
        self.model = model
        self.iterator = iterator
        self.cuda_device = cuda_device
        
    def _extract_data(self, batch) -> np.ndarray:
        out_dict = self.model(**batch)
        return out_dict
    
    def predict(self, ds: Iterable[Instance]) -> np.ndarray:
        pred_generator = self.iterator(ds, num_epochs=1, shuffle=False)
        self.model.eval()
        pred_generator_tqdm = tqdm(pred_generator,
                                   total=self.iterator.get_num_batches(ds))
        preds = []
        with torch.no_grad():
            for batch in pred_generator_tqdm:
                batch = nn_util.move_to_device(batch, self.cuda_device)
                preds.append(self._extract_data(batch))
        return preds

In [22]:
predictor = CoNLL03Predictor(model, reader, frozen=True)

In [49]:
print(test_instances[10].fields['tokens'].tokens)

[Takuya, Takagi, scored, the, winner, in, the, 88th, minute, ,, rising, to, head, a, Hiroshige, Yanagimoto, cross, towards, the, Syrian, goal, which, goalkeeper, Salem, Bitar, appeared, to, have, covered, but, then, allowed, to, slip, into, the, net, .]


In [23]:
tags_pred = predictor.predict_instance(test_instances[10])

In [24]:
print(tags_pred['labels'])

['I-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-PER', 'I-PER', 'O', 'O', 'O', 'I-MISC', 'O', 'O', 'O', 'I-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


In [51]:
print(tags_pred['predicted'])

['I-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-PER', 'I-PER', 'O', 'O', 'O', 'I-MISC', 'O', 'O', 'O', 'I-MISC', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


In [29]:
from allennlp.data.iterators import BasicIterator
seq_iterator = BasicIterator(batch_size=BATCH_SIZE)
seq_iterator.index_with(vocab)

In [30]:
predictor = TagsPredictor(model, seq_iterator, cuda_device=0 if USE_GPU else -1)

In [31]:
test_preds = predictor.predict(test_instances)


  0%|          | 0/54 [00:00<?, ?it/s][A
  2%|▏         | 1/54 [00:00<00:11,  4.55it/s][A
  6%|▌         | 3/54 [00:00<00:09,  5.60it/s][A
  7%|▋         | 4/54 [00:00<00:12,  4.16it/s][A
  9%|▉         | 5/54 [00:00<00:10,  4.75it/s][A
 11%|█         | 6/54 [00:01<00:08,  5.57it/s][A
 13%|█▎        | 7/54 [00:01<00:07,  6.03it/s][A
 17%|█▋        | 9/54 [00:01<00:06,  7.12it/s][A
 20%|██        | 11/54 [00:01<00:05,  7.50it/s][A
 22%|██▏       | 12/54 [00:01<00:05,  7.90it/s][A
 24%|██▍       | 13/54 [00:01<00:05,  7.29it/s][A
 26%|██▌       | 14/54 [00:01<00:05,  7.17it/s][A
 28%|██▊       | 15/54 [00:02<00:05,  6.68it/s][A
 30%|██▉       | 16/54 [00:02<00:05,  6.71it/s][A
 31%|███▏      | 17/54 [00:02<00:05,  6.20it/s][A
 33%|███▎      | 18/54 [00:02<00:05,  6.50it/s][A
 35%|███▌      | 19/54 [00:02<00:05,  6.42it/s][A
 37%|███▋      | 20/54 [00:02<00:05,  6.65it/s][A
 39%|███▉      | 21/54 [00:03<00:05,  6.09it/s][A
 41%|████      | 22/54 [00:03<00:05,  6.35it/s

In [42]:
mean_test_loss = sum([pred['loss'] for pred in test_preds])/len(test_preds)
mean_test_loss

tensor(0.2731)