In [1]:
import time
from distutils.version import LooseVersion
import numpy as np
import tensorflow as tf
from tensorflow.python.layers.core import Dense

assert LooseVersion(tf.__version__) >= LooseVersion("1.1")
print("Tensorflow Version: {}".format(tf.__version__))

Tensorflow Version: 1.4.0


In [2]:
# 配置参数
class ModelConfig():
    encoder_hidden_layers = [50, 50]
    decoder_hidden_layers = [50, 50]
    dropout_prob = 0.5
    encoder_embedding_size = 15
    decoder_embedding_size = 15
    

class TrainConfig():
    epochs = 10
    every_checkpoint = 100
    learning_rate = 0.01
    max_grad_norm = 3


class Config():
    batch_size = 128
    infer_prob = 0.2
    
    source_path = "data/letters_source.txt"
    target_path = "data/letters_target.txt"
    
    train = TrainConfig()
    model = ModelConfig()

In [3]:
config = Config()

In [4]:
# 生成数据

class DataGen():
    
    def __init__(self, config):
        self.source_path = config.source_path
        self.target_path = config.target_path
        
        self.source_char_to_int = {}
        self.source_int_to_char = {}
        self.target_char_to_int = {}
        self.target_int_to_char = {}
        
        self.source_data = []
        self.target_data = []
        
        
    def read_data(self):
        with open(self.source_path, "r") as f:
            source_char_to_int, source_int_to_char, source_data = self.gen_vocab_dict(f.read())
        self.source_char_to_int = source_char_to_int
        self.source_int_to_char = source_int_to_char
        self.source_data = source_data
            
        with open(self.target_path, "r") as f:
            target_char_to_int, target_int_to_char, target_data = self.gen_vocab_dict(f.read(), True)
        self.target_char_to_int = target_char_to_int
        self.target_int_to_char = target_int_to_char
        self.target_data = target_data
            
    def gen_vocab_dict(self, string, is_target=False):
        special_words = ['<PAD>', '<UNK>', '<GO>',  '<EOS>']
        vocab = list(set(string))
        vocab.remove("\n")
        vocab = special_words + vocab

        
        int_to_char = {index: char for index, char in enumerate(vocab)}
        char_to_int = {char: index for index, char in int_to_char.items()}

        word_list = string.strip().split("\n")
        if is_target:
            data = [[char_to_int.get(char, '<UNK>') for char in word] + [char_to_int['<EOS>']] for word in word_list]
        else:
            data = [[char_to_int.get(char, '<UNK>') for char in word] for word in word_list]
        return char_to_int, int_to_char, data
    

In [5]:
dataGen = DataGen(config)
dataGen.read_data()
print("source data: {}".format(dataGen.source_data[0]))
print("target data: {}".format(dataGen.target_data[0]))
print("source data length: {}".format(len(dataGen.source_data)))
print("target data length: {}".format(len(dataGen.target_data)))

source data: [8, 26, 6, 10, 10]
target data: [6, 8, 10, 10, 26, 3]
source data length: 10000
target data length: 10000


In [6]:
# 定义模型
class Seq2SeqModel():
    
    def __init__(self, config, encoder_vocab_size, target_char_to_int, is_infer=False):
        self.inputs = tf.placeholder(tf.int32, [None, None], name="inputs")
        self.targets = tf.placeholder(tf.int32, [None, None], name="targets")
        self.dropout_prob = tf.placeholder(tf.float32, name="dropout_prob")
        self.source_sequence_length = tf.placeholder(tf.int32, [None], name="source_sequence_length")
        self.target_sequence_length = tf.placeholder(tf.int32, [None], name="target_sequence_length")
        self.target_max_length = tf.reduce_max(self.target_sequence_length, name='target_max_length')
        
        decoder_output = self.seq2seq(config, encoder_vocab_size, target_char_to_int, is_infer)
        
        if is_infer:
            self.infer_logits = tf.identity(decoder_output.sample_id, "infer_logits")
            
        else:
            self.logits = tf.identity(decoder_output.rnn_output, "logits")
            
            masks = tf.sequence_mask(self.target_sequence_length, self.target_max_length, dtype=tf.float32, name="mask")

            with tf.name_scope("loss"):
                self.loss = tf.contrib.seq2seq.sequence_loss(self.logits, self.targets, masks)

            with tf.name_scope("accuracy"):
                self.predictions = tf.argmax(self.logits, 2)
                correctness = tf.equal(tf.cast(self.predictions, dtype=tf.int32), self.targets)
                self.accu = tf.reduce_mean(tf.cast(correctness, "float"), name="accu")
        
    def encoder(self, config, encoder_vocab_size):
        encoder_embed_input = tf.contrib.layers.embed_sequence(self.inputs, encoder_vocab_size, config.model.encoder_embedding_size)
        
        def get_lstm_cell(hidden_size):
            lstm_cell = tf.nn.rnn_cell.LSTMCell(hidden_size, state_is_tuple=True, 
                                                initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2))
            drop_cell = tf.nn.rnn_cell.DropoutWrapper(cell=lstm_cell, output_keep_prob=self.dropout_prob)
            
            return drop_cell
        
        cell = tf.nn.rnn_cell.MultiRNNCell([get_lstm_cell(hidden_size) for hidden_size in config.model.encoder_hidden_layers])
        outputs, final_state = tf.nn.dynamic_rnn(cell, encoder_embed_input, sequence_length=self.source_sequence_length, dtype=tf.float32)
        
        return outputs, final_state

    def decoder(self, config, encoder_state, target_char_to_int, is_infer):
        
        decoder_vocab_size = len(target_char_to_int)
        
        embeddings = tf.Variable(tf.random_uniform([decoder_vocab_size, config.model.decoder_embedding_size]))
        decoder_embed_input = tf.nn.embedding_lookup(embeddings, self.targets)
        
        def get_lstm_cell(hidden_size):
            lstm_cell = tf.nn.rnn_cell.LSTMCell(hidden_size, state_is_tuple=True, 
                                                initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2))
            drop_cell = tf.nn.rnn_cell.DropoutWrapper(cell=lstm_cell, output_keep_prob=self.dropout_prob)
            
            return drop_cell
        
        cell = tf.nn.rnn_cell.MultiRNNCell([get_lstm_cell(hidden_size) for hidden_size in config.model.decoder_hidden_layers])
        
        # 定义有Dense方法生成的全连接层
        output_layer = Dense(decoder_vocab_size, kernel_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1))
        
        # 定义训练时的decode的代码
        with tf.variable_scope("decode"):
            # 得到help对象，帮助读取数据
            train_helper = tf.contrib.seq2seq.TrainingHelper(inputs=decoder_embed_input, sequence_length=self.target_sequence_length)
            
            # 构建decoder
            train_decoder = tf.contrib.seq2seq.BasicDecoder(cell, train_helper, encoder_state, output_layer)
            train_decoder_output, train_state, train_sequence_length = tf.contrib.seq2seq.dynamic_decode(train_decoder, impute_finished=True, 
                                                                                                         maximum_iterations=self.target_max_length)
            
        
        # 定义预测时的decode代码
        with tf.variable_scope("decode", reuse=True):
            # 解码时的第一个时间步上的输入，之后的时间步上的输入是上一时间步的输出
            start_tokens = tf.tile(tf.constant([target_char_to_int["<GO>"]], dtype=tf.int32), [config.batch_size], name="start_tokens")
            
            # 解码时按贪心法解码，按照最大条件概率来预测输出值，该方法需要输入启动词和结束词，启动词是个一维tensor，结束词是标量
            infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embeddings, start_tokens, target_char_to_int["<EOS>"])
            infer_decoder = tf.contrib.seq2seq.BasicDecoder(cell, infer_helper, encoder_state, output_layer)
            infer_decoder_output, infer_state, infer_sequence_length = tf.contrib.seq2seq.dynamic_decode(infer_decoder,
                                                                                                           impute_finished=True,
                                                                                                          maximum_iterations=self.target_max_length)
            
        if is_infer:
            return infer_decoder_output
        
        return train_decoder_output
    
    def seq2seq(self, config, encoder_vocab_size, target_char_to_int, is_infer):
        """
        将encoder和decoder合并输出
        """
        encoder_output, encoder_state = self.encoder(config, encoder_vocab_size)
        
        decoder_output = self.decoder(config, encoder_state, target_char_to_int, is_infer)
        
        return decoder_output

In [7]:
# 定义其他的函数
def pad_batch(batch, char_to_int):
    sequence_length = [len(sequence) for sequence in batch]
    max_length = max(sequence_length)
    
    new_batch = [sequence + [char_to_int["<PAD>"]] * (max_length - len(sequence)) for sequence in batch]
    
    return sequence_length, max_length, new_batch    
    
def next_batch(source, target, batch_size, source_char_to_int, target_char_to_int):
    num_batches = len(source) // batch_size
    for i in range(num_batches):
        source_batch = source[i * batch_size: (i + 1) * batch_size]
        target_batch = target[i * batch_size: (i + 1) * batch_size]
        
        source_sequence_length, source_max_length, new_source_batch = pad_batch(source_batch, source_char_to_int)
        target_sequence_length, target_max_length, new_target_batch = pad_batch(target_batch, target_char_to_int)
        
        yield dict(source_batch=np.array(new_source_batch), target_batch=np.array(new_target_batch), 
                   source_sequence_length=np.array(source_sequence_length), target_sequence_length=np.array(target_sequence_length), 
                   target_max_length=target_max_length)

In [8]:
# 训练模型

class Engine():
    def __init__(self):
        self.config = Config()
        self.dataGen = DataGen(self.config)
        self.dataGen.read_data()
        self.sess = None
        self.global_step = 0
        
    def train_step(self, sess, train_op, train_model, params):
        
        feed_dict = {
            train_model.inputs: params["source_batch"],
            train_model.targets: params["target_batch"],
            train_model.dropout_prob: self.config.model.dropout_prob,
            train_model.source_sequence_length: params["source_sequence_length"],
            train_model.target_sequence_length: params["target_sequence_length"],
        }
        
        _, loss, accu = sess.run([train_op, train_model.loss, train_model.accu], feed_dict)
        
        return loss, accu
    
    def infer_step(self, sess, infer_model, params):
        
        feed_dict = {
            infer_model.inputs: params["source_batch"],
            infer_model.targets: params["target_batch"],
            infer_model.dropout_prob: 1.0,
            infer_model.source_sequence_length: params["source_sequence_length"],
            infer_model.target_sequence_length: params["target_sequence_length"],
        }
        
        logits = sess.run([infer_model.infer_logits], feed_dict)
        predictions = logits[0]
        
        prediction = [sequence[:end] for sequence in predictions for end in params["target_sequence_length"]]
        target = [sequence[:end] for sequence in params["target_batch"] for end in params["target_sequence_length"]]
        
        total = 0
        correct = 0
        for i in range(len(prediction)):
            for j in range(len(prediction[i])):
                if prediction[i][j] == target[i][j]:
                    correct += 1
            total += len(prediction[i])
            
        accu = correct / total
        
        return accu
    
    def run_epoch(self):
        config = self.config
        dataGen = self.dataGen
        
        source_data = dataGen.source_data
        target_data = dataGen.target_data
        
        train_split = int(len(source_data) * config.infer_prob)
        
        train_source_data = source_data[train_split:]
        infer_source_data = source_data[: train_split]
        
        train_target_data = target_data[train_split:]
        infer_target_data = target_data[: train_split]
        
        source_char_to_int = dataGen.source_char_to_int
        target_char_to_int = dataGen.target_char_to_int
        
        encoder_vocab_size = len(source_char_to_int)
        
        batch_size = config.batch_size
        
        with tf.Graph().as_default():
            session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
            sess = tf.Session(config=session_conf)
            with sess.as_default():
                with tf.name_scope("train"):
                    with tf.variable_scope("seq2seq"):
                        train_model = Seq2SeqModel(config, encoder_vocab_size, target_char_to_int, is_infer=False)
                        
                with tf.name_scope("infer"):
                    with tf.variable_scope("seq2seq", reuse=True):
                        infer_model = Seq2SeqModel(config, encoder_vocab_size, target_char_to_int, is_infer=True)
                
                global_step = tf.Variable(0, name="global_step", trainable=False)
                
                optimizer = tf.train.AdamOptimizer(config.train.learning_rate)
                grads_and_vars = optimizer.compute_gradients(train_model.loss)
                grads_and_vars = [(tf.clip_by_norm(g, config.train.max_grad_norm), v) for g, v in grads_and_vars if g is not None]
                train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step, name="train_op")
                
                saver = tf.train.Saver(tf.global_variables())
                sess.run(tf.global_variables_initializer())
                
                print("初始化完成，开始训练模型")
                for i in range(config.train.epochs):
                    for params in next_batch(train_source_data, train_target_data, batch_size, source_char_to_int, target_char_to_int):
                        loss, accu = self.train_step(sess, train_op, train_model, params)
                        current_step = tf.train.global_step(sess, global_step)
                        print("step: {}  loss: {}  accu: {}".format(current_step, loss, accu))
                        
                        if current_step % config.train.every_checkpoint == 0:
                            accus = []
                            for params in next_batch(infer_source_data, infer_target_data, batch_size, source_char_to_int, target_char_to_int):
                                accu = self.infer_step(sess, infer_model, params)
                                accus.append(accu)
                            print("\n")
                            print("Evaluation accuracy: {}".format(sum(accus) / len(accus)))
                            print("\n")
                            saver.save(sess, "model/my-model", global_step=current_step)
                            
engine = Engine()
engine.run_epoch()

初始化完成，开始训练模型
step: 1  loss: 3.3972744941711426  accu: 0.396484375
step: 2  loss: 3.352038860321045  accu: 0.5068359375
step: 3  loss: 3.2333476543426514  accu: 0.537109375
step: 4  loss: 3.133819818496704  accu: 0.5107421875
step: 5  loss: 3.1084017753601074  accu: 0.494140625
step: 6  loss: 3.019141674041748  accu: 0.5361328125
step: 7  loss: 3.019252061843872  accu: 0.5205078125
step: 8  loss: 2.961327314376831  accu: 0.4794921875
step: 9  loss: 2.9175174236297607  accu: 0.55078125
step: 10  loss: 2.9328219890594482  accu: 0.515625
step: 11  loss: 2.8361661434173584  accu: 0.55859375
step: 12  loss: 2.825479507446289  accu: 0.490234375
step: 13  loss: 2.8294155597686768  accu: 0.470703125
step: 14  loss: 2.733366012573242  accu: 0.5234375
step: 15  loss: 2.662925958633423  accu: 0.560546875
step: 16  loss: 2.6590795516967773  accu: 0.5185546875
step: 17  loss: 2.6025102138519287  accu: 0.541015625
step: 18  loss: 2.5877397060394287  accu: 0.5234375
step: 19  loss: 2.548031806945801  

step: 157  loss: 0.5990403294563293  accu: 0.869140625
step: 158  loss: 0.5435983538627625  accu: 0.90625
step: 159  loss: 0.6049606800079346  accu: 0.8759765625
step: 160  loss: 0.5833047032356262  accu: 0.8701171875
step: 161  loss: 0.5891702175140381  accu: 0.873046875
step: 162  loss: 0.5502305626869202  accu: 0.90234375
step: 163  loss: 0.5205056667327881  accu: 0.8974609375
step: 164  loss: 0.5394227504730225  accu: 0.87890625
step: 165  loss: 0.5013298988342285  accu: 0.8857421875
step: 166  loss: 0.5243359804153442  accu: 0.8935546875
step: 167  loss: 0.5390182733535767  accu: 0.8798828125
step: 168  loss: 0.4949582517147064  accu: 0.8896484375
step: 169  loss: 0.46689775586128235  accu: 0.9033203125
step: 170  loss: 0.5204634666442871  accu: 0.8994140625
step: 171  loss: 0.5004429221153259  accu: 0.8994140625
step: 172  loss: 0.47173967957496643  accu: 0.8994140625
step: 173  loss: 0.4594154953956604  accu: 0.904296875
step: 174  loss: 0.48389893770217896  accu: 0.904296875
st

step: 308  loss: 0.08454903960227966  accu: 0.9912109375
step: 309  loss: 0.12060148268938065  accu: 0.98046875
step: 310  loss: 0.08520105481147766  accu: 0.98828125
step: 311  loss: 0.09579800814390182  accu: 0.984375
step: 312  loss: 0.08985547721385956  accu: 0.990234375
step: 313  loss: 0.08534594625234604  accu: 0.9873046875
step: 314  loss: 0.09979809075593948  accu: 0.982421875
step: 315  loss: 0.08885108679533005  accu: 0.98828125
step: 316  loss: 0.0866524800658226  accu: 0.9912109375
step: 317  loss: 0.08490806072950363  accu: 0.986328125
step: 318  loss: 0.08909434080123901  accu: 0.9873046875
step: 319  loss: 0.10428498685359955  accu: 0.98046875
step: 320  loss: 0.09092546999454498  accu: 0.9873046875
step: 321  loss: 0.12563425302505493  accu: 0.9873046875
step: 322  loss: 0.09819359332323074  accu: 0.984375
step: 323  loss: 0.08511895686388016  accu: 0.9873046875
step: 324  loss: 0.10679618269205093  accu: 0.98046875
step: 325  loss: 0.08973217010498047  accu: 0.9882812

step: 461  loss: 0.06919383257627487  accu: 0.9892578125
step: 462  loss: 0.06006919592618942  accu: 0.98828125
step: 463  loss: 0.04837641492486  accu: 0.99609375
step: 464  loss: 0.049837395548820496  accu: 0.9921875
step: 465  loss: 0.044343285262584686  accu: 0.9951171875
step: 466  loss: 0.04446916654706001  accu: 0.9931640625
step: 467  loss: 0.03809773176908493  accu: 0.9931640625
step: 468  loss: 0.04920855164527893  accu: 0.9951171875
step: 469  loss: 0.03982819616794586  accu: 0.99609375
step: 470  loss: 0.04347032308578491  accu: 0.994140625
step: 471  loss: 0.044612783938646317  accu: 0.9970703125
step: 472  loss: 0.04375750198960304  accu: 0.994140625
step: 473  loss: 0.034561432898044586  accu: 0.998046875
step: 474  loss: 0.04676268622279167  accu: 0.99609375
step: 475  loss: 0.04529407247900963  accu: 0.9970703125
step: 476  loss: 0.04496853053569794  accu: 0.9931640625
step: 477  loss: 0.050967857241630554  accu: 0.9921875
step: 478  loss: 0.04266749322414398  accu: 0.

step: 608  loss: 0.04323248937726021  accu: 0.9931640625
step: 609  loss: 0.03463099151849747  accu: 0.99609375
step: 610  loss: 0.027712874114513397  accu: 0.998046875
step: 611  loss: 0.055524345487356186  accu: 0.98828125
step: 612  loss: 0.031027166172862053  accu: 0.99609375
step: 613  loss: 0.04517051950097084  accu: 0.9892578125
step: 614  loss: 0.026956038549542427  accu: 0.9970703125
step: 615  loss: 0.027233174070715904  accu: 0.9970703125
step: 616  loss: 0.04304413124918938  accu: 0.9921875
step: 617  loss: 0.05577976256608963  accu: 0.9912109375
step: 618  loss: 0.03609348088502884  accu: 0.994140625
step: 619  loss: 0.042129334062337875  accu: 0.9951171875
step: 620  loss: 0.05439337342977524  accu: 0.9892578125
