In [14]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from tensorflow.python.layers import core as core_layers
import tensorflow as tf
import numpy as np
import time
import myResidualCell
from tensorflow.python.ops import array_ops
import jieba
from bleu import BLEU
import random
import cPickle
import matplotlib.pyplot as plt


class CVRAE:
    def __init__(self, dp, rnn_size, n_layers, latent_dim, encoder_embedding_dim, decoder_embedding_dim, max_infer_length,
                 sess=tf.Session(), lr=0.001, grad_clip=5.0, beam_width=5, force_teaching_ratio=1.0, beam_penalty=1.0,
                residual=False, output_keep_prob=0.5, input_keep_prob=0.9, cell_type='lstm', reverse=False,
                 condition_embedding_dim=4, num_c=3,
                latent_weight=0.1, beta_decay_period=10, beta_decay_offset=5, decay_scheme='luong234', is_save=True):
        
        self.rnn_size = rnn_size
        self.latent_dim = latent_dim
        self.n_layers = n_layers
        self.grad_clip = grad_clip
        self.dp = dp
        self.step = 0
        self.encoder_embedding_dim = encoder_embedding_dim
        self.decoder_embedding_dim = decoder_embedding_dim
        self.beam_width = beam_width
        self.num_c = num_c
        self.condition_embedding_dim = condition_embedding_dim
        self.latent_weight = latent_weight
        self.beam_penalty = beam_penalty
        self.max_infer_length = max_infer_length
        self.residual = residual
        self.is_save = is_save
        self.decay_scheme = decay_scheme
        if self.residual:
            assert encoder_embedding_dim == rnn_size
            assert decoder_embedding_dim == rnn_size
        self.reverse = reverse
        self.cell_type = cell_type
        self.force_teaching_ratio = force_teaching_ratio
        self._output_keep_prob = output_keep_prob
        self._input_keep_prob = input_keep_prob
        self.beta_decay_period = beta_decay_period
        self.beta_decay_offset = beta_decay_offset
        self.sess = sess
        self.lr=lr
        self.build_graph()
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(tf.global_variables(), max_to_keep = 10)
        self.summary_placeholders, self.update_ops, self.summary_op = self.setup_summary()
        
    # end constructor

    def build_graph(self):
        self.register_symbols()
        self.add_input_layer()
        self.add_encoder_layer()
        self.add_stochastic_layer()
        self.add_decoder_hidden()
        with tf.variable_scope('decode'):
            self.add_decoder_for_training()
        with tf.variable_scope('decode', reuse=True):
            self.add_decoder_for_inference()
        with tf.variable_scope('decode', reuse=True):
            self.add_decoder_for_prefix_inference()
        self.add_backward_path()
    # end method

    def add_input_layer(self):
        self.X = tf.placeholder(tf.int32, [None, None], name="X")
        self.Y = tf.placeholder(tf.int32, [None, None], name="Y")
        self.C = tf.placeholder(tf.int32, [None], name='Condition')
        self.X_seq_len = tf.placeholder(tf.int32, [None], name="X_seq_len")
        self.Y_seq_len = tf.placeholder(tf.int32, [None], name="Y_seq_len")
        self.input_keep_prob = tf.placeholder(tf.float32,name="input_keep_prob")
        self.output_keep_prob = tf.placeholder(tf.float32,name="output_keep_prob")
        self.batch_size = tf.shape(self.X)[0]
        self.B = tf.placeholder(tf.float32, name='Beta_deterministic_warmup')
        self.global_step = tf.Variable(0, name="global_step", trainable=False)
    # end method

    def single_cell(self, reuse=False):
        if self.cell_type == 'lstm':
             cell = tf.contrib.rnn.LayerNormBasicLSTMCell(self.rnn_size, reuse=reuse)
        else:
            cell = tf.contrib.rnn.GRUBlockCell(self.rnn_size)    
        cell = tf.contrib.rnn.DropoutWrapper(cell, self.output_keep_prob, self.input_keep_prob)
        if self.residual:
            cell = myResidualCell.ResidualWrapper(cell)
        return cell
    
    def add_encoder_layer(self):
        encoder_embedding = tf.get_variable('encoder_embedding', [len(self.dp.X_w2id), self.encoder_embedding_dim],
                                             tf.float32, tf.random_uniform_initializer(-1.0, 1.0))
        self.c_embedding = tf.get_variable('condition_embedding', [self.num_c, self.condition_embedding_dim],
                                             tf.float32, tf.random_uniform_initializer(-1.0, 1.0))
        self.c_inputs = tf.nn.embedding_lookup(self.c_embedding, self.C)
        self.encoder_inputs = tf.nn.embedding_lookup(encoder_embedding, self.X)
        bi_encoder_output, bi_encoder_state = tf.nn.bidirectional_dynamic_rnn(
            cell_fw = tf.contrib.rnn.MultiRNNCell([self.single_cell() for _ in range(self.n_layers)]), 
            cell_bw = tf.contrib.rnn.MultiRNNCell([self.single_cell() for _ in range(self.n_layers)]),
            inputs = self.encoder_inputs,
            sequence_length = self.X_seq_len,
            dtype = tf.float32,
            scope = 'bidirectional_rnn')
        #print bi_encoder_state
        if self.cell_type == 'lstm':
            self.encoder_out = tf.concat([bi_encoder_state[0][-1][1],bi_encoder_state[1][-1][1]], -1)
        else:
            self.encoder_out = tf.concat([bi_encoder_state[0][-1],bi_encoder_state[1][-1]], -1)
        
    def add_stochastic_layer(self):
        # reparametrization trick
        self.z_mu = tf.layers.dense(self.encoder_out, self.latent_dim)
        self.z_lgs2 = tf.layers.dense(self.encoder_out, self.latent_dim)
        noise = tf.random_normal(tf.shape(self.z_lgs2))
        self._z = self.z_mu + tf.exp(0.5 * self.z_lgs2) * noise
        self.z = tf.concat([self._z, self.c_inputs], -1)
        
    def add_decoder_hidden(self):
        hidden_state_list = []
        for i in range(self.n_layers * 2):
            if self.cell_type == 'gru':
                hidden_state_list.append(tf.layers.dense(self.z, self.rnn_size))
            else:
                hidden_state_list.append(tf.contrib.rnn.LSTMStateTuple(tf.layers.dense(self.z, self.rnn_size), tf.layers.dense(self.z, self.rnn_size))) 
        self.decoder_init_state = tuple(hidden_state_list)
        
    def processed_decoder_input(self):
        main = tf.strided_slice(self.Y, [0, 0], [self.batch_size, -1], [1, 1]) # remove last char
        decoder_input = tf.concat([tf.fill([self.batch_size, 1], self._y_go), main], 1)
        return decoder_input

    def add_decoder_for_training(self):
        self.decoder_cell = tf.contrib.rnn.MultiRNNCell([self.single_cell() for _ in range(2 * self.n_layers)])
        decoder_embedding = tf.get_variable('decoder_embedding', [len(self.dp.Y_w2id), self.decoder_embedding_dim],
                                             tf.float32, tf.random_uniform_initializer(-1.0, 1.0))
        emb = tf.nn.embedding_lookup(decoder_embedding, self.processed_decoder_input())
        inputs = tf.expand_dims(self.z, 1)
        inputs = tf.tile(inputs, [1, tf.shape(emb)[1], 1])
        inputs = tf.concat([emb, inputs],2) 
        training_helper = tf.contrib.seq2seq.TrainingHelper(
            inputs = inputs,
            sequence_length = self.Y_seq_len,
            time_major = False)
        training_decoder = tf.contrib.seq2seq.BasicDecoder(
            cell = self.decoder_cell,
            helper = training_helper,
            initial_state = self.decoder_init_state, #self.decoder_cell.zero_state(self.batch_size, tf.float32),
            output_layer = core_layers.Dense(len(self.dp.Y_w2id)))
        training_decoder_output, training_final_state, _ = tf.contrib.seq2seq.dynamic_decode(
            decoder = training_decoder,
            impute_finished = True,
            maximum_iterations = tf.reduce_max(self.Y_seq_len))
        self.training_logits = training_decoder_output.rnn_output
        self.init_prefix_state = training_final_state

    def add_decoder_for_inference(self):   
        decoder_embedding = tf.get_variable('decoder_embedding')
        self.beam_f = (lambda ids: tf.concat([tf.nn.embedding_lookup(decoder_embedding, ids), 
                                    tf.tile(tf.expand_dims(self.z, 1), 
                                            [1,int(tf.nn.embedding_lookup(decoder_embedding, ids).get_shape()[1]), 1]) if len(ids.get_shape()) !=1 
                                             else self.z], -1))

        predicting_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
            cell = self.decoder_cell,
            embedding = self.beam_f, 
            start_tokens = tf.tile(tf.constant([self._y_go], dtype=tf.int32), [self.batch_size]),
            end_token = self._y_eos,
            initial_state = tf.contrib.seq2seq.tile_batch(self.decoder_init_state, self.beam_width),#self.decoder_cell.zero_state(self.batch_size * self.beam_width, tf.float32),
            beam_width = self.beam_width,
            output_layer = core_layers.Dense(len(self.dp.Y_w2id), _reuse=True),
            length_penalty_weight = self.beam_penalty)
        predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
            decoder = predicting_decoder,
            impute_finished = False,
            maximum_iterations = self.max_infer_length)
        self.predicting_ids = predicting_decoder_output.predicted_ids
        self.score = predicting_decoder_output.beam_search_decoder_output.scores
        
    def add_decoder_for_prefix_inference(self):   
        predicting_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
            cell = self.decoder_cell,
            embedding = self.beam_f,
            start_tokens = tf.tile(tf.constant([self._y_go], dtype=tf.int32), [self.batch_size]),
            end_token = self._y_eos,
            initial_state = tf.contrib.seq2seq.tile_batch(self.init_prefix_state, self.beam_width),
            beam_width = self.beam_width,
            output_layer = core_layers.Dense(len(self.dp.Y_w2id), _reuse=True),
            length_penalty_weight = self.beam_penalty)
        
        self.prefix_go = tf.placeholder(tf.int32, [None])
        prefix_go_beam = tf.tile(tf.expand_dims(self.prefix_go, 1), [1, self.beam_width])
        prefix_emb = self.beam_f(prefix_go_beam)
        predicting_decoder._start_inputs = prefix_emb
        predicting_prefix_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
            decoder = predicting_decoder,
            impute_finished = False,
            maximum_iterations = self.max_infer_length)
        self.predicting_prefix_ids = predicting_prefix_decoder_output.predicted_ids
        self.prefix_score = predicting_prefix_decoder_output.beam_search_decoder_output.scores

    def add_backward_path(self):
        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)
        self.reconstruct_loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,
                                                     targets = self.Y,
                                                     weights = masks)
        self.batch_reconstruct_loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,
                                                     targets = self.Y,
                                                     weights = masks,
                                                     average_across_batch=False)
        self.kl_loss = tf.reduce_mean(-0.5 * tf.reduce_sum(1 + self.z_lgs2 - tf.square(self.z_mu) - tf.exp(self.z_lgs2), 1))
        self.loss = self.reconstruct_loss + self.B * self.latent_weight * self.kl_loss
        
        params = tf.trainable_variables()
        gradients = tf.gradients(self.loss, params)
        clipped_gradients, _ = tf.clip_by_global_norm(gradients, self.grad_clip)
        self.learning_rate = tf.constant(self.lr)
        self.learning_rate = self.get_learning_rate_decay(self.decay_scheme)  # decay
        self.train_op = tf.train.AdamOptimizer(self.learning_rate).apply_gradients(zip(clipped_gradients, params), global_step=self.global_step)

    def register_symbols(self):
        self._x_go = self.dp.X_w2id['<GO>']
        self._x_eos = self.dp.X_w2id['<EOS>']
        self._x_pad = self.dp.X_w2id['<PAD>']
        self._x_unk = self.dp.X_w2id['<UNK>']
        
        self._y_go = self.dp.Y_w2id['<GO>']
        self._y_eos = self.dp.Y_w2id['<EOS>']
        self._y_pad = self.dp.Y_w2id['<PAD>']
        self._y_unk = self.dp.Y_w2id['<UNK>']
    
    def infer(self, input_word, c):
        input_word = list(jieba.cut(input_word))
        if self.reverse:
            input_word = input_word[::-1]
        input_indices = [self.dp.X_w2id.get(char, self._x_unk) for char in input_word]
        out_indices = self.sess.run(self.predicting_ids, {
            self.X: [input_indices], self.C:[c], self.X_seq_len: [len(input_indices)], self.output_keep_prob:1, self.input_keep_prob:1})
        outputs = []
        for idx in range(out_indices.shape[-1]):
            eos_id = self.dp.Y_w2id['<EOS>']
            ot = out_indices[0,:,idx]
            if eos_id in ot:
                ot = ot.tolist()
                ot = ot[:ot.index(eos_id)]
            if self.reverse:
                ot = ot[::-1]
            output_str = ''.join([self.dp.Y_id2w.get(i, u'&') for i in ot])
            outputs.append(output_str)
        return outputs
    
    def prefix_infer(self, input_word, prefix, c):
        input_word = list(jieba.cut(input_word))
        prefix = list(jieba.cut(prefix))
        input_indices_X = [self.dp.X_w2id.get(char, self._x_unk) for char in input_word]
        input_indices_Y = [self.dp.Y_w2id.get(char, self._y_unk) for char in prefix]
        
        prefix_go = []
        prefix_go.append(input_indices_Y[-1]) 
        out_indices, scores = self.sess.run([self.predicting_prefix_ids, self.prefix_score], {
            self.X: [input_indices_X], self.X_seq_len: [len(input_indices_X)], self.C:[c], self.Y:[input_indices_Y], self.Y_seq_len:[len(input_indices_Y)],
            self.prefix_go: prefix_go, self.input_keep_prob:1, self.output_keep_prob:1})
        
        outputs = []
        for idx in range(out_indices.shape[-1]):
            eos_id = self.dp.Y_w2id['<EOS>']
            ot = out_indices[0,:,idx]
            if eos_id in ot:
                ot = ot.tolist()
                ot = ot[:ot.index(eos_id)]
                if self.reverse:
                    ot = ot[::-1]
            if self.reverse:
                output_str = ''.join([self.dp.Y_id2w.get(i, u'&') for i in ot]) + prefix
            else:
                output_str = prefix + ''.join([self.dp.Y_id2w.get(i, u'&') for i in ot])
            outputs.append(output_str)
        return outputs
    
    def xToz(self, input_word):
        input_word = list(jieba.cut(input_word))
        if self.reverse:
            input_word = input_word[::-1]
        input_indices = [self.dp.X_w2id.get(char, self._x_unk) for char in input_word]
        z = self.sess.run(self._z, {self.X: [input_indices], self.X_seq_len: [len(input_indices)], self.output_keep_prob:1, self.input_keep_prob:1})
        return z
    # end method
    
    def zTox(self, z, c):
        out_indices = self.sess.run(self.predicting_ids, {self.batch_size:z.shape[0],
            self._z:z, self.C:[c], self.output_keep_prob:1, self.input_keep_prob:1})
        outputs = []
        for idx in range(out_indices.shape[-1]):
            eos_id = self.dp.Y_w2id['<EOS>']
            ot = out_indices[0,:,idx]
            if eos_id in ot:
                ot = ot.tolist()
                ot = ot[:ot.index(eos_id)]
            if self.reverse:
                ot = ot[::-1]
            output_str = ''.join([self.dp.Y_id2w.get(i, u'&') for i in ot])
            outputs.append(output_str)
        return outputs
    
    def generate(self, c, batch_size = 3):
        c_batch = [c for _ in range(batch_size)]
        out_indices = self.sess.run(self.predicting_ids, { self.batch_size:batch_size,
            self._z:np.random.randn(batch_size, self.latent_dim), self.C: c_batch, self.output_keep_prob:1, self.input_keep_prob:1})
        outputs = []
        for idx in range(out_indices.shape[0]):
            eos_id = self.dp.Y_w2id['<EOS>']
            ot = out_indices[idx,:,0]   # The 0th beam of each batch 
            if eos_id in ot:
                ot = ot.tolist()
                ot = ot[:ot.index(eos_id)]
            if self.reverse:
                ot = ot[::-1]
            output_str = ''.join([self.dp.Y_id2w.get(i, u'&') for i in ot])
            outputs.append(output_str)
        return outputs
    
    def restore(self, path):
        self.saver.restore(self.sess, path)
        print 'restore %s success' % path
        
    def get_learning_rate_decay(self, decay_scheme='luong234'):
        num_train_steps = self.dp.num_steps
        if decay_scheme == "luong10":
            start_decay_step = int(num_train_steps / 2)
            remain_steps = num_train_steps - start_decay_step
            decay_steps = int(remain_steps / 10)  # decay 10 times
            decay_factor = 0.5
        else:
            start_decay_step = int(num_train_steps * 2 / 3)
            remain_steps = num_train_steps - start_decay_step
            decay_steps = int(remain_steps / 4)  # decay 4 times
            decay_factor = 0.5
        return tf.cond(
            self.global_step < start_decay_step,
            lambda: self.learning_rate,
            lambda: tf.train.exponential_decay(
                self.learning_rate,
                (self.global_step - start_decay_step),
                decay_steps, decay_factor, staircase=True),
            name="learning_rate_decay_cond")
    
    def setup_summary(self):
        train_loss = tf.Variable(0.)
        tf.summary.scalar('Train_loss', train_loss) 
        train_KL_loss = tf.Variable(0.)
        tf.summary.scalar('Train_KL_loss', train_KL_loss)
        train_r_loss = tf.Variable(0.)
        tf.summary.scalar('Train_R_loss', train_r_loss)
        test_loss = tf.Variable(0.)
        tf.summary.scalar('Test_loss', test_loss) 
        test_KL_loss = tf.Variable(0.)
        tf.summary.scalar('Test_KL_loss', test_KL_loss)
        test_r_loss = tf.Variable(0.)
        tf.summary.scalar('Test_R_loss', test_r_loss)
        beta = tf.Variable(0.)
        tf.summary.scalar('Beta', beta)
        tf.summary.scalar('lr_rate', self.learning_rate)
        tf.summary.histogram("z_mu", self.z_mu)
        tf.summary.histogram("z_ls2", self.z_lgs2)
        tf.summary.histogram("z", self.z)
        
        summary_vars = [train_loss, train_KL_loss, train_r_loss, test_loss, test_KL_loss, test_r_loss, beta]
        summary_placeholders = [tf.placeholder(tf.float32) for _ in xrange(len(summary_vars))]
        update_ops = [summary_vars[i].assign(summary_placeholders[i]) for i in xrange(len(summary_vars))]
        summary_op = tf.summary.merge_all()
        return summary_placeholders, update_ops, summary_op

In [15]:
class CVRAE_DP:
    def __init__(self, X_indices, Y_indices, C, X_w2id, Y_w2id, BATCH_SIZE, n_epoch):
        assert len(X_indices) == len(Y_indices)
        num_test = int(len(X_indices) * 0.1)
        self.n_epoch = n_epoch
        self.X_train = np.array(X_indices[num_test:])
        self.Y_train = np.array(Y_indices[num_test:])
        self.C_train = np.array(C[num_test:])
        self.X_test = np.array(X_indices[:num_test])
        self.Y_test = np.array(Y_indices[:num_test])
        self.C_test = np.array(C[num_test:])
        self.num_batch = int(len(self.X_train) / BATCH_SIZE)
        self.num_steps = self.num_batch * self.n_epoch
        self.batch_size = BATCH_SIZE
        self.X_w2id = X_w2id
        self.X_id2w = dict(zip(X_w2id.values(), X_w2id.keys()))
        self.Y_w2id = Y_w2id
        self.Y_id2w = dict(zip(Y_w2id.values(), Y_w2id.keys()))
        self._x_pad = self.X_w2id['<PAD>']
        self._y_pad = self.Y_w2id['<PAD>']
        print 'Train_data: %d | Test_data: %d | Batch_size: %d | Num_batch: %d | X_vocab_size: %d | Y_vocab_size: %d' % (len(self.X_train), len(self.X_test), BATCH_SIZE, self.num_batch, len(self.X_w2id), len(self.Y_w2id))
        
    def next_batch(self, X, Y, C):
        r = np.random.permutation(len(X))
        X = X[r]
        Y = Y[r]
        C = C[r]
        for i in range(0, len(X) - len(X) % self.batch_size, self.batch_size):
            X_batch = X[i : i + self.batch_size]
            Y_batch = Y[i : i + self.batch_size]
            c_batch = C[i : i + self.batch_size]
            padded_X_batch, X_batch_lens = self.pad_sentence_batch(X_batch, self._x_pad)
            padded_Y_batch, Y_batch_lens = self.pad_sentence_batch(Y_batch, self._y_pad)
            yield (np.array(padded_X_batch),
                   np.array(padded_Y_batch),
                   c_batch,
                   X_batch_lens,
                   Y_batch_lens)
    
    def sample_test_batch(self):
        padded_X_batch, X_batch_lens = self.pad_sentence_batch(self.X_test[: self.batch_size], self._x_pad)
        padded_Y_batch, Y_batch_lens = self.pad_sentence_batch(self.Y_test[: self.batch_size], self._y_pad)
        c_batch = self.C_test[: self.batch_size]
        return np.array(padded_X_batch), np.array(padded_Y_batch), c_batch, X_batch_lens, Y_batch_lens
        
    def pad_sentence_batch(self, sentence_batch, pad_int):
        padded_seqs = []
        seq_lens = []
        max_sentence_len = max([len(sentence) for sentence in sentence_batch])
        for sentence in sentence_batch:
            padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))
            seq_lens.append(len(sentence))
        return padded_seqs, seq_lens
    
   


In [16]:
import scipy.interpolate as si
from scipy import interpolate


def BetaGenerator(epoches, beta_decay_period, beta_decay_offset):
    points = [[0,0], [0, beta_decay_offset],[0, beta_decay_offset + 0.33 * beta_decay_period], [1, beta_decay_offset + 0.66*beta_decay_period],[1, beta_decay_offset + beta_decay_period], [1, epoches] ];
    points = np.array(points)
    x = points[:,0]
    y = points[:,1]
    t = range(len(points))
    ipl_t = np.linspace(0.0, len(points) - 1, 100)
    x_tup = si.splrep(t, x, k=3)
    y_tup = si.splrep(t, y, k=3)
    x_list = list(x_tup)
    xl = x.tolist()
    x_list[1] = xl + [0.0, 0.0, 0.0, 0.0]
    y_list = list(y_tup)
    yl = y.tolist()
    y_list[1] = yl + [0.0, 0.0, 0.0, 0.0]
    x_i = si.splev(ipl_t, x_list)
    y_i = si.splev(ipl_t, y_list)
    return interpolate.interp1d(y_i, x_i)

class CVRAE_util:
    def __init__(self, dp, model, display_freq=3):
        self.display_freq = display_freq
        self.dp = dp
        self.model = model
        self.summary_cnt = 0
        self.betaG = BetaGenerator(self.dp.n_epoch*self.dp.num_batch, self.model.beta_decay_period*self.dp.num_batch, self.model.beta_decay_offset*self.dp.num_batch)
        
    def train(self, epoch):
        avg_loss = 0.0
        avg_r_loss = 0.0
        avg_kl_loss = 0.0
        tic = time.time()
        X_test_batch, Y_test_batch, C_test_batch, X_test_batch_lens, Y_test_batch_lens = self.dp.sample_test_batch()
        for local_step, (X_train_batch, Y_train_batch, C_train_batch, X_train_batch_lens, Y_train_batch_lens) in enumerate(
            self.dp.next_batch(self.dp.X_train, self.dp.Y_train, self.dp.C_train)):
            beta = 0.001 + self.betaG(self.model.step) # add small value to avoid points to scatter
            self.model.step, _, loss, r_loss, kl_loss = self.model.sess.run([self.model.global_step, self.model.train_op, 
                                                            self.model.loss, self.model.reconstruct_loss, self.model.kl_loss], 
                                          {self.model.X: X_train_batch,
                                           self.model.Y: Y_train_batch,
                                           self.model.C: C_train_batch,
                                           self.model.X_seq_len: X_train_batch_lens,
                                           self.model.Y_seq_len: Y_train_batch_lens,
                                           self.model.output_keep_prob:self.model._output_keep_prob,
                                           self.model.input_keep_prob:self.model._input_keep_prob,
                                          self.model.B:beta})
            avg_loss += loss
            avg_r_loss += r_loss
            avg_kl_loss += kl_loss
            # summary
            if local_step % 10 == 0:
                self.summary_cnt += 1
                val_loss, val_r_loss, val_kl_loss = self.model.sess.run([self.model.loss, self.model.reconstruct_loss, self.model.kl_loss], 
                                               {self.model.X: X_test_batch,
                                                     self.model.Y: Y_test_batch,
                                                     self.model.C: C_test_batch,
                                                     self.model.X_seq_len: X_test_batch_lens,
                                                     self.model.Y_seq_len: Y_test_batch_lens,
                                                     self.model.output_keep_prob:1,
                                                     self.model.input_keep_prob:1,
                                                     self.model.B:beta})
                stats = [avg_loss/(local_step+1), avg_kl_loss/(local_step+1), avg_r_loss/(local_step+1),
                         val_loss, val_kl_loss, val_r_loss, beta]
                for i in xrange(len(stats)):
                    self.model.sess.run(self.model.update_ops[i], feed_dict={
                        self.model.summary_placeholders[i]: float(stats[i])
                    })
                summary_str = self.model.sess.run(self.model.summary_op, {
                    self.model.X: X_test_batch, 
                    self.model.X_seq_len: X_test_batch_lens,
                    self.model.C: C_test_batch,
                    self.model.output_keep_prob:1,
                    self.model.input_keep_prob:1})
                self.summary_writer.add_summary(summary_str, self.summary_cnt)
                
            if local_step % (self.dp.num_batch / self.display_freq) == 0:
                val_loss, val_r_loss, val_kl_loss = self.model.sess.run([self.model.loss, self.model.reconstruct_loss, self.model.kl_loss], 
                                               {self.model.X: X_test_batch,
                                                     self.model.Y: Y_test_batch,
                                                     self.model.C: C_test_batch,
                                                     self.model.X_seq_len: X_test_batch_lens,
                                                     self.model.Y_seq_len: Y_test_batch_lens,
                                                     self.model.output_keep_prob:1,
                                                     self.model.input_keep_prob:1,
                                                     self.model.B:beta})
                print "Epoch %d/%d | Batch %d/%d | Train_loss: %.3f = %.3f + %.3f | Test_loss: %.3f = %.3f + %.3f | Time_cost:%.3f" % (epoch, self.n_epoch, local_step, self.dp.num_batch, 
                                                                                                                                       avg_loss / (local_step + 1),
                                                                                                                                       avg_r_loss / (local_step + 1),
                                                                                                                                       avg_kl_loss / (local_step + 1),
                                                                                                                                       val_loss, val_r_loss, val_kl_loss, time.time()-tic)
                self.cal()
                tic = time.time()
        return avg_loss / self.dp.num_batch, avg_r_loss / self.dp.num_batch, avg_kl_loss / self.dp.num_batch
    
    def test(self):
        avg_loss = 0.0
        avg_r_loss = 0.0
        avg_kl_loss = 0.0
        beta = 0.001 + self.betaG(self.model.step) # add small value to avoid points to scatter
        for local_step, (X_test_batch, Y_test_batch, C_test_batch, X_test_batch_lens, Y_test_batch_lens) in enumerate(
            self.dp.next_batch(self.dp.X_test, self.dp.Y_test, self.dp.C_test)):
            val_loss, val_r_loss, val_kl_loss = self.model.sess.run([self.model.loss, self.model.reconstruct_loss, self.model.kl_loss], 
                                                                   {self.model.X: X_test_batch,
                                                                         self.model.Y: Y_test_batch,
                                                                         self.model.C: C_test_batch,
                                                                         self.model.X_seq_len: X_test_batch_lens,
                                                                         self.model.Y_seq_len: Y_test_batch_lens,
                                                                         self.model.output_keep_prob:1,
                                                                         self.model.input_keep_prob:1,
                                                                         self.model.B:beta})
            avg_loss += val_loss
            avg_r_loss += val_r_loss
            avg_kl_loss += val_kl_loss
        return avg_loss / (local_step + 1), avg_r_loss / (local_step + 1), avg_kl_loss / (local_step + 1)
    
    def fit(self, train_dir, is_bleu):
        self.n_epoch = self.dp.n_epoch
        test_loss_list = []
        train_loss_list = []
        test_r_loss_list = []
        train_r_loss_list = []
        test_kl_loss_list = []
        train_kl_loss_list = []
        time_cost_list = []
        bleu_list = []
        timestamp = str(int(time.time()))
        out_dir = os.path.abspath(os.path.join(train_dir, "runs", timestamp))
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)
        print "Writing to %s" % out_dir
        checkpoint_prefix = os.path.join(out_dir, "model")
        self.summary_writer = tf.summary.FileWriter(os.path.join(out_dir, 'Summary'), self.model.sess.graph)
        for epoch in range(1, self.n_epoch+1):
            tic = time.time()
            train_loss, train_r_loss, train_kl_loss = self.train(epoch)
            train_loss_list.append(train_loss)
            train_r_loss_list.append(train_r_loss)
            train_kl_loss_list.append(train_kl_loss)
            
            test_loss, test_r_loss, test_kl_loss = self.test()
            test_loss_list.append(test_loss)
            test_r_loss_list.append(test_r_loss)
            test_kl_loss_list.append(test_kl_loss)
            toc = time.time()
            time_cost_list.append((toc - tic))
            if is_bleu:
                bleu = self.test_bleu()
                bleu_list.append(bleu)
                print "Epoch %d/%d | Train_loss: %.3f = %.3f + %.3f | Test_loss: %.3f = %.3f + %.3f | Bleu: %.3f" % (epoch, self.n_epoch, train_loss, train_r_loss, train_kl_loss, test_loss, test_r_loss, test_kl_loss, bleu)
            else:
                bleu = 0.0
                print "Epoch %d/%d | Train_loss: %.3f = %.3f + %.3f | Test_loss: %.3f = %.3f + %.3f" % (epoch, self.n_epoch, train_loss, train_r_loss, train_kl_loss, test_loss, test_r_loss, test_kl_loss)
            if self.model.is_save:
                cPickle.dump((train_loss_list, train_r_loss_list, train_kl_loss_list, test_loss_list, test_r_loss_list, test_kl_loss_list, time_cost_list, bleu_list), open(os.path.join(out_dir,"res.pkl"),'wb'))
                path = self.model.saver.save(self.model.sess, checkpoint_prefix, global_step=epoch)
                print "Saved model checkpoint to %s" % path
    
    def show(self, sent, id2w):
        return "".join([id2w.get(idx, u'&') for idx in sent])
    
    def cal(self, n_example=5):
        train_n_example = int(n_example / 2)
        test_n_example = n_example - train_n_example
        for _ in range(test_n_example):
            example = self.show(self.dp.X_test[_], self.dp.X_id2w)
            y = self.show(self.dp.Y_test[_], self.dp.Y_id2w)
            o = self.model.infer(example, c=0)[0]
            print 'Test_Input: %s | Output: %s | GroundTruth: %s' % (example, o, y)
        for _ in range(train_n_example):
            example = self.show(self.dp.X_train[_], self.dp.X_id2w)
            y = self.show(self.dp.Y_train[_], self.dp.Y_id2w)
            o = self.model.infer(example, c=0)[0]
            print 'Train_Input: %s | Output: %s | GroundTruth: %s' % (example, o, y) 
        o = self.model.generate(c=0)
        print 'generate top:'
        for oo in o:
            print '【',oo,'】'
        print ""
        o = self.model.generate(c=1)
        print 'generate mid:'
        for oo in o:
            print '【',oo,'】'
        print ""
        o = self.model.generate(c=2)
        print 'generate down:'
        for oo in o:
            print '【',oo,'】'
        print ""
        
    def test_bleu(self, N=300, gram=4):
        all_score = []
        for i in range(N):
            input_indices = self.show(self.dp.X_test[i], self.dp.X_id2w)
            o = self.model.infer(input_indices)[0]
            refer4bleu = [[' '.join([self.dp.Y_id2w.get(w, u'&') for w in self.dp.Y_test[i]])]]
            candi = [' '.join(w for w in o)]
            score = BLEU(candi, refer4bleu, gram=gram)
            all_score.append(score)
        return np.mean(all_score)
    
    def show_res(self, path):
        res = cPickle.load(open(path))
        plt.figure(1)
        plt.title('The train results') 
        l1, = plt.plot(res[0], 'g')
        l2, = plt.plot(res[1], 'r')
        l3, = plt.plot(res[2], 'b')
        plt.legend(handles = [l1, l2, l3], labels = ["Train_loss","Train_r_loss","Train_kl_loss"], loc = 'best')
        plt.show()
        
        plt.figure(1)
        plt.title('The test results') 
        l4, = plt.plot(res[3], 'g')
        l5, = plt.plot(res[4], 'r')
        l6, = plt.plot(res[5], 'r')
        l7, = plt.plot(res[-1], 'b')
        plt.legend(handles = [l4, l5, l6, l7], labels = ["Test_loss","Test_r_loss","Test_kl_loss","BLEU"], loc = 'best')
        plt.show()
        
    def test_all(self, path, epoch_range, is_bleu=True):
        val_loss_list = []
        bleu_list = []
        for i in range(epoch_range[0], epoch_range[-1]):
            self.model.restore(path + str(i))
            val_loss = self.test()
            val_loss_list.append(val_loss)
            if is_bleu:
                bleu_score = self.test_bleu()
                bleu_list.append(bleu_score)
        plt.figure(1)
        plt.title('The results') 
        l1, = plt.plot(val_loss_list,'r')
        l2, = plt.plot(bleu_list,'b')
        plt.legend(handles = [l1, l2], labels = ["Test_loss","BLEU"], loc = 'best')
        plt.show()
        
    

In [4]:
import cPickle
top, mid, down = cPickle.load(open('JD/JD_top_mid_down_indices.pkl'))
#data = cPickle.load(open('JD/JD_small_indices.pkl'))

w2id, id2w = cPickle.load(open('JD/JD_w2id_id2w.pkl'))
train_dir = 'char_vae_model/'
print len(w2id)
#print len(top), len(mid), len(down)

data_Y = top + mid + down
data_X = [data[:-1] for data in data_Y]
data_C = []
for t in top:
    data_C.append(0)
for t in mid:
    data_C.append(1)
for t in down:
    data_C.append(2)

9191


In [17]:
BATCH_SIZE = 256
NUM_EPOCH = 15

dp = CVRAE_DP(data_X, data_Y, data_C, w2id, w2id, BATCH_SIZE, n_epoch=NUM_EPOCH)
g = tf.Graph() 
sess = tf.Session(graph=g) 
with sess.as_default():
    with sess.graph.as_default():
        model = CVRAE(
            dp = dp,
            rnn_size = 1024,
            latent_dim = 16,
            n_layers = 1,
            encoder_embedding_dim = 512,
            decoder_embedding_dim = 512,
            cell_type='gru',
            max_infer_length=31,
            residual=False,
            is_save=True,
            beam_width=5,
            sess= sess
        )
        
util = CVRAE_util(dp=dp, model=model)
#util.fit(train_dir=train_dir, is_bleu=False)


Train_data: 526028 | Test_data: 58447 | Batch_size: 256 | Num_batch: 2054 | X_vocab_size: 9191 | Y_vocab_size: 9191


# 测试

In [6]:
import random

candi = []
candi += random.sample(top, 10)
candi += random.sample(down, 10)
candi += random.sample(mid, 10)

candi = ["".join([dp.X_id2w[idx] for idx in c[:-1]]) for c in candi]
for c in candi:
    print c

货收到了，我很满意。手镯很好。下次还会来的。
很不错的一款项链，非常满意。
东西很好 就是有点大啊
收到宝贝。看着不错也很漂亮。
看着价位很不值，戴上还可以。
质量太垃圾了
相当给力，店主送的宝物俺很喜欢！！！将继续关注啊。
项链很漂亮，做工精致光泽度好，喜欢
夹子很好夹起来也很好看下次会再买
很满意的说
这个买回来真心好丑，<UNK>还是错的，下次不买了
项链和描述的一样，但是 快递小弟很讨厌
？？？？？？？假的，真是一分钱一分货，一点也没错
<UNK>们，被骗了，太垃圾了
靠 哥们算是被你们 <UNK>爹了
难看差评难看差评难看差评
掉色严重，两天就花了，不建议购买
不满意，态度不行。
啥东西。 不好 上当了
买的是女款 却给了个男款
快递太慢了，一份价钱一份货吧
这个一般！
好，非常好看，有<UNK><UNK>
没用就不见了
一般吧，只是一面水晶，后面<UNK>了不太喜欢
链子不怎么样，连个首饰盒都没有。
质量还可以，还送了小礼物。上面的<UNK>可能<UNK>。
……<UNK>
物流很难
下个<UNK>


In [26]:
def linear_transfer(good, bad, candi_list, cls):
    z_good = model.xToz(good)
    z_bad = model.xToz(bad)
    o_good = model.zTox(z_good)[0]
    o_bad = model.zTox(z_bad)[0]
    z = z_bad - z_good
    print 'good:%s | bad:%s' % (o_good, o_bad)
    for c in candi_list:
        z_c = model.xToz(c)
        o_r = model.zTox(z_c, cls)[0]
        o_tb = model.zTox(z_c + z, cls)[0]
        o_tg = model.zTox(z_c - z, cls)[0]
        print "-----------------------------------------------------------"
        print "origin: %s | reconstruct: %s" % (c, o_r)
        print "transfer_g: %s | transfer_b: %s " % (o_tg, o_tb)
        
def interpolations(A, B, cls, n=10):
    z1 = model.xToz(A)
    z2 = model.xToz(B)
    z_list = [z1+((z2-z1)/n)*i for i in range(1,n+1)]
    print "-----------------------------------------------------------"
    print A
    for z in z_list:
        print model.zTox(z, cls)[0]
    print B

def senti_transfer(A):
    z = model.xToz(A)
    o1 = model.zTox(z, 0)[0]
    o2 = model.zTox(z, 1)[0]
    o3 = model.zTox(z, 2)[0]
    print "origin: %s | top: %s | mid: %s | down %s" % (A, o1, o2, o3)
    
def generate():
    o = model.generate(c=0)
    print 'generate top:'
    for oo in o:
        print '【',oo,'】'
    print ""
    o = model.generate(c=1)
    print 'generate mid:'
    for oo in o:
        print '【',oo,'】'
    print ""
    o = model.generate(c=2)
    print 'generate down:'
    for oo in o:
        print '【',oo,'】'
    print ""

In [27]:
for i in range(12, 15):
    model.restore('/root/VAE/char_vae_model/runs/1513646690/model-%d' % i)
    generate()
    for s in candi:
        senti_transfer(s)
    

INFO:tensorflow:Restoring parameters from /root/VAE/char_vae_model/runs/1513646690/model-9
restore /root/VAE/char_vae_model/runs/1513646690/model-9 success
generate top:
【 质量有保证，喜欢的不得了，很好 】
【 包装不错，戒指非常满意。 】
【 发货很给力，宝贝收到质量很好，店家人很好很满意的一次购物！ 】

generate mid:
【 卖家非常非常非常非常非常好，款式很时尚，非常满意的一次网购哦！值得推荐 】
【 <UNK>是假的，很大方。欺骗消费者 】
【 很精致很小也很小的啊。 】

generate down:
【 虽然简单，但还是不错的 】
【 项链不错哦，女朋友说是真好的说~ 】
【 很好，款式新颖，带着很大方，非常漂亮，给个好评 】

origin: 货收到了，我很满意。手镯很好。下次还会来的。 | top: 货收到了，快递很给力。戴着很好。满意的一次购物。 | mid: 货收到了，快递很给力。戴着很好的。。。 | down 货收到了，快递很给力。戴着很好的。。。
origin: 很不错的一款项链，非常满意。 | top: 很不错的一款宝贝，非常满意。 | mid: 很不错的一款宝贝，非常满意。 | down 很不错的一款宝贝，非常满意。
origin: 东西很好 就是有点大啊 | top: 东西很好 就是有点大啊 | mid: 东西很好 就是有点大啊 | down 东西很好 就是有点大啊
origin: 收到宝贝。看着不错也很漂亮。 | top: 收到货了。很漂亮哦。也很漂亮 | mid: 收到货了。看起来很漂亮。也不错 | down 收到项链还可以。不满意。
origin: 看着价位很不值，戴上还可以。 | top: 总体来说还好，带上也很好看？ | mid: 还阔以，感觉也不是很好看 | down 项链有点细，总体还算满意。
origin: 质量太垃圾了 | top: 质量太垃圾了， | mid: 质量太垃圾了， | down 质量太垃圾了，
origin: 相当给力，店主送的宝物俺很喜欢！！！将继续关注啊。 | top: 店主人很棒，给同事买的生日礼物！她很喜欢！客服也有

origin: <UNK>们，被骗了，太垃圾了 | top: 超级好看，这款项链很好看，快递也很快， | mid: 超级无敌<UNK>，真的是醉了，很丑！ | down 超级无敌丑，真的是<UNK>，<UNK>了！
origin: 靠 哥们算是被你们 <UNK>爹了 | top: 好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好好 | mid: <UNK> 让我等了好久才来评价 | down 差评 我等了好久才来评价 <UNK>
origin: 难看差评难看差评难看差评 | top: 满意，同事们都说好看 | mid: 假货。同事们都说好看 | down 假货。跟图片上一样
origin: 掉色严重，两天就花了，不建议购买 | top: 太差劲了，链子细得可怜，不建议购买 | mid: 太容易刮花了，而且还不对称，不值得购买 | down 太差劲了，一扯就褪色了，劝别买
origin: 不满意，态度不行。 | top: 不错，两天就到了。 | mid: 不满意，送货太慢了。 | down 不满意，送货太慢了。
origin: 啥东西。 不好 上当了 | top: 质量不好，物流慢，5天才到！ | mid: 质量不好，送货速度太慢！ | down 质量不好，差评 差评差评差评
origin: 买的是女款 却给了个男款 | top: 我要的是情侣号 结果给我发错了 | mid: 说的25号 结果是18厘米的 | down 说的25号 结果是18厘米的
origin: 快递太慢了，一份价钱一份货吧 | top: 送货速度太慢了，10天才到。 | mid: 送货速度太慢了，10号才到。 | down 快递太慢了，4天才到2天
origin: 这个一般！ | top: 这个还好吧！ | mid: 东西一般吧！ | down 东西一般吧！
origin: 好，非常好看，有<UNK><UNK> | top: 好看，很有个性，要是<UNK>的话就更好了 | mid: 好看，很有个性，要是<UNK>的话就更好了 | down 不好，很容易断，而且又大又短又<UNK>
origin: 没用就不见了 | top: 哈哈哈哈，很强势 | mid: 将就用，不耐用 | down 一扎就断了，
origin: 一般吧，只

origin: 物流很难 | top: 很好看，很喜欢 | mid: <UNK> | down <UNK>
origin: 下个<UNK> | top: 和图上的不一样 | mid: 和图片上的不一样 | down 和图片上完全不一样
INFO:tensorflow:Restoring parameters from /root/VAE/char_vae_model/runs/1513646690/model-14
restore /root/VAE/char_vae_model/runs/1513646690/model-14 success
generate top:
【 挺好的，就是大了点 】
【 质量挺好的，颜色也很正 】
【 还可以，就是有点小贵 】

generate mid:
【 刚买的时候还没带过几天就掉色了 】
【 感觉不值这个价！！！！！！！！！！！！！！！！！！！！！！！！！！！ 】
【 假的是假的。。。。。。。。。。。。。。。。。。。。。。。。。。 】

generate down:
【 ******************************* 】
【 没有想像中的那么好 】
【 差评 差评 差评 差评 差评 差评差评差评差评差评差评差评差评差评 】

origin: 货收到了，我很满意。手镯很好。下次还会来的。 | top: 质量不错，卖家服务态度好，物流也快，下次还会再来的 | mid: 还可以，就是快递太慢了，等了四天才到 | down 还可以，就是快递太慢了，等的花儿都谢了
origin: 很不错的一款项链，非常满意。 | top: 质量很好，做工精致，款式漂亮，价格实惠，值得购买。 | mid: 不咋滴。。。。。。。。。。。。。。。。。。。。。。。。。。。。 | down 不咋地。。。。。。。。。。。。。。。。。。。。。。。。。。。。
origin: 东西很好 就是有点大啊 | top: 东西挺好的 女朋友很喜欢 就是物流有点慢 | mid: 东西挺好的 就是链子有点细 | down 东西挺好的 就是不知道会不会掉色
origin: 收到宝贝。看着不错也很漂亮。 | top: 还可以，就是不知道会不会掉色 | mid: 还可以，就是不知道会不会掉色 | down 还可以就是不知道时间长了会不会掉色
origin: 看着价位很不

In [13]:
for z in candi[:10]:
    print 'origin:', z
    senti_transfer(z, 1)

origin: 货收到了，我很满意。手镯很好。下次还会来的。


ValueError: Cannot feed value of shape (1, 16) for Tensor u'concat_1:0', which has shape '(?, 20)'