In [1]:
import tensorlayer as tl
import tensorflow as tf
from tensorlayer.layers import *

import numpy as np
import time
import pickle

In [2]:
def load_data(PATH='./'):
    # read data control dictionaries
    with open(PATH + 'metadata.pkl', 'rb') as f:
        metadata = pickle.load(f)
    # read numpy arrays
    idx_q = np.load(PATH + 'idx_q.npy')
    idx_a = np.load(PATH + 'idx_a.npy')
    return metadata, idx_q, idx_a

In [3]:
metadata, idx_q, idx_a = load_data(PATH='./') 

In [4]:
batch_size = 32
xvocab_size = len(metadata['idx2w']) # 8002 (0~8001)
emb_dim = 1024

w2idx = metadata['w2idx']   # dict  word 2 index
idx2w = metadata['idx2w']   # list index 2 word

unk_id = w2idx['unk']   # 1
pad_id = w2idx['_']     # 0

In [5]:
start_id = xvocab_size  # 8002
end_id = xvocab_size+1  # 8003

w2idx.update({'start_id': start_id})
w2idx.update({'end_id': end_id})
idx2w = idx2w + ['start_id', 'end_id']

xvocab_size = yvocab_size = xvocab_size + 2

In [6]:
def model(encode_seqs, decode_seqs, is_train=True, reuse=False):
    with tf.variable_scope("model", reuse=reuse):
        # for chatbot, you can use the same embedding layer,
        # for translation, you may want to use 2 seperated embedding layers
        with tf.variable_scope("embedding") as vs:
            net_encode = EmbeddingInputlayer(
                inputs = encode_seqs,
                vocabulary_size = xvocab_size,
                embedding_size = emb_dim,
                name = 'seq_embedding')
            vs.reuse_variables()
#             tl.layers.set_name_reuse(True) # remove if TL version == 1.8.0+
            net_decode = EmbeddingInputlayer(
                inputs = decode_seqs,
                vocabulary_size = xvocab_size,
                embedding_size = emb_dim,
                name = 'seq_embedding')
        net_rnn = Seq2Seq(net_encode, net_decode,
                cell_fn = tf.contrib.rnn.BasicLSTMCell,
                n_hidden = emb_dim,
                initializer = tf.random_uniform_initializer(-0.1, 0.1),
                encode_sequence_length = retrieve_seq_length_op2(encode_seqs),
                decode_sequence_length = retrieve_seq_length_op2(decode_seqs),
                initial_state_encode = None,
                dropout = (0.5 if is_train else None),
                n_layer = 3,
                return_seq_2d = True,
                name = 'seq2seq')
        net_out = DenseLayer(net_rnn, n_units=xvocab_size, act=tf.identity, name='output')
    return net_out, net_rnn

In [7]:
with tf.device('/device:GPU:0'):
    encode_seqs = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="encode_seqs")
    decode_seqs = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="decode_seqs")
    target_seqs = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="target_seqs")
    target_mask = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="target_mask") # tl.prepro.sequences_get_mask()
net_out, _ = model(encode_seqs, decode_seqs, is_train=True, reuse=False)

[TL] EmbeddingInputlayer model/embedding/seq_embedding: (8004, 1024)
[TL] EmbeddingInputlayer model/embedding/seq_embedding: (8004, 1024)
[TL] [*] Seq2Seq model/seq2seq: n_hidden: 1024 cell_fn: BasicLSTMCell dropout: 0.5 n_layer: 3
[TL] DynamicRNNLayer model/seq2seq/encode: n_hidden: 1024, in_dim: 3 in_shape: (32, ?, 1024) cell_fn: BasicLSTMCell dropout: 0.5 n_layer: 3
[TL]        batch_size (concurrent processes): 32
[TL] DynamicRNNLayer model/seq2seq/decode: n_hidden: 1024, in_dim: 3 in_shape: (32, ?, 1024) cell_fn: BasicLSTMCell dropout: 0.5 n_layer: 3
[TL]        batch_size (concurrent processes): 32
[TL] DenseLayer  model/output: 8004 No Activation


In [8]:
with tf.device('/device:GPU:0'):
    encode_seqs2 = tf.placeholder(dtype=tf.int64, shape=[1, None], name="encode_seqs")
    decode_seqs2 = tf.placeholder(dtype=tf.int64, shape=[1, None], name="decode_seqs")
net, net_rnn = model(encode_seqs2, decode_seqs2, is_train=False, reuse=True)
y = tf.nn.softmax(net.outputs)

[TL] EmbeddingInputlayer model/embedding/seq_embedding: (8004, 1024)
[TL] EmbeddingInputlayer model/embedding/seq_embedding: (8004, 1024)
[TL] [*] Seq2Seq model/seq2seq: n_hidden: 1024 cell_fn: BasicLSTMCell dropout: None n_layer: 3
[TL] DynamicRNNLayer model/seq2seq/encode: n_hidden: 1024, in_dim: 3 in_shape: (1, ?, 1024) cell_fn: BasicLSTMCell dropout: None n_layer: 3
[TL]        batch_size (concurrent processes): 1
[TL] DynamicRNNLayer model/seq2seq/decode: n_hidden: 1024, in_dim: 3 in_shape: (1, ?, 1024) cell_fn: BasicLSTMCell dropout: None n_layer: 3
[TL]        batch_size (concurrent processes): 1
[TL] DenseLayer  model/output: 8004 No Activation


In [9]:
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True))
tl.layers.initialize_global_variables(sess)
tl.files.load_and_assign_npz(sess=sess, name='n.npz', network=net)

Instructions for updating: This API is deprecated in favor of `tf.global_variables_initializer`

[TL] [*] Load n.npz SUCCESS!


<tensorlayer.layers.dense.base_dense.DenseLayer at 0x7fcec936df60>

In [10]:
def answer_me_bot(inputs=""):
    seed = inputs.lower()
    seed_id = [w2idx[w] for w in seed.split(" ")]
    
    state = sess.run(net_rnn.final_state_encode,
                                    {encode_seqs2: [seed_id]})
    o, state = sess.run([y, net_rnn.final_state_decode],
                                        {net_rnn.initial_state_decode: state,
                                        decode_seqs2: [[start_id]]})
    w_id = tl.nlp.sample_top(o[0], top_k=3)
    w = idx2w[w_id]
    # 3. decode, feed state iteratively
    sentence = [w]
    for _ in range(30): # max sentence length
        o, state = sess.run([y, net_rnn.final_state_decode],
                            {net_rnn.initial_state_decode: state,
                            decode_seqs2: [[w_id]]})
        w_id = tl.nlp.sample_top(o[0], top_k=2)
        w = idx2w[w_id]
        if w_id == end_id:
            break
        sentence = sentence + [w]
    print(" >", ' '.join(sentence))
    
    

In [16]:
answer_me_bot("test me")

 > i want you to tell me what you said
