**Declaration**: Most code of this work is from https://github.com/tensorlayer/seq2seq-chatbot. I changed some of data processing code and modified the model in order to make it work on my laptop. An important goal of this final project is to compare the impacts of pretrained embedding and trained-from-scratch embedding to the results. My work is mainly focused on introducing the GloVe embedding matrix, in which I have to modify the code from data prepratation and totally rewrite the embedding layer part of the model. 

In [1]:
import os
import urllib
import zipfile
import nltk
import numpy as np
import tensorflow as tf
import pickle
import re
import string

# Load GloVe

In [2]:
EMBEDDING_DIMENSION=100 # Available dimensions for 6B data is 50, 100, 200, 300
data_directory = './'

In [3]:
if not os.path.isdir(data_directory):
    os.path.makedirs(data_directory)
    
glove_weights_file_path = os.path.join(data_directory, f'glove.6B.{EMBEDDING_DIMENSION}d.txt')

In [4]:
if not os.path.isfile(glove_weights_file_path):
    # Glove embedding weights can be downloaded from https://nlp.stanford.edu/projects/glove/
    glove_fallback_url = 'http://nlp.stanford.edu/data/glove.6B.zip'
    local_zip_file_path = os.path.join(data_directory, os.path.basename(glove_fallback_url))
    if not os.path.isfile(local_zip_file_path):
        print(f'Retreiving glove weights from {glove_fallback_url}')
        urllib.request.urlretrieve(glove_fallback_url, local_zip_file_path)
    with zipfile.ZipFile(local_zip_file_path, 'r') as z:
        print(f'Extracting glove weights from {local_zip_file_path}')
        z.extractall(path=data_directory)

In [5]:
PAD_TOKEN = 0
glove_data_directory = '.'

word2idx = { 'PAD': PAD_TOKEN } # dict so we can lookup indices for tokenising our text later from string to sequence of integers
weights = []

with open (glove_data_directory + '/' +'glove.6B.100d.txt', 'r') as file:
    for index, line in enumerate(file):
        values = line.split() # Word and weights separated by space
        word = values[0] # Word is first symbol on each line
        word_weights = np.asarray(values[1:], dtype=np.float32) # Remainder of line is weights for word
        word2idx[word] = index + 1 # PAD is our zeroth index so shift by one
        weights.append(word_weights)
        
        if index + 1 == 40_000:
            # Limit vocabulary to top 40k terms
            break

EMBEDDING_DIMENSION = len(weights[0])
# Insert the PAD weights at index 0 now we know the embedding dimension
weights.insert(0, np.random.randn(EMBEDDING_DIMENSION))

# Append unknown and pad to end of vocab and initialize as random
UNKNOWN_TOKEN = len(weights)
word2idx['UNK'] = UNKNOWN_TOKEN
weights.append(np.random.randn(EMBEDDING_DIMENSION))

# Construct our final vocab
weights = np.asarray(weights, dtype=np.float32)

VOCAB_SIZE = weights.shape[0]

# Data Preparation

In [6]:
EN_WHITELIST = '0123456789abcdefghijklmnopqrstuvwxyz ' # space is included in whitelist
EN_BLACKLIST = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\''

limit = {
        'maxq' : 25,
        'minq' : 2,
        'maxa' : 25,
        'mina' : 2
        }
UNK = 'UNK'

In [7]:
'''
    1. Read from 'movie-lines.txt'
    2. Create a dictionary with ( key = line_id, value = text )
'''
def get_id2line():
    lines=open('../../../nlp_data/movie_lines.txt', encoding='utf-8', errors='ignore').read().split('\n')
    id2line = {}
    for line in lines:
        _line = line.split(' +++$+++ ')
        if len(_line) == 5:
            id2line[_line[0]] = _line[4]
    return id2line

'''
    1. Read from 'movie_conversations.txt'
    2. Create a list of [list of line_id's]
'''
def get_conversations():
    conv_lines = open('../../../nlp_data/movie_conversations.txt', encoding='utf-8', errors='ignore').read().split('\n')
    convs = [ ]
    for line in conv_lines[:-1]:
        _line = line.split(' +++$+++ ')[-1][1:-1].replace("'","").replace(" ","")
        convs.append(_line.split(','))
    return convs

'''
    1. Get each conversation
    2. Get each line from conversation
    3. Save each conversation to file
'''
def extract_conversations(convs,id2line,path=''):
    idx = 0
    for conv in convs:
        f_conv = open(path + str(idx)+'.txt', 'w')
        for line_id in conv:
            f_conv.write(id2line[line_id])
            f_conv.write('\n')
        f_conv.close()
        idx += 1

'''
    Get lists of all conversations as Questions and Answers
    1. [questions]
    2. [answers]
'''
def gather_dataset(convs, id2line):
    questions = []; answers = []

    for conv in convs:
        if len(conv) %2 != 0:
            conv = conv[:-1]
        for i in range(len(conv)):
            if i%2 == 0:
                questions.append(id2line[conv[i]])
            else:
                answers.append(id2line[conv[i]])

    return questions, answers


'''
    We need 4 files
    1. train.enc : Encoder input for training
    2. train.dec : Decoder input for training
    3. test.enc  : Encoder input for testing
    4. test.dec  : Decoder input for testing
'''
def prepare_seq2seq_files(questions, answers, path='',TESTSET_SIZE = 30000):

    # open files
    train_enc = open(path + 'train.enc','w')
    train_dec = open(path + 'train.dec','w')
    test_enc  = open(path + 'test.enc', 'w')
    test_dec  = open(path + 'test.dec', 'w')

    # choose 30,000 (TESTSET_SIZE) items to put into testset
    test_ids = random.sample([i for i in range(len(questions))],TESTSET_SIZE)

    for i in range(len(questions)):
        if i in test_ids:
            test_enc.write(questions[i]+'\n')
            test_dec.write(answers[i]+ '\n' )
        else:
            train_enc.write(questions[i]+'\n')
            train_dec.write(answers[i]+ '\n' )
        if i%10000 == 0:
            print('\n>> written {} lines'.format(i))

    # close files
    train_enc.close()
    train_dec.close()
    test_enc.close()
    test_dec.close()



'''
 remove anything that isn't in the vocabulary
    return str(pure en)
'''
def filter_line(line, whitelist):
    return ''.join([ ch for ch in line if ch in whitelist ])

def add_space_punct(line):
    return line.translate(str.maketrans({key: " {0} ".format(key) for key in string.punctuation}))

def replace_didt(line):
    return re.sub("\sdidn't\s", ' did not ', line)


'''
 filter too long and too short sequences
    return tuple( filtered_ta, filtered_en )
'''
def filter_data(qseq, aseq):
    filtered_q, filtered_a = [], []
    raw_data_len = len(qseq)

    assert len(qseq) == len(aseq)

    for i in range(raw_data_len):
        qlen, alen = len(qseq[i].split()), len(aseq[i].split())
        if qlen >= limit['minq'] and qlen <= limit['maxq']:
            if alen >= limit['mina'] and alen <= limit['maxa']:
                filtered_q.append(qseq[i])
                filtered_a.append(aseq[i])

    # print the fraction of the original data, filtered
    filt_data_len = len(filtered_q)
    filtered = int((raw_data_len - filt_data_len)*100/raw_data_len)
    print(str(filtered) + '% filtered from original data')

    return filtered_q, filtered_a



'''
 read list of words, create index to word,
  word to index dictionaries
    return tuple( vocab->(word, count), idx2w, w2idx )
'''
def index_(tokenized_sentences, vocab_size):
    # get frequency distribution
    freq_dist = nltk.FreqDist(itertools.chain(*tokenized_sentences))
    # get vocabulary of 'vocab_size' most used words
    vocab = freq_dist.most_common(vocab_size)
    # index2word
    index2word = ['_'] + [UNK] + [ x[0] for x in vocab ]
    # word2index
    word2index = dict([(w,i) for i,w in enumerate(index2word)] )
    return index2word, word2index, freq_dist

'''
 filter based on number of unknowns (words not in vocabulary)
  filter out the worst sentences
'''
def filter_unk(qtokenized, atokenized, w2idx):
    data_len = len(qtokenized)

    filtered_q, filtered_a = [], []

    for qline, aline in zip(qtokenized, atokenized):
        unk_count_q = len([ w for w in qline if w not in w2idx ])
        unk_count_a = len([ w for w in aline if w not in w2idx ])
        if unk_count_a <= 2:
            if unk_count_q > 0:
                if unk_count_q/len(qline) > 0.2:
                    pass
            filtered_q.append(qline)
            filtered_a.append(aline)

    # print the fraction of the original data, filtered
    filt_data_len = len(filtered_q)
    filtered = int((data_len - filt_data_len)*100/data_len)
    print(str(filtered) + '% filtered from original data')

    return filtered_q, filtered_a




'''
 create the final dataset :
  - convert list of items to arrays of indices
  - add zero padding
      return ( [array_en([indices]), array_ta([indices]) )
'''
def zero_pad(qtokenized, atokenized, w2idx):
    # num of rows
    data_len = len(qtokenized)

    # numpy arrays to store indices
    idx_q = np.zeros([data_len, limit['maxq']], dtype=np.int32)
    idx_a = np.zeros([data_len, limit['maxa']], dtype=np.int32)

    for i in range(data_len):
        q_indices = pad_seq(qtokenized[i], w2idx, limit['maxq'])
        a_indices = pad_seq(atokenized[i], w2idx, limit['maxa'])

        #print(len(idx_q[i]), len(q_indices))
        #print(len(idx_a[i]), len(a_indices))
        idx_q[i] = np.array(q_indices)
        idx_a[i] = np.array(a_indices)

    return idx_q, idx_a


'''
 replace words with indices in a sequence
  replace with unknown if word not in lookup
    return [list of indices]
'''
def pad_seq(seq, lookup, maxlen):
    indices = []
    for word in seq:
        if word in lookup:
            indices.append(lookup[word])
        else:
            indices.append(lookup[UNK])
    return indices + [0]*(maxlen - len(seq))




import numpy as np
from random import sample

'''
 split data into train (70%), test (15%) and valid(15%)
    return tuple( (trainX, trainY), (testX,testY), (validX,validY) )
'''
def split_dataset(x, y, ratio = [0.7, 0.15, 0.15] ):
    # number of examples
    data_len = len(x)
    lens = [ int(data_len*item) for item in ratio ]

    trainX, trainY = x[:lens[0]], y[:lens[0]]
    testX, testY = x[lens[0]:lens[0]+lens[1]], y[lens[0]:lens[0]+lens[1]]
    validX, validY = x[-lens[-1]:], y[-lens[-1]:]

    return (trainX,trainY), (testX,testY), (validX,validY)


'''
 generate batches from dataset
    yield (x_gen, y_gen)
    TODO : fix needed
'''
def batch_gen(x, y, batch_size):
    # infinite while
    while True:
        for i in range(0, len(x), batch_size):
            if (i+1)*batch_size < len(x):
                yield x[i : (i+1)*batch_size ].T, y[i : (i+1)*batch_size ].T

'''
 generate batches, by random sampling a bunch of items
    yield (x_gen, y_gen)
'''
def rand_batch_gen(x, y, batch_size):
    while True:
        sample_idx = sample(list(np.arange(len(x))), batch_size)
        yield x[sample_idx].T, y[sample_idx].T


'''
 a generic decode function
    inputs : sequence, lookup
'''
def decode(sequence, lookup, separator=''): # 0 used for padding, is ignored
    return separator.join([ lookup[element] for element in sequence if element ])

def load_data(PATH=''):
    # read data control dictionaries
    with open(PATH + 'metadata.pkl', 'rb') as f:
        metadata = pickle.load(f)
    # read numpy arrays
    idx_q = np.load(PATH + 'idx_q.npy')
    idx_a = np.load(PATH + 'idx_a.npy')
    return metadata, idx_q, idx_a

In [8]:
id2line = get_id2line()
convs = get_conversations()
questions, answers = gather_dataset(convs,id2line)

# change to lower case (just for en)
questions = [ line.lower() for line in questions ]
answers = [ line.lower() for line in answers ]

# replace didn't to did not
questions = [ replace_didt(line) for line in questions ]
answers = [ replace_didt(line) for line in answers ]

# pad punctuation
questions = [ add_space_punct(line) for line in questions ]
answers = [ add_space_punct(line) for line in answers ]

# filter out too long or too short sequences
qlines, alines = filter_data(questions, answers)

26% filtered from original data


In [12]:
for q,a in zip(qlines[141:145], alines[141:145]):
    print('q : [{0}]; a : [{1}]'.format(q,a))

q : [pick you up friday ,  then]; a : [oh ,  right .   friday . ]
q : [the night i take you to places you ' ve never been before .   and back . ]; a : [like where ?   the 7 - eleven on burnside ?  do you even know my name ,  screwboy ? ]
q : [you hate me don ' t you ? ]; a : [i don ' t really think you warrant that strong an emotion . ]
q : [then say you ' ll spend dollar night at the track with me . ]; a : [and why would i do that ? ]


In [9]:
# convert list of [lines of text] into list of [list of words ]
print('\n>> Segment lines into words')
qtokenized = [ [w.strip() for w in wordlist.split() if w] for wordlist in qlines ]
atokenized = [ [w.strip() for w in wordlist.split() if w] for wordlist in alines ]


>> Segment lines into words


In [10]:
# filter out sentences with too many unknowns
print('\n >> Filter Unknowns')
qtokenized, atokenized = filter_unk(qtokenized, atokenized, word2idx)
print('\n Final dataset len : ' + str(len(qtokenized)))
print('\n >> Zero Padding')
idx_q, idx_a = zero_pad(qtokenized, atokenized, word2idx)
print('\n >> Save numpy arrays to disk')
# save them
np.save('idx_q.npy', idx_q)
np.save('idx_a.npy', idx_a)

# let us now save the necessary dictionaries
metadata = {
        'w2idx' : word2idx,
#         'idx2w' : idx2w,
        'limit' : limit
#         'freq_dist' : freq_dist
            }

# write to disk : data control dictionaries
with open('metadata.pkl', 'wb') as f:
    pickle.dump(metadata, f)

# count of unknowns
unk_count = (idx_q == UNKNOWN_TOKEN).sum() + (idx_a == UNKNOWN_TOKEN).sum()
# count of words
word_count = (idx_q != UNKNOWN_TOKEN).sum() + (idx_a != UNKNOWN_TOKEN).sum()

print('% unknown : {0}'.format(100 * (unk_count/word_count)))
print('Dataset count : ' + str(idx_q.shape[0]))



 >> Filter Unknowns
0% filtered from original data

 Final dataset len : 100684

 >> Zero Padding

 >> Save numpy arrays to disk
% unknown : 0.7643293233365639
Dataset count : 100684


# Model

In [11]:
import tensorflow as tf
import tensorlayer as tl
from tensorlayer.layers import *

import tensorflow as tf
import numpy as np
import time

In [12]:
VOCAB_LENGTH = len(word2idx)

In [13]:
metadata, idx_q, idx_a = load_data(PATH='./')          # Cornell Moive
(trainX, trainY), (testX, testY), (validX, validY) = split_dataset(idx_q, idx_a)

In [14]:
trainX = trainX.tolist()
trainY = trainY.tolist()
testX = testX.tolist()
testY = testY.tolist()
validX = validX.tolist()
validY = validY.tolist()

trainX = tl.prepro.remove_pad_sequences(trainX)
trainY = tl.prepro.remove_pad_sequences(trainY)
testX = tl.prepro.remove_pad_sequences(testX)
testY = tl.prepro.remove_pad_sequences(testY)
validX = tl.prepro.remove_pad_sequences(validX)
validY = tl.prepro.remove_pad_sequences(validY)

In [15]:
idx2word = list(word2idx.keys())

In [16]:
xseq_len = len(trainX)#.shape[-1]
yseq_len = len(trainY)#.shape[-1]
assert xseq_len == yseq_len
batch_size = 32
n_step = int(xseq_len/batch_size)
emb_dim = 100 

unk_id = word2idx['UNK']   
pad_id = word2idx['PAD']     

start_id = VOCAB_LENGTH 
end_id = VOCAB_LENGTH + 1

word2idx.update({'start_id': start_id})
word2idx.update({'end_id': end_id})
idx2word = idx2word + ['start_id', 'end_id']
xvocab_size = len(idx2word)

VOCAB_LENGTH = VOCAB_LENGTH + 2


In [17]:
print("encode_seqs", [idx2word[id] for id in trainX[10]])
target_seqs = tl.prepro.sequences_add_end_id([trainY[10]], end_id=end_id)[0]
    # target_seqs = tl.prepro.remove_pad_sequences([target_seqs], pad_id=pad_id)[0]
print("target_seqs", [idx2word[id] for id in target_seqs])
decode_seqs = tl.prepro.sequences_add_start_id([trainY[10]], start_id=start_id, remove_last=False)[0]
    # decode_seqs = tl.prepro.remove_pad_sequences([decode_seqs], pad_id=pad_id)[0]
print("decode_seqs", [idx2word[id] for id in decode_seqs])
target_mask = tl.prepro.sequences_get_mask([target_seqs])[0]
print("target_mask", target_mask)
print(len(target_seqs), len(decode_seqs), len(target_mask))

encode_seqs ['hi', '.']
target_seqs ['looks', 'like', 'things', 'worked', 'out', 'tonight', ',', 'huh', '?', 'end_id']
decode_seqs ['start_id', 'looks', 'like', 'things', 'worked', 'out', 'tonight', ',', 'huh', '?']
target_mask [1 1 1 1 1 1 1 1 1 1]
10 10 10


###============= model
def model(encode_seqs, decode_seqs, is_train=True, reuse=False):
    with tf.variable_scope("model", reuse=reuse):
        # for chatbot, you can use the same embedding layer,
        # for translation, you may want to use 2 seperated embedding layers
        with tf.variable_scope("embedding") as vs:
            net_encode = EmbeddingInputlayer(
                inputs = encode_seqs,
                vocabulary_size = xvocab_size,
                embedding_size = emb_dim,
                name = 'seq_embedding')
            vs.reuse_variables()
            tl.layers.set_name_reuse(True) # remove if TL version == 1.8.0+
            net_decode = EmbeddingInputlayer(
                inputs = decode_seqs,
                vocabulary_size = xvocab_size,
                embedding_size = emb_dim,
                name = 'seq_embedding')
        net_rnn = Seq2Seq(net_encode, net_decode,
                cell_fn = tf.contrib.rnn.BasicLSTMCell,
                n_hidden = emb_dim,
                initializer = tf.random_uniform_initializer(-0.1, 0.1),
                encode_sequence_length = retrieve_seq_length_op2(encode_seqs),
                decode_sequence_length = retrieve_seq_length_op2(decode_seqs),
                initial_state_encode = None,
                dropout = (0.5 if is_train else None),
                n_layer = 3,
                return_seq_2d = True,
                name = 'seq2seq')
        net_out = DenseLayer(net_rnn, n_units=xvocab_size, act=tf.identity, name='output')
    return net_out, net_rnn

In [18]:
glove_weights_initializer = tf.constant_initializer(weights)
# embedding_weights = tf.get_variable(
#     name='embedding_weights', 
#     shape=(VOCAB_LENGTH, EMBEDDING_DIMENSION), 
#     initializer=glove_weights_initializer,
#     trainable=False)
# embedding = tf.nn.embedding_lookup(embedding_weights, features['word_indices'])

## This part is for testing the embedding input layer (don't run)

In [None]:
x = tf.placeholder(tf.int32, shape=[batch_size])

In [None]:
aa = tl.layers.EmbeddingInputlayer(inputs=x, vocabulary_size=VOCAB_LENGTH,
                                                       embedding_size=emb_dim, E_init=glove_weights_initializer,
                                                        E_init_args={'trainable': False},
                                                       name='embed')

In [None]:
sess = tf.InteractiveSession()

In [None]:
tl.layers.initialize_global_variables(sess)

In [None]:
sess.run(aa.outputs, feed_dict={x: [20,30, 23, 1,2,3,4,5,6,7,78,8,9,9,5,4,3,2,2,2,23,3,4,5,2,3,3,3,4,5,6,7]})

## Run the following

In [19]:
###============= model
def model(encode_seqs, decode_seqs, is_train=True, reuse=False):
    with tf.variable_scope("model", reuse=reuse):
        # for chatbot, you can use the same embedding layer,
        # for translation, you may want to use 2 seperated embedding layers
        with tf.variable_scope("embedding") as vs:
            net_encode = tl.layers.EmbeddingInputlayer(inputs=encode_seqs, vocabulary_size=VOCAB_LENGTH,
                                                       embedding_size=emb_dim, E_init=glove_weights_initializer,
                                                        E_init_args={'trainable': False},
                                                       name='embed')
            vs.reuse_variables()
            tl.layers.set_name_reuse(True) # remove if TL version == 1.8.0+
            net_decode =tl.layers.EmbeddingInputlayer(inputs=decode_seqs, vocabulary_size=VOCAB_LENGTH,
                                                      E_init=glove_weights_initializer, embedding_size=emb_dim,
                                                      E_init_args={'trainable': False},
                                                      name='embed')
        net_rnn = Seq2Seq(net_encode, net_decode,
                cell_fn = tf.contrib.rnn.BasicLSTMCell,
                n_hidden = emb_dim,
                initializer = tf.random_uniform_initializer(-0.1, 0.1),
                encode_sequence_length = retrieve_seq_length_op2(encode_seqs),
                decode_sequence_length = retrieve_seq_length_op2(decode_seqs),
                initial_state_encode = None,
                dropout = (0.5 if is_train else None),
                n_layer = 3,
                return_seq_2d = True,
                name = 'seq2seq')
        net_out = DenseLayer(net_rnn, n_units=xvocab_size, act=tf.identity, name='output')
    return net_out, net_rnn

In [20]:
# model for training
with tf.device('/device:GPU:0'):
    encode_seqs = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="encode_seqs")
    decode_seqs = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="decode_seqs")
    target_seqs = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="target_seqs")
    target_mask = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="target_mask") # tl.prepro.sequences_get_mask()
net_out, _ = model(encode_seqs, decode_seqs, is_train=True, reuse=False)

# model for inferencing
with tf.device('/device:GPU:0'):
    encode_seqs2 = tf.placeholder(dtype=tf.int64, shape=[1, None], name="encode_seqs")
    decode_seqs2 = tf.placeholder(dtype=tf.int64, shape=[1, None], name="decode_seqs")
net, net_rnn = model(encode_seqs2, decode_seqs2, is_train=False, reuse=True)
y = tf.nn.softmax(net.outputs)

loss = tl.cost.cross_entropy_seq_with_mask(logits=net_out.outputs, target_seqs=target_seqs, input_mask=target_mask, return_details=False, name='cost')

net_out.print_params(False)
# original was 0.0001
lr = 0.0001
train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss)

sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True))
tl.layers.initialize_global_variables(sess)
tl.files.load_and_assign_npz(sess=sess, name='n2.npz', network=net)

[TL] EmbeddingInputlayer model/embedding/embed: (40004, 100)
Instructions for updating: TensorLayer relies on TensorFlow to check name reusing

[TL] EmbeddingInputlayer model/embedding/embed: (40004, 100)
[TL] [*] Seq2Seq model/seq2seq: n_hidden: 100 cell_fn: BasicLSTMCell dropout: 0.5 n_layer: 3
[TL] DynamicRNNLayer model/seq2seq/encode: n_hidden: 100, in_dim: 3 in_shape: (32, ?, 100) cell_fn: BasicLSTMCell dropout: 0.5 n_layer: 3
[TL]        batch_size (concurrent processes): 32
[TL] DynamicRNNLayer model/seq2seq/decode: n_hidden: 100, in_dim: 3 in_shape: (32, ?, 100) cell_fn: BasicLSTMCell dropout: 0.5 n_layer: 3
[TL]        batch_size (concurrent processes): 32
[TL] DenseLayer  model/output: 40004 No Activation
[TL] EmbeddingInputlayer model/embedding/embed: (40004, 100)
[TL] EmbeddingInputlayer model/embedding/embed: (40004, 100)
[TL] [*] Seq2Seq model/seq2seq: n_hidden: 100 cell_fn: BasicLSTMCell dropout: None n_layer: 3
[TL] DynamicRNNLayer model/seq2seq/encode: n_hidden: 100, i

<tensorlayer.layers.dense.base_dense.DenseLayer at 0x7fc330e00ba8>

In [None]:
###============= train
n_epoch = 65
for epoch in range(n_epoch):
    epoch_time = time.time()
    ## shuffle training data
    from sklearn.utils import shuffle
    trainX, trainY = shuffle(trainX, trainY, random_state=0)
    ## train an epoch
    total_err, n_iter = 0, 0
    for X, Y in tl.iterate.minibatches(inputs=trainX, targets=trainY, batch_size=batch_size, shuffle=False):
        step_time = time.time()

        X = tl.prepro.pad_sequences(X)
        _target_seqs = tl.prepro.sequences_add_end_id(Y, end_id=end_id)
        _target_seqs = tl.prepro.pad_sequences(_target_seqs)

        _decode_seqs = tl.prepro.sequences_add_start_id(Y, start_id=start_id, remove_last=False)
        _decode_seqs = tl.prepro.pad_sequences(_decode_seqs)
        _target_mask = tl.prepro.sequences_get_mask(_target_seqs)

        ## you can view the data here
        # for i in range(len(X)):
        #     print(i, [idx2word[id] for id in X[i]])
        #     # print(i, [idx2word[id] for id in Y[i]])
        #     print(i, [idx2word[id] for id in _target_seqs[i]])
        #     print(i, [idx2word[id] for id in _decode_seqs[i]])
        #     print(i, _target_mask[i])
        #     print(len(_target_seqs[i]), len(_decode_seqs[i]), len(_target_mask[i]))
        # exit()

        _, err = sess.run([train_op, loss],
                        {encode_seqs: X,
                        decode_seqs: _decode_seqs,
                        target_seqs: _target_seqs,
                        target_mask: _target_mask})

        if n_iter % 200 == 0:
            print("Epoch[%d/%d] step:[%d/%d] loss:%f took:%.5fs" % (epoch, n_epoch, n_iter, n_step, err, time.time() - step_time))

        total_err += err; n_iter += 1

        ###============= inference
        if n_iter % 1000 == 0:
            seeds = ["happy birthday have a nice day",
                    "how was it going"]
            for seed in seeds:
                print("Query >", seed)
                seed_id = [word2idx[w] for w in seed.split(" ")]
                for _ in range(5):  # 1 Query --> 5 Reply
                    # 1. encode, get state
                    state = sess.run(net_rnn.final_state_encode,
                                    {encode_seqs2: [seed_id]})
                    # 2. decode, feed start_id, get first word
                    #   ref https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_ptb_lstm_state_is_tuple.py
                    o, state = sess.run([y, net_rnn.final_state_decode],
                                    {net_rnn.initial_state_decode: state,
                                    decode_seqs2: [[start_id]]})
                    w_id = tl.nlp.sample_top(o[0], top_k=3)
                    w = idx2word[w_id]
                    # 3. decode, feed state iteratively
                    sentence = [w]
                    for _ in range(30): # max sentence length
                        o, state = sess.run([y, net_rnn.final_state_decode],
                                        {net_rnn.initial_state_decode: state,
                                        decode_seqs2: [[w_id]]})
                        w_id = tl.nlp.sample_top(o[0], top_k=2)
                        w = idx2word[w_id]
                        if w_id == end_id:
                            break
                        sentence = sentence + [w]
                    print(" >", ' '.join(sentence))

    print("Epoch[%d/%d] averaged loss:%f took:%.5fs" % (epoch, n_epoch, total_err/n_iter, time.time()-epoch_time))

    tl.files.save_npz(net.all_params, name='n2.npz', sess=sess)

Epoch[0/65] step:[0/2202] loss:3.664974 took:0.91226s
Epoch[0/65] step:[200/2202] loss:3.815032 took:0.11539s
Epoch[0/65] step:[400/2202] loss:4.102279 took:0.11165s
Epoch[0/65] step:[600/2202] loss:3.887945 took:0.11302s
Epoch[0/65] step:[800/2202] loss:4.050809 took:0.11463s
Query > happy birthday have a nice day
 > i don ' t know .
 > i ' ll get you .
 > i don ' t know .
 > i don ' t want to be . . .
 > i don ' t know .
Query > how was it going
 > the UNK .
 > the UNK .
 > i don ' t know , i ' m not going to get it .
 > the UNK . . .
 > i don ' t know .
Epoch[0/65] step:[1000/2202] loss:3.798570 took:0.09714s
Epoch[0/65] step:[1200/2202] loss:4.014073 took:0.11517s
Epoch[0/65] step:[1400/2202] loss:3.820731 took:0.12036s
Epoch[0/65] step:[1600/2202] loss:3.702559 took:0.10918s
Epoch[0/65] step:[1800/2202] loss:3.670743 took:0.11417s
Query > happy birthday have a nice day
 > you ' re not going to be a UNK , i ' m sorry .
 > you ' re a little man . i ' m not going to get it .
 > you '

Epoch[5/65] step:[2000/2202] loss:3.790045 took:0.12017s
Epoch[5/65] step:[2200/2202] loss:3.754567 took:0.11691s
Epoch[5/65] averaged loss:3.862785 took:249.16664s
[TL] [*] n2.npz saved
Epoch[6/65] step:[0/2202] loss:3.780042 took:0.10063s
Epoch[6/65] step:[200/2202] loss:3.617723 took:0.11164s
Epoch[6/65] step:[400/2202] loss:3.896684 took:0.10613s
Epoch[6/65] step:[600/2202] loss:3.986490 took:0.11406s
Epoch[6/65] step:[800/2202] loss:3.823431 took:0.11175s
Query > happy birthday have a nice day
 > you ' re not going to be a little .
 > you ' re not a UNK .
 > you ' re not going to be a UNK .
 > you ' re not going to be a UNK .
 > what ?
Query > how was it going
 > i ' ll be a little UNK .
 > it ' s a UNK , i ' m sorry .
 > i don ' t know .
 > i don ' t know .
 > the UNK , i ' m not going .
Epoch[6/65] step:[1000/2202] loss:3.601669 took:0.11531s
Epoch[6/65] step:[1200/2202] loss:3.860496 took:0.10877s
Epoch[6/65] step:[1400/2202] loss:3.916353 took:0.09829s
Epoch[6/65] step:[1600/2

Epoch[11/65] step:[2200/2202] loss:3.870764 took:0.10632s
Epoch[11/65] averaged loss:3.860251 took:249.22431s
[TL] [*] n2.npz saved
Epoch[12/65] step:[0/2202] loss:4.058467 took:0.11452s
Epoch[12/65] step:[200/2202] loss:3.688139 took:0.11110s
Epoch[12/65] step:[400/2202] loss:3.804825 took:0.11146s
Epoch[12/65] step:[600/2202] loss:3.846092 took:0.10244s
Epoch[12/65] step:[800/2202] loss:3.774758 took:0.11497s
Query > happy birthday have a nice day
 > you ' re not a UNK .
 > you ' re a good man , i ' m sorry , i ' m not .
 > what ?
 > you ' re a UNK .
 > i don ' t know , i ' m sorry .
Query > how was it going
 > it ' s not a good UNK .
 > the UNK .
 > the UNK .
 > the one .
 > the UNK .
Epoch[12/65] step:[1000/2202] loss:4.088830 took:0.11143s
Epoch[12/65] step:[1200/2202] loss:3.552500 took:0.11755s
Epoch[12/65] step:[1400/2202] loss:3.947554 took:0.11105s
Epoch[12/65] step:[1600/2202] loss:3.780357 took:0.11254s
Epoch[12/65] step:[1800/2202] loss:3.729454 took:0.11279s
Query > happy

Epoch[17/65] step:[2000/2202] loss:3.743751 took:0.11776s
Epoch[17/65] step:[2200/2202] loss:3.577564 took:0.10563s
Epoch[17/65] averaged loss:3.856284 took:249.23881s
[TL] [*] n2.npz saved
Epoch[18/65] step:[0/2202] loss:3.876742 took:0.11508s
Epoch[18/65] step:[200/2202] loss:4.006594 took:0.11732s
Epoch[18/65] step:[400/2202] loss:3.979013 took:0.11693s
Epoch[18/65] step:[600/2202] loss:4.172915 took:0.10982s
Epoch[18/65] step:[800/2202] loss:3.706369 took:0.10623s
Query > happy birthday have a nice day
 > i ' m sorry .
 > what ' s the difference , you ' re a UNK .
 > you ' re not a good man .
 > i ' ll be right .
 > you ' re not a UNK .
Query > how was it going
 > i ' m not going to be a UNK .
 > the UNK .
 > the UNK .
 > i don ' t know . . .
 > it ' s not a UNK . . .
Epoch[18/65] step:[1000/2202] loss:3.653560 took:0.10366s
Epoch[18/65] step:[1200/2202] loss:3.501284 took:0.11472s
Epoch[18/65] step:[1400/2202] loss:4.006666 took:0.11854s
Epoch[18/65] step:[1600/2202] loss:3.901409

Epoch[23/65] step:[2000/2202] loss:3.710427 took:0.10627s
Epoch[23/65] step:[2200/2202] loss:4.034122 took:0.11080s
Epoch[23/65] averaged loss:3.854452 took:248.89103s
[TL] [*] n2.npz saved
Epoch[24/65] step:[0/2202] loss:4.066487 took:0.10864s
Epoch[24/65] step:[200/2202] loss:3.671767 took:0.11196s
Epoch[24/65] step:[400/2202] loss:3.816467 took:0.11911s
Epoch[24/65] step:[600/2202] loss:3.632607 took:0.11500s
Epoch[24/65] step:[800/2202] loss:3.962574 took:0.10969s
Query > happy birthday have a nice day
 > you don ' t want to be able . i ' m not sure .
 > i ' m not sure .
 > i don ' t know .
 > what ' s the matter ?
 > i don ' t want to talk .
Query > how was it going
 > the UNK .
 > the UNK .
 > i don ' t know .
 > i don ' t know .
 > the UNK .
Epoch[24/65] step:[1000/2202] loss:3.745245 took:0.11489s
Epoch[24/65] step:[1200/2202] loss:3.788352 took:0.11197s
Epoch[24/65] step:[1400/2202] loss:3.658774 took:0.12243s
Epoch[24/65] step:[1600/2202] loss:3.712565 took:0.12244s
Epoch[24/

Epoch[29/65] step:[2000/2202] loss:3.543935 took:0.11567s
Epoch[29/65] step:[2200/2202] loss:3.696616 took:0.11405s
Epoch[29/65] averaged loss:3.852783 took:249.38370s
[TL] [*] n2.npz saved
Epoch[30/65] step:[0/2202] loss:3.559965 took:0.11733s
Epoch[30/65] step:[200/2202] loss:3.584791 took:0.09599s
Epoch[30/65] step:[400/2202] loss:4.087370 took:0.11158s
Epoch[30/65] step:[600/2202] loss:3.698358 took:0.10300s
Epoch[30/65] step:[800/2202] loss:4.002708 took:0.12215s
Query > happy birthday have a nice day
 > you ' re not going to be a little man . i don ' t know .
 > what ?
 > you ' ll have a little UNK .
 > you ' re a good man .
 > i ' m not going to be .
Query > how was it going
 > i ' m sorry , i ' m sorry .
 > it ' s a UNK .
 > i don ' t know , you ' re not going to get it .
 > the one .
 > i ' ll get you .
Epoch[30/65] step:[1000/2202] loss:3.731169 took:0.10716s
Epoch[30/65] step:[1200/2202] loss:3.769906 took:0.11054s
Epoch[30/65] step:[1400/2202] loss:3.792693 took:0.12267s
Ep