In [10]:
import itertools

import torch
import torch.optim as optim
from allennlp.data.dataset_readers.seq2seq import Seq2SeqDatasetReader
from allennlp.data.iterators import BasicIterator
from allennlp.data.token_indexers import SingleIdTokenIndexer
from allennlp.data.tokenizers.character_tokenizer import CharacterTokenizer
from allennlp.data.tokenizers.word_tokenizer import WordTokenizer
from allennlp.data.vocabulary import Vocabulary
from allennlp.nn.activations import Activation
from allennlp.models.encoder_decoders.simple_seq2seq import SimpleSeq2Seq
from allennlp.modules.attention import LinearAttention, BilinearAttention, DotProductAttention
from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper, StackedSelfAttentionEncoder
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.predictors import SimpleSeq2SeqPredictor
from allennlp.training.trainer import Trainer

from tqdm import tqdm as tqdm
from allennlp.data import DatasetReader
from allennlp.models.archival import load_archive

from allennlp.models.archival import archive_model

import os

In [11]:
CUDA_DEVICE = 0

reader = Seq2SeqDatasetReader(
    source_tokenizer=CharacterTokenizer(),
)

train_dataset = reader.read('data/train.txt')
validation_dataset = reader.read('data/dev.txt')
test_dataset = reader.read('data/test.txt')

vocab = Vocabulary.from_instances(train_dataset + validation_dataset)

20000it [00:00, 24330.91it/s]
2000it [00:00, 48780.62it/s]
2000it [00:00, 9523.82it/s]
100%|█████████████████████████████████████████████████████████████████████████| 22000/22000 [00:00<00:00, 99999.99it/s]


In [13]:
embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'), embedding_dim=20)
encoder = PytorchSeq2SeqWrapper(torch.nn.LSTM(20, 100, batch_first=True, bidirectional=True, dropout=0.3))

source_embedder = BasicTextFieldEmbedder({"tokens": embedding})

attention = LinearAttention(200, 200)

model = SimpleSeq2Seq(vocab, source_embedder, encoder, 21,
                      attention=attention,
                      beam_size=5,
                      scheduled_sampling_ratio=0.5,
                      use_bleu=True)

optimizer = optim.Adam(model.parameters(), lr=0.01)
iterator = BasicIterator(batch_size=4096)

iterator.index_with(vocab)

In [14]:
model.to(device=CUDA_DEVICE)

SimpleSeq2Seq(
  (_source_embedder): BasicTextFieldEmbedder(
    (token_embedder_tokens): Embedding()
  )
  (_encoder): PytorchSeq2SeqWrapper(
    (_module): LSTM(20, 100, batch_first=True, dropout=0.3, bidirectional=True)
  )
  (_attention): LinearAttention()
  (_target_embedder): Embedding()
  (_decoder_cell): LSTMCell(220, 200)
  (_output_projection_layer): Linear(in_features=200, out_features=141, bias=True)
)

In [16]:
trainer = Trainer(model=model,
                  optimizer=optimizer,
                  iterator=iterator,
                  train_dataset=train_dataset,
                  validation_dataset=validation_dataset,
                  num_epochs=100,
                  model_save_interval=10,
                  num_serialized_models_to_keep=3,
                  serialization_dir='log',
                  cuda_device=CUDA_DEVICE)

trainer.train()

loss: 3.9660 ||: 100%|███████████████████████████████████████████████████████████████████| 5/5 [00:02<00:00,  2.03it/s]
BLEU: 0.0000, loss: 3.2764 ||: 100%|█████████████████████████████████████████████████████| 1/1 [00:04<00:00,  4.05s/it]
loss: 3.2246 ||: 100%|███████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  2.93it/s]
BLEU: 0.0136, loss: 3.1169 ||: 100%|█████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.39s/it]
loss: 3.1057 ||: 100%|███████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  2.94it/s]
BLEU: 0.0000, loss: 3.0275 ||: 100%|█████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.71s/it]
loss: 3.0314 ||: 100%|███████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  2.63it/s]
BLEU: 0.0000, loss: 3.0279 ||: 100%|█████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.36s/it]
loss: 2.9909 ||: 100%|██████████████████

KeyboardInterrupt: 

In [39]:
!cp simple_seq2seq.json log/config.json
vocab.save_to_files(os.path.join("log", "vocabulary"))
archive_model("log")

In [None]:
mymodel = load_archive('log/model.tar.gz', 0)

In [44]:
mymodel.model.eval()
dataset_reader_params = mymodel.config["dataset_reader"]
dataset_reader = DatasetReader.from_params(dataset_reader_params)
test_data = dataset_reader.read('data/test.txt')

2000it [00:00, 6711.44it/s]


In [45]:
total = 0
correct = 0
for i in tqdm(test_data):
    gold = (''.join([ str(t) for t in i["target_tokens"][1:-1]]))
    pred = ''.join([str(t) for t in mymodel.model.forward_on_instance(i)["predicted_tokens"]])
    if gold == pred:
        correct += 1
    total += 1
print(float(correct) / total)

100%|██████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:46<00:00, 42.60it/s]

0.4465





In [46]:
mymodel.model.forward_on_instance(dataset_reader.text_to_instance("думала"))

{'class_log_probabilities': array([-0.27276325, -2.9633589 , -3.3173828 , -3.5159779 , -4.247253  ],
       dtype=float32),
 'predictions': array([[18, 21, 17,  5,  8, 16,  3],
        [18, 21, 17,  5,  8,  3,  3],
        [18, 21, 17,  5,  8,  5,  3],
        [18, 21, 17,  5, 13, 16,  3],
        [18, 21, 17,  5, 13,  3,  3]], dtype=int64),
 'predicted_tokens': ['д', 'у', 'м', 'а', 'т', 'ь']}