In [5]:
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token
from allennlp.data import Instance
from allennlp.data.fields import TextField, SequenceLabelField


In [9]:
TextField([(Token("hello"))], {"tokens": SingleIdTokenIndexer()}).tokens


[hello]

In [1]:


from typing import Iterator, List, Dict

import torch
import torch.optim as optim
import numpy as np

from allennlp.data import Instance
from allennlp.data.fields import TextField, SequenceLabelField

from allennlp.data.dataset_readers import DatasetReader

from allennlp.common.file_utils import cached_path

from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token

from allennlp.data.vocabulary import Vocabulary

from allennlp.models import Model

from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits

from allennlp.training.metrics import CategoricalAccuracy

from allennlp.data.iterators import BucketIterator

from allennlp.training.trainer import Trainer

from allennlp.predictors import SentenceTaggerPredictor

torch.manual_seed(1)

@DatasetReader.register('pos-tutorial')
class PosDatasetReader(DatasetReader):
    """
    DatasetReader for PoS tagging data, one sentence per line, like

        The###DET dog###NN ate###V the###DET apple###NN
    """

    def __init__(self, token_indexers: Dict[str, TokenIndexer] = None) -> None:
        super().__init__(lazy=False)
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}

    def text_to_instance(self, tokens: List[Token], tags: List[str] = None) -> Instance:
        sentence_field = TextField(tokens, self.token_indexers)
        fields = {"sentence": sentence_field}

        if tags:
            label_field = SequenceLabelField(labels=tags, sequence_field=sentence_field)
            fields["labels"] = label_field

        return Instance(fields)

    def _read(self, file_path: str) -> Iterator[Instance]:
        with open(file_path) as f:
            for line in f:
                pairs = line.strip().split()
                sentence, tags = zip(*(pair.split("###") for pair in pairs))
                yield self.text_to_instance([Token(word) for word in sentence], tags)


@Model.register('lstm-tagger')                
class LstmTagger(Model):

    def __init__(self,

                 word_embeddings: TextFieldEmbedder,

                 encoder: Seq2SeqEncoder,

                 vocab: Vocabulary) -> None:

        super().__init__(vocab)
        self.word_embeddings = word_embeddings
        self.encoder = encoder

        self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(),
                                          out_features=vocab.get_vocab_size('labels'))

        self.accuracy = CategoricalAccuracy()

    def forward(self,
                sentence: Dict[str, torch.Tensor],
                labels: torch.Tensor = None) -> Dict[str, torch.Tensor]:

        mask = get_text_field_mask(sentence)

        embeddings = self.word_embeddings(sentence)

        encoder_out = self.encoder(embeddings, mask)

        tag_logits = self.hidden2tag(encoder_out)
        output = {"tag_logits": tag_logits}

        if labels is not None:
            self.accuracy(tag_logits, labels, mask)
            output["loss"] = sequence_cross_entropy_with_logits(tag_logits, labels, mask)

        return output

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}


Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


93B [00:00, 27927.99B/s]             
2it [00:00, 2897.62it/s]
93B [00:00, 18824.88B/s]             
2it [00:00, 2066.67it/s]
100%|██████████| 4/4 [00:00<00:00, 18620.66it/s]
accuracy: 0.3333, loss: 1.1685 ||: 100%|██████████| 1/1 [00:00<00:00, 18.67it/s]
accuracy: 0.3333, loss: 1.1592 ||: 100%|██████████| 1/1 [00:00<00:00, 506.86it/s]
accuracy: 0.3333, loss: 1.1604 ||: 100%|██████████| 1/1 [00:00<00:00, 163.97it/s]
accuracy: 0.3333, loss: 1.1516 ||: 100%|██████████| 1/1 [00:00<00:00, 622.49it/s]
accuracy: 0.3333, loss: 1.1529 ||: 100%|██████████| 1/1 [00:00<00:00, 204.66it/s]
accuracy: 0.3333, loss: 1.1445 ||: 100%|██████████| 1/1 [00:00<00:00, 548.78it/s]
accuracy: 0.3333, loss: 1.1458 ||: 100%|██████████| 1/1 [00:00<00:00, 132.70it/s]
accuracy: 0.3333, loss: 1.1379 ||: 100%|██████████| 1/1 [00:00<00:00, 474.52it/s]
accuracy: 0.3333, loss: 1.1391 ||: 100%|██████████| 1/1 [00:00<00:00, 142.51it/s]
accuracy: 0.3333, loss: 1.1316 ||: 100%|██████████| 1/1 [00:00<00:00, 81.12it/s]
accurac

accuracy: 0.4444, loss: 1.0378 ||: 100%|██████████| 1/1 [00:00<00:00, 200.97it/s]
accuracy: 0.4444, loss: 1.0361 ||: 100%|██████████| 1/1 [00:00<00:00, 256.39it/s]
accuracy: 0.4444, loss: 1.0374 ||: 100%|██████████| 1/1 [00:00<00:00, 91.05it/s]
accuracy: 0.4444, loss: 1.0357 ||: 100%|██████████| 1/1 [00:00<00:00, 544.29it/s]
accuracy: 0.4444, loss: 1.0370 ||: 100%|██████████| 1/1 [00:00<00:00, 282.29it/s]
accuracy: 0.4444, loss: 1.0353 ||: 100%|██████████| 1/1 [00:00<00:00, 590.58it/s]
accuracy: 0.4444, loss: 1.0366 ||: 100%|██████████| 1/1 [00:00<00:00, 183.01it/s]
accuracy: 0.4444, loss: 1.0349 ||: 100%|██████████| 1/1 [00:00<00:00, 536.97it/s]
accuracy: 0.4444, loss: 1.0362 ||: 100%|██████████| 1/1 [00:00<00:00, 269.61it/s]
accuracy: 0.4444, loss: 1.0345 ||: 100%|██████████| 1/1 [00:00<00:00, 625.55it/s]
accuracy: 0.4444, loss: 1.0357 ||: 100%|██████████| 1/1 [00:00<00:00, 220.89it/s]
accuracy: 0.4444, loss: 1.0341 ||: 100%|██████████| 1/1 [00:00<00:00, 627.42it/s]
accuracy: 0.4444,

accuracy: 0.4444, loss: 0.9415 ||: 100%|██████████| 1/1 [00:00<00:00, 250.81it/s]
accuracy: 0.4444, loss: 0.9390 ||: 100%|██████████| 1/1 [00:00<00:00, 525.87it/s]
accuracy: 0.4444, loss: 0.9396 ||: 100%|██████████| 1/1 [00:00<00:00, 263.31it/s]
accuracy: 0.4444, loss: 0.9370 ||: 100%|██████████| 1/1 [00:00<00:00, 593.93it/s]
accuracy: 0.4444, loss: 0.9375 ||: 100%|██████████| 1/1 [00:00<00:00, 177.35it/s]
accuracy: 0.4444, loss: 0.9349 ||: 100%|██████████| 1/1 [00:00<00:00, 286.40it/s]
accuracy: 0.4444, loss: 0.9355 ||: 100%|██████████| 1/1 [00:00<00:00, 142.73it/s]
accuracy: 0.4444, loss: 0.9329 ||: 100%|██████████| 1/1 [00:00<00:00, 332.30it/s]
accuracy: 0.4444, loss: 0.9334 ||: 100%|██████████| 1/1 [00:00<00:00, 184.62it/s]
accuracy: 0.4444, loss: 0.9307 ||: 100%|██████████| 1/1 [00:00<00:00, 605.76it/s]
accuracy: 0.4444, loss: 0.9313 ||: 100%|██████████| 1/1 [00:00<00:00, 239.41it/s]
accuracy: 0.4444, loss: 0.9286 ||: 100%|██████████| 1/1 [00:00<00:00, 580.13it/s]
accuracy: 0.4444

accuracy: 0.6667, loss: 0.6275 ||: 100%|██████████| 1/1 [00:00<00:00, 158.14it/s]
accuracy: 0.7778, loss: 0.6253 ||: 100%|██████████| 1/1 [00:00<00:00, 666.19it/s]
accuracy: 0.6667, loss: 0.6238 ||: 100%|██████████| 1/1 [00:00<00:00, 206.54it/s]
accuracy: 0.7778, loss: 0.6217 ||: 100%|██████████| 1/1 [00:00<00:00, 140.93it/s]
accuracy: 0.6667, loss: 0.6202 ||: 100%|██████████| 1/1 [00:00<00:00, 258.57it/s]
accuracy: 0.7778, loss: 0.6181 ||: 100%|██████████| 1/1 [00:00<00:00, 563.37it/s]
accuracy: 0.7778, loss: 0.6166 ||: 100%|██████████| 1/1 [00:00<00:00, 85.42it/s]
accuracy: 0.7778, loss: 0.6146 ||: 100%|██████████| 1/1 [00:00<00:00, 582.95it/s]
accuracy: 0.7778, loss: 0.6130 ||: 100%|██████████| 1/1 [00:00<00:00, 81.39it/s]
accuracy: 0.7778, loss: 0.6110 ||: 100%|██████████| 1/1 [00:00<00:00, 564.13it/s]
accuracy: 0.7778, loss: 0.6093 ||: 100%|██████████| 1/1 [00:00<00:00, 136.37it/s]
accuracy: 0.7778, loss: 0.6074 ||: 100%|██████████| 1/1 [00:00<00:00, 631.96it/s]
accuracy: 0.7778, 

accuracy: 1.0000, loss: 0.2966 ||: 100%|██████████| 1/1 [00:00<00:00, 457.69it/s]
accuracy: 1.0000, loss: 0.2958 ||: 100%|██████████| 1/1 [00:00<00:00, 158.19it/s]
accuracy: 1.0000, loss: 0.2940 ||: 100%|██████████| 1/1 [00:00<00:00, 211.00it/s]
accuracy: 1.0000, loss: 0.2932 ||: 100%|██████████| 1/1 [00:00<00:00, 155.13it/s]
accuracy: 1.0000, loss: 0.2914 ||: 100%|██████████| 1/1 [00:00<00:00, 350.55it/s]
accuracy: 1.0000, loss: 0.2906 ||: 100%|██████████| 1/1 [00:00<00:00, 84.35it/s]
accuracy: 1.0000, loss: 0.2889 ||: 100%|██████████| 1/1 [00:00<00:00, 201.84it/s]
accuracy: 1.0000, loss: 0.2881 ||: 100%|██████████| 1/1 [00:00<00:00, 149.42it/s]
accuracy: 1.0000, loss: 0.2863 ||: 100%|██████████| 1/1 [00:00<00:00, 79.00it/s]
accuracy: 1.0000, loss: 0.2855 ||: 100%|██████████| 1/1 [00:00<00:00, 198.40it/s]
accuracy: 1.0000, loss: 0.2838 ||: 100%|██████████| 1/1 [00:00<00:00, 504.30it/s]
accuracy: 1.0000, loss: 0.2830 ||: 100%|██████████| 1/1 [00:00<00:00, 116.54it/s]
accuracy: 1.0000, 

accuracy: 1.0000, loss: 0.1285 ||: 100%|██████████| 1/1 [00:00<00:00, 482.88it/s]
accuracy: 1.0000, loss: 0.1286 ||: 100%|██████████| 1/1 [00:00<00:00, 126.44it/s]
accuracy: 1.0000, loss: 0.1275 ||: 100%|██████████| 1/1 [00:00<00:00, 243.91it/s]
accuracy: 1.0000, loss: 0.1276 ||: 100%|██████████| 1/1 [00:00<00:00, 122.34it/s]
accuracy: 1.0000, loss: 0.1265 ||: 100%|██████████| 1/1 [00:00<00:00, 351.52it/s]
accuracy: 1.0000, loss: 0.1267 ||: 100%|██████████| 1/1 [00:00<00:00, 230.04it/s]
accuracy: 1.0000, loss: 0.1256 ||: 100%|██████████| 1/1 [00:00<00:00, 495.66it/s]
accuracy: 1.0000, loss: 0.1257 ||: 100%|██████████| 1/1 [00:00<00:00, 95.72it/s]
accuracy: 1.0000, loss: 0.1247 ||: 100%|██████████| 1/1 [00:00<00:00, 568.87it/s]
accuracy: 1.0000, loss: 0.1248 ||: 100%|██████████| 1/1 [00:00<00:00, 104.51it/s]
accuracy: 1.0000, loss: 0.1237 ||: 100%|██████████| 1/1 [00:00<00:00, 629.02it/s]
accuracy: 1.0000, loss: 0.1239 ||: 100%|██████████| 1/1 [00:00<00:00, 93.40it/s]
accuracy: 1.0000, 

accuracy: 1.0000, loss: 0.0907 ||: 100%|██████████| 1/1 [00:00<00:00, 710.54it/s]
accuracy: 1.0000, loss: 0.0908 ||: 100%|██████████| 1/1 [00:00<00:00, 281.42it/s]
accuracy: 1.0000, loss: 0.0901 ||: 100%|██████████| 1/1 [00:00<00:00, 614.10it/s]
accuracy: 1.0000, loss: 0.0902 ||: 100%|██████████| 1/1 [00:00<00:00, 258.91it/s]
accuracy: 1.0000, loss: 0.0895 ||: 100%|██████████| 1/1 [00:00<00:00, 315.88it/s]
accuracy: 1.0000, loss: 0.0897 ||: 100%|██████████| 1/1 [00:00<00:00, 94.43it/s]
accuracy: 1.0000, loss: 0.0890 ||: 100%|██████████| 1/1 [00:00<00:00, 600.99it/s]
accuracy: 1.0000, loss: 0.0891 ||: 100%|██████████| 1/1 [00:00<00:00, 101.10it/s]
accuracy: 1.0000, loss: 0.0884 ||: 100%|██████████| 1/1 [00:00<00:00, 253.16it/s]
accuracy: 1.0000, loss: 0.0885 ||: 100%|██████████| 1/1 [00:00<00:00, 95.15it/s]
accuracy: 1.0000, loss: 0.0878 ||: 100%|██████████| 1/1 [00:00<00:00, 680.78it/s]
accuracy: 1.0000, loss: 0.0880 ||: 100%|██████████| 1/1 [00:00<00:00, 143.24it/s]
accuracy: 1.0000, 

accuracy: 1.0000, loss: 0.0531 ||: 100%|██████████| 1/1 [00:00<00:00, 259.48it/s]
accuracy: 1.0000, loss: 0.0528 ||: 100%|██████████| 1/1 [00:00<00:00, 610.08it/s]
accuracy: 1.0000, loss: 0.0529 ||: 100%|██████████| 1/1 [00:00<00:00, 232.09it/s]
accuracy: 1.0000, loss: 0.0526 ||: 100%|██████████| 1/1 [00:00<00:00, 546.49it/s]
accuracy: 1.0000, loss: 0.0527 ||: 100%|██████████| 1/1 [00:00<00:00, 237.37it/s]
accuracy: 1.0000, loss: 0.0523 ||: 100%|██████████| 1/1 [00:00<00:00, 323.66it/s]
accuracy: 1.0000, loss: 0.0524 ||: 100%|██████████| 1/1 [00:00<00:00, 90.71it/s]
accuracy: 1.0000, loss: 0.0521 ||: 100%|██████████| 1/1 [00:00<00:00, 527.92it/s]
accuracy: 1.0000, loss: 0.0522 ||: 100%|██████████| 1/1 [00:00<00:00, 215.20it/s]
accuracy: 1.0000, loss: 0.0518 ||: 100%|██████████| 1/1 [00:00<00:00, 87.98it/s]
accuracy: 1.0000, loss: 0.0519 ||: 100%|██████████| 1/1 [00:00<00:00, 225.79it/s]
accuracy: 1.0000, loss: 0.0516 ||: 100%|██████████| 1/1 [00:00<00:00, 99.03it/s]
accuracy: 1.0000, l

accuracy: 1.0000, loss: 0.0356 ||: 100%|██████████| 1/1 [00:00<00:00, 441.65it/s]
accuracy: 1.0000, loss: 0.0357 ||: 100%|██████████| 1/1 [00:00<00:00, 250.27it/s]
accuracy: 1.0000, loss: 0.0355 ||: 100%|██████████| 1/1 [00:00<00:00, 238.12it/s]
accuracy: 1.0000, loss: 0.0356 ||: 100%|██████████| 1/1 [00:00<00:00, 196.78it/s]
accuracy: 1.0000, loss: 0.0354 ||: 100%|██████████| 1/1 [00:00<00:00, 330.44it/s]
accuracy: 1.0000, loss: 0.0354 ||: 100%|██████████| 1/1 [00:00<00:00, 113.95it/s]
accuracy: 1.0000, loss: 0.0352 ||: 100%|██████████| 1/1 [00:00<00:00, 398.85it/s]
accuracy: 1.0000, loss: 0.0353 ||: 100%|██████████| 1/1 [00:00<00:00, 123.95it/s]
accuracy: 1.0000, loss: 0.0351 ||: 100%|██████████| 1/1 [00:00<00:00, 666.29it/s]
accuracy: 1.0000, loss: 0.0352 ||: 100%|██████████| 1/1 [00:00<00:00, 178.53it/s]
accuracy: 1.0000, loss: 0.0350 ||: 100%|██████████| 1/1 [00:00<00:00, 242.56it/s]
accuracy: 1.0000, loss: 0.0351 ||: 100%|██████████| 1/1 [00:00<00:00, 223.28it/s]
accuracy: 1.0000

accuracy: 1.0000, loss: 0.0263 ||: 100%|██████████| 1/1 [00:00<00:00, 235.61it/s]
accuracy: 1.0000, loss: 0.0262 ||: 100%|██████████| 1/1 [00:00<00:00, 363.65it/s]
accuracy: 1.0000, loss: 0.0262 ||: 100%|██████████| 1/1 [00:00<00:00, 103.16it/s]
accuracy: 1.0000, loss: 0.0261 ||: 100%|██████████| 1/1 [00:00<00:00, 261.31it/s]
accuracy: 1.0000, loss: 0.0261 ||: 100%|██████████| 1/1 [00:00<00:00, 272.82it/s]
accuracy: 1.0000, loss: 0.0260 ||: 100%|██████████| 1/1 [00:00<00:00, 438.64it/s]
accuracy: 1.0000, loss: 0.0261 ||: 100%|██████████| 1/1 [00:00<00:00, 298.51it/s]
accuracy: 1.0000, loss: 0.0259 ||: 100%|██████████| 1/1 [00:00<00:00, 594.09it/s]
accuracy: 1.0000, loss: 0.0260 ||: 100%|██████████| 1/1 [00:00<00:00, 243.08it/s]
accuracy: 1.0000, loss: 0.0259 ||: 100%|██████████| 1/1 [00:00<00:00, 577.25it/s]
accuracy: 1.0000, loss: 0.0259 ||: 100%|██████████| 1/1 [00:00<00:00, 287.85it/s]
accuracy: 1.0000, loss: 0.0258 ||: 100%|██████████| 1/1 [00:00<00:00, 648.27it/s]
accuracy: 1.0000

accuracy: 1.0000, loss: 0.0205 ||: 100%|██████████| 1/1 [00:00<00:00, 180.41it/s]
accuracy: 1.0000, loss: 0.0204 ||: 100%|██████████| 1/1 [00:00<00:00, 593.25it/s]
accuracy: 1.0000, loss: 0.0204 ||: 100%|██████████| 1/1 [00:00<00:00, 244.24it/s]
accuracy: 1.0000, loss: 0.0204 ||: 100%|██████████| 1/1 [00:00<00:00, 521.36it/s]
accuracy: 1.0000, loss: 0.0204 ||: 100%|██████████| 1/1 [00:00<00:00, 212.25it/s]
accuracy: 1.0000, loss: 0.0203 ||: 100%|██████████| 1/1 [00:00<00:00, 576.85it/s]
accuracy: 1.0000, loss: 0.0203 ||: 100%|██████████| 1/1 [00:00<00:00, 240.93it/s]
accuracy: 1.0000, loss: 0.0203 ||: 100%|██████████| 1/1 [00:00<00:00, 581.98it/s]
accuracy: 1.0000, loss: 0.0203 ||: 100%|██████████| 1/1 [00:00<00:00, 142.25it/s]
accuracy: 1.0000, loss: 0.0202 ||: 100%|██████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 1.0000, loss: 0.0203 ||: 100%|██████████| 1/1 [00:00<00:00, 232.65it/s]
accuracy: 1.0000, loss: 0.0202 ||: 100%|██████████| 1/1 [00:00<00:00, 532.61it/s]
accuracy: 1.0000

[38;5;2m✔ Download and installation successful[0m
You can now load the model via spacy.load('en_core_web_sm')
[38;5;2m✔ Linking successful[0m
/home/kir/miniconda3/envs/allennlp/lib/python3.6/site-packages/en_core_web_sm
-->
/home/kir/miniconda3/envs/allennlp/lib/python3.6/site-packages/spacy/data/en_core_web_sm
You can now load the model via spacy.load('en_core_web_sm')
['DET', 'NN', 'V', 'DET', 'NN']


In [None]:

reader = PosDatasetReader()

train_dataset = reader.read(cached_path(
    'https://raw.githubusercontent.com/allenai/allennlp'
    '/master/tutorials/tagger/training.txt'))
validation_dataset = reader.read(cached_path(
    'https://raw.githubusercontent.com/allenai/allennlp'
    '/master/tutorials/tagger/validation.txt'))

vocab = Vocabulary.from_instances(train_dataset + validation_dataset)

EMBEDDING_DIM = 6
HIDDEN_DIM = 6

token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                            embedding_dim=EMBEDDING_DIM)
word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})

lstm = PytorchSeq2SeqWrapper(torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))

model = LstmTagger(word_embeddings, lstm, vocab)

if torch.cuda.is_available():
    cuda_device = 0

    model = model.cuda(cuda_device)
else:

    cuda_device = -1

optimizer = optim.SGD(model.parameters(), lr=0.1)

iterator = BucketIterator(batch_size=2, sorting_keys=[("sentence", "num_tokens")])

iterator.index_with(vocab)

trainer = Trainer(model=model,
                  optimizer=optimizer,
                  iterator=iterator,
                  train_dataset=train_dataset,
                  validation_dataset=validation_dataset,
                  patience=10,
                  num_epochs=1000,
                  cuda_device=cuda_device)

trainer.train()

predictor = SentenceTaggerPredictor(model, dataset_reader=reader)

tag_logits = predictor.predict("The dog ate the apple")['tag_logits']

tag_ids = np.argmax(tag_logits, axis=-1)

print([model.vocab.get_token_from_index(i, 'labels') for i in tag_ids])

# Here's how to save the model.
with open("/tmp/model.th", 'wb') as f:
    torch.save(model.state_dict(), f)

vocab.save_to_files("/tmp/vocabulary")

# And here's how to reload the model.
vocab2 = Vocabulary.from_files("/tmp/vocabulary")

model2 = LstmTagger(word_embeddings, lstm, vocab2)

with open("/tmp/model.th", 'rb') as f:
    model2.load_state_dict(torch.load(f))

if cuda_device > -1:
    model2.cuda(cuda_device)

predictor2 = SentenceTaggerPredictor(model2, dataset_reader=reader)
tag_logits2 = predictor2.predict("The dog ate the apple")['tag_logits']
np.testing.assert_array_almost_equal(tag_logits2, tag_logits)



In [1]:
import unicodedata

import string


from typing import Iterator, List, Dict

import torch
import torch.optim as optim
import numpy as np

from allennlp.data import Instance
from allennlp.data.fields import TextField, SequenceLabelField

from allennlp.data.dataset_readers import DatasetReader

from allennlp.common.file_utils import cached_path

from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token

from allennlp.data.vocabulary import Vocabulary

from allennlp.models import Model

from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits

from allennlp.training.metrics import CategoricalAccuracy

from allennlp.data.iterators import BucketIterator

from allennlp.training.trainer import Trainer

from allennlp.predictors import SentenceTaggerPredictor

torch.manual_seed(1)


all_letters = string.ascii_letters + " .,;'"
# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
        and c in all_letters
    )


@DatasetReader.register('names-tagger-reader')
class NamesDatasetReader(DatasetReader):
    """
    DatasetReader for names  
    """

    def __init__(self, token_indexers: Dict[str, TokenIndexer] = None) -> None:
        super().__init__(lazy=False)
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}

    def text_to_instance(self, tokens: List[Token], tags: List[str] = None) -> Instance:
        name_field = TextField(tokens, self.token_indexers)
        fields = {"name": name_field}

        if tags:
            label_field = SequenceLabelField(labels=tags, sequence_field=name_field)
            fields["label"] = label_field

        return Instance(fields)


    def _read(self, file_path: str) -> Iterator[Instance]:
        with open(cached_path(file_path)) as f:
            for i in range(10):
                names_list = list()
                tags_list = list()
                for line in f:
                    name, tag = unicodeToAscii(line).strip().split()
                    names_list.append(name)
                    tags_list.append(tag)
                    yield self.text_to_instance([Token(name) for name in names_list], tags_list)

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [2]:
reader = NamesDatasetReader()

train_dataset = reader.read('data/train.txt')
validation_dataset = reader.read('data/val.txt')



vocab = Vocabulary.from_instances(train_dataset + validation_dataset)

0it [00:00, ?it/s]

NameError: name 'tag' is not defined