In [1]:
import argparse
import os
import time
import numpy as np
import json

import torch
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from random import shuffle, seed

from dataset import SequencePairDataset
from model.encoder_decoder import EncoderDecoder
from evaluate import evaluate
from utils import to_np, trim_seqs
import bcolz, pickle
from tensorboardX import SummaryWriter
from tqdm import tqdm
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction


## Load Data

In [2]:
seed(0)

In [3]:
with open("data/cleanup.golden.jsonl", "r") as f:
    data = [json.loads(line) for line in f.readlines()]
shuffle(data)

In [4]:
raw_train_data = data[:int(len(data)*0.8)]
raw_test_data = data[int(len(data)*0.8):]

In [5]:
def parse_raw_data(raw_data):
    train_data = []
    for dp in raw_data:
        train_data.append((
            dp['natural'],
            dp['raw_ltl']
        ))
    return train_data

In [6]:
train_data = parse_raw_data(raw_train_data)
test_data = parse_raw_data(raw_test_data)

In [7]:
CUDA = True
train_dataset = SequencePairDataset(train_data,
                                    use_cuda=CUDA,
                                    is_val=False,
                                    use_extended_vocab=True)

val_dataset = SequencePairDataset(test_data,lang=train_dataset.lang,
                                    use_cuda=CUDA,
                                    is_val=True,
                                    use_extended_vocab=True)

In [8]:
BATCH_SIZE = 32
train_data_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_data_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

## Load Glove

In [9]:
GLOVE = True

In [10]:
def vectors_for_input_language(lang):
    # source for this code: https://medium.com/@martinpella/how-to-use-pre-trained-word-embeddings-in-pytorch-71ca59249f76
    glove_path = 'glove/'
    vectors = bcolz.open(glove_path + '6B.50.dat')[:]
    words = pickle.load(open(glove_path + '6B.50_words.pkl', 'rb'))
    word2idx = pickle.load(open(glove_path + '6B.50_idx.pkl', 'rb'))

    glove = {w: vectors[word2idx[w]] for w in words}

    target_vocab = lang.idx_to_tok

    emb_dim = 50

    matrix_len = len(target_vocab)
    weights_matrix = np.zeros((matrix_len, emb_dim))
    words_found = 0
    i = 0

    for word in target_vocab:
        try:
            # print(target_vocab[word])
            weights_matrix[i] = glove[target_vocab[word]]
            # print(i)
            # print(word)
            words_found += 1
            i += 1
        except KeyError:
            weights_matrix[i] = np.random.normal(scale=0.6, size=(emb_dim,))

    return weights_matrix


In [11]:
GLOVE = True
if GLOVE:
    glove_map = vectors_for_input_language(train_dataset.lang)
    # glove_encoder = TestEncoderRNN(
    #     input_lang.n_words, embed_size, hidden_size, glove_map)

In [16]:
len(train_dataset.lang.idx_to_tok)

195

In [12]:
glove_map.shape

(195, 50)

## Build CopyNet Encoder Decoder Model

In [13]:
MAX_LEN = 64 # same with the dataset length
EMBED_SIZE = 50
HIDDEN_SIZE = 256
MODEL_NAME = "TEST"
encoder_decoder = EncoderDecoder(train_dataset.lang,
                                    max_length=MAX_LEN,
                                    embedding_size=EMBED_SIZE,
                                    hidden_size= HIDDEN_SIZE,
                                    decoder_type = 'copy')
encoder_decoder = encoder_decoder.cuda() if CUDA else encoder_decoder

In [17]:
encoder_decoder.encoder.embedding.weight.data.shape

torch.Size([195, 50])

### replace the randomly initialized embedding with glove embedding

In [21]:
encoder_decoder.encoder.embedding.weight.data.copy_(torch.from_numpy(glove_map))

tensor([[ 0.5152,  0.8012, -0.1373,  ..., -1.1572, -0.0800,  0.0821],
        [ 0.4180,  0.2497, -0.4124,  ..., -0.1841, -0.1151, -0.7858],
        [ 0.6805, -0.0393,  0.3019,  ..., -0.0733, -0.0647, -0.2604],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')

In [22]:
from train import train

In [23]:
epochs = 50
writer = SummaryWriter('./logs/%s_%s' % (MODEL_NAME, str(int(time.time()))))
train(encoder_decoder,
        train_data_loader,
        model_name=MODEL_NAME,
        val_data_loader=val_data_loader,
        keep_prob=1.0,
        teacher_forcing_schedule=np.arange(1.0, 0.0, -1.0/epochs),
        lr=0.001,
        max_length=encoder_decoder.decoder.max_length,
        writer=writer,)

epoch 0


100%|██████████| 85/85 [00:14<00:00,  5.89it/s]
  input_variable = Variable(
  target_variable = Variable(target_idxs[order, :], volatile=True)
100%|██████████| 22/22 [00:01<00:00, 15.20it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 12, 2]
val loss: 5.92090, val score: 0.24815
----------------------------------------------------------------------------------------------------
epoch 1



100%|██████████| 85/85 [00:12<00:00,  6.96it/s]
100%|██████████| 22/22 [00:01<00:00, 15.29it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 12, 2]
val loss: 11.76815, val score: 0.38848
----------------------------------------------------------------------------------------------------
epoch 2



100%|██████████| 85/85 [00:11<00:00,  7.21it/s]
100%|██████████| 22/22 [00:01<00:00, 15.30it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 12, 2]
val loss: 28.51410, val score: 0.51699
----------------------------------------------------------------------------------------------------
epoch 3



100%|██████████| 85/85 [00:11<00:00,  7.12it/s]
100%|██████████| 22/22 [00:01<00:00, 14.24it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: 60.49424, val score: 0.60118
----------------------------------------------------------------------------------------------------
epoch 4



100%|██████████| 85/85 [00:11<00:00,  7.11it/s]
100%|██████████| 22/22 [00:01<00:00, 15.40it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: 124.37651, val score: 0.67208
----------------------------------------------------------------------------------------------------
epoch 5



100%|██████████| 85/85 [00:11<00:00,  7.15it/s]
100%|██████████| 22/22 [00:01<00:00, 15.52it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: 159.89110, val score: 0.69276
----------------------------------------------------------------------------------------------------
epoch 6



100%|██████████| 85/85 [00:11<00:00,  7.39it/s]
100%|██████████| 22/22 [00:01<00:00, 15.49it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -10096.18736, val score: 0.81241
----------------------------------------------------------------------------------------------------
epoch 7



100%|██████████| 85/85 [00:11<00:00,  7.18it/s]
100%|██████████| 22/22 [00:01<00:00, 15.33it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -160.96440, val score: 0.87149
----------------------------------------------------------------------------------------------------
epoch 8



100%|██████████| 85/85 [00:11<00:00,  7.12it/s]
100%|██████████| 22/22 [00:01<00:00, 15.42it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -158.89162, val score: 0.88479
----------------------------------------------------------------------------------------------------
epoch 9



100%|██████████| 85/85 [00:11<00:00,  7.21it/s]
100%|██████████| 22/22 [00:01<00:00, 15.43it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -134.46446, val score: 0.91433
----------------------------------------------------------------------------------------------------
epoch 10



100%|██████████| 85/85 [00:11<00:00,  7.13it/s]
100%|██████████| 22/22 [00:01<00:00, 14.08it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -164.46170, val score: 0.91433
----------------------------------------------------------------------------------------------------
epoch 11



100%|██████████| 85/85 [00:11<00:00,  7.11it/s]
100%|██████████| 22/22 [00:01<00:00, 15.05it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -116.49741, val score: 0.92171
----------------------------------------------------------------------------------------------------
epoch 12



100%|██████████| 85/85 [00:11<00:00,  7.17it/s]
100%|██████████| 22/22 [00:01<00:00, 15.34it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -90.29626, val score: 0.94092
----------------------------------------------------------------------------------------------------
epoch 13



100%|██████████| 85/85 [00:11<00:00,  7.42it/s]
100%|██████████| 22/22 [00:01<00:00, 14.66it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -91.09345, val score: 0.94978
----------------------------------------------------------------------------------------------------
epoch 14



100%|██████████| 85/85 [00:12<00:00,  7.04it/s]
100%|██████████| 22/22 [00:01<00:00, 15.41it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -87.12136, val score: 0.95569
----------------------------------------------------------------------------------------------------
epoch 15



100%|██████████| 85/85 [00:11<00:00,  7.19it/s]
100%|██████████| 22/22 [00:01<00:00, 15.35it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -86.06262, val score: 0.96012
----------------------------------------------------------------------------------------------------
epoch 16



100%|██████████| 85/85 [00:11<00:00,  7.26it/s]
100%|██████████| 22/22 [00:01<00:00, 15.42it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -94.31878, val score: 0.94830
----------------------------------------------------------------------------------------------------
epoch 17



100%|██████████| 85/85 [00:11<00:00,  7.17it/s]
100%|██████████| 22/22 [00:01<00:00, 15.29it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -86.36654, val score: 0.95569
----------------------------------------------------------------------------------------------------
epoch 18



100%|██████████| 85/85 [00:12<00:00,  7.02it/s]
100%|██████████| 22/22 [00:01<00:00, 14.19it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -84.47589, val score: 0.95864
----------------------------------------------------------------------------------------------------
epoch 19



100%|██████████| 85/85 [00:11<00:00,  7.37it/s]
100%|██████████| 22/22 [00:01<00:00, 15.40it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -88.74404, val score: 0.96012
----------------------------------------------------------------------------------------------------
epoch 20



100%|██████████| 85/85 [00:11<00:00,  7.20it/s]
100%|██████████| 22/22 [00:01<00:00, 15.20it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -75.66516, val score: 0.96603
----------------------------------------------------------------------------------------------------
epoch 21



100%|██████████| 85/85 [00:12<00:00,  7.01it/s]
100%|██████████| 22/22 [00:01<00:00, 15.19it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -76.78558, val score: 0.96603
----------------------------------------------------------------------------------------------------
epoch 22



100%|██████████| 85/85 [00:11<00:00,  7.19it/s]
100%|██████████| 22/22 [00:01<00:00, 15.62it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -86.79571, val score: 0.95864
----------------------------------------------------------------------------------------------------
epoch 23



100%|██████████| 85/85 [00:11<00:00,  7.18it/s]
100%|██████████| 22/22 [00:01<00:00, 15.24it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -78.93122, val score: 0.96012
----------------------------------------------------------------------------------------------------
epoch 24



100%|██████████| 85/85 [00:11<00:00,  7.16it/s]
100%|██████████| 22/22 [00:01<00:00, 15.61it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -75.84130, val score: 0.97194
----------------------------------------------------------------------------------------------------
epoch 25



100%|██████████| 85/85 [00:12<00:00,  6.98it/s]
100%|██████████| 22/22 [00:01<00:00, 14.28it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -77.53520, val score: 0.97046
----------------------------------------------------------------------------------------------------
epoch 26



100%|██████████| 85/85 [00:11<00:00,  7.30it/s]
100%|██████████| 22/22 [00:01<00:00, 15.25it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -75.83220, val score: 0.97046
----------------------------------------------------------------------------------------------------
epoch 27



100%|██████████| 85/85 [00:11<00:00,  7.20it/s]
100%|██████████| 22/22 [00:01<00:00, 15.41it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -77.01484, val score: 0.97046
----------------------------------------------------------------------------------------------------
epoch 28



100%|██████████| 85/85 [00:11<00:00,  7.18it/s]
100%|██████████| 22/22 [00:01<00:00, 15.22it/s]

Sample output:
              [[1, 6, 8, 28, 14, 17, 6, 13, 2]]
              [1, 6, 8, 28, 14, 17, 6, 13, 2]
val loss: -74.73236, val score: 0.96307
----------------------------------------------------------------------------------------------------
epoch 29



 32%|███▏      | 27/85 [00:03<00:07,  7.36it/s]


KeyboardInterrupt: 