<a href="https://colab.research.google.com/github/RtjShreyD/Eng-Mandarin/blob/master/Train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install allennlp

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [3]:
import itertools

import torch
import torch.optim as optim
from allennlp.data.dataset_readers.seq2seq import Seq2SeqDatasetReader
from allennlp.data.iterators import BucketIterator
from allennlp.data.token_indexers import SingleIdTokenIndexer
from allennlp.data.tokenizers.character_tokenizer import CharacterTokenizer
from allennlp.data.tokenizers.word_tokenizer import WordTokenizer
from allennlp.data.vocabulary import Vocabulary
from allennlp.nn.activations import Activation
from allennlp.models.encoder_decoders.simple_seq2seq import SimpleSeq2Seq
from allennlp.modules.attention import LinearAttention, BilinearAttention, DotProductAttention
from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper, StackedSelfAttentionEncoder
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.predictors import SimpleSeq2SeqPredictor
from allennlp.training.trainer import Trainer

EN_EMBEDDING_DIM = 256
ZH_EMBEDDING_DIM = 256
HIDDEN_DIM = 256
CUDA_DEVICE = -1

def main():
    reader = Seq2SeqDatasetReader(
        source_tokenizer=WordTokenizer(),
        target_tokenizer=CharacterTokenizer(),
        source_token_indexers={'tokens': SingleIdTokenIndexer()},
        target_token_indexers={'tokens': SingleIdTokenIndexer(namespace='target_tokens')})
    
    train_dataset = reader.read(
        '/content/drive/My Drive/Eng_Mandarin/data/tatoeba.eng_cmn.train.tsv')
    validation_dataset = reader.read(
        '/content/drive/My Drive/Eng_Mandarin/data/tatoeba.eng_cmn.dev.tsv')

    vocab = Vocabulary.from_instances(train_dataset + validation_dataset,
                                      min_count={'tokens': 3, 'target_tokens': 3})

    en_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                             embedding_dim=EN_EMBEDDING_DIM)
    # encoder = PytorchSeq2SeqWrapper(
    #     torch.nn.LSTM(EN_EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))
    
    encoder = StackedSelfAttentionEncoder(
        input_dim=EN_EMBEDDING_DIM, 
        hidden_dim=HIDDEN_DIM, 
        projection_dim=128, 
        feedforward_hidden_dim=128, 
        num_layers=1, 
        num_attention_heads=8)

    source_embedder = BasicTextFieldEmbedder({"tokens": en_embedding})

    #attention = LinearAttention(HIDDEN_DIM, HIDDEN_DIM, activation=Activation.by_name('tanh')())
    #attention = BilinearAttention(HIDDEN_DIM, HIDDEN_DIM)
    attention = DotProductAttention()

    max_decoding_steps = 20 

    model = SimpleSeq2Seq(vocab, source_embedder, encoder, max_decoding_steps,
                          target_embedding_dim=ZH_EMBEDDING_DIM,
                          target_namespace='target_tokens',
                          attention=attention,
                          beam_size=8,
                          use_bleu=True)
    optimizer = optim.Adam(model.parameters())
    iterator = BucketIterator(batch_size=32, sorting_keys=[("source_tokens", "num_tokens")])

    iterator.index_with(vocab)

    trainer = Trainer(model=model,
                      optimizer=optimizer,
                      iterator=iterator,
                      train_dataset=train_dataset,
                      validation_dataset=validation_dataset,
                      num_epochs=10,
                      patience=1,
                      cuda_device=CUDA_DEVICE)

    trainer.train()
    predictor = SimpleSeq2SeqPredictor(model, reader)

    for instance in itertools.islice(validation_dataset, 10):
      print('SOURCE:', instance.fields['source_tokens'].tokens)
      print('GOLD:', instance.fields['target_tokens'].tokens)
      print('PRED:', predictor.predict_instance(instance)['predicted_tokens'])
      print('////////////////////////////////////////////\n')

    with open("/content/drive/My Drive/Eng_Mandarin/data/model.th", 'wb') as f:
      torch.save(model.state_dict(), f)

      vocab.save_to_files("/content/drive/My Drive/Eng_Mandarin/data/vocabulary")



    # for i in range(50):
    #     print('Epoch: {}'.format(i))
    #     trainer.train()

    #     predictor = SimpleSeq2SeqPredictor(model, reader)

    #     for instance in itertools.islice(validation_dataset, 10):
    #         print('SOURCE:', instance.fields['source_tokens'].tokens)
    #         print('GOLD:', instance.fields['target_tokens'].tokens)
    #         print('PRED:', predictor.predict_instance(instance)['predicted_tokens'])
    #         print('\n')


if __name__ == '__main__':
    main()


0it [00:00, ?it/s][A
286it [00:00, 2854.45it/s][A
556it [00:00, 2802.96it/s][A
697it [00:00, 644.84it/s] [A
951it [00:00, 830.79it/s][A
1237it [00:01, 1055.34it/s][A
1513it [00:01, 1295.20it/s][A
1772it [00:01, 1523.14it/s][A
2045it [00:01, 1755.24it/s][A
2310it [00:01, 1952.34it/s][A
2565it [00:01, 2098.10it/s][A
2816it [00:01, 2183.92it/s][A
3064it [00:01, 2243.27it/s][A
3318it [00:01, 2323.58it/s][A
3585it [00:01, 2416.11it/s][A
3839it [00:02, 2451.35it/s][A
4113it [00:02, 2529.26it/s][A
4375it [00:02, 2551.10it/s][A
4635it [00:02, 2508.52it/s][A
4897it [00:02, 2538.46it/s][A
5154it [00:02, 2491.28it/s][A
5406it [00:02, 2296.36it/s][A
5655it [00:02, 2350.51it/s][A
5916it [00:02, 2421.08it/s][A
6181it [00:02, 2483.89it/s][A
6450it [00:03, 2540.78it/s][A
6716it [00:03, 2574.65it/s][A
6980it [00:03, 2593.66it/s][A
7250it [00:03, 2623.11it/s][A
7531it [00:03, 2674.88it/s][A
7806it [00:03, 2695.40it/s][A
8077it [00:03, 2672.09it/s][A
8347it [00:03, 2678.7

SOURCE: [@start@, I, have, to, go, to, sleep, ., @end@]
GOLD: [@start@, 我, 该, 去, 睡, 觉, 了, 。, @end@]
PRED: ['我', '去', '睡', '觉', '了', '。']
////////////////////////////////////////////

SOURCE: [@start@, I, just, do, n't, know, what, to, say, ., @end@]
GOLD: [@start@, 我, 就, 是, 不, 知, 道, 說, 些, 什, 麼, 。, @end@]
PRED: ['我', '不', '知', '道', '該', '怎', '麼', '知', '道', '。']
////////////////////////////////////////////

SOURCE: [@start@, I, may, give, up, soon, and, just, nap, instead, ., @end@]
GOLD: [@start@, 也, 许, 我, 会, 马, 上, 放, 弃, 然, 后, 去, 睡, 一, 觉, 。, @end@]
PRED: ['我', '可', '能', '忍', '受', '不', '舒', '服', '。']
////////////////////////////////////////////

SOURCE: [@start@, I, 'm, going, to, go, ., @end@]
GOLD: [@start@, 我, 要, 走, 了, 。, @end@]
PRED: ['我', '要', '去', '。']
////////////////////////////////////////////

SOURCE: [@start@, That, 's, MY, line, !, @end@]
GOLD: [@start@, 那, 是, 我, 的, 台, 词, ！, @end@]
PRED: ['那', '是', '胡', '扯', '!']
////////////////////////////////////////////

SOURCE: [@start@,