In [1]:
import tensorflow as tf
import os
from six.moves import cPickle
import collections
import numpy as np
import codecs

### 读取和预处理数据

In [2]:
data_dir = "./data/sherlock/"


input_file = os.path.join(data_dir, "input.txt")
vocab_file = os.path.join(data_dir, "vocab.pkl")
tensor_file = os.path.join(data_dir, "data.npy")
    

with codecs.open(input_file, "r", encoding="utf-8") as f:
    data = f.read()
counter = collections.Counter(data)
counter_pairs = sorted(counter.items(), key=lambda x: -x[1])
chars, _ = zip(*counter_pairs)
vocab_size = len(chars)
vocab = dict(zip(chars, range(len(chars))))
tensor = np.array(list(map(vocab.get, data)))
    

### 把数据处理成batch

In [3]:
batch_size = 50
seq_length = 50
num_batches = int(tensor.size / (batch_size * seq_length))
if num_batches == 0:
    assert False, "Not enough data. Make seq_length and batch_size small."
tensor = tensor[:num_batches * batch_size * seq_length]
xdata = tensor
ydata = np.copy(tensor)
ydata[:-1] = xdata[1:]
ydata[-1] = xdata[0]
x_batches = np.split(xdata.reshape(batch_size, -1), num_batches, 1)
y_batches = np.split(ydata.reshape(batch_size, -1), num_batches, 1)
pointer = 0

### 定义RNN模型

In [6]:
from tensorflow.contrib import rnn
from tensorflow.contrib import legacy_seq2seq


class Model():
    def __init__(self, training=True):
        self.batch_size = 50
        self.seq_length = 50
        if not training:
            self.batch_size = 1
            self.seq_length = 1
            
        self.rnn_size = 128
        self.num_layers = 2
        self.input_keep_prob = 1.0
        self.output_keep_prob = 1.0
        self.grad_clip = 5.0
        self.training = 1

        cell_fn = rnn.BasicRNNCell

        cells = []
        for _ in range(self.num_layers):
            cell = cell_fn(self.rnn_size)
            # dropout
            if training and (self.input_keep_prob < 1.0 or self.output_keep_prob < 1.0):
                cell = rnn.DropoutWrapper(cell, input_keep_prob=self.input_keep_prob,
                                         output_keep_prob=self.output_keep_prob)
            cells.append(cell)

        self.cell = cell = rnn.MultiRNNCell(cells, state_is_tuple=True)

        # placeholder for input and output
        self.input_data = tf.placeholder(tf.int32, [self.batch_size, self.seq_length])
        self.output_data = tf.placeholder(tf.int32, [self.batch_size, self.seq_length])
        self.initial_state = cell.zero_state(self.batch_size, tf.float32)
        with tf.variable_scope("rnnlm"):
            softmax_w = tf.get_variable("softmax_w", [self.rnn_size, vocab_size])
            softmax_b = tf.get_variable("softmax_b", [vocab_size])
        embedding = tf.get_variable("embedding", [vocab_size, self.rnn_size])
        inputs = tf.nn.embedding_lookup(embedding, self.input_data)
        if training and self.output_keep_prob:
            inputs = tf.nn.dropout(inputs, self.output_keep_prob)

        inputs = tf.split(inputs, self.seq_length, 1)
        inputs = [tf.squeeze(input_, [1]) for input_ in inputs]

        def loop(prev, _):
            prev = tf.matmul(prev, softmax_w) + softmax_b
            prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
            return tf.nn.embedding_lookup(embedding, prev_symbol)

        outputs, last_state = legacy_seq2seq.rnn_decoder(inputs, self.initial_state, cell, 
                                                         loop_function=loop if not training else None, scope="rnnlm")
        output = tf.reshape(tf.concat(outputs, 1), [-1, self.rnn_size])
        self.logits = tf.matmul(output, softmax_w) + softmax_b
        self.probs = tf.nn.softmax(self.logits)
        loss = legacy_seq2seq.sequence_loss_by_example([self.logits],
                                                      [tf.reshape(self.output_data, [-1])],
                                                      [tf.ones([self.batch_size * self.seq_length])])
        with tf.name_scope("cost"):
            self.cost = tf.reduce_sum(loss) / batch_size / seq_length
        self.final_state = last_state
        self.lr = tf.Variable(0.0, trainable=False)
        # gradient clipping on trainable variables
        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars), self.grad_clip)

        # optimizer and train
        with tf.name_scope("optimizer"):
            optimizer = tf.train.AdamOptimizer(self.lr)
        self.train_op = optimizer.apply_gradients(zip(grads, tvars))
    
    def sample(self, sess, chars, vocab, num=200, prime="The "):
        state = sess.run(self.cell.zero_state(1, tf.float32))
        for char in prime[:-1]:
            x = np.zeros((1,1))
            x[0, 0] = vocab[char]
            feed = {self.input_data: x, self.initial_state: state}
            [state] = sess.run([self.final_state], feed)

        def weighted_pick(weights):
            t = np.cumsum(weights)
            s = np.sum(weights)
            return (int(np.searchsorted(t, np.random.rand(1)*s)))

        ret = prime
        char = prime[-1]
        for n in range(num):
            x = np.zeros((1,1))
            x[0, 0] = vocab[char]
            feed = {self.input_data: x, self.initial_state: state}
            [p, state] = sess.run([self.probs, self.final_state], feed)
            p = p[0]
            sample = weighted_pick(p)

            pred = chars[sample]
            ret += pred
            char = pred
        return ret



### 开始训练RNN

In [7]:
import time
num_epochs = 50
learning_rate = 0.002
decay_rate = 0.97
save_dir = "./save/"
save_every = 500


tf.reset_default_graph()
model = Model() 


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())   
    saver = tf.train.Saver(tf.global_variables())

    for e in range(num_epochs):
        sess.run(tf.assign(model.lr, learning_rate*(decay_rate ** e)))
        # reset the pointer to load from the beginning
        pointer = 0
        state = sess.run(model.initial_state)
        for b in range(num_batches):
            start = time.time()
            x, y = x_batches[pointer], y_batches[pointer]
            pointer += 1
            feed = {model.input_data: x, model.output_data: y}
            feed[model.initial_state] = state
                
            train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)
            end = time.time()
            print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch={:.3f}".format(
                e * num_batches + b, num_epochs * num_batches, e, train_loss, end-start))
            
            if ( e * num_batches + b ) % save_every == 0 or ( 
                e == num_epochs - 1 and b == num_batches - 1):
                checkpoint_path = os.path.join(save_dir, "model.ckpt")
                saver.save(sess, checkpoint_path, global_step=e*num_batches + b)
                print("model saved to {}".format(checkpoint_path))
                
        
        

0/67600 (epoch 0), train_loss = 4.570, time/batch=0.157
model saved to ./save/model.ckpt
1/67600 (epoch 0), train_loss = 4.346, time/batch=0.059
2/67600 (epoch 0), train_loss = 3.888, time/batch=0.070
3/67600 (epoch 0), train_loss = 3.604, time/batch=0.065
4/67600 (epoch 0), train_loss = 3.391, time/batch=0.072
5/67600 (epoch 0), train_loss = 3.303, time/batch=0.054
6/67600 (epoch 0), train_loss = 3.195, time/batch=0.055
7/67600 (epoch 0), train_loss = 3.135, time/batch=0.052
8/67600 (epoch 0), train_loss = 3.109, time/batch=0.061
9/67600 (epoch 0), train_loss = 3.043, time/batch=0.066
10/67600 (epoch 0), train_loss = 2.999, time/batch=0.052
11/67600 (epoch 0), train_loss = 3.015, time/batch=0.052
12/67600 (epoch 0), train_loss = 3.027, time/batch=0.063
13/67600 (epoch 0), train_loss = 3.009, time/batch=0.063
14/67600 (epoch 0), train_loss = 3.025, time/batch=0.056
15/67600 (epoch 0), train_loss = 2.966, time/batch=0.055
16/67600 (epoch 0), train_loss = 2.997, time/batch=0.076
17/67600

144/67600 (epoch 0), train_loss = 2.028, time/batch=0.100
145/67600 (epoch 0), train_loss = 2.071, time/batch=0.114
146/67600 (epoch 0), train_loss = 2.014, time/batch=0.109
147/67600 (epoch 0), train_loss = 2.026, time/batch=0.121
148/67600 (epoch 0), train_loss = 2.063, time/batch=0.094
149/67600 (epoch 0), train_loss = 1.960, time/batch=0.081
150/67600 (epoch 0), train_loss = 2.033, time/batch=0.063
151/67600 (epoch 0), train_loss = 2.002, time/batch=0.065
152/67600 (epoch 0), train_loss = 1.974, time/batch=0.067
153/67600 (epoch 0), train_loss = 2.044, time/batch=0.069
154/67600 (epoch 0), train_loss = 1.956, time/batch=0.067
155/67600 (epoch 0), train_loss = 2.043, time/batch=0.081
156/67600 (epoch 0), train_loss = 2.048, time/batch=0.103
157/67600 (epoch 0), train_loss = 2.024, time/batch=0.086
158/67600 (epoch 0), train_loss = 1.957, time/batch=0.092
159/67600 (epoch 0), train_loss = 2.016, time/batch=0.074
160/67600 (epoch 0), train_loss = 1.972, time/batch=0.064
161/67600 (epo

286/67600 (epoch 0), train_loss = 1.814, time/batch=0.083
287/67600 (epoch 0), train_loss = 1.841, time/batch=0.088
288/67600 (epoch 0), train_loss = 1.821, time/batch=0.101
289/67600 (epoch 0), train_loss = 1.888, time/batch=0.095
290/67600 (epoch 0), train_loss = 1.826, time/batch=0.064
291/67600 (epoch 0), train_loss = 1.782, time/batch=0.060
292/67600 (epoch 0), train_loss = 1.803, time/batch=0.082
293/67600 (epoch 0), train_loss = 1.806, time/batch=0.230
294/67600 (epoch 0), train_loss = 1.840, time/batch=0.087
295/67600 (epoch 0), train_loss = 1.832, time/batch=0.072
296/67600 (epoch 0), train_loss = 1.797, time/batch=0.072
297/67600 (epoch 0), train_loss = 1.780, time/batch=0.062
298/67600 (epoch 0), train_loss = 1.767, time/batch=0.061
299/67600 (epoch 0), train_loss = 1.854, time/batch=0.064
300/67600 (epoch 0), train_loss = 1.803, time/batch=0.071
301/67600 (epoch 0), train_loss = 1.828, time/batch=0.058
302/67600 (epoch 0), train_loss = 1.806, time/batch=0.061
303/67600 (epo

429/67600 (epoch 0), train_loss = 1.688, time/batch=0.075
430/67600 (epoch 0), train_loss = 1.766, time/batch=0.079
431/67600 (epoch 0), train_loss = 1.685, time/batch=0.107
432/67600 (epoch 0), train_loss = 1.705, time/batch=0.092
433/67600 (epoch 0), train_loss = 1.698, time/batch=0.060
434/67600 (epoch 0), train_loss = 1.683, time/batch=0.085
435/67600 (epoch 0), train_loss = 1.779, time/batch=0.067
436/67600 (epoch 0), train_loss = 1.727, time/batch=0.074
437/67600 (epoch 0), train_loss = 1.709, time/batch=0.065
438/67600 (epoch 0), train_loss = 1.736, time/batch=0.067
439/67600 (epoch 0), train_loss = 1.735, time/batch=0.069
440/67600 (epoch 0), train_loss = 1.686, time/batch=0.059
441/67600 (epoch 0), train_loss = 1.709, time/batch=0.065
442/67600 (epoch 0), train_loss = 1.677, time/batch=0.079
443/67600 (epoch 0), train_loss = 1.669, time/batch=0.064
444/67600 (epoch 0), train_loss = 1.673, time/batch=0.067
445/67600 (epoch 0), train_loss = 1.746, time/batch=0.062
446/67600 (epo

KeyboardInterrupt: 

In [None]:
saver.saver_def.filename_tensor_name

### 产生一些句子试一试

In [8]:
n = 500
prime = " "

tf.reset_default_graph()
model = Model(training=False)

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    saver = tf.train.Saver(tf.global_variables())
    ckpt = tf.train.get_checkpoint_state(save_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print(model.sample(sess, chars, vocab, n, prime).encode("utf-8"))


# with open(os.path.join(save_dir, "config.pkl"), "rb") as f:
#     saved_args = cPickle.load(f)
# with open(os.path.join(save_dir, "chars_vocab.pkl"), "rb") as f:
#     chars, vocab = cPickle.load(f)


INFO:tensorflow:Restoring parameters from ./save/model.ckpt-500
b' and interess."\n\n     "Louted-rooked to thone co we-with foo hervers the deadper, deroubn, for the tage.\n\n     "Boordinathere whot stanted an\n     nenlifer, frabobing and feong to ond\n     the\n     tell offred tushingow It rage wh\n     his musuppleds threr is gin come\'s the kevericelds, and a redorness, and a dees\n     rearse\n     pfors a saged and thrred to With on\n     the poring the wimed th he pater of mary he "What he glought upond upon cloopen ofred heriging up. To the strateen to the paces'
