In [2]:
import numpy as np
import tensorflow as tf
from  tensorflow.keras.preprocessing.sequence import pad_sequences

# Advanced Seq2Seq

## Problems of Naive Seq2Seq

之前的 Seq2Seq 我们会在每次都把 ground truth 喂给 Seq2Seq 中的 encoder, 但事实上, **decoder 应该依赖自己前面产生的 token 来进行下一个 token 的生成。这样才能在 inference 阶段自主地产生一些句子，因为 inference 阶段是没有 ground truth 的**


## Implementation  
这个代码就是实现了在 decode 阶段用 decoder 产生的 ** previous state, previous prediction 来进行下一个 token 的预测**。主要使用了 `tf.nn.raw_rnn` 接口来实现，可以自定义一个 loop_function 控制 RNN 的循环。
但是存在一个问题，即使在 test 阶段不再需要 ground truth, 但需要我们提供需要产生的序列长度才能进行 test，可以进一步进行优化（自定义 RNN cell，并且加入长度相关的 mask），使得在 inference 阶段只需要 encoder 端的信息就能够产生歌词。


## Result Comparison
通过和 Naive Seq2Seq 的对比，可以看到在相同的步数（1000），Naive Seq2Seq 已经能够完全正确预测下一句歌词，但是 Advanced Seq2Seq 不行，因为 **自己产生的token** 很有可能是不对的，前一步的错误会导致后面的崩坏。

## Optimization
进一步地优化方向主要有三个：
1. 增大数据集，只有一首歌词，仅仅是个 toy model
2. 加入 Attention 机制，具体就是在 decoder 阶段可以考虑对 encoder 的 outputs 进行 Attention 操作
3. 加入 [Scheduled Sampling](https://arxiv.org/abs/1506.03099) 



In [3]:
# load dataset
x = []
y =[]
i = 0
with open("time.txt", "r") as f:
    for l in f:
        if i % 2 == 0:
            x.append(l.strip())
        else:
            y.append(l.strip())
        i+=1

In [5]:
from tensorflow.keras.preprocessing.text import Tokenizer
import jieba
# Tokenizer 
tokenizer = Tokenizer(oov_token='<UNK>')
words = []
for s in x:
    words.extend(jieba.cut(s))
for s in y:
    words.extend(jieba.cut(s))
    
tokenizer.fit_on_texts(words)

Building prefix dict from the default dictionary ...
Loading model from cache /var/folders/0l/3x73_lfs6czgjngbxbtn1vrh0000gn/T/jieba.cache
Loading model cost 0.893 seconds.
Prefix dict has been built succesfully.


In [6]:
from tensorflow.contrib.rnn import LSTMCell, LSTMStateTuple

In [26]:
train = []
for s in x:
    xx = tokenizer.texts_to_sequences(list(jieba.cut(s)))
#     print(list(jieba.cut(s)))
#     print(s)
    if len(xx) > 0:
        idx = [ xxx[0] for xxx in xx if xxx != []]
        train.append(idx)
test = []
for s in y:
    xx = tokenizer.texts_to_sequences(list(jieba.cut(s)))
#     print(list(jieba.cut(s)))
#     print(s)
    if len(xx) > 0:
        idx = [ xxx[0] for xxx in xx if xxx != []]
        test.append(idx)

In [8]:
vocab_size = len(tokenizer.word_docs) + 2 # pad + eos

In [70]:
input_embedding_size = 16

encoder_hidden_units = 8
decoder_hidden_units = encoder_hidden_units * 2

graph = tf.Graph()
with graph.as_default():
    encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs')
    encoder_inputs_length = tf.placeholder(shape=(None,), dtype=tf.int32, name='encoder_inputs_length')

    decoder_targets = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_targets')
    decoder_lengths = tf.placeholder(shape=(None,), dtype=tf.int32, name='decoder_targets_length')
    
    embeddings = tf.Variable(tf.random_uniform([vocab_size, input_embedding_size], -1.0, 1.0), dtype=tf.float32)

    encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)
    with tf.name_scope("encoder"):
        encoder_cell = LSTMCell(encoder_hidden_units)

        ((encoder_fw_outputs,
      encoder_bw_outputs), 
         (encoder_fw_final_state,
      encoder_bw_final_state)) = (
        tf.nn.bidirectional_dynamic_rnn(cell_fw=encoder_cell,
                                        cell_bw=encoder_cell,
                                        inputs=encoder_inputs_embedded,
                                        sequence_length=encoder_inputs_length,
                                        dtype=tf.float32, time_major=True)
        )
    
        # concaten fw and bw outputs
        encoder_outputs = tf.concat((encoder_fw_outputs, encoder_bw_outputs), 2)
        # concaten fw and bw state_c
        encoder_final_state_c = tf.concat(
            (encoder_fw_final_state.c, encoder_bw_final_state.c), 1)
        # # concaten fw and bw state_h
        encoder_final_state_h = tf.concat(
            (encoder_fw_final_state.h, encoder_bw_final_state.h), 1)
        # format as LSTMStateTuple
        encoder_final_state = LSTMStateTuple(
            c=encoder_final_state_c,
            h=encoder_final_state_h
        )
        
    with tf.name_scope("decoder"):
        decoder_cell = LSTMCell(decoder_hidden_units)
        encoder_max_time, batch_size = tf.unstack(tf.shape(encoder_inputs))
        # not sure
#         decoder_lengths = encoder_inputs_length + 3 # +2 additional steps, +1 leading <EOS> token for decoder inputs
    with tf.name_scope("output"):
        W = tf.Variable(tf.random_uniform([decoder_hidden_units, vocab_size], -1, 1), dtype=tf.float32)
        b = tf.Variable(tf.zeros([vocab_size]), dtype=tf.float32)
    
    # assert EOS == 1 and PAD == 0

    eos_time_slice = tf.zeros([batch_size], dtype=tf.int32, name='EOS')
#     pad_time_slice = tf.zeros([batch_size], dtype=tf.int32, name='PAD')

    eos_step_embedded = tf.nn.embedding_lookup(embeddings, eos_time_slice)
#     pad_step_embedded = tf.nn.embedding_lookup(embeddings, pad_time_slice)
    
    # Loop transition function is a mapping (time, previous_cell_output, previous_cell_state, previous_loop_state)
    #               ->     (elements_finished, input, cell_state, output, loop_state).
    def loop_fn_initial():
        initial_elements_finished = (0 >= decoder_lengths)  # all False at the initial step
        initial_input = eos_step_embedded
        initial_cell_state = encoder_final_state
        initial_cell_output = None
        initial_loop_state = None  # we don't need to pass any additional information
        return (initial_elements_finished,
                initial_input,
                initial_cell_state,
                initial_cell_output,
                initial_loop_state)
    
    def loop_fn_transition(time, previous_output, previous_state, previous_loop_state):

        def get_next_input():
            # Wx + b -> logits
            output_logits = tf.add(tf.matmul(previous_output, W), b)
            prediction = tf.argmax(output_logits, axis=1)
            next_input = tf.nn.embedding_lookup(embeddings, prediction)
            return next_input

        elements_finished = (time >= decoder_lengths) # this operation produces boolean tensor of [batch_size]
                                                      # defining if corresponding sequence has ended

        finished = tf.reduce_all(elements_finished) # -> boolean scalar
        input = tf.cond(finished, lambda: eos_step_embedded, get_next_input) # if finished, then return pad_step_embeded
        state = previous_state
        output = previous_output
        loop_state = None

        return (elements_finished, 
                input,
                state,
                output,
                loop_state)
    def loop_fn(time, previous_output, previous_state, previous_loop_state):
        if previous_state is None:    # time == 0
            assert previous_output is None and previous_state is None
            return loop_fn_initial()
        else:
            return loop_fn_transition(time, previous_output, previous_state, previous_loop_state)

    decoder_outputs_ta, decoder_final_state, _ = tf.nn.raw_rnn(decoder_cell, loop_fn)
    decoder_outputs = decoder_outputs_ta.stack()
    print(decoder_outputs) # [max_steps, batch_size, hidden_dim] 
    
    # flatten outputs to [max_steps * batch_sizes, hidden_dim]
    decoder_max_steps, decoder_batch_size, decoder_dim = tf.unstack(tf.shape(decoder_outputs))
    decoder_outputs_flat = tf.reshape(decoder_outputs, (-1, decoder_dim))
    # logits = Wx+ b
    decoder_logits_flat = tf.add(tf.matmul(decoder_outputs_flat, W), b)
    print(decoder_logits_flat.shape)
    # reshape back to [max_step, batch_size, vocab]
    decoder_logits = tf.reshape(decoder_logits_flat, (decoder_max_steps, decoder_batch_size, vocab_size))
    # get prediction
    decoder_prediction = tf.argmax(decoder_logits, 2)
    print(decoder_targets.shape) # 
    print(vocab_size)
    print(decoder_logits.shape)
    # logits_size=[338,210] labels_size=[364,210] 
    print(tf.one_hot(decoder_targets, depth=vocab_size, dtype=tf.float32))
    stepwise_cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(
    labels=tf.one_hot(decoder_targets, depth=vocab_size, dtype=tf.float32),
    logits=decoder_logits,
)

    loss = tf.reduce_mean(stepwise_cross_entropy)
    train_op = tf.train.AdamOptimizer().minimize(loss)

Tensor("TensorArrayStack/TensorArrayGatherV3:0", shape=(?, ?, 16), dtype=float32)
(?, 210)
(?, ?)
210
(?, ?, 210)
Tensor("one_hot:0", shape=(?, ?, 210), dtype=float32)


In [16]:
def batch(inputs, max_sequence_length=None):
    """
    Args:
        inputs:
            list of sentences (integer lists)
        max_sequence_length:
            integer specifying how large should `max_time` dimension be.
            If None, maximum sequence length would be used
    
    Outputs:
        inputs_time_major:
            input sentences transformed into time-major matrix 
            (shape [max_time, batch_size]) padded with 0s
        sequence_lengths:
            batch-sized list of integers specifying amount of active 
            time steps in each input sequence
    """
    
    sequence_lengths = [len(seq) for seq in inputs]
    batch_size = len(inputs)
    
    if max_sequence_length is None:
        max_sequence_length = max(sequence_lengths)
    
    inputs_batch_major = np.zeros(shape=[batch_size, max_sequence_length], dtype=np.int32) # == PAD
    
    for i, seq in enumerate(inputs):
        for j, element in enumerate(seq):
            inputs_batch_major[i, j] = element

    # [batch_size, max_time] -> [max_time, batch_size]
    inputs_time_major = inputs_batch_major.swapaxes(0, 1)

    return inputs_time_major, sequence_lengths

In [27]:
train_time, train_lengths = batch(train)

In [54]:
print(train_time.shape)
print(max(train_lengths))

(10, 26)
10


In [56]:
test_time, test_lengths = batch(test)
print(test_time.shape)
print(max(test_lengths))

(11, 26)
11


In [51]:
PAD = 0
EOS = len(tokenizer.word_index) + 1
test_targets_, _ = batch(
        [(sequence)  for sequence in test]
    )

In [58]:
print(test_targets_.shape)

(11, 26)


In [77]:
word2idx = tokenizer.word_index
id2word = {k: v for v, k in zip(word2idx.keys(), word2idx.values())}
def translate(word_indexs):
    words = []
    for idx in word_indexs:
        word = id2word.get(idx)
#         print(word)
        if word:
            words.append(id2word.get(idx))
        else:
            words.append("<UNK>")
    return "".join(words)

In [82]:
import random

epochs = 10000

with tf.Session(graph=graph) as sess:
    sess.run(tf.global_variables_initializer())
    for e in range(epochs):
        l, _ = sess.run([loss, train_op], feed_dict={
                encoder_inputs: train_time,
                encoder_inputs_length: train_lengths,
                decoder_targets: test_targets_,
                decoder_lengths: test_lengths
        })
#         print(l)
        if e % 50 == 0:
    
            pred = sess.run(decoder_prediction, feed_dict={
                encoder_inputs: train_time,
                encoder_inputs_length: train_lengths,
#                 decoder_targets: test_targets_,
                decoder_lengths: [ i + 1 for i in train_lengths]
            })
            
#             print(translate(np.transpose(pred)[0]))

            rand_idx = (int) (random.random() * (train_time.shape[1]) )
            print("epoch: %d loss : %f" %( e, l))
            print("input: ", translate(np.transpose(train_time)[rand_idx]))
            print("predict: ", translate(np.transpose(pred)[rand_idx]))
            print("ground truth: ",translate(np.transpose(test_time)[rand_idx]) )
            print("-----------------------------------------")
#             print("Input: ")
            
            
    

epoch: 0 loss : 5.367332
input:  我发现我早已经长大<UNK><UNK><UNK><UNK>
predict:  尽如刀尽如刀如刀独自如刀<UNK><UNK><UNK><UNK>
ground truth:  我发现我早不说谎话<UNK><UNK><UNK><UNK>
-----------------------------------------
epoch: 50 loss : 5.194360
input:  我渐渐长大成人<UNK><UNK><UNK><UNK><UNK><UNK><UNK>
predict:  起跑的hiphip时间时间时间时间时间时间时间
ground truth:  眼看着世界沉沦<UNK><UNK><UNK><UNK><UNK><UNK><UNK>
-----------------------------------------
epoch: 100 loss : 4.877698
input:  脱掉了曾经<UNK><UNK><UNK><UNK><UNK><UNK><UNK>
predict:  母亲的的的时间时间时间时间时间时间时间
ground truth:  认为很时尚的大肥裤子<UNK><UNK><UNK><UNK><UNK>
-----------------------------------------
epoch: 150 loss : 4.447875
input:  我闭上双眼祈祷<UNK><UNK><UNK><UNK><UNK><UNK>
predict:  我我我以前忘时间时间时间时间时间时间
ground truth:  我知道努力学习以后才能把歌写好<UNK><UNK>
-----------------------------------------
epoch: 200 loss : 4.064038
input:  经典的就像是oldschoolflow<UNK><UNK><UNK>
predict:  从从质疑质疑质疑质疑质疑质疑时间时间时间
ground truth:  从不用质疑<UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK>
-----------------------------------------
epoch: 2

epoch: 2000 loss : 0.833663
input:  我闭上双眼祈祷<UNK><UNK><UNK><UNK><UNK><UNK>
predict:  我知道努力学习以后才能<UNK><UNK><UNK><UNK><UNK><UNK>
ground truth:  我知道努力学习以后才能把歌写好<UNK><UNK>
-----------------------------------------
epoch: 2050 loss : 0.798718
input:  小霸王游戏机<UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK>
predict:  陪我度过<UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK>
ground truth:  陪我度过了一段时期<UNK><UNK><UNK><UNK><UNK>
-----------------------------------------
epoch: 2100 loss : 0.764761
input:  特别的怀念<UNK><UNK><UNK><UNK><UNK><UNK><UNK>
predict:  回到童年童年童年<UNK><UNK><UNK><UNK><UNK><UNK><UNK>
ground truth:  回到童年<UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK>
-----------------------------------------
epoch: 2150 loss : 0.731806
input:  一路走汗在流再回首已没有<UNK><UNK>
predict:  那双扶着我不跌倒的手<UNK><UNK><UNK>
ground truth:  那双扶着我不跌倒的手<UNK><UNK><UNK>
-----------------------------------------
epoch: 2200 loss : 0.699868
input:  再没骗自己的理由<UNK><UNK><UNK><UNK><UNK>
predict:  时间如刀不再温柔温柔如刀<UNK><UNK><UNK><UNK><UNK>
ground truth:  时间如刀不再温柔<UNK

epoch: 4000 loss : 0.134781
input:  小霸王游戏机<UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK>
predict:  陪我度过<UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK>
ground truth:  陪我度过了一段时期<UNK><UNK><UNK><UNK><UNK>
-----------------------------------------
epoch: 4050 loss : 0.129450
input:  小霸王游戏机<UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK>
predict:  陪我度过<UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK>
ground truth:  陪我度过了一段时期<UNK><UNK><UNK><UNK><UNK>
-----------------------------------------
epoch: 4100 loss : 0.124378
input:  校门口美味的路边摊还在不在<UNK><UNK>
predict:  包里塞的小吃那是外婆给的<UNK><UNK>
ground truth:  包里塞的小吃那是外婆给的爱<UNK>
-----------------------------------------
epoch: 4150 loss : 0.119552
input:  再没骗自己的理由<UNK><UNK><UNK><UNK><UNK>
predict:  时间如刀不再温柔温柔<UNK><UNK><UNK><UNK><UNK><UNK>
ground truth:  时间如刀不再温柔<UNK><UNK><UNK><UNK><UNK><UNK><UNK>
-----------------------------------------
epoch: 4200 loss : 0.114957
input:  街边的落日和小时候的复读机<UNK><UNK><UNK>
predict:  度过了末日但我要比以前<UNK><UNK><UNK>
ground truth:  度过了末日但我要比以前有出息<UNK>
------------

epoch: 6000 loss : 0.034075
input:  不用问我的来路<UNK><UNK><UNK><UNK><UNK>
predict:  贫民窟的艺术家艺术家的艺术家<UNK><UNK><UNK><UNK><UNK>
ground truth:  贫民窟的艺术家<UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK>
-----------------------------------------
epoch: 6050 loss : 0.033074
input:  那心酸的滋味我只能独自体会<UNK><UNK>
predict:  想要回到过去但是时光不能倒退倒退倒退<UNK><UNK>
ground truth:  想要回到过去但是时光不能倒退<UNK><UNK><UNK><UNK>
-----------------------------------------
epoch: 6100 loss : 0.032107
input:  经典的就像是oldschoolflow<UNK><UNK><UNK>
predict:  从不用质疑质疑质疑回不去早已质疑<UNK><UNK><UNK>
ground truth:  从不用质疑<UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK>
-----------------------------------------
epoch: 6150 loss : 0.031172
input:  当我没日没夜工作从长水起飞<UNK><UNK><UNK>
predict:  时间反复催促让我忘了疲惫<UNK><UNK><UNK>
ground truth:  时间反复催促让我忘了疲惫<UNK><UNK><UNK>
-----------------------------------------
epoch: 6200 loss : 0.030269
input:  再没骗自己的理由<UNK><UNK><UNK><UNK><UNK>
predict:  时间如刀不再温柔温柔<UNK><UNK><UNK><UNK><UNK><UNK>
ground truth:  时间如刀不再温柔<UNK><UNK><UNK><UNK><UNK><UNK><UNK>
-

epoch: 7950 loss : 0.011557
input:  课桌上面摆着老师送的铅笔刀<UNK><UNK><UNK>
predict:  她教我如何起跑起跑起跑是质疑疼<UNK><UNK><UNK>
ground truth:  她教我如何起跑<UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK>
-----------------------------------------
epoch: 8000 loss : 0.011259
input:  小霸王游戏机<UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK>
predict:  陪我度过<UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK>
ground truth:  陪我度过了一段时期<UNK><UNK><UNK><UNK><UNK>
-----------------------------------------
epoch: 8050 loss : 0.010969
input:  所以我努力的跑<UNK><UNK><UNK><UNK><UNK>
predict:  把一切全部看透看透从前<UNK><UNK><UNK><UNK><UNK>
ground truth:  把一切全部看透<UNK><UNK><UNK><UNK><UNK><UNK><UNK>
-----------------------------------------
epoch: 8100 loss : 0.010688
input:  脱掉了曾经<UNK><UNK><UNK><UNK><UNK><UNK><UNK>
predict:  认为很时尚的<UNK><UNK><UNK><UNK><UNK><UNK><UNK>
ground truth:  认为很时尚的大肥裤子<UNK><UNK><UNK><UNK><UNK>
-----------------------------------------
epoch: 8150 loss : 0.010414
input:  当我没日没夜工作从长水起飞<UNK><UNK><UNK>
predict:  时间反复催促让我忘了疲惫<UNK><UNK><UNK>
ground truth:  时间反