In [9]:
import os
#os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import torch

from typing import *

from allennlp.modules.seq2seq_encoders import *
from allennlp.training.metrics import Perplexity
from allennlp.data import DatasetReader, Instance, Field
from allennlp.data.fields import LabelField, TextField
from allennlp.data.token_indexers import *
from allennlp.data.tokenizers import *
from allennlp.data.data_loaders import *
from allennlp.modules.text_field_embedders import *
from allennlp.modules.token_embedders import *

import torch
from allennlp.data import Vocabulary, TextFieldTensors
from allennlp.models import Model
from allennlp.modules import TextFieldEmbedder, Seq2SeqEncoder
from allennlp.data.token_indexers import *
from allennlp.data.tokenizers import *
from allennlp.nn import util

from allennlp.data import DataLoader, DatasetReader, Instance, Vocabulary
from allennlp.data.data_loaders import MultiProcessDataLoader
from allennlp.models import Model
from allennlp.modules.seq2vec_encoders import BagOfEmbeddingsEncoder
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.training.trainer import GradientDescentTrainer, Trainer
from allennlp.training.optimizers import AdamOptimizer

from allennlp.predictors import *

import glob
import os

import numpy as np

In [142]:
DATA_ROOT = '/local1/d0/447-data'

START_TOKEN = "@@START@@"

In [179]:
#@DatasetReader.register("all-data")
class TextReader(DatasetReader):
    def __init__(self, tokenizer_in: Tokenizer = None, tokenizer_out: Tokenizer = None, 
                 token_indexer_in = None, token_indexer_out = None, 
                 max_tokens: int = None, truncate_last_in: bool = True, 
                 include_labels: bool = True, **kwargs):
        super().__init__(**kwargs)
        self.tokenizer_in = tokenizer_in or CharacterTokenizer(start_tokens=[START_TOKEN])
        self.tokenizer_out = tokenizer_out or CharacterTokenizer()
        self.token_indexers_in = {"tokens": token_indexer_in or SingleIdTokenIndexer('tokens')}
        self.token_indexers_out = {"labels": token_indexer_in or SingleIdTokenIndexer('labels')}
        self.max_tokens = max_tokens
        self.truncate_last_in = truncate_last_in
        self.include_labels = include_labels
    
    def text_to_instance(self, text: str) -> Instance:  # type: ignore
        tokens_in = self.tokenizer_in.tokenize(text)
        if self.truncate_last_in:
            tokens_in = tokens_in[:-1]
        tokens_out = self.tokenizer_out.tokenize(text)
        if self.max_tokens:
            tokens_in = tokens_in[: self.max_tokens]
            tokens_out = tokens_out[: self.max_tokens]
        text_field = TextField(tokens_in, self.token_indexers_in)
        fields: Dict[str, Field] = {"text": text_field}
        if self.include_labels:
            fields["labels"] = TextField(tokens_out, self.token_indexers_out)
        return Instance(fields)
    
    def _read(self, file_root: str) -> Iterable[Instance]:
        filenames = glob.glob(os.path.join(DATA_ROOT, file_root, '**/*'), recursive=True)
        for filename in filenames:
            with open(filename) as file:
                for line in file:
                    line = line.strip()
                    if not len(line):
                        continue
                    yield self.text_to_instance(line)

In [180]:
#@Model.register('lstm_lm')
class LSTMLM(Model):
    def __init__(
        self, vocab: Vocabulary, embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder):
        super().__init__(vocab)
        self.embedder = embedder
        self.encoder = encoder
        self.vocab_size = vocab.get_vocab_size('labels')
        self.classifier = torch.nn.Linear(encoder.get_output_dim(), self.vocab_size)
        self.perplexity = Perplexity()
    def forward(self, text: TextFieldTensors, labels: TextFieldTensors = None) -> Dict[str, torch.Tensor]:
        # Shape: (batch_size, num_tokens, embedding_dim)
        embedded_text = self.embedder(text)
        # Shape: (batch_size, num_tokens)
        mask = util.get_text_field_mask(text)
        # Shape: (batch_size, num_tokens, encoding_dim)
        encoded_text = self.encoder(embedded_text, mask)
        # Shape: (batch_size, num_tokens, vocab_size)
        logits = self.classifier(encoded_text)
        # Shape: (batch_size, num_tokens, vocab_size)
        probs = torch.nn.functional.softmax(logits, dim=-1)
        output = {}
        output['probs'] = probs
        if labels is not None:
            labels = labels['labels']['tokens']
            labels[~mask] = -100
            loss = torch.nn.functional.cross_entropy(logits.view(-1, self.vocab_size), labels.view(-1))
            output["loss"] = loss
            self.perplexity(loss)
        return output

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"perplexity": self.perplexity.get_metric(reset)}

In [181]:
def build_vocab():
    reader = TextReader(truncate_last_in=False)
    loader = MultiProcessDataLoader(reader, 'train', batch_size=8, shuffle=False)
    vocab = Vocabulary.from_instances(loader.iter_instances())
    return vocab

In [182]:
device = torch.device('cuda')



reader = TextReader()

train_loader = MultiProcessDataLoader(reader, 'train', batch_size=8, shuffle=True)
vocab = build_vocab()

vocab_size = vocab.get_vocab_size()
embedder = BasicTextFieldEmbedder(
    {'tokens': Embedding(embedding_dim=50, num_embeddings=vocab_size)}
)
encoder = LstmSeq2SeqEncoder(50, 50)
model = LSTMLM(vocab, embedder, encoder)
#model.to(device)

train_loader.index_with(vocab)
parameters = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
trainer = GradientDescentTrainer(model=model, data_loader=train_loader, num_epochs=200, optimizer=AdamOptimizer(parameters), 
                                 validation_metric='-perplexity',
                                cuda_device=-1)
trainer.train()


loading instances: |          | 0/? [00:00<?, ?it/s]

loading instances: |          | 0/? [00:00<?, ?it/s]

building vocab: |          | 0/? [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

{'best_epoch': 199,
 'peak_worker_0_memory_MB': 2903.859375,
 'peak_gpu_0_memory_MB': 0.87255859375,
 'training_duration': '0:00:22.599617',
 'training_start_epoch': 0,
 'training_epochs': 199,
 'epoch': 199,
 'training_perplexity': 2.1682035641246893,
 'training_loss': 0.7738989740610123,
 'training_worker_0_memory_MB': 2903.859375,
 'training_gpu_0_memory_MB': 0.87255859375}

In [183]:

#@Predictor.register('my_predictor')
class MyPredictor(Predictor):
    def predict(self, sentence):
        return self.predict_json({'sentence': sentence})
    
    def _json_to_instance(self, json_dict):
        sentence = json_dict["sentence"]
        return self._dataset_reader.text_to_instance(sentence)


In [191]:
pred = MyPredictor(model, TextReader(truncate_last_in=False, include_labels=False))
outputs = pred.predict('LST')
pairs = [(vocab.get_token_from_index(token_id, 'labels'), prob) for token_id, prob in enumerate(outputs['probs'][-1])]
pairs.sort(key=lambda x: x[1], reverse=True)
#print([[(vocab.get_token_from_index(token_id, 'tokens'), prob) 
#        for token_id, prob in enumerate(output)] 
#       for output in outputs['probs']])
for item in pairs[:10]:
    print(item)

('M', 0.3894152343273163)
('e', 0.23108302056789398)
('L', 0.08222412317991257)
('(', 0.04794950783252716)
('T', 0.04179971665143967)
('a', 0.02521764487028122)
('S', 0.023489294573664665)
(':', 0.022975564002990723)
('h', 0.021479301154613495)
('o', 0.020070254802703857)
