In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from tensorflow.python.layers import core as core_layers
import tensorflow as tf
import numpy as np
import time
import myResidualCell
import jieba
from bleu import BLEU
import random
import cPickle
import matplotlib.pyplot as plt

def get_initializer(matrix):
    def _initializer(shape, dtype=None, partition_info=None, **kwargs): return matrix
    return _initializer

class Seq2Seq:
    def __init__(self, dp, rnn_size, n_layers, encoder_embedding_dim, decoder_embedding_dim, max_infer_length,
                 sess=tf.Session(), lr=0.0001, grad_clip=5.0, beam_width=5, force_teaching_ratio=1.0, beam_penalty=1.0,
                residual=False, output_keep_prob=0.5, input_keep_prob=0.9, cell_type='lstm', reverse=False,
                encoder_pre_embedding=None, decoder_pre_embedding=None, decay_scheme='luong234', is_save=True, emb_fix=False):
        
        self.rnn_size = rnn_size
        self.n_layers = n_layers
        self.grad_clip = grad_clip
        self.dp = dp
        self.encoder_embedding_dim = encoder_embedding_dim
        self.decoder_embedding_dim = decoder_embedding_dim
        self.encoder_pre_embedding = encoder_pre_embedding
        self.decoder_pre_embedding = decoder_pre_embedding
        self.beam_width = beam_width
        self.beam_penalty = beam_penalty
        self.max_infer_length = max_infer_length
        self.residual = residual
        self.decay_scheme = decay_scheme
        if self.residual:
            assert encoder_embedding_dim == rnn_size
            assert decoder_embedding_dim == rnn_size
        self.reverse = reverse
        self.emb_fix = emb_fix
        self.cell_type = cell_type
        self.force_teaching_ratio = force_teaching_ratio
        self._output_keep_prob = output_keep_prob
        self._input_keep_prob = input_keep_prob
        self.sess = sess
        self.lr=lr
        self.build_graph()
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(tf.global_variables(), max_to_keep = 5)
        self.summary_placeholders, self.update_ops, self.summary_op = self.setup_summary()
        self.is_save = is_save
        
    # end constructor

    def build_graph(self):
        self.register_symbols()
        self.add_input_layer()
        self.add_encoder_layer()
        with tf.variable_scope('decode'):
            self.add_decoder_for_training()
        with tf.variable_scope('decode', reuse=True):
            self.add_decoder_for_inference()
        #with tf.variable_scope('decode', reuse=True):
        #    self.add_decoder_for_prefix_inference()
        #with tf.variable_scope('decode', reuse=True):
        #    self.add_decoder_for_greedy_inference()
        self.add_backward_path()
    # end method

    def add_input_layer(self):
        self.X = tf.placeholder(tf.int32, [None, None], name="X")
        self.Y = tf.placeholder(tf.int32, [None, None], name="Y")
        self.X_seq_len = tf.placeholder(tf.int32, [None], name="X_seq_len")
        self.Y_seq_len = tf.placeholder(tf.int32, [None], name="Y_seq_len")
        self.input_keep_prob = tf.placeholder(tf.float32,name="input_keep_prob")
        self.output_keep_prob = tf.placeholder(tf.float32,name="output_keep_prob")
        self.batch_size = tf.shape(self.X)[0]
        self.global_step = tf.Variable(0, name="global_step", trainable=False)
    # end method

    def single_cell(self, reuse=False):
        if self.cell_type == 'lstm':
             cell = tf.contrib.rnn.LayerNormBasicLSTMCell(self.rnn_size, reuse=reuse)
        else:
            cell = tf.contrib.rnn.GRUBlockCell(self.rnn_size)    
        cell = tf.contrib.rnn.DropoutWrapper(cell, self.output_keep_prob, self.input_keep_prob)
        if self.residual:
            cell = myResidualCell.ResidualWrapper(cell)
        return cell
    
    def add_encoder_layer(self):
        if type(self.encoder_pre_embedding) != type(None):
            if self.emb_fix:
                encoder_embedding = tf.get_variable('encoder_embedding', [len(self.dp.X_w2id), self.encoder_embedding_dim],
                                                     tf.float32, initializer=get_initializer(self.encoder_pre_embedding), trainable=False) 
            else:    
                encoder_embedding = tf.get_variable('encoder_embedding', [len(self.dp.X_w2id), self.encoder_embedding_dim],
                                                     tf.float32, initializer=get_initializer(self.encoder_pre_embedding)) 
        else:
            encoder_embedding = tf.get_variable('encoder_embedding', [len(self.dp.X_w2id), self.encoder_embedding_dim],
                                                 tf.float32, tf.random_uniform_initializer(-1.0, 1.0))
        
        self.encoder_inputs = tf.nn.embedding_lookup(encoder_embedding, self.X)
        bi_encoder_output, bi_encoder_state = tf.nn.bidirectional_dynamic_rnn(
            cell_fw = tf.contrib.rnn.MultiRNNCell([self.single_cell() for _ in range(self.n_layers)]), 
            cell_bw = tf.contrib.rnn.MultiRNNCell([self.single_cell() for _ in range(self.n_layers)]),
            inputs = self.encoder_inputs,
            sequence_length = self.X_seq_len,
            dtype = tf.float32,
            scope = 'bidirectional_rnn')
        self.encoder_out = tf.concat(bi_encoder_output, 2)
        encoder_state = []
        for layer_id in range(self.n_layers):
            encoder_state.append(bi_encoder_state[0][layer_id])  # forward
            encoder_state.append(bi_encoder_state[1][layer_id])  # backward
        self.encoder_state = tuple(encoder_state)

    def processed_decoder_input(self):
        main = tf.strided_slice(self.Y, [0, 0], [self.batch_size, -1], [1, 1]) # remove last char
        decoder_input = tf.concat([tf.fill([self.batch_size, 1], self._y_go), main], 1)
        return decoder_input

    def add_attention_for_training(self):
        if self.cell_type == 'lstm':
            attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
                num_units = self.rnn_size, 
                memory = self.encoder_out,
                memory_sequence_length = self.X_seq_len,
                normalize=True)
        else:
            attention_mechanism = tf.contrib.seq2seq.LuongAttention(
                num_units = self.rnn_size, 
                memory = self.encoder_out,
                memory_sequence_length = self.X_seq_len,
                scale=True)
        
        self.decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
            cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(2 * self.n_layers)]),
            attention_mechanism = attention_mechanism,
            alignment_history = True,
            attention_layer_size = self.rnn_size)

    def add_decoder_for_training(self):
        self.add_attention_for_training()
        if type(self.decoder_pre_embedding) != type(None):
            if self.emb_fix:
                decoder_embedding = tf.get_variable('decoder_embedding', [len(self.dp.Y_w2id), self.decoder_embedding_dim],
                                                     tf.float32, initializer=get_initializer(self.decoder_pre_embedding), trainable=False)
            else:
                decoder_embedding = tf.get_variable('decoder_embedding', [len(self.dp.Y_w2id), self.decoder_embedding_dim],
                                                     tf.float32, initializer=get_initializer(self.decoder_pre_embedding)) 
        else:
            decoder_embedding = tf.get_variable('decoder_embedding', [len(self.dp.Y_w2id), self.decoder_embedding_dim],
                                                 tf.float32, tf.random_uniform_initializer(-1.0, 1.0))
        training_helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(
            inputs = tf.nn.embedding_lookup(decoder_embedding, self.processed_decoder_input()),
            sequence_length = self.Y_seq_len,
            embedding = decoder_embedding,
            sampling_probability = 1 - self.force_teaching_ratio,
            time_major = False)
        training_decoder = tf.contrib.seq2seq.BasicDecoder(
            cell = self.decoder_cell,
            helper = training_helper,
            initial_state = self.decoder_cell.zero_state(self.batch_size, tf.float32).clone(cell_state=self.encoder_state),
            output_layer = core_layers.Dense(len(self.dp.Y_w2id)))
        training_decoder_output, self.train_final_state, _ = tf.contrib.seq2seq.dynamic_decode(
            decoder = training_decoder,
            impute_finished = True,
            maximum_iterations = tf.reduce_max(self.Y_seq_len))
        self.training_logits = training_decoder_output.rnn_output
        self.init_prefix_state = self.train_final_state

    def add_attention_for_inference(self):
        self.encoder_out_tiled = tf.contrib.seq2seq.tile_batch(self.encoder_out, self.beam_width)
        self.encoder_state_tiled = tf.contrib.seq2seq.tile_batch(self.encoder_state, self.beam_width)
        self.X_seq_len_tiled = tf.contrib.seq2seq.tile_batch(self.X_seq_len, self.beam_width)
        if self.cell_type == 'lstm':
            attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
                num_units = self.rnn_size, 
                memory = self.encoder_out_tiled,
                memory_sequence_length = self.X_seq_len_tiled,
                normalize=True)
        else:
            attention_mechanism = tf.contrib.seq2seq.LuongAttention(
                num_units = self.rnn_size, 
                memory = self.encoder_out_tiled,
                memory_sequence_length = self.X_seq_len_tiled,
                scale=True)
        self.decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
            cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell(reuse=True) for _ in range(2 * self.n_layers)]),
            attention_mechanism = attention_mechanism,
            attention_layer_size = self.rnn_size)
        
    def add_attention_for_greedy_inference(self):
        if self.cell_type == 'lstm':
            attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
                num_units = self.rnn_size, 
                memory = self.encoder_out,
                memory_sequence_length = self.X_seq_len,
                normalize=True)
        else:
            attention_mechanism = tf.contrib.seq2seq.LuongAttention(
                num_units = self.rnn_size, 
                memory = self.encoder_out,
                memory_sequence_length = self.X_seq_len,
                scale=True)
        
        self.decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
            cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(2 * self.n_layers)]),
            attention_mechanism = attention_mechanism,
            alignment_history = True,
            attention_layer_size = self.rnn_size)
        
    def add_decoder_for_greedy_inference(self):
        self.add_attention_for_greedy_inference()
        decoder_embedding = tf.get_variable('decoder_embedding')
        greedy_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
            embedding= decoder_embedding, 
            start_tokens = tf.tile(tf.constant([self._y_go], dtype=tf.int32), [self.batch_size]), 
            end_token = self._y_eos)
        greedy_decoder = tf.contrib.seq2seq.BasicDecoder(
            cell = self.decoder_cell,
            helper = greedy_helper,
            initial_state = self.decoder_cell.zero_state(self.batch_size, tf.float32).clone(cell_state=self.encoder_state),
            output_layer = core_layers.Dense(len(self.dp.Y_w2id)))
        greedy_decoder_output, self.greedy_final_state, _ = tf.contrib.seq2seq.dynamic_decode(
            decoder = greedy_decoder,
            impute_finished = False,
            maximum_iterations = self.max_infer_length)
       
        self.greedy_output = greedy_decoder_output.sample_id  
        self.alignment_history = self.greedy_final_state.alignment_history.stack()
        
    def add_decoder_for_inference(self):
        self.add_attention_for_inference()
        ini = self.decoder_cell.zero_state(self.batch_size * self.beam_width, tf.float32).clone(
                            cell_state = self.encoder_state_tiled)
        predicting_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
            cell = self.decoder_cell,
            embedding = tf.get_variable('decoder_embedding'),
            start_tokens = tf.tile(tf.constant([self._y_go], dtype=tf.int32), [self.batch_size]),
            end_token = self._y_eos,
            initial_state = self.decoder_cell.zero_state(self.batch_size * self.beam_width, tf.float32).clone(
                            cell_state = self.encoder_state_tiled),
            beam_width = self.beam_width,
            output_layer = core_layers.Dense(len(self.dp.Y_w2id), _reuse=True),
            length_penalty_weight = self.beam_penalty)
        predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
            decoder = predicting_decoder,
            impute_finished = False,
            maximum_iterations = self.max_infer_length)
        self.predicting_ids = predicting_decoder_output.predicted_ids
        self.score = predicting_decoder_output.beam_search_decoder_output.scores
        
    def add_decoder_for_prefix_inference(self):
        self.add_attention_for_inference()

        prefix_cell_state = tf.contrib.seq2seq.tile_batch(self.init_prefix_state.cell_state, self.beam_width)
        prefix_attention = tf.contrib.seq2seq.tile_batch(self.init_prefix_state.attention, self.beam_width)
        prefix_time = self.init_prefix_state.time
        prefix_alignments = self.init_prefix_state.alignments
        prefix_alignment_history = self.init_prefix_state.alignment_history
        
        init_state = tf.contrib.seq2seq.AttentionWrapperState(cell_state=prefix_cell_state, 
                                                      attention=prefix_attention, 
                                                      time=prefix_time,
                                                      alignments=prefix_alignments,
                                                      alignment_history=prefix_alignment_history)
        predicting_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
            cell = self.decoder_cell,
            embedding = tf.get_variable('decoder_embedding'),
            start_tokens = tf.tile(tf.constant([self._y_go], dtype=tf.int32), [self.batch_size]),
            end_token = self._y_eos,
            initial_state = init_state,
            beam_width = self.beam_width,
            output_layer = core_layers.Dense(len(self.dp.Y_w2id), _reuse=True),
            length_penalty_weight = self.beam_penalty)
        self.prefix_go = tf.placeholder(tf.int32, [None])
        prefix_go_beam = tf.tile(tf.expand_dims(self.prefix_go, 1), [1, self.beam_width])
        prefix_emb = tf.nn.embedding_lookup(tf.get_variable('decoder_embedding'), prefix_go_beam)
        predicting_decoder._start_inputs = prefix_emb
        predicting_prefix_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
            decoder = predicting_decoder,
            impute_finished = False,
            maximum_iterations = self.max_infer_length)
        self.predicting_prefix_ids = predicting_prefix_decoder_output.predicted_ids
        self.prefix_score = predicting_prefix_decoder_output.beam_search_decoder_output.scores

    def add_backward_path(self):
        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)
        self.loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,
                                                     targets = self.Y,
                                                     weights = masks)
        self.batch_loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,
                                                     targets = self.Y,
                                                     weights = masks,
                                                     average_across_batch=False)
        params = tf.trainable_variables()
        gradients = tf.gradients(self.loss, params)
        clipped_gradients, _ = tf.clip_by_global_norm(gradients, self.grad_clip)
        self.learning_rate = tf.constant(self.lr)
        self.learning_rate = self.get_learning_rate_decay(self.decay_scheme)  # decay
        self.train_op = tf.train.AdamOptimizer(self.learning_rate).apply_gradients(zip(clipped_gradients, params), global_step=self.global_step)

    def register_symbols(self):
        self._x_go = self.dp.X_w2id['<GO>']
        self._x_eos = self.dp.X_w2id['<EOS>']
        self._x_pad = self.dp.X_w2id['<PAD>']
        self._x_unk = self.dp.X_w2id['<UNK>']
        
        self._y_go = self.dp.Y_w2id['<GO>']
        self._y_eos = self.dp.Y_w2id['<EOS>']
        self._y_pad = self.dp.Y_w2id['<PAD>']
        self._y_unk = self.dp.Y_w2id['<UNK>']
    
    def infer(self, input_word):
        input_indices = [self.dp.X_w2id.get(char, self._x_unk) for char in input_word]
        out_indices = self.sess.run(self.predicting_ids, {
            self.X: [input_indices], self.X_seq_len: [len(input_indices)], self.output_keep_prob:1, self.input_keep_prob:1})
        outputs = []
        for idx in range(out_indices.shape[-1]):
            eos_id = self.dp.Y_w2id['<EOS>']
            ot = out_indices[0,:,idx]
            if eos_id in ot:
                ot = ot.tolist()
                ot = ot[:ot.index(eos_id)]
            if self.reverse:
                ot = ot[::-1]
            output_str = ''.join([self.dp.Y_id2w.get(i, u'&') for i in ot])
            outputs.append(output_str)
        return outputs
    
    def batch_infer(self, input_words):
        input_indices = [[self.dp.X_w2id.get(char, self._x_unk) for char in input_word] for input_word in input_words]
        input_indices, lengths = self.dp.pad_sentence_batch(input_indices, self._x_pad)
        out_indices = self.sess.run(self.predicting_ids, {
            self.X: input_indices, self.X_seq_len: lengths, self.output_keep_prob:1, self.input_keep_prob:1})
        outputs = []
        for idx in range(out_indices.shape[0]):
            eos_id = self.dp.Y_w2id['<EOS>']
            ot = out_indices[idx,:,0]   # (batch, length, beam)
            if eos_id in ot: 
                ot = ot.tolist()
                ot = ot[:ot.index(eos_id)]
            if self.reverse:
                ot = ot[::-1]
            output_str = ''.join([self.dp.Y_id2w.get(i, u'&') for i in ot])
            outputs.append(output_str)
        assert len(outputs) == len(input_words)
        return outputs
    
    def prefix_infer(self, input_word, prefix):
        input_indices_X = [self.dp.X_w2id.get(char, self._x_unk) for char in input_word]
        input_indices_Y = [self.dp.Y_w2id.get(char, self._y_unk) for char in prefix]
        
        prefix_go = []
        prefix_go.append(input_indices_Y[-1]) 
        out_indices, scores = self.sess.run([self.predicting_prefix_ids, self.prefix_score], {
            self.X: [input_indices_X], self.X_seq_len: [len(input_indices_X)], self.Y:[input_indices_Y], self.Y_seq_len:[len(input_indices_Y)],
            self.prefix_go: prefix_go, self.input_keep_prob:1, self.output_keep_prob:1})
        
        outputs = []
        for idx in range(out_indices.shape[-1]):
            eos_id = self.dp.Y_w2id['<EOS>']
            ot = out_indices[0,:,idx]
            if eos_id in ot:
                ot = ot.tolist()
                ot = ot[:ot.index(eos_id)]
                if self.reverse:
                    ot = ot[::-1]
            if self.reverse:
                output_str = ''.join([self.dp.Y_id2w.get(i, u'&') for i in ot]) + prefix
            else:
                output_str = prefix + ''.join([self.dp.Y_id2w.get(i, u'&') for i in ot])
            outputs.append(output_str)
        return outputs
        
    
    def restore(self, path):
        self.saver.restore(self.sess, path)
        print 'restore %s success' % path
        
    def get_learning_rate_decay(self, decay_scheme='luong234'):
        num_train_steps = self.dp.num_steps
        if decay_scheme == "luong10":
            start_decay_step = int(num_train_steps / 2)
            remain_steps = num_train_steps - start_decay_step
            decay_steps = int(remain_steps / 10)  # decay 10 times
            decay_factor = 0.5
        else:
            start_decay_step = int(num_train_steps * 2 / 3)
            remain_steps = num_train_steps - start_decay_step
            decay_steps = int(remain_steps / 4)  # decay 4 times
            decay_factor = 0.5
        return tf.cond(
            self.global_step < start_decay_step,
            lambda: self.learning_rate,
            lambda: tf.train.exponential_decay(
                self.learning_rate,
                (self.global_step - start_decay_step),
                decay_steps, decay_factor, staircase=True),
            name="learning_rate_decay_cond")
    
    def setup_summary(self):
        train_loss = tf.Variable(0.)
        tf.summary.scalar('Train_loss', train_loss)
        
        test_loss = tf.Variable(0.)
        tf.summary.scalar('Test_loss', test_loss)
        
        bleu_score = tf.Variable(0.)
        tf.summary.scalar('BLEU_score', bleu_score)

        tf.summary.scalar('lr_rate', self.learning_rate)
        
        summary_vars = [train_loss, test_loss, bleu_score]
        summary_placeholders = [tf.placeholder(tf.float32) for _ in xrange(len(summary_vars))]
        update_ops = [summary_vars[i].assign(summary_placeholders[i]) for i in xrange(len(summary_vars))]
        summary_op = tf.summary.merge_all()
        return summary_placeholders, update_ops, summary_op
    
    def add_visual_summary(self):
        """create attention image and attention summary."""
        self.attention_alignment = (self.train_final_state.alignment_history.stack())
        # Reshape to (batch, src_seq_len, tgt_seq_len,1)
        attention_images = tf.expand_dims(
              tf.transpose(self.attention_alignment, [1, 2, 0]), -1)
        # Scale to range [0, 255]
        attention_images *= 255
        self.attention_summary = tf.summary.image("attention_images", attention_images)
        
    def write_attention_summary(self, summary_writer, X, Y):
        if self.reverse:
            Y = Y[::-1]
        input_indices = [self.dp.X_w2id.get(char, self._x_unk) for char in X]
        output_indices = [self.dp.Y_w2id.get(char, self._x_unk) for char in Y]

        summary_str, attention_alignment = self.sess.run([self.attention_summary, self.attention_alignment], {self.X: [input_indices],
                                                     self.Y: [output_indices],
                                                     self.X_seq_len: [len(input_indices)],
                                                     self.Y_seq_len: [len(output_indices)],
                                                     self.output_keep_prob:1,
                                                     self.input_keep_prob:1})
        #print 'write summary'
        summary_writer.add_summary(summary_str, 1)
        return attention_alignment
    
    def show_attention(self, X, Y, is_show=True):
        if self.reverse:
            Y = Y[::-1]
        input_indices = [self.dp.X_w2id.get(char, self._x_unk) for char in X]
        output_indices = [self.dp.Y_w2id.get(char, self._x_unk) for char in Y]

        attention_alignment = self.sess.run(self.attention_alignment, {self.X: [input_indices],
                                                     self.Y: [output_indices],
                                                     self.X_seq_len: [len(input_indices)],
                                                     self.Y_seq_len: [len(output_indices)],
                                                     self.output_keep_prob:1,
                                                     self.input_keep_prob:1})
        attention_alignment = attention_alignment.transpose((1,0,2))[0]  # (batch, tgt_seq_len, src_seq_len)
        assert attention_alignment.shape[0] == len(Y)
        assert attention_alignment.shape[1] == len(X)
        if is_show:
            for i,yw in enumerate(Y):
                print yw, " :",
                for j,x in enumerate(X):
                    print "%.2f%s" % (attention_alignment[i][j], x),
                print ""
        return attention_alignment

In [2]:
class Seq2Seq_DP:
    def __init__(self, X_indices, Y_indices, X_w2id, Y_w2id, BATCH_SIZE, n_epoch):
        assert len(X_indices) == len(Y_indices)
        num_test = int(len(X_indices) * 0.1)
        self.n_epoch = n_epoch
        self.X_train = np.array(X_indices[num_test:])
        self.Y_train = np.array(Y_indices[num_test:])
        self.X_test = np.array(X_indices[:num_test])
        self.Y_test = np.array(Y_indices[:num_test])
        self.num_batch = int(len(self.X_train) / BATCH_SIZE)
        self.num_steps = self.num_batch * self.n_epoch
        self.batch_size = BATCH_SIZE
        self.X_w2id = X_w2id
        self.X_id2w = dict(zip(X_w2id.values(), X_w2id.keys()))
        self.Y_w2id = Y_w2id
        self.Y_id2w = dict(zip(Y_w2id.values(), Y_w2id.keys()))
        self._x_pad = self.X_w2id['<PAD>']
        self._y_pad = self.Y_w2id['<PAD>']
        print 'Train_data: %d | Test_data: %d | Batch_size: %d | Num_batch: %d | X_vocab_size: %d | Y_vocab_size: %d' % (len(self.X_train), len(self.X_test), BATCH_SIZE, self.num_batch, len(self.X_w2id), len(self.Y_w2id))
        
    def next_batch(self, X, Y):
        r = np.random.permutation(len(X))
        X = X[r]
        Y = Y[r]
        for i in range(0, len(X) - len(X) % self.batch_size, self.batch_size):
            X_batch = X[i : i + self.batch_size]
            Y_batch = Y[i : i + self.batch_size]
            padded_X_batch, X_batch_lens = self.pad_sentence_batch(X_batch, self._x_pad)
            padded_Y_batch, Y_batch_lens = self.pad_sentence_batch(Y_batch, self._y_pad)
            yield (np.array(padded_X_batch),
                   np.array(padded_Y_batch),
                   X_batch_lens,
                   Y_batch_lens)
    
    def sample_test_batch(self):
        padded_X_batch, X_batch_lens = self.pad_sentence_batch(self.X_test[: self.batch_size], self._x_pad)
        padded_Y_batch, Y_batch_lens = self.pad_sentence_batch(self.Y_test[: self.batch_size], self._y_pad)
        return np.array(padded_X_batch), np.array(padded_Y_batch), X_batch_lens, Y_batch_lens
        
    def pad_sentence_batch(self, sentence_batch, pad_int):
        padded_seqs = []
        seq_lens = []
        max_sentence_len = max([len(sentence) for sentence in sentence_batch])
        for sentence in sentence_batch:
            padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))
            seq_lens.append(len(sentence))
        return padded_seqs, seq_lens


In [3]:
class Seq2Seq_util:
    def __init__(self, dp, model, summary_path='Summary', display_freq=3):
        self.display_freq = display_freq
        self.dp = dp
        self.model = model
        self.summary_path = summary_path
    
    def train(self, epoch):
        avg_loss = 0.0
        tic = time.time()
        X_test_batch, Y_test_batch, X_test_batch_lens, Y_test_batch_lens = self.dp.sample_test_batch()
        for local_step, (X_train_batch, Y_train_batch, X_train_batch_lens, Y_train_batch_lens) in enumerate(
            self.dp.next_batch(self.dp.X_train, self.dp.Y_train)):
            self.model.step, _, loss = self.model.sess.run([self.model.global_step, self.model.train_op, self.model.loss], 
                                          {self.model.X: X_train_batch,
                                           self.model.Y: Y_train_batch,
                                           self.model.X_seq_len: X_train_batch_lens,
                                           self.model.Y_seq_len: Y_train_batch_lens,
                                           self.model.output_keep_prob:self.model._output_keep_prob,
                                           self.model.input_keep_prob:self.model._input_keep_prob})
            avg_loss += loss
            # summary
            """
            if local_step % 10 == 0:
                val_loss = self.model.sess.run(self.model.loss, {self.model.X: X_test_batch,
                                                     self.model.Y: Y_test_batch,
                                                     self.model.X_seq_len: X_test_batch_lens,
                                                     self.model.Y_seq_len: Y_test_batch_lens,
                                                     self.model.output_keep_prob:1,
                                                     self.model.input_keep_prob:1})
                stats = [avg_loss/(local_step+1), val_loss, 0]
                for i in xrange(len(stats)):
                    self.model.sess.run(self.model.update_ops[i], feed_dict={
                        self.model.summary_placeholders[i]: float(stats[i])
                    })
                summary_str = self.model.sess.run(self.model.summary_op)
                #print 'write summary'
                self.summary_writer.add_summary(summary_str, self.model.step + 1)
            """
            if local_step % (self.dp.num_batch / self.display_freq) == 0:
                val_loss = self.model.sess.run(self.model.loss, {self.model.X: X_test_batch,
                                                     self.model.Y: Y_test_batch,
                                                     self.model.X_seq_len: X_test_batch_lens,
                                                     self.model.Y_seq_len: Y_test_batch_lens,
                                                     self.model.output_keep_prob:1,
                                                     self.model.input_keep_prob:1})
                print "Epoch %d/%d | Batch %d/%d | Train_loss: %.3f | Test_loss: %.3f | Time_cost:%.3f" % (epoch, self.n_epoch, local_step, self.dp.num_batch, avg_loss / (local_step + 1), val_loss, time.time()-tic)
                self.cal()
                tic = time.time()
            
        return avg_loss / self.dp.num_batch
    
    def test(self):
        avg_loss = 0.0
        for local_step, (X_test_batch, Y_test_batch, X_test_batch_lens, Y_test_batch_lens) in enumerate(
            self.dp.next_batch(self.dp.X_test, self.dp.Y_test)):
            val_loss = self.model.sess.run(self.model.loss, {self.model.X: X_test_batch,
                                                 self.model.Y: Y_test_batch,
                                                 self.model.X_seq_len: X_test_batch_lens,
                                                 self.model.Y_seq_len: Y_test_batch_lens,
                                                 self.model.output_keep_prob:1,
                                                 self.model.input_keep_prob:1})
            avg_loss += val_loss
        return avg_loss / (local_step + 1)
    
    def fit(self, train_dir, is_bleu):
        self.n_epoch = self.dp.n_epoch
        test_loss_list = []
        train_loss_list = []
        time_cost_list = []
        bleu_list = []
        timestamp = str(int(time.time()))
        #out_dir = os.path.abspath(os.path.join(train_dir, "runs", timestamp))
        out_dir = os.path.join(train_dir, self.summary_path)
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)
        print "Writing to %s" % out_dir
        checkpoint_prefix = os.path.join(out_dir, "model")
        self.summary_writer = tf.summary.FileWriter(out_dir, self.model.sess.graph)
        for epoch in range(1, self.n_epoch+1):
            tic = time.time()
            train_loss = self.train(epoch)
            train_loss_list.append(train_loss)
            test_loss = self.test()
            test_loss_list.append(test_loss)
            toc = time.time()
            time_cost_list.append((toc - tic))
            if is_bleu:
                bleu = self.test_bleu()
                bleu_list.append(bleu)
                print "Epoch %d/%d | Train_loss: %.3f | Test_loss: %.3f | Bleu: %.3f" % (epoch, self.n_epoch, train_loss, test_loss, bleu)
            else:
                bleu = 0.0
                print "Epoch %d/%d | Train_loss: %.3f | Test_loss: %.3f" % (epoch, self.n_epoch, train_loss, test_loss)
            
            stats = [train_loss, test_loss, bleu]
            for i in xrange(len(stats)):
                    self.model.sess.run(self.model.update_ops[i], feed_dict={
                        self.model.summary_placeholders[i]: float(stats[i])
                    })
            summary_str = self.model.sess.run(self.model.summary_op)
            #print 'write summary'
            self.summary_writer.add_summary(summary_str, self.model.step + 1)
            #print 'writing summary'
            cPickle.dump((train_loss_list, test_loss_list, time_cost_list, bleu_list), open(os.path.join(out_dir,"res.pkl"),'wb'))
            if self.model.is_save:    
                path = self.model.saver.save(self.model.sess, checkpoint_prefix, global_step=epoch)
                print "Saved model checkpoint to %s" % path
    
    def show(self, sent, id2w):
        return "".join([id2w.get(idx, u'&') for idx in sent])
    
    def cal(self, n_example=5):
        train_n_example = int(n_example / 2)
        test_n_example = n_example - train_n_example
        for _ in range(test_n_example):
            example = self.show(self.dp.X_test[_], self.dp.X_id2w)
            y = self.show(self.dp.Y_test[_], self.dp.Y_id2w)
            o = self.model.infer(example)[0]
            print 'TestInput: %s | Output: %s | GroundTruth: %s' % (example, o, y)
        for _ in range(train_n_example):
            example = self.show(self.dp.X_train[_], self.dp.X_id2w)
            y = self.show(self.dp.Y_train[_], self.dp.Y_id2w)
            o = self.model.infer(example)[0]
            print 'TrainInput: %s | Output: %s | GroundTruth: %s' % (example, o, y) 
        print ""
        
    def test_bleu(self, gram=2, batch_size=256):
        all_score = []
        n_batch = int(len(self.dp.X_test) / batch_size)
        for i in range(n_batch):
            #tic = time.time()
            input_indices = [self.show(X_test, self.dp.X_id2w) for X_test in self.dp.X_test[i*batch_size:(i+1)*batch_size]]
            o = self.model.batch_infer(input_indices)
            #print '%d batch_infer_time:%.3f' % (i, time.time() -tic)
            #tic = time.time()
            for j in range(batch_size):
                if len(o[j]) < 1:
                    all_score.append(0.0)
                else:
                    if self.model.reverse:
                        refer4bleu = [[' '.join([self.dp.Y_id2w.get(w, u'&') for w in self.dp.Y_test[i*batch_size+j][:-1][::-1]])]]
                    else:
                        refer4bleu = [[' '.join([self.dp.Y_id2w.get(w, u'&') for w in self.dp.Y_test[i*batch_size+j][:-1]])]]
                    candi = [' '.join(w for w in o[j])]
                    score = BLEU(candi, refer4bleu, gram=gram)
                    all_score.append(score)
            #print 'bleu_cost_time:%.3f' % (time.time() -tic)
        return np.mean(all_score)
    
    def show_res(self, path):
        res = cPickle.load(open(path))
        plt.figure(1)
        plt.title('The results') 
        l1, = plt.plot(res[0], 'g')
        l2, = plt.plot(res[1], 'r')
        l3, = plt.plot(res[3], 'b')
        plt.legend(handles = [l1, l2, l3], labels = ["Train_loss","Test_loss","BLEU"], loc = 'best')
        plt.show()
        
    def test_all(self, path, epoch_range, is_bleu=True):
        val_loss_list = []
        bleu_list = []
        for i in range(epoch_range[0], epoch_range[-1]):
            self.model.restore(path + str(i))
            val_loss = self.test()
            val_loss_list.append(val_loss)
            if is_bleu:
                bleu_score = self.test_bleu()
                bleu_list.append(bleu_score)
        plt.figure(1)
        plt.title('The results') 
        l1, = plt.plot(val_loss_list,'r')
        l2, = plt.plot(bleu_list,'b')
        plt.legend(handles = [l1, l2], labels = ["Test_loss","BLEU"], loc = 'best')
        plt.show()
        
    

In [4]:
def config2name(rnn_size, n_layers, cell_type, residual, reverse=False, emb=False, emb_fix=False):
    name = 'rnn_size-%d-n_layers-%d-cell-%s' % (rnn_size, n_layers, cell_type)
    if residual:
        name += '-residual'
    if reverse:
        name += '-reverse'
    if emb_fix:
        name += '-fix_emb'
    elif emb:
        name += '-emb'
    return name

In [5]:
train_dir ='a2c_model/'
X_indices, Y_indices = cPickle.load(open('data/a2c_X_Y_indices_no_unk.pkl','rb'))
X_w2id, Y_w2id, X_id2w, Y_id2w = cPickle.load(open('data/a2c_Xw2id_Yw2id_Xid2w_Yid2w.pkl','rb'))
X_indices = [x[:-1] for x in X_indices]
Y_indices = [y[:-1][::-1]+[y[-1],] for y in Y_indices]
encoder_emb = cPickle.load(open('data/pre_embedding.pkl'))
print encoder_emb.shape

(4451, 1024)


In [6]:
for y in Y_indices[:10]:
    print "".join([Y_id2w[idx] for idx in y])

，了灭覆国六<EOS>
；了一统下天，了灭覆国六<EOS>
。起耸然巍殿宫房阿<EOS>
，伸延折曲西向后然<EOS>
。阳咸到通直一，伸延折曲西向后然<EOS>
。墙宫了进流<EOS>
，楼高座一步五<EOS>
，连勾心中拱木的叠层，抱环差参势地的同不借凭自各<EOS>
。下上争互在像<EOS>
，着旋盘<EOS>


In [7]:
BATCH_SIZE = 128
NUM_EPOCH = 10
for cell_type in ['lstm']:
    for rnn_size in [1024]:
        for n_layers in [1]:
            for residual in [True]:
                dp = Seq2Seq_DP(X_indices, Y_indices, X_w2id, Y_w2id, BATCH_SIZE, n_epoch=NUM_EPOCH)
                g = tf.Graph() 
                sess = tf.Session(graph=g) 
                with sess.as_default():
                    with sess.graph.as_default():
                        model = Seq2Seq(
                            dp = dp,
                            rnn_size = rnn_size,
                            n_layers = n_layers,
                            encoder_embedding_dim = rnn_size,
                            encoder_pre_embedding = encoder_emb,
                            decoder_embedding_dim = rnn_size,
                            cell_type = cell_type,
                            max_infer_length=30,
                            residual = residual,
                            reverse = True,
                            emb_fix = True,
                            is_save = True,
                            sess= sess
                        )



Train_data: 323805 | Test_data: 35978 | Batch_size: 128 | Num_batch: 2529 | X_vocab_size: 4451 | Y_vocab_size: 4408


In [None]:
util = Seq2Seq_util(dp=dp, model=model, summary_path=config2name(rnn_size, n_layers, cell_type, residual, reverse=True, emb=True, emb_fix=True))
#util.fit(train_dir=train_dir, is_bleu=True)

In [None]:
BLEU_list1 = []
for i in range(6, 11):
    model.restore('a2c_model/rnn_size-1024-n_layers-1-cell-lstm-residual-reverse-emb/model-%d' % i)
    s = util.test_bleu(batch_size=256)
    BLEU_list1.append(s)

INFO:tensorflow:Restoring parameters from a2c_model/rnn_size-1024-n_layers-1-cell-lstm-residual-reverse-emb/model-6
restore a2c_model/rnn_size-1024-n_layers-1-cell-lstm-residual-reverse-emb/model-6 success


In [None]:
BLEU_list2 = []
for i in range(6, 11):
    model.restore('a2c_model/rnn_size-1024-n_layers-1-cell-lstm-residual-reverse-fix_emb/model-%d' % i)
    s = util.test_bleu(batch_size=256)
    BLEU_list2.append(s)

In [None]:
print BLEU_list1
print BLEU_list2
plt.figure()
plt.plot(BLEU_list1, 'r--')
plt.plot(BLEU_list2, 'b--')
plt.show()

In [None]:
model.restore('a2c_model/rnn_size-1024-n_layers-1-cell-lstm-residual-reverse-10')
s = util.test_bleu(batch_size=256)
print s

# 测试可视化

In [None]:
model.restore('rnn_size-1024-n_layers-1-cell-lstm-residual-reverse-10')
util.cal()

In [None]:
Y = model.infer(u'不为则死！')[0]
print Y
summary_writer = tf.summary.FileWriter('visual_test/', model.sess.graph)
model.add_visual_summary()
att = model.write_attention_summary(summary_writer, u'不为一则死！', Y)

In [None]:
for _ in range(10):
    model.write_attention_summary(summary_writer, u'不为则死！', Y)

In [None]:
for i,x in enumerate(X_indices[-10:]):
    X = "".join([dp.X_id2w[idx] for idx in x])
    Y = model.infer(X)[0]
    model.show_attention(X, Y)