In [60]:
import tensorflow as tf
import os
from six.moves import cPickle
import collections
import numpy as np
import codecs
import jieba

### Load the Corpus

In [61]:
FILE_PATH = './data/诛仙.txt'
# Whether or not use Chinese split words, if false, use single chars to feed
USE_SPLIT = True                  

#### Load the book as a string

In [62]:
corpus_raw = u""

with codecs.open(FILE_PATH, 'r', 'utf-8') as book_file:
    corpus_raw += book_file.read()

print("Corpus is {} characters long".format(len(corpus_raw)))

Corpus is 3126269 characters long


### Process Corpus
##### Create lookup tables

In [77]:
def create_lookup_tables(text, use_split=USE_SPLIT):
    """
    Create lookup tables for vocab
    :param text: The corpus text split into words
    :return: A tuple of dicts (vocab_to_int, int_to_vocab)
    """
    words = list(jieba.cut(text))
    vocab = set(words) if use_split else set(text)
    
    int_to_vocab = {key: word for key, word in enumerate(vocab)}
    vocab_to_int = {word: key for key, word in enumerate(vocab)}
    
    if use_split:
        text_index = [vocab_to_int[word] for word in words]
    else:
        text_index = [vocab_to_int[word] for word in text]
    
    return vocab_to_int, int_to_vocab, text_index

##### Process data

In [81]:
vocab_to_int, int_to_vocab, corpus_int = create_lookup_tables(corpus_raw)
print("Vocabulary size : {}, number of Chinese words in text : {}".format(len(corpus_int), len(vocab_to_int)))

Vocabulary size : 2050766, number of Chinese words in text : 38012


# Build the Network
### Batch the Data

In [82]:
def get_batches(int_text, batch_size, seq_length):
    """
    Return batches of input and target data
    :param int_text: text with words replaced by their ids
    :param batch_size: the size that each batch of data should be
    :param seq_length: the length of each sequence
    :return: batches of data as a numpy array
    """
    words_per_batch = batch_size * seq_length
    num_batches = len(int_text)//words_per_batch
    int_text = int_text[:num_batches*words_per_batch]
    y = np.array(int_text[1:] + [int_text[0]])
    x = np.array(int_text)
    
    x_batches = np.split(x.reshape(batch_size, -1), num_batches, axis=1)
    y_batches = np.split(y.reshape(batch_size, -1), num_batches, axis=1)
    
    batch_data = list(zip(x_batches, y_batches))
    
    return np.array(batch_data)

### Hyperparameters

In [83]:
num_epochs = 10000
batch_size = 512
rnn_size = 512
num_layers = 3
keep_prob = 0.7
embed_dim = 512
seq_length = 30
learning_rate = 0.001
save_dir = './save'

### Build the Graph

In [99]:
train_graph = tf.Graph()
with train_graph.as_default():    
    
    # Initialize input placeholders
    input_text = tf.placeholder(tf.int32, [None, None], name='input')
    targets = tf.placeholder(tf.int32, [None, None], name='targets')
    lr = tf.placeholder(tf.float32, name='learning_rate')
    
    # Calculate text attributes
    vocab_size = len(int_to_vocab)
    input_text_shape = tf.shape(input_text)
    
    # Build the RNN cell
    lstm = tf.contrib.rnn.BasicLSTMCell(num_units=rnn_size)
    drop_cell = tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=keep_prob)
    cell = tf.contrib.rnn.MultiRNNCell([drop_cell] * num_layers)
    
    # Set the initial state
    initial_state = cell.zero_state(input_text_shape[0], tf.float32)
    initial_state = tf.identity(initial_state, name='initial_state')
    
    # Create word embedding as input to RNN
    embed = tf.contrib.layers.embed_sequence(input_text, vocab_size, embed_dim)
    
    # Build RNN
    outputs, final_state = tf.nn.dynamic_rnn(cell, embed, dtype=tf.float32)
    final_state = tf.identity(final_state, name='final_state')
    
    # Take RNN output and make logits
    logits = tf.contrib.layers.fully_connected(outputs, vocab_size, activation_fn=None)
    
    # Calculate the probability of generating each word
    probs = tf.nn.softmax(logits, name='probs')
    
    # Define loss function
    cost = tf.contrib.seq2seq.sequence_loss(
        logits,
        targets,
        tf.ones([input_text_shape[0], input_text_shape[1]])
    )
    
    # Learning rate optimizer
    optimizer = tf.train.AdamOptimizer(learning_rate)
    
    # Gradient clipping to avoid exploding gradients
    gradients = optimizer.compute_gradients(cost)
    capped_gradients = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gradients if grad is not None]
    train_op = optimizer.apply_gradients(capped_gradients)
    

### Train the Network

In [102]:
import time

batches = get_batches(corpus_int, batch_size, seq_length)
num_batches = len(batches)
start_time = time.time()

print("Num Batches :", num_batches)

with tf.Session(graph=train_graph) as sess:
    sess.run(tf.global_variables_initializer())
    
    for epoch in range(num_epochs):
        state = sess.run(initial_state, {input_text: batches[0][0]})
        
        for batch_index, (x, y) in enumerate(batches):
            feed_dict = {
                input_text: x,
                targets: y,
                initial_state: state,
                lr: learning_rate
            }
            train_loss, state, _ = sess.run([cost, final_state, train_op], feed_dict)
            
            if batch_index % 5 == 0:
                time_elapsed = time.time() - start_time
                print('Epoch {:>3} Batch {:>4}/{}   train_loss = {:.3f}   time_elapsed = {:.3f}   time_remaining = {:.0f}'.format(
                    epoch + 1,
                    batch_index + 1,
                    len(batches),
                    train_loss,
                    time_elapsed,
                    ((num_batches * num_epochs)/((epoch + 1) * (batch_index + 1))) * time_elapsed - time_elapsed))
                
                # save model every 5 batches
                saver = tf.train.Saver()
                saver.save(sess, save_dir)
                print('Model Trained and Saved')
            

Num Batches : 133
Epoch   1 Batch    1/133   train_loss = 10.546
Model Trained and Saved


KeyboardInterrupt: 

# Generate Text
### Pick a Random Word

In [103]:
def pick_word(probabilities, int_to_vocab):
    """
    Pick the next word with some randomness
    :param probabilities: Probabilites of the next word
    :param int_to_vocab: Dictionary of word ids as the keys and words as the values
    :return: String of the predicted word
    """
    return np.random.choice(list(int_to_vocab.values()), 1, p=probabilities)[0]


### Load the Graph and Generate

In [104]:
gen_length = 1000
prime_words = '我'

loaded_graph = tf.Graph()
with tf.Session(graph=loaded_graph) as sess:
    # Load the saved model
    loader = tf.train.import_meta_graph(save_dir + '.meta')
    loader.restore(sess, save_dir)
    
    # Get tensors from loaded graph
    input_text = loaded_graph.get_tensor_by_name('input:0')
    initial_state = loaded_graph.get_tensor_by_name('initial_state:0')
    final_state = loaded_graph.get_tensor_by_name('final_state:0')
    probs = loaded_graph.get_tensor_by_name('probs:0')
    
    # Sentences generation setup
    gen_sentences = list(jieba.cut(prime_words)) if USE_SPLIT else prime_words.split()
    prev_state = sess.run(initial_state, {input_text: np.array([[1 for word in gen_sentences]])})
    
    # Generate sentences
    for n in range(gen_length):
        # Dynamic Input
        dyn_input = [[vocab_to_int[word] for word in gen_sentences[-seq_length:]]]
        dyn_seq_length = len(dyn_input[0])

        # Get Prediction
        probabilities, prev_state = sess.run(
            [probs, final_state],
            {input_text: dyn_input, initial_state: prev_state})
        
        # Get predict word
        word_probs = probabilities[0][dyn_seq_length-1]
        pred_word = pick_word(word_probs, int_to_vocab)

        gen_sentences.append(pred_word)
        
    # Remove tokens
    chapter_text = ''.join(gen_sentences)
        
    print(chapter_text)

INFO:tensorflow:Restoring parameters from ./save
我 海上 日常用品 温润 抛开 以手 云海 神众 爬起来 异常 振荡 之至 激动不已 英雄末路 木架 间手 连个 手下败将 清光 以目 向後看 与此同时 笑了起来 佳品 越斗 恍若 嘘声四起 仰观 带来 和睦 逐级 小灰似 敬若天神 清泉 越慢 退居 称谢 上红下 栽 端端正正 二物 黄泉 浓如墨 连林 十道 猜测 小灰挠 龙头 千次 修复 四指 不是故意 哐啷 飞奔而去 大洞 诸多 离闻 微一 光压 神宫 李洵 拗断 笑了笑 二百四十 任何人 放不下 白白净净 暗叹 互相残杀 檐角 山中 宾主 废话 似林 乎 山路 搭理 过时 塔 纵然 玩物 本事 粗鲁 空牌 所惊 第二排 侍立 镇族 决定 人海茫茫 千百年 公然 路滑 狂噬 闪出 不整 傻小子 貌美如花 光剑 地气 木板门 这比 冷着 放倒 莫测高深 低眉 熊掌 冷夜 头绪 小灰喝 后患 临别时 暗地 绷 灾噩 连累 一提 高过 手里 直非 竹林精舍 曾带 帮衬 弄 相别 咬紧 但事 怪力 举杯 回谷 人挥 甩进 有意无意 能以 煌煌 密传 一魂 撕裂 色彩 这么久 惊起 残光 巨啸 看书 夯量 火苗 出卖 平锋玉尺 能胜到 满场 如牛 香袋 所赐 满满当当 这具 何愁 逆流 鼓七鸣 告慰 直攻 预料中 独居 地贴 钟乳石 屡屡 爪子 礼道 鬼眸 一真 不得已而为之 御岩术 杀人狂 那如水 叱骂 合适 大笑不止 大事 满座 不信 另一端 浮云 蹦起 般的 大批 形若 弱女子 说不下去 暂避 地别 理 近来 悠然 留住 雪琪面 功业 神秘莫测 通入 欲用 厮斗声 远看 自治州 死亡 密语 勒紧 提心吊胆 电击 比较 全败 前面 冲前 得心应手 研究 既入 已亮 其实 心不死 不去 痛恨 沁入 紧绷绷 零乱 吼叫声 咒来 精至 共分 极其重要 铃中 恶战 大家 还令 互通 娇媚 倒行逆施 命是 低眉 通天彻地 我怪 若无其事 迸发出 自不必说 持开 拖来拖去 征站 抬头 事关重大 推著 自许 两次 敌得过 精神支柱 轰然 扶摇 不让 四分之三 帮助 生之力 面前 青土 真火 错路 学田 打赌 冠绝 旧识 心障 并不比 修到 重插 式微 少说 虚怀若谷 那涸 无耻之徒 吸食 雅观 不肯 主毒 或用 八人 欲言又止 人类 

# Save the text

In [108]:
import os

with open('generated_text.txt', "w") as text_file:
    text_file.write(chapter_text)