In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from tensorflow.python.layers import core as core_layers
import tensorflow as tf
import numpy as np
import time
import myResidualCell
import jieba
from bleu import BLEU
import random
import cPickle
import matplotlib.pyplot as plt


class CLM:
    def __init__(self, dp, rnn_size, n_layers, decoder_embedding_dim, condition_embedding_dim, max_infer_length,
                 sess=tf.Session(), lr=0.001, grad_clip=5.0, beam_width=5, force_teaching_ratio=1.0, beam_penalty=1.0,
                residual=False, output_keep_prob=0.5, input_keep_prob=0.9, cell_type='lstm', reverse=False, is_save=True,
                decay_scheme='luong234'):
        
        self.rnn_size = rnn_size
        self.n_layers = n_layers
        self.grad_clip = grad_clip
        self.dp = dp
        self.decoder_embedding_dim = decoder_embedding_dim
        self.beam_width = beam_width
        self.beam_penalty = beam_penalty
        self.max_infer_length = max_infer_length
        self.residual = residual
        self.decay_scheme = decay_scheme
        self.condition_embedding_dim = condition_embedding_dim
        self.reverse = reverse
        self.cell_type = cell_type
        self.force_teaching_ratio = force_teaching_ratio
        self._output_keep_prob = output_keep_prob
        self._input_keep_prob = input_keep_prob
        self.is_save = is_save
        self.sess = sess
        self.lr=lr
        self.build_graph()
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(tf.global_variables(), max_to_keep = 15)
        self.summary_placeholders, self.update_ops, self.summary_op = self.setup_summary()
        
    # end constructor

    def build_graph(self):
        self.register_symbols()
        self.add_input_layer()
        self.add_condition()
        with tf.variable_scope('decode'):
            self.add_decoder_for_training()
        with tf.variable_scope('decode', reuse=True):
            self.add_decoder_for_prefix_inference()
        self.add_backward_path()
    # end method

    def add_input_layer(self):
        self.X = tf.placeholder(tf.int32, [None, None], name="X")
        self.X_seq_len = tf.placeholder(tf.int32, [None], name="X_seq_len")
        self.input_keep_prob = tf.placeholder(tf.float32,name="input_keep_prob")
        self.output_keep_prob = tf.placeholder(tf.float32,name="output_keep_prob")
        self.C = tf.placeholder(tf.int32, [None], name='Condition')
        self.batch_size = tf.shape(self.X)[0]
        self.global_step = tf.Variable(0, name="global_step", trainable=False)
    # end method

    def single_cell(self, reuse=False):
        if self.cell_type == 'lstm':
             cell = tf.contrib.rnn.LayerNormBasicLSTMCell(self.rnn_size, reuse=reuse)
        else:
            cell = tf.contrib.rnn.GRUBlockCell(self.rnn_size)    
        cell = tf.contrib.rnn.DropoutWrapper(cell, self.output_keep_prob, self.input_keep_prob)
        if self.residual:
            cell = myResidualCell.ResidualWrapper(cell)
        return cell

    def processed_decoder_input(self):
        main = tf.strided_slice(self.X, [0, 0], [self.batch_size, -1], [1, 1]) # remove last char
        decoder_input = tf.concat([tf.fill([self.batch_size, 1], self._x_go), main], 1)
        return decoder_input
    
    def add_condition(self):
        self.condition_embedding = tf.get_variable('condition_embedding', [3, self.condition_embedding_dim],
                                             tf.float32, tf.random_uniform_initializer(-1.0, 1.0))
        self.c_inputs = tf.nn.embedding_lookup(self.condition_embedding, self.C)
        hidden_state_list = []
        for i in range(self.n_layers * 1):
            if self.cell_type == 'gru':
                hidden_state_list.append(tf.layers.dense(self.c_inputs, self.rnn_size))
            else:
                hidden_state_list.append(tf.contrib.rnn.LSTMStateTuple(tf.layers.dense(self.c_inputs, self.rnn_size), tf.layers.dense(self.c_inputs, self.rnn_size))) 
        self.decoder_init_state = tuple(hidden_state_list)
        
    def add_decoder_for_training(self):
        self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(1 * self.n_layers)])
        decoder_embedding = tf.get_variable('word_embedding', [len(self.dp.X_w2id), self.decoder_embedding_dim],
                                             tf.float32, tf.random_uniform_initializer(-1.0, 1.0))
        emb = tf.nn.embedding_lookup(decoder_embedding, self.processed_decoder_input())
        inputs = tf.expand_dims(self.c_inputs, 1)
        inputs = tf.tile(inputs, [1, tf.shape(emb)[1], 1])
        inputs = tf.concat([emb, inputs],2) 
        training_helper = tf.contrib.seq2seq.TrainingHelper(
            inputs = inputs,
            sequence_length = self.X_seq_len,
            time_major = False)
        training_decoder = tf.contrib.seq2seq.BasicDecoder(
            cell = self.decoder_cell,
            helper = training_helper,
            initial_state = self.decoder_init_state, #.clone(cell_state=self.encoder_state),
            output_layer = core_layers.Dense(len(self.dp.X_w2id)))
        training_decoder_output, training_final_state, _ = tf.contrib.seq2seq.dynamic_decode(
            decoder = training_decoder,
            impute_finished = True,
            maximum_iterations = tf.reduce_max(self.X_seq_len))
        self.training_logits = training_decoder_output.rnn_output
        self.init_prefix_state = training_final_state

        
    def add_decoder_for_prefix_inference(self):
        decoder_embedding = tf.get_variable('word_embedding')
        self.beam_f = (lambda ids: tf.concat([tf.nn.embedding_lookup(decoder_embedding, ids), 
                                    tf.tile(tf.expand_dims(self.c_inputs, 1), 
                                            [1,int(tf.nn.embedding_lookup(decoder_embedding, ids).get_shape()[1]), 1]) if len(ids.get_shape()) !=1 
                                             else self.c_inputs], -1))
        predicting_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
            cell = self.decoder_cell,
            embedding = self.beam_f,
            start_tokens = tf.tile(tf.constant([self._x_go], dtype=tf.int32), [self.batch_size]),
            end_token = self._x_eos,
            initial_state = tf.contrib.seq2seq.tile_batch(self.init_prefix_state, self.beam_width),
            beam_width = self.beam_width,
            output_layer = core_layers.Dense(len(self.dp.X_w2id), _reuse=True),
            length_penalty_weight = self.beam_penalty)
        
        self.prefix_go = tf.placeholder(tf.int32, [None])
        prefix_go_beam = tf.tile(tf.expand_dims(self.prefix_go, 1), [1, self.beam_width])
        prefix_emb = self.beam_f(prefix_go_beam)
        predicting_decoder._start_inputs = prefix_emb
        predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
            decoder = predicting_decoder,
            impute_finished = False,
            maximum_iterations = self.max_infer_length)
        self.prefix_infer_outputs = predicting_decoder_output.predicted_ids
        self.score = predicting_decoder_output.beam_search_decoder_output.scores
            
    def add_backward_path(self):
        masks = tf.sequence_mask(self.X_seq_len, tf.reduce_max(self.X_seq_len), dtype=tf.float32)
        self.loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,
                                                     targets = self.X,
                                                     weights = masks)
        self.batch_loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,
                                                     targets = self.X,
                                                     weights = masks,
                                                     average_across_batch=False)
        params = tf.trainable_variables()
        gradients = tf.gradients(self.loss, params)
        clipped_gradients, _ = tf.clip_by_global_norm(gradients, self.grad_clip)
        self.learning_rate = tf.constant(self.lr)
        self.learning_rate = self.get_learning_rate_decay(self.decay_scheme)  # decay
        self.train_op = tf.train.AdamOptimizer(self.learning_rate).apply_gradients(zip(clipped_gradients, params), global_step=self.global_step)

    def register_symbols(self):
        self._x_go = self.dp.X_w2id['<GO>']
        self._x_eos = self.dp.X_w2id['<EOS>']
        self._x_pad = self.dp.X_w2id['<PAD>']
        self._x_unk = self.dp.X_w2id['<UNK>']
    
    def infer(self, input_word, C, batch_size=1, is_show=True):
        #input_word = jieba.cut(input_word)
        xx = [char for char in input_word]
        if self.reverse:
            xx = xx[::-1]
        length = [len(xx),] * batch_size
        input_indices = [[self.dp.X_w2id.get(char, self._x_unk) for char in xx]] * batch_size
        prefix_go = []
        for ipt in input_indices:
            prefix_go.append(ipt[-1])
        out_indices, scores = self.sess.run([self.prefix_infer_outputs, self.score], {
            self.X: input_indices, self.C: [C], self.X_seq_len: length, self.prefix_go: prefix_go, self.input_keep_prob:1,
                                                    self.output_keep_prob:1})
        outputs = []
        for idx in range(out_indices.shape[-1]):
            eos_id = self.dp.X_w2id['<EOS>']
            ot = out_indices[0,:,idx]
            if eos_id in ot:
                ot = ot.tolist()
                ot = ot[:ot.index(eos_id)]
                if self.reverse:
                    ot = ot[::-1]
            if self.reverse:
                output_str = ''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) + input_word
            else:
                output_str = input_word+''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot])
            outputs.append(output_str)
        return outputs
        
    
    def restore(self, path):
        self.saver.restore(self.sess, path)
        print 'restore %s success' % path
        
    def get_learning_rate_decay(self, decay_scheme='luong234'):
        num_train_steps = self.dp.num_steps
        if decay_scheme == "luong10":
            start_decay_step = int(num_train_steps / 2)
            remain_steps = num_train_steps - start_decay_step
            decay_steps = int(remain_steps / 10)  # decay 10 times
            decay_factor = 0.5
        else:
            start_decay_step = int(num_train_steps * 2 / 3)
            remain_steps = num_train_steps - start_decay_step
            decay_steps = int(remain_steps / 4)  # decay 4 times
            decay_factor = 0.5
        return tf.cond(
            self.global_step < start_decay_step,
            lambda: self.learning_rate,
            lambda: tf.train.exponential_decay(
                self.learning_rate,
                (self.global_step - start_decay_step),
                decay_steps, decay_factor, staircase=True),
            name="learning_rate_decay_cond")
    
    def setup_summary(self):
        train_loss = tf.Variable(0.)
        tf.summary.scalar('Train_loss', train_loss)
        
        test_loss = tf.Variable(0.)
        tf.summary.scalar('Test_loss', test_loss)
        
        bleu_score = tf.Variable(0.)
        tf.summary.scalar('BLEU_score', bleu_score)

        tf.summary.scalar('lr_rate', self.learning_rate)
        
        summary_vars = [train_loss, test_loss, bleu_score]
        summary_placeholders = [tf.placeholder(tf.float32) for _ in xrange(len(summary_vars))]
        update_ops = [summary_vars[i].assign(summary_placeholders[i]) for i in xrange(len(summary_vars))]
        summary_op = tf.summary.merge_all()
        return summary_placeholders, update_ops, summary_op

In [None]:
class CLM_DP:
    def __init__(self, X_indices, C, X_w2id, BATCH_SIZE, n_epoch):
        num_test = int(len(X_indices) * 0.1)
        self.n_epoch = n_epoch
        self.X_train = np.array(X_indices[num_test:])
        self.X_test = np.array(X_indices[:num_test])
        self.C_train = np.array(C[num_test:])
        self.C_test = np.array(C[num_test:])
        self.num_batch = int(len(self.X_train) / BATCH_SIZE)
        self.num_steps = self.num_batch * self.n_epoch
        self.batch_size = BATCH_SIZE
        self.X_w2id = X_w2id
        self.X_id2w = dict(zip(X_w2id.values(), X_w2id.keys()))
        self._x_pad = self.X_w2id['<PAD>']
        print 'Train_data: %d | Test_data: %d | Batch_size: %d | Num_batch: %d | X_vocab_size: %d ' % (len(self.X_train), len(self.X_test), BATCH_SIZE, self.num_batch, len(self.X_w2id))
        
    def next_batch(self, X, C):
        r = np.random.permutation(len(X))
        X = X[r]
        C = C[r]
        for i in range(0, len(X) - len(X) % self.batch_size, self.batch_size):
            X_batch = X[i : i + self.batch_size]
            C_batch = C[i : i + self.batch_size]
            padded_X_batch, X_batch_lens = self.pad_sentence_batch(X_batch, self._x_pad)
            yield (np.array(padded_X_batch),
                   X_batch_lens,
                   C_batch)
    
    def sample_test_batch(self):
        padded_X_batch, X_batch_lens = self.pad_sentence_batch(self.X_test[: self.batch_size], self._x_pad)
        c_batch = self.C_test[: self.batch_size]
        return np.array(padded_X_batch), X_batch_lens, c_batch
        
    def pad_sentence_batch(self, sentence_batch, pad_int):
        padded_seqs = []
        seq_lens = []
        max_sentence_len = max([len(sentence) for sentence in sentence_batch])
        for sentence in sentence_batch:
            padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))
            seq_lens.append(len(sentence))
        return padded_seqs, seq_lens


In [None]:
class CLM_util:
    def __init__(self, dp, model, display_freq=3):
        self.display_freq = display_freq
        self.dp = dp
        self.model = model
        
    def train(self, epoch):
        avg_loss = 0.0
        tic = time.time()
        X_test_batch, X_test_batch_lens, C_test_batch = self.dp.sample_test_batch()
        for local_step, (X_train_batch, X_train_batch_lens, C_train_batch) in enumerate(
            self.dp.next_batch(self.dp.X_train, self.dp.C_train)):
            self.model.step, _, loss = self.model.sess.run([self.model.global_step, self.model.train_op, self.model.loss], 
                                          {self.model.X: X_train_batch,
                                           self.model.C: C_train_batch,
                                           self.model.X_seq_len: X_train_batch_lens,
                                           self.model.output_keep_prob:self.model._output_keep_prob,
                                           self.model.input_keep_prob:self.model._input_keep_prob})
            avg_loss += loss
            """
            stats = [loss]
            for i in xrange(len(stats)):
                self.model.sess.run(self.model.update_ops[i], feed_dict={
                    self.model.summary_placeholders[i]: float(stats[i])
                })
            summary_str = self.model.sess.run([self.model.summary_op])
            self.summary_writer.add_summary(summary_str, self.model.step + 1)
            """
            if local_step % (self.dp.num_batch / self.display_freq) == 0:
                val_loss = self.model.sess.run(self.model.loss, {self.model.X: X_test_batch,
                                                     self.model.C: C_test_batch,
                                                     self.model.X_seq_len: X_test_batch_lens,
                                                     self.model.output_keep_prob:1,
                                                     self.model.input_keep_prob:1})
                print "Epoch %d/%d | Batch %d/%d | Train_loss: %.3f | Test_loss: %.3f | Time_cost:%.3f" % (epoch, self.n_epoch, local_step, self.dp.num_batch, avg_loss / (local_step + 1), val_loss, time.time()-tic)
                self.cal()
                tic = time.time()
        return avg_loss / self.dp.num_batch
    
    def test(self):
        avg_loss = 0.0
        local_step = 0
        for local_step, (X_test_batch, X_test_batch_lens, C_test_batch) in enumerate(
            self.dp.next_batch(self.dp.X_test, self.dp.C_test)):
            val_loss = self.model.sess.run(self.model.loss, {self.model.X: X_test_batch,
                                                             self.model.C: C_test_batch,
                                                             self.model.X_seq_len: X_test_batch_lens,
                                                             self.model.output_keep_prob:1,
                                                             self.model.input_keep_prob:1})
            avg_loss += val_loss
        return avg_loss / (local_step + 1)
    
    def fit(self, train_dir, is_bleu):
        self.n_epoch = self.dp.n_epoch
        test_loss_list = []
        train_loss_list = []
        time_cost_list = []
        bleu_list = []
        #timestamp = str(int(time.time()))
        #out_dir = os.path.abspath(os.path.join(train_dir, "runs", timestamp))
        out_dir = train_dir
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)
        print "Writing to %s" % out_dir
        checkpoint_prefix = os.path.join(out_dir, "model")
        self.summary_writer = tf.summary.FileWriter(os.path.join(out_dir, 'Summary'), self.model.sess.graph)
        for epoch in range(1, self.n_epoch+1):
            tic = time.time()
            train_loss = self.train(epoch)
            train_loss_list.append(train_loss)
            test_loss = self.test()
            test_loss_list.append(test_loss)
            toc = time.time()
            time_cost_list.append((toc - tic))
            if is_bleu:
                bleu = self.test_bleu()
                bleu_list.append(bleu)
                print "Epoch %d/%d | Train_loss: %.3f | Test_loss: %.3f | Bleu: %.3f" % (epoch, self.n_epoch, train_loss, test_loss, bleu)
            else:
                bleu = 0.0
                print "Epoch %d/%d | Train_loss: %.3f | Test_loss: %.3f" % (epoch, self.n_epoch, train_loss, test_loss)
            
            print '============================================'
            stats = [train_loss, test_loss, bleu]
            for i in xrange(len(stats)):
                self.model.sess.run(self.model.update_ops[i], feed_dict={
                    self.model.summary_placeholders[i]: float(stats[i])
                })
            summary_str = self.model.sess.run(self.model.summary_op)
            self.summary_writer.add_summary(summary_str, epoch)
            if self.model.is_save:
                cPickle.dump((train_loss_list, test_loss_list, time_cost_list, bleu_list), open(os.path.join(out_dir,"res.pkl"),'wb'))
                path = self.model.saver.save(self.model.sess, checkpoint_prefix, global_step=epoch)
                print "Saved model checkpoint to %s" % path
    
    def show(self, sent, id2w):
        return "".join([id2w.get(idx, u'&') for idx in sent])
    
    def cal(self, n_example=5):
        train_n_example = int(n_example / 2)
        test_n_example = n_example - train_n_example
        
        for c in range(3):
            if c == 0:
                print 'top:'
            elif c == 1:
                print 'middle:'
            else:
                print 'down:'
            train_examples = random.sample(self.dp.X_train, train_n_example)
            test_examples = random.sample(self.dp.X_test, test_n_example)
            for _ in range(train_n_example):
                example = self.show(train_examples[_][:-1], self.dp.X_id2w)
                if len(example) < 3:
                    continue
                length = random.randint(1, len(example)-2)
                o = self.model.infer(example[:length], C=c)[0]
                print 'Train_Input: %s | Output: %s | GroundTruth: %s' % (example[:length], o, example)

            for _ in range(test_n_example):
                example = self.show(test_examples[_][:-1], self.dp.X_id2w)
                if len(example) < 3:
                    continue
                length = random.randint(1, len(example)-2)
                o = self.model.infer(example[:length], C=c)[0]
                print 'Test_Input: %s | Output: %s | GroundTruth: %s' % (example[:length], o, example)
            print ""
    """    
    def test_bleu(self, N=300, gram=4):
        all_score = []
        for i in range(N):
            input_indices = self.show(self.dp.X_test[i][:-1], self.dp.X_id2w)
            o = self.model.infer(input_indices)[0]
            refer4bleu = [[' '.join([self.dp.X_id2w.get(w, u'&') for w in self.dp.X_test[i]])]]
            candi = [' '.join(w for w in o)]
            score = BLEU(candi, refer4bleu, gram=gram)
            all_score.append(score)
        return np.mean(all_score)
    """
    def show_res(self, path):
        res = cPickle.load(open(path))
        plt.figure(1)
        plt.title('The results') 
        l1, = plt.plot(res[0], 'g')
        l2, = plt.plot(res[1], 'r')
        l3, = plt.plot(res[3], 'b')
        plt.legend(handles = [l1, l2, l3], labels = ["Train_loss","Test_loss","BLEU"], loc = 'best')
        plt.show()
        
    def test_all(self, path, epoch_range, is_bleu=True):
        val_loss_list = []
        bleu_list = []
        for i in range(epoch_range[0], epoch_range[-1]):
            self.model.restore(path + str(i))
            val_loss = self.test()
            val_loss_list.append(val_loss)
            if is_bleu:
                bleu_score = self.test_bleu()
                bleu_list.append(bleu_score)
        plt.figure(1)
        plt.title('The results') 
        l1, = plt.plot(val_loss_list,'r')
        l2, = plt.plot(bleu_list,'b')
        plt.legend(handles = [l1, l2], labels = ["Test_loss","BLEU"], loc = 'best')
        plt.show()
        
    

In [None]:

top, mid, down = cPickle.load(open('3JD_top_mid_down_indices.pkl'))
w2id, id2w = cPickle.load(open('3JD_w2id_id2w.pkl'))

data_C = []
for t in top:
    data_C.append(0)
for t in mid:
    data_C.append(1)
for t in down:
    data_C.append(2)



In [None]:
BATCH_SIZE = 128
NUM_EPOCH = 15
X_indices = top + mid + down
train_dir ='JD_rnn/fmodel/'
#X_indices = X_indices[:40]
dp = CLM_DP(X_indices, data_C, w2id, BATCH_SIZE, n_epoch=NUM_EPOCH)
g = tf.Graph() 
sess = tf.Session(graph=g) 
with sess.as_default():
    with sess.graph.as_default():
        model = CLM(
            dp = dp,
            rnn_size = 1024,
            n_layers = 1,
            decoder_embedding_dim = 1000,
            condition_embedding_dim = 24,
            cell_type='gru',
            max_infer_length = 40,
            is_save = True,
            #residual=True,
            sess= sess
        )
        
util = CLM_util(dp=dp, model=model)




In [None]:
util.fit(train_dir=train_dir, is_bleu=False)