## Sentence POS tagging using ELMo embeddings

This tutorial has been taken from: https://allennlp.org/tutorials

In [1]:
# We shall use the allennlp library to work with elmo embeddings. This is the library made open source by the paper authors
# To install allennlp uncomment the following line
#!pip install allennlp

In [2]:
from typing import Iterator, List, Dict
import torch
import torch.optim as optim
import numpy as np
from allennlp.data import Instance
from allennlp.data.fields import TextField, SequenceLabelField
from allennlp.data.dataset_readers import DatasetReader
from allennlp.common.file_utils import cached_path
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token
from allennlp.data.vocabulary import Vocabulary
from allennlp.models import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.training.metrics import CategoricalAccuracy
from allennlp.data.iterators import BucketIterator
from allennlp.training.trainer import Trainer
from allennlp.predictors import SentenceTaggerPredictor

In [3]:
torch.manual_seed(1)

<torch._C.Generator at 0x7feb087272f0>

Each example is considered as an Instance and each sentence has a text - an object of type TextPiece and a sequence of POS tags, an object of type SequenceLabelField

## Defining dataset and model classes

In [13]:
class PosDatasetReader(DatasetReader):
    """
    DatasetReader for PoS tagging data, one sentence per line, like

        The###DET dog###NN ate###V the###DET apple###NN
    """
    def __init__(self, token_indexers: Dict[str, TokenIndexer] = None) -> None:
        super().__init__(lazy=False)
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
        
    # Converting each input instance to sentence and its POS tags to labels
    def text_to_instance(self, tokens: List[Token], tags: List[str] = None) -> Instance:
        sentence_field = TextField(tokens, self.token_indexers)
        fields = {"sentence": sentence_field}

        if tags:
            label_field = SequenceLabelField(labels=tags, sequence_field=sentence_field)
            fields["labels"] = label_field

        return Instance(fields)
    
    def _read(self, file_path: str) -> Iterator[Instance]:
        with open(file_path) as f:
            for line in f:
                pairs = line.strip().split()
                sentence, tags = zip(*(pair.split("###") for pair in pairs))
                yield self.text_to_instance([Token(word) for word in sentence], tags)

Model will consist of an embedding layer, followed by a LSTM, then by a feedforward layer. In the following class, we define the model, which at its base is built on nn.Module

In [5]:
class LstmTagger(Model):
    def __init__(self,
                 word_embeddings: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 vocab: Vocabulary) -> None:
        
        super().__init__(vocab)
        self.word_embeddings = word_embeddings
        self.encoder = encoder
        self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(),
                                          out_features=vocab.get_vocab_size('labels'))
        self.accuracy = CategoricalAccuracy()
        
    def forward(self,
                sentence: Dict[str, torch.Tensor],
                labels: torch.Tensor = None) -> Dict[str, torch.Tensor]:
        # this mask is to differentiate between padded and original tokens in the sentence because we often 
        # our sentences to a finite length
        
        mask = get_text_field_mask(sentence)
        embeddings = self.word_embeddings(sentence)
        encoder_out = self.encoder(embeddings, mask)
        tag_logits = self.hidden2tag(encoder_out)
        output = {"tag_logits": tag_logits}
        if labels is not None:
            self.accuracy(tag_logits, labels, mask)
            output["loss"] = sequence_cross_entropy_with_logits(tag_logits, labels, mask)

        return output
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}

In [15]:
# Here we define our training and validation dataset

reader = PosDatasetReader()
train_dataset = reader.read(cached_path(
    'https://raw.githubusercontent.com/allenai/allennlp/master/tutorials/tagger/training.txt'))
validation_dataset = reader.read(cached_path(
    'https://raw.githubusercontent.com/allenai/allennlp/master/tutorials/tagger/validation.txt'))
vocab = Vocabulary.from_instances(train_dataset + validation_dataset)

2it [00:00, 2881.69it/s]
93B [00:00, 288087.35B/s]            
2it [00:00, 6605.20it/s]
100%|██████████| 4/4 [00:00<00:00, 37365.74it/s]


In [16]:
EMBEDDING_DIM = 6
HIDDEN_DIM = 6
token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                            embedding_dim=EMBEDDING_DIM)
word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})
lstm = PytorchSeq2SeqWrapper(torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))
model = LstmTagger(word_embeddings, lstm, vocab)

In [17]:
# Check if GPU and Cuda are available
if torch.cuda.is_available():
    cuda_device = 0
    model = model.cuda(cuda_device)
else:
    cuda_device = -1

In [18]:
optimizer = optim.SGD(model.parameters(), lr=0.1)

# The BucketIterator below sorts instances on their sequence length so that instances with similar sequence lengths
# are batched together for optimization
iterator = BucketIterator(batch_size=2, sorting_keys=[("sentence", "num_tokens")])

In [19]:
iterator.index_with(vocab)

## Train the model

In [20]:
trainer = Trainer(model=model,
                  optimizer=optimizer,
                  iterator=iterator,
                  train_dataset=train_dataset,
                  validation_dataset=validation_dataset,
                  patience=10,
                  num_epochs=50,
                  cuda_device=cuda_device)

In [21]:
trainer.train()

accuracy: 0.3333, loss: 1.1685 ||: 100%|██████████| 1/1 [00:00<00:00, 84.26it/s]
accuracy: 0.3333, loss: 1.1592 ||: 100%|██████████| 1/1 [00:00<00:00, 499.74it/s]
accuracy: 0.3333, loss: 1.1604 ||: 100%|██████████| 1/1 [00:00<00:00, 262.34it/s]
accuracy: 0.3333, loss: 1.1516 ||: 100%|██████████| 1/1 [00:00<00:00, 633.77it/s]
accuracy: 0.3333, loss: 1.1529 ||: 100%|██████████| 1/1 [00:00<00:00, 257.92it/s]
accuracy: 0.3333, loss: 1.1445 ||: 100%|██████████| 1/1 [00:00<00:00, 414.25it/s]
accuracy: 0.3333, loss: 1.1458 ||: 100%|██████████| 1/1 [00:00<00:00, 229.20it/s]
accuracy: 0.3333, loss: 1.1379 ||: 100%|██████████| 1/1 [00:00<00:00, 518.71it/s]
accuracy: 0.3333, loss: 1.1391 ||: 100%|██████████| 1/1 [00:00<00:00, 257.19it/s]
accuracy: 0.3333, loss: 1.1316 ||: 100%|██████████| 1/1 [00:00<00:00, 614.28it/s]
accuracy: 0.3333, loss: 1.1329 ||: 100%|██████████| 1/1 [00:00<00:00, 276.30it/s]
accuracy: 0.3333, loss: 1.1259 ||: 100%|██████████| 1/1 [00:00<00:00, 665.66it/s]
accuracy: 0.3333,

accuracy: 0.4444, loss: 1.0520 ||: 100%|██████████| 1/1 [00:00<00:00, 545.28it/s]
accuracy: 0.4444, loss: 1.0534 ||: 100%|██████████| 1/1 [00:00<00:00, 261.91it/s]
accuracy: 0.4444, loss: 1.0517 ||: 100%|██████████| 1/1 [00:00<00:00, 664.50it/s]
accuracy: 0.4444, loss: 1.0531 ||: 100%|██████████| 1/1 [00:00<00:00, 221.73it/s]
accuracy: 0.4444, loss: 1.0513 ||: 100%|██████████| 1/1 [00:00<00:00, 667.56it/s]
accuracy: 0.4444, loss: 1.0528 ||: 100%|██████████| 1/1 [00:00<00:00, 275.04it/s]
accuracy: 0.4444, loss: 1.0510 ||: 100%|██████████| 1/1 [00:00<00:00, 524.42it/s]
accuracy: 0.4444, loss: 1.0524 ||: 100%|██████████| 1/1 [00:00<00:00, 238.22it/s]
accuracy: 0.4444, loss: 1.0507 ||: 100%|██████████| 1/1 [00:00<00:00, 665.87it/s]
accuracy: 0.4444, loss: 1.0521 ||: 100%|██████████| 1/1 [00:00<00:00, 264.06it/s]
accuracy: 0.4444, loss: 1.0504 ||: 100%|██████████| 1/1 [00:00<00:00, 620.37it/s]
accuracy: 0.4444, loss: 1.0518 ||: 100%|██████████| 1/1 [00:00<00:00, 250.62it/s]
accuracy: 0.4444

accuracy: 0.4444, loss: 1.0374 ||: 100%|██████████| 1/1 [00:00<00:00, 225.78it/s]
accuracy: 0.4444, loss: 1.0357 ||: 100%|██████████| 1/1 [00:00<00:00, 590.58it/s]
accuracy: 0.4444, loss: 1.0370 ||: 100%|██████████| 1/1 [00:00<00:00, 251.58it/s]
accuracy: 0.4444, loss: 1.0353 ||: 100%|██████████| 1/1 [00:00<00:00, 641.43it/s]
accuracy: 0.4444, loss: 1.0366 ||: 100%|██████████| 1/1 [00:00<00:00, 220.16it/s]
accuracy: 0.4444, loss: 1.0349 ||: 100%|██████████| 1/1 [00:00<00:00, 690.65it/s]
accuracy: 0.4444, loss: 1.0362 ||: 100%|██████████| 1/1 [00:00<00:00, 213.32it/s]
accuracy: 0.4444, loss: 1.0345 ||: 100%|██████████| 1/1 [00:00<00:00, 756.41it/s]
accuracy: 0.4444, loss: 1.0357 ||: 100%|██████████| 1/1 [00:00<00:00, 273.60it/s]
accuracy: 0.4444, loss: 1.0341 ||: 100%|██████████| 1/1 [00:00<00:00, 645.97it/s]
accuracy: 0.4444, loss: 1.0353 ||: 100%|██████████| 1/1 [00:00<00:00, 266.02it/s]
accuracy: 0.4444, loss: 1.0336 ||: 100%|██████████| 1/1 [00:00<00:00, 562.09it/s]
accuracy: 0.4444

accuracy: 0.4444, loss: 1.0060 ||: 100%|██████████| 1/1 [00:00<00:00, 535.33it/s]
accuracy: 0.4444, loss: 1.0069 ||: 100%|██████████| 1/1 [00:00<00:00, 264.41it/s]
accuracy: 0.4444, loss: 1.0051 ||: 100%|██████████| 1/1 [00:00<00:00, 722.04it/s]
accuracy: 0.4444, loss: 1.0061 ||: 100%|██████████| 1/1 [00:00<00:00, 255.11it/s]
accuracy: 0.4444, loss: 1.0042 ||: 100%|██████████| 1/1 [00:00<00:00, 740.78it/s]
accuracy: 0.4444, loss: 1.0051 ||: 100%|██████████| 1/1 [00:00<00:00, 269.56it/s]
accuracy: 0.4444, loss: 1.0033 ||: 100%|██████████| 1/1 [00:00<00:00, 783.69it/s]
accuracy: 0.4444, loss: 1.0042 ||: 100%|██████████| 1/1 [00:00<00:00, 229.47it/s]
accuracy: 0.4444, loss: 1.0024 ||: 100%|██████████| 1/1 [00:00<00:00, 513.57it/s]
accuracy: 0.4444, loss: 1.0033 ||: 100%|██████████| 1/1 [00:00<00:00, 251.62it/s]
accuracy: 0.4444, loss: 1.0014 ||: 100%|██████████| 1/1 [00:00<00:00, 547.92it/s]
accuracy: 0.4444, loss: 1.0023 ||: 100%|██████████| 1/1 [00:00<00:00, 273.53it/s]
accuracy: 0.4444

accuracy: 0.4444, loss: 0.9396 ||: 100%|██████████| 1/1 [00:00<00:00, 235.52it/s]
accuracy: 0.4444, loss: 0.9370 ||: 100%|██████████| 1/1 [00:00<00:00, 706.11it/s]
accuracy: 0.4444, loss: 0.9375 ||: 100%|██████████| 1/1 [00:00<00:00, 224.37it/s]
accuracy: 0.4444, loss: 0.9349 ||: 100%|██████████| 1/1 [00:00<00:00, 731.99it/s]
accuracy: 0.4444, loss: 0.9355 ||: 100%|██████████| 1/1 [00:00<00:00, 245.17it/s]
accuracy: 0.4444, loss: 0.9329 ||: 100%|██████████| 1/1 [00:00<00:00, 612.40it/s]
accuracy: 0.4444, loss: 0.9334 ||: 100%|██████████| 1/1 [00:00<00:00, 257.86it/s]
accuracy: 0.4444, loss: 0.9307 ||: 100%|██████████| 1/1 [00:00<00:00, 769.17it/s]
accuracy: 0.4444, loss: 0.9313 ||: 100%|██████████| 1/1 [00:00<00:00, 259.00it/s]
accuracy: 0.4444, loss: 0.9286 ||: 100%|██████████| 1/1 [00:00<00:00, 779.47it/s]
accuracy: 0.4444, loss: 0.9292 ||: 100%|██████████| 1/1 [00:00<00:00, 246.67it/s]
accuracy: 0.4444, loss: 0.9264 ||: 100%|██████████| 1/1 [00:00<00:00, 551.88it/s]
accuracy: 0.4444

accuracy: 0.6667, loss: 0.8006 ||: 100%|██████████| 1/1 [00:00<00:00, 586.86it/s]
accuracy: 0.6667, loss: 0.8005 ||: 100%|██████████| 1/1 [00:00<00:00, 265.87it/s]
accuracy: 0.6667, loss: 0.7972 ||: 100%|██████████| 1/1 [00:00<00:00, 762.74it/s]
accuracy: 0.6667, loss: 0.7971 ||: 100%|██████████| 1/1 [00:00<00:00, 267.02it/s]
accuracy: 0.6667, loss: 0.7937 ||: 100%|██████████| 1/1 [00:00<00:00, 553.34it/s]
accuracy: 0.6667, loss: 0.7936 ||: 100%|██████████| 1/1 [00:00<00:00, 272.39it/s]
accuracy: 0.6667, loss: 0.7903 ||: 100%|██████████| 1/1 [00:00<00:00, 809.71it/s]
accuracy: 0.6667, loss: 0.7901 ||: 100%|██████████| 1/1 [00:00<00:00, 193.43it/s]
accuracy: 0.6667, loss: 0.7868 ||: 100%|██████████| 1/1 [00:00<00:00, 570.42it/s]
accuracy: 0.6667, loss: 0.7867 ||: 100%|██████████| 1/1 [00:00<00:00, 251.13it/s]
accuracy: 0.6667, loss: 0.7833 ||: 100%|██████████| 1/1 [00:00<00:00, 701.27it/s]
accuracy: 0.6667, loss: 0.7831 ||: 100%|██████████| 1/1 [00:00<00:00, 252.65it/s]
accuracy: 0.6667

accuracy: 0.6667, loss: 0.6238 ||: 100%|██████████| 1/1 [00:00<00:00, 259.74it/s]
accuracy: 0.7778, loss: 0.6217 ||: 100%|██████████| 1/1 [00:00<00:00, 681.67it/s]
accuracy: 0.6667, loss: 0.6202 ||: 100%|██████████| 1/1 [00:00<00:00, 249.07it/s]
accuracy: 0.7778, loss: 0.6181 ||: 100%|██████████| 1/1 [00:00<00:00, 733.53it/s]
accuracy: 0.7778, loss: 0.6166 ||: 100%|██████████| 1/1 [00:00<00:00, 213.14it/s]
accuracy: 0.7778, loss: 0.6146 ||: 100%|██████████| 1/1 [00:00<00:00, 624.99it/s]
accuracy: 0.7778, loss: 0.6130 ||: 100%|██████████| 1/1 [00:00<00:00, 277.47it/s]
accuracy: 0.7778, loss: 0.6110 ||: 100%|██████████| 1/1 [00:00<00:00, 807.06it/s]
accuracy: 0.7778, loss: 0.6093 ||: 100%|██████████| 1/1 [00:00<00:00, 262.87it/s]
accuracy: 0.7778, loss: 0.6074 ||: 100%|██████████| 1/1 [00:00<00:00, 621.10it/s]
accuracy: 0.7778, loss: 0.6057 ||: 100%|██████████| 1/1 [00:00<00:00, 226.14it/s]
accuracy: 0.7778, loss: 0.6038 ||: 100%|██████████| 1/1 [00:00<00:00, 618.08it/s]
accuracy: 0.7778

accuracy: 1.0000, loss: 0.4474 ||: 100%|██████████| 1/1 [00:00<00:00, 611.24it/s]
accuracy: 1.0000, loss: 0.4453 ||: 100%|██████████| 1/1 [00:00<00:00, 232.29it/s]
accuracy: 1.0000, loss: 0.4439 ||: 100%|██████████| 1/1 [00:00<00:00, 581.41it/s]
accuracy: 1.0000, loss: 0.4418 ||: 100%|██████████| 1/1 [00:00<00:00, 210.48it/s]
accuracy: 1.0000, loss: 0.4405 ||: 100%|██████████| 1/1 [00:00<00:00, 476.63it/s]
accuracy: 1.0000, loss: 0.4384 ||: 100%|██████████| 1/1 [00:00<00:00, 205.01it/s]
accuracy: 1.0000, loss: 0.4371 ||: 100%|██████████| 1/1 [00:00<00:00, 632.91it/s]
accuracy: 1.0000, loss: 0.4350 ||: 100%|██████████| 1/1 [00:00<00:00, 196.00it/s]
accuracy: 1.0000, loss: 0.4337 ||: 100%|██████████| 1/1 [00:00<00:00, 517.82it/s]
accuracy: 1.0000, loss: 0.4316 ||: 100%|██████████| 1/1 [00:00<00:00, 181.59it/s]
accuracy: 1.0000, loss: 0.4303 ||: 100%|██████████| 1/1 [00:00<00:00, 482.60it/s]
accuracy: 1.0000, loss: 0.4283 ||: 100%|██████████| 1/1 [00:00<00:00, 206.39it/s]
accuracy: 1.0000

accuracy: 1.0000, loss: 0.2958 ||: 100%|██████████| 1/1 [00:00<00:00, 229.67it/s]
accuracy: 1.0000, loss: 0.2940 ||: 100%|██████████| 1/1 [00:00<00:00, 678.03it/s]
accuracy: 1.0000, loss: 0.2932 ||: 100%|██████████| 1/1 [00:00<00:00, 284.48it/s]
accuracy: 1.0000, loss: 0.2914 ||: 100%|██████████| 1/1 [00:00<00:00, 575.03it/s]
accuracy: 1.0000, loss: 0.2906 ||: 100%|██████████| 1/1 [00:00<00:00, 252.79it/s]
accuracy: 1.0000, loss: 0.2889 ||: 100%|██████████| 1/1 [00:00<00:00, 737.91it/s]
accuracy: 1.0000, loss: 0.2881 ||: 100%|██████████| 1/1 [00:00<00:00, 233.37it/s]
accuracy: 1.0000, loss: 0.2863 ||: 100%|██████████| 1/1 [00:00<00:00, 606.46it/s]
accuracy: 1.0000, loss: 0.2855 ||: 100%|██████████| 1/1 [00:00<00:00, 264.32it/s]
accuracy: 1.0000, loss: 0.2838 ||: 100%|██████████| 1/1 [00:00<00:00, 671.20it/s]
accuracy: 1.0000, loss: 0.2830 ||: 100%|██████████| 1/1 [00:00<00:00, 233.87it/s]
accuracy: 1.0000, loss: 0.2813 ||: 100%|██████████| 1/1 [00:00<00:00, 657.41it/s]
accuracy: 1.0000

accuracy: 1.0000, loss: 0.1911 ||: 100%|██████████| 1/1 [00:00<00:00, 668.95it/s]
accuracy: 1.0000, loss: 0.1910 ||: 100%|██████████| 1/1 [00:00<00:00, 260.61it/s]
accuracy: 1.0000, loss: 0.1895 ||: 100%|██████████| 1/1 [00:00<00:00, 601.68it/s]
accuracy: 1.0000, loss: 0.1893 ||: 100%|██████████| 1/1 [00:00<00:00, 252.85it/s]
accuracy: 1.0000, loss: 0.1878 ||: 100%|██████████| 1/1 [00:00<00:00, 553.41it/s]
accuracy: 1.0000, loss: 0.1877 ||: 100%|██████████| 1/1 [00:00<00:00, 257.19it/s]
accuracy: 1.0000, loss: 0.1862 ||: 100%|██████████| 1/1 [00:00<00:00, 676.06it/s]
accuracy: 1.0000, loss: 0.1861 ||: 100%|██████████| 1/1 [00:00<00:00, 234.36it/s]
accuracy: 1.0000, loss: 0.1847 ||: 100%|██████████| 1/1 [00:00<00:00, 725.16it/s]
accuracy: 1.0000, loss: 0.1846 ||: 100%|██████████| 1/1 [00:00<00:00, 280.11it/s]
accuracy: 1.0000, loss: 0.1831 ||: 100%|██████████| 1/1 [00:00<00:00, 709.70it/s]
accuracy: 1.0000, loss: 0.1830 ||: 100%|██████████| 1/1 [00:00<00:00, 262.60it/s]
accuracy: 1.0000

accuracy: 1.0000, loss: 0.1286 ||: 100%|██████████| 1/1 [00:00<00:00, 255.95it/s]
accuracy: 1.0000, loss: 0.1275 ||: 100%|██████████| 1/1 [00:00<00:00, 664.60it/s]
accuracy: 1.0000, loss: 0.1276 ||: 100%|██████████| 1/1 [00:00<00:00, 249.13it/s]
accuracy: 1.0000, loss: 0.1265 ||: 100%|██████████| 1/1 [00:00<00:00, 680.78it/s]
accuracy: 1.0000, loss: 0.1267 ||: 100%|██████████| 1/1 [00:00<00:00, 259.34it/s]
accuracy: 1.0000, loss: 0.1256 ||: 100%|██████████| 1/1 [00:00<00:00, 660.10it/s]
accuracy: 1.0000, loss: 0.1257 ||: 100%|██████████| 1/1 [00:00<00:00, 269.85it/s]
accuracy: 1.0000, loss: 0.1247 ||: 100%|██████████| 1/1 [00:00<00:00, 706.23it/s]
accuracy: 1.0000, loss: 0.1248 ||: 100%|██████████| 1/1 [00:00<00:00, 182.28it/s]
accuracy: 1.0000, loss: 0.1237 ||: 100%|██████████| 1/1 [00:00<00:00, 477.60it/s]
accuracy: 1.0000, loss: 0.1239 ||: 100%|██████████| 1/1 [00:00<00:00, 237.97it/s]
accuracy: 1.0000, loss: 0.1228 ||: 100%|██████████| 1/1 [00:00<00:00, 737.40it/s]
accuracy: 1.0000

accuracy: 1.0000, loss: 0.0907 ||: 100%|██████████| 1/1 [00:00<00:00, 668.41it/s]
accuracy: 1.0000, loss: 0.0908 ||: 100%|██████████| 1/1 [00:00<00:00, 257.71it/s]
accuracy: 1.0000, loss: 0.0901 ||: 100%|██████████| 1/1 [00:00<00:00, 597.73it/s]
accuracy: 1.0000, loss: 0.0902 ||: 100%|██████████| 1/1 [00:00<00:00, 221.69it/s]
accuracy: 1.0000, loss: 0.0895 ||: 100%|██████████| 1/1 [00:00<00:00, 696.96it/s]
accuracy: 1.0000, loss: 0.0897 ||: 100%|██████████| 1/1 [00:00<00:00, 241.66it/s]
accuracy: 1.0000, loss: 0.0890 ||: 100%|██████████| 1/1 [00:00<00:00, 664.92it/s]
accuracy: 1.0000, loss: 0.0891 ||: 100%|██████████| 1/1 [00:00<00:00, 242.67it/s]
accuracy: 1.0000, loss: 0.0884 ||: 100%|██████████| 1/1 [00:00<00:00, 682.67it/s]
accuracy: 1.0000, loss: 0.0885 ||: 100%|██████████| 1/1 [00:00<00:00, 276.54it/s]
accuracy: 1.0000, loss: 0.0878 ||: 100%|██████████| 1/1 [00:00<00:00, 576.06it/s]
accuracy: 1.0000, loss: 0.0880 ||: 100%|██████████| 1/1 [00:00<00:00, 233.59it/s]
accuracy: 1.0000

accuracy: 1.0000, loss: 0.0681 ||: 100%|██████████| 1/1 [00:00<00:00, 250.69it/s]
accuracy: 1.0000, loss: 0.0676 ||: 100%|██████████| 1/1 [00:00<00:00, 680.23it/s]
accuracy: 1.0000, loss: 0.0678 ||: 100%|██████████| 1/1 [00:00<00:00, 263.02it/s]
accuracy: 1.0000, loss: 0.0673 ||: 100%|██████████| 1/1 [00:00<00:00, 593.51it/s]
accuracy: 1.0000, loss: 0.0674 ||: 100%|██████████| 1/1 [00:00<00:00, 267.73it/s]
accuracy: 1.0000, loss: 0.0669 ||: 100%|██████████| 1/1 [00:00<00:00, 514.13it/s]
accuracy: 1.0000, loss: 0.0670 ||: 100%|██████████| 1/1 [00:00<00:00, 228.87it/s]
accuracy: 1.0000, loss: 0.0666 ||: 100%|██████████| 1/1 [00:00<00:00, 698.35it/s]
accuracy: 1.0000, loss: 0.0667 ||: 100%|██████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.0662 ||: 100%|██████████| 1/1 [00:00<00:00, 767.06it/s]
accuracy: 1.0000, loss: 0.0663 ||: 100%|██████████| 1/1 [00:00<00:00, 218.10it/s]
accuracy: 1.0000, loss: 0.0659 ||: 100%|██████████| 1/1 [00:00<00:00, 556.57it/s]
accuracy: 1.0000

accuracy: 1.0000, loss: 0.0530 ||: 100%|██████████| 1/1 [00:00<00:00, 634.16it/s]
accuracy: 1.0000, loss: 0.0531 ||: 100%|██████████| 1/1 [00:00<00:00, 252.58it/s]
accuracy: 1.0000, loss: 0.0528 ||: 100%|██████████| 1/1 [00:00<00:00, 678.91it/s]
accuracy: 1.0000, loss: 0.0529 ||: 100%|██████████| 1/1 [00:00<00:00, 250.93it/s]
accuracy: 1.0000, loss: 0.0526 ||: 100%|██████████| 1/1 [00:00<00:00, 649.78it/s]
accuracy: 1.0000, loss: 0.0527 ||: 100%|██████████| 1/1 [00:00<00:00, 241.20it/s]
accuracy: 1.0000, loss: 0.0523 ||: 100%|██████████| 1/1 [00:00<00:00, 575.35it/s]
accuracy: 1.0000, loss: 0.0524 ||: 100%|██████████| 1/1 [00:00<00:00, 252.15it/s]
accuracy: 1.0000, loss: 0.0521 ||: 100%|██████████| 1/1 [00:00<00:00, 744.60it/s]
accuracy: 1.0000, loss: 0.0522 ||: 100%|██████████| 1/1 [00:00<00:00, 236.51it/s]
accuracy: 1.0000, loss: 0.0518 ||: 100%|██████████| 1/1 [00:00<00:00, 534.44it/s]
accuracy: 1.0000, loss: 0.0519 ||: 100%|██████████| 1/1 [00:00<00:00, 265.88it/s]
accuracy: 1.0000

accuracy: 1.0000, loss: 0.0431 ||: 100%|██████████| 1/1 [00:00<00:00, 239.29it/s]
accuracy: 1.0000, loss: 0.0429 ||: 100%|██████████| 1/1 [00:00<00:00, 626.11it/s]
accuracy: 1.0000, loss: 0.0429 ||: 100%|██████████| 1/1 [00:00<00:00, 266.03it/s]
accuracy: 1.0000, loss: 0.0427 ||: 100%|██████████| 1/1 [00:00<00:00, 557.09it/s]
accuracy: 1.0000, loss: 0.0428 ||: 100%|██████████| 1/1 [00:00<00:00, 247.17it/s]
accuracy: 1.0000, loss: 0.0425 ||: 100%|██████████| 1/1 [00:00<00:00, 553.12it/s]
accuracy: 1.0000, loss: 0.0426 ||: 100%|██████████| 1/1 [00:00<00:00, 231.96it/s]
accuracy: 1.0000, loss: 0.0424 ||: 100%|██████████| 1/1 [00:00<00:00, 657.52it/s]
accuracy: 1.0000, loss: 0.0424 ||: 100%|██████████| 1/1 [00:00<00:00, 237.37it/s]
accuracy: 1.0000, loss: 0.0422 ||: 100%|██████████| 1/1 [00:00<00:00, 646.37it/s]
accuracy: 1.0000, loss: 0.0423 ||: 100%|██████████| 1/1 [00:00<00:00, 202.93it/s]
accuracy: 1.0000, loss: 0.0420 ||: 100%|██████████| 1/1 [00:00<00:00, 603.84it/s]
accuracy: 1.0000

accuracy: 1.0000, loss: 0.0357 ||: 100%|██████████| 1/1 [00:00<00:00, 619.18it/s]
accuracy: 1.0000, loss: 0.0358 ||: 100%|██████████| 1/1 [00:00<00:00, 264.19it/s]
accuracy: 1.0000, loss: 0.0356 ||: 100%|██████████| 1/1 [00:00<00:00, 702.68it/s]
accuracy: 1.0000, loss: 0.0357 ||: 100%|██████████| 1/1 [00:00<00:00, 252.27it/s]
accuracy: 1.0000, loss: 0.0355 ||: 100%|██████████| 1/1 [00:00<00:00, 620.37it/s]
accuracy: 1.0000, loss: 0.0356 ||: 100%|██████████| 1/1 [00:00<00:00, 260.69it/s]
accuracy: 1.0000, loss: 0.0354 ||: 100%|██████████| 1/1 [00:00<00:00, 631.39it/s]
accuracy: 1.0000, loss: 0.0354 ||: 100%|██████████| 1/1 [00:00<00:00, 254.97it/s]
accuracy: 1.0000, loss: 0.0352 ||: 100%|██████████| 1/1 [00:00<00:00, 723.03it/s]
accuracy: 1.0000, loss: 0.0353 ||: 100%|██████████| 1/1 [00:00<00:00, 234.14it/s]
accuracy: 1.0000, loss: 0.0351 ||: 100%|██████████| 1/1 [00:00<00:00, 648.17it/s]
accuracy: 1.0000, loss: 0.0352 ||: 100%|██████████| 1/1 [00:00<00:00, 229.10it/s]
accuracy: 1.0000

accuracy: 1.0000, loss: 0.0305 ||: 100%|██████████| 1/1 [00:00<00:00, 243.80it/s]
accuracy: 1.0000, loss: 0.0303 ||: 100%|██████████| 1/1 [00:00<00:00, 684.45it/s]
accuracy: 1.0000, loss: 0.0304 ||: 100%|██████████| 1/1 [00:00<00:00, 254.12it/s]
accuracy: 1.0000, loss: 0.0303 ||: 100%|██████████| 1/1 [00:00<00:00, 720.92it/s]
accuracy: 1.0000, loss: 0.0303 ||: 100%|██████████| 1/1 [00:00<00:00, 219.00it/s]
accuracy: 1.0000, loss: 0.0302 ||: 100%|██████████| 1/1 [00:00<00:00, 682.56it/s]
accuracy: 1.0000, loss: 0.0302 ||: 100%|██████████| 1/1 [00:00<00:00, 245.65it/s]
accuracy: 1.0000, loss: 0.0301 ||: 100%|██████████| 1/1 [00:00<00:00, 809.71it/s]
accuracy: 1.0000, loss: 0.0301 ||: 100%|██████████| 1/1 [00:00<00:00, 260.02it/s]
accuracy: 1.0000, loss: 0.0300 ||: 100%|██████████| 1/1 [00:00<00:00, 760.94it/s]
accuracy: 1.0000, loss: 0.0300 ||: 100%|██████████| 1/1 [00:00<00:00, 270.79it/s]
accuracy: 1.0000, loss: 0.0299 ||: 100%|██████████| 1/1 [00:00<00:00, 586.12it/s]
accuracy: 1.0000

accuracy: 1.0000, loss: 0.0263 ||: 100%|██████████| 1/1 [00:00<00:00, 756.41it/s]
accuracy: 1.0000, loss: 0.0263 ||: 100%|██████████| 1/1 [00:00<00:00, 248.82it/s]
accuracy: 1.0000, loss: 0.0262 ||: 100%|██████████| 1/1 [00:00<00:00, 558.05it/s]
accuracy: 1.0000, loss: 0.0263 ||: 100%|██████████| 1/1 [00:00<00:00, 259.15it/s]
accuracy: 1.0000, loss: 0.0262 ||: 100%|██████████| 1/1 [00:00<00:00, 592.00it/s]
accuracy: 1.0000, loss: 0.0262 ||: 100%|██████████| 1/1 [00:00<00:00, 221.00it/s]
accuracy: 1.0000, loss: 0.0261 ||: 100%|██████████| 1/1 [00:00<00:00, 695.46it/s]
accuracy: 1.0000, loss: 0.0261 ||: 100%|██████████| 1/1 [00:00<00:00, 231.05it/s]
accuracy: 1.0000, loss: 0.0260 ||: 100%|██████████| 1/1 [00:00<00:00, 691.90it/s]
accuracy: 1.0000, loss: 0.0261 ||: 100%|██████████| 1/1 [00:00<00:00, 219.34it/s]
accuracy: 1.0000, loss: 0.0259 ||: 100%|██████████| 1/1 [00:00<00:00, 434.96it/s]
accuracy: 1.0000, loss: 0.0260 ||: 100%|██████████| 1/1 [00:00<00:00, 205.00it/s]
accuracy: 1.0000

accuracy: 1.0000, loss: 0.0231 ||: 100%|██████████| 1/1 [00:00<00:00, 247.52it/s]
accuracy: 1.0000, loss: 0.0231 ||: 100%|██████████| 1/1 [00:00<00:00, 588.67it/s]
accuracy: 1.0000, loss: 0.0231 ||: 100%|██████████| 1/1 [00:00<00:00, 251.44it/s]
accuracy: 1.0000, loss: 0.0230 ||: 100%|██████████| 1/1 [00:00<00:00, 511.25it/s]
accuracy: 1.0000, loss: 0.0230 ||: 100%|██████████| 1/1 [00:00<00:00, 234.73it/s]
accuracy: 1.0000, loss: 0.0229 ||: 100%|██████████| 1/1 [00:00<00:00, 548.28it/s]
accuracy: 1.0000, loss: 0.0230 ||: 100%|██████████| 1/1 [00:00<00:00, 236.18it/s]
accuracy: 1.0000, loss: 0.0229 ||: 100%|██████████| 1/1 [00:00<00:00, 482.44it/s]
accuracy: 1.0000, loss: 0.0229 ||: 100%|██████████| 1/1 [00:00<00:00, 214.74it/s]
accuracy: 1.0000, loss: 0.0228 ||: 100%|██████████| 1/1 [00:00<00:00, 641.43it/s]
accuracy: 1.0000, loss: 0.0229 ||: 100%|██████████| 1/1 [00:00<00:00, 250.24it/s]
accuracy: 1.0000, loss: 0.0228 ||: 100%|██████████| 1/1 [00:00<00:00, 610.79it/s]
accuracy: 1.0000

accuracy: 1.0000, loss: 0.0205 ||: 100%|██████████| 1/1 [00:00<00:00, 620.28it/s]
accuracy: 1.0000, loss: 0.0205 ||: 100%|██████████| 1/1 [00:00<00:00, 248.39it/s]
accuracy: 1.0000, loss: 0.0205 ||: 100%|██████████| 1/1 [00:00<00:00, 553.48it/s]
accuracy: 1.0000, loss: 0.0205 ||: 100%|██████████| 1/1 [00:00<00:00, 238.95it/s]
accuracy: 1.0000, loss: 0.0204 ||: 100%|██████████| 1/1 [00:00<00:00, 742.35it/s]
accuracy: 1.0000, loss: 0.0204 ||: 100%|██████████| 1/1 [00:00<00:00, 241.41it/s]
accuracy: 1.0000, loss: 0.0204 ||: 100%|██████████| 1/1 [00:00<00:00, 678.47it/s]
accuracy: 1.0000, loss: 0.0204 ||: 100%|██████████| 1/1 [00:00<00:00, 214.22it/s]
accuracy: 1.0000, loss: 0.0203 ||: 100%|██████████| 1/1 [00:00<00:00, 556.79it/s]
accuracy: 1.0000, loss: 0.0203 ||: 100%|██████████| 1/1 [00:00<00:00, 247.28it/s]
accuracy: 1.0000, loss: 0.0203 ||: 100%|██████████| 1/1 [00:00<00:00, 637.82it/s]
accuracy: 1.0000, loss: 0.0203 ||: 100%|██████████| 1/1 [00:00<00:00, 206.73it/s]
accuracy: 1.0000

accuracy: 1.0000, loss: 0.0184 ||: 100%|██████████| 1/1 [00:00<00:00, 260.27it/s]
accuracy: 1.0000, loss: 0.0184 ||: 100%|██████████| 1/1 [00:00<00:00, 679.35it/s]
accuracy: 1.0000, loss: 0.0184 ||: 100%|██████████| 1/1 [00:00<00:00, 274.26it/s]
accuracy: 1.0000, loss: 0.0183 ||: 100%|██████████| 1/1 [00:00<00:00, 678.36it/s]
accuracy: 1.0000, loss: 0.0184 ||: 100%|██████████| 1/1 [00:00<00:00, 249.68it/s]
accuracy: 1.0000, loss: 0.0183 ||: 100%|██████████| 1/1 [00:00<00:00, 684.00it/s]
accuracy: 1.0000, loss: 0.0183 ||: 100%|██████████| 1/1 [00:00<00:00, 215.34it/s]
accuracy: 1.0000, loss: 0.0183 ||: 100%|██████████| 1/1 [00:00<00:00, 599.87it/s]
accuracy: 1.0000, loss: 0.0183 ||: 100%|██████████| 1/1 [00:00<00:00, 261.91it/s]
accuracy: 1.0000, loss: 0.0182 ||: 100%|██████████| 1/1 [00:00<00:00, 634.54it/s]
accuracy: 1.0000, loss: 0.0182 ||: 100%|██████████| 1/1 [00:00<00:00, 231.76it/s]
accuracy: 1.0000, loss: 0.0182 ||: 100%|██████████| 1/1 [00:00<00:00, 569.49it/s]
accuracy: 1.0000

{'best_epoch': 999,
 'peak_cpu_memory_MB': 252.37504,
 'training_duration': '0:00:22.650910',
 'training_start_epoch': 0,
 'training_epochs': 999,
 'epoch': 999,
 'training_accuracy': 1.0,
 'training_loss': 0.018094655126333237,
 'training_cpu_memory_MB': 252.37504,
 'validation_accuracy': 1.0,
 'validation_loss': 0.018039997667074203,
 'best_validation_accuracy': 1.0,
 'best_validation_loss': 0.018039997667074203}

## Prediction

In [22]:
predictor = SentenceTaggerPredictor(model, dataset_reader=reader)
tag_logits = predictor.predict("The dog ate the apple")['tag_logits']
tag_ids = np.argmax(tag_logits, axis=-1)
print([model.vocab.get_token_from_index(i, 'labels') for i in tag_ids])

Spacy models 'en_core_web_sm' not found.  Downloading and installing.


[38;5;2m✔ Download and installation successful[0m
You can now load the model via spacy.load('en_core_web_sm')
[38;5;2m✔ Linking successful[0m
/Users/skhurana/wwc/venv/lib/python3.7/site-packages/en_core_web_sm -->
/Users/skhurana/wwc/venv/lib/python3.7/site-packages/spacy/data/en_core_web_sm
You can now load the model via spacy.load('en_core_web_sm')
['DET', 'NN', 'V', 'DET', 'NN']


## Saving and reloading

In [23]:
with open("/tmp/model.th", 'wb') as f:
    torch.save(model.state_dict(), f)
vocab.save_to_files("/tmp/vocabulary")

In [24]:
vocab2 = Vocabulary.from_files("/tmp/vocabulary")
model2 = LstmTagger(word_embeddings, lstm, vocab2)

with open("/tmp/model.th", 'rb') as f:
    model2.load_state_dict(torch.load(f))
    
if cuda_device > -1:
    model2.cuda(cuda_device)
    
predictor2 = SentenceTaggerPredictor(model2, dataset_reader=reader)
tag_logits2 = predictor2.predict("The dog ate the apple")['tag_logits']
np.testing.assert_array_almost_equal(tag_logits2, tag_logits)

# Using ELMo embeddings

In [35]:
from allennlp.modules.elmo import Elmo, batch_to_ids

options_file = "https://allennlp.s3.amazonaws.com/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
weight_file = "https://allennlp.s3.amazonaws.com/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"


In [36]:
elmo = Elmo(options_file, weight_file, 2, dropout=0)

100%|██████████| 336/336 [00:00<00:00, 1770459.98B/s]
100%|██████████| 374434792/374434792 [00:38<00:00, 9815985.01B/s] 


In [39]:
sentences = [['First', 'sentence', '.'], ['Another', '.']]
character_ids = batch_to_ids(sentences)

embeddings = elmo(character_ids)


In [40]:
embeddings
# Now we can use these embeddings in the model as an input and train them

{'elmo_representations': [tensor([[[ 0.1474, -0.1475,  0.1376,  ...,  0.0270, -0.4051, -0.0498],
           [ 0.2394,  0.0769,  0.4126,  ..., -0.1671, -0.1707,  0.3884],
           [-0.7602, -0.4944, -0.5355,  ..., -0.0803,  0.0361,  0.1128]],
  
          [[ 0.2603, -0.4437,  0.2726,  ..., -0.0830, -0.1522, -0.1361],
           [-0.7772, -0.4294, -0.2651,  ..., -0.0803,  0.0361,  0.1128],
           [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]],
         grad_fn=<CopySlices>),
  tensor([[[ 0.1474, -0.1475,  0.1376,  ...,  0.0270, -0.4051, -0.0498],
           [ 0.2394,  0.0769,  0.4126,  ..., -0.1671, -0.1707,  0.3884],
           [-0.7602, -0.4944, -0.5355,  ..., -0.0803,  0.0361,  0.1128]],
  
          [[ 0.2603, -0.4437,  0.2726,  ..., -0.0830, -0.1522, -0.1361],
           [-0.7772, -0.4294, -0.2651,  ..., -0.0803,  0.0361,  0.1128],
           [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]],
         grad_fn=<CopySlices>)],
 'mask': tensor([[1, 1,

### Interactively

In [41]:
import scipy
from allennlp.commands.elmo import ElmoEmbedder

In [47]:
elmo = ElmoEmbedder()
sentence1 = ["I", "went", "to", "the", "river", "bank", "today"]
sentence2 = "I went to the bank today to withdraw cash".split(' ')
sentence3 = "I sat by the bank today for breakfast".split(' ')
v1 = elmo.embed_sentence(sentence1)
v2 = elmo.embed_sentence(sentence2)
v3 = elmo.embed_sentence(sentence3)

In [45]:
# Shape is number of layers
v1.shape

(3, 7, 1024)

In [46]:
# Compute distance metric between the embeddings at the last layer
scipy.spatial.distance.cosine(v1[2][5], v2[2][4])

0.1516461968421936

In [48]:
scipy.spatial.distance.cosine(v1[2][5], v3[2][4])

0.2314092516899109