In [1]:
import tensorflow as tf
import numpy as np
import argparse
import os
import random
import time
import collections

generateNum = 5                   # number of generated poems per time
type = "poetrySong"                   # dataset to use, shijing, songci, etc
trainPoems = "./dataset/" + type + "/" + type + ".txt" # training file location
checkpointsPath = "./checkpoints/" + type # checkpoints location


In [2]:
isEvaluate=False
poems = []
file = open(trainPoems, "r")
for line in file:  #every line is a poem
    title, author, poem = line.strip().split("::")  #get title and poem
    poem = poem.replace(' ','')
    if len(poem) < 10 or len(poem) > 512:  #filter poem
        continue
    if '_' in poem or '《' in poem or '[' in poem or '(' in poem or '（' in poem:
        continue
    poem = '[' + poem + ']' #add start and end signs
    poems.append(poem)
    #print(title, author, poem)

#counting words
wordFreq = collections.Counter()
for poem in poems:
    wordFreq.update(poem)
# print(wordFreq)

# erase words which are not common
#--------------------bug-------------------------
# word num less than original num, which causes nan value in loss function
# erase = []
# for key in wordFreq:
#     if wordFreq[key] < 2:
#         erase.append(key)
# for key in erase:
#     del wordFreq[key]

wordFreq[" "] = -1
wordPairs = sorted(wordFreq.items(), key = lambda x: -x[1])
words, freq = zip(*wordPairs)
wordNum = len(words)

wordToID = dict(zip(words, range(wordNum))) #word to ID
poemsVector = [([wordToID[word] for word in poem]) for poem in poems] # poem to vector
if isEvaluate: #evaluating need divide dataset into test set and train set
    trainVector = poemsVector[:int(len(poemsVector) * trainRatio)]
    testVector = poemsVector[int(len(poemsVector) * trainRatio):]
else:
    trainVector = poemsVector
    testVector = []
print("訓練樣本總數： %d" % len(trainVector))
print("測試樣本總數： %d" % len(testVector))

訓練樣本總數： 252478
測試樣本總數： 0


In [3]:
 def buildModel(wordNum, gtX, hidden_units = 128, layers = 2):
        """build rnn"""
        with tf.variable_scope("embedding"): #embedding
            embedding = tf.get_variable("embedding", [wordNum, hidden_units], dtype = tf.float32)
            inputbatch = tf.nn.embedding_lookup(embedding, gtX)

        basicCell = tf.contrib.rnn.BasicLSTMCell(hidden_units, state_is_tuple = True)
        stackCell = tf.contrib.rnn.MultiRNNCell([basicCell] * layers)
        initState = stackCell.zero_state(np.shape(gtX)[0], tf.float32)
        outputs, finalState = tf.nn.dynamic_rnn(stackCell, inputbatch, initial_state = initState)
        outputs = tf.reshape(outputs, [-1, hidden_units])

        with tf.variable_scope("softmax"):
            w = tf.get_variable("w", [hidden_units, wordNum])
            b = tf.get_variable("b", [wordNum])
            logits = tf.matmul(outputs, w) + b

        probs = tf.nn.softmax(logits)
        return logits, probs, stackCell, initState, finalState

In [4]:
def probsToWord(weights, words):
        """probs to word"""
        prefixSum = np.cumsum(weights) #prefix sum
        ratio = np.random.rand(1)
        index = np.searchsorted(prefixSum, ratio * prefixSum[-1]) # large margin has high possibility to be sampled
        return words[index[0]]

In [5]:
def test():
        """write regular poem"""
        print("genrating...")
        gtX = tf.placeholder(tf.int32, shape=[1, None])  # input
        logits, probs, stackCell, initState, finalState = buildModel(wordNum, gtX)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            checkPoint = tf.train.get_checkpoint_state(checkpointsPath)
            # if have checkPoint, restore checkPoint
            if checkPoint and checkPoint.model_checkpoint_path:
                saver.restore(sess, checkPoint.model_checkpoint_path)
                print("restored %s" % checkPoint.model_checkpoint_path)
            else:
                print("no checkpoint found!")
                exit(1)

            poems = []
            for i in range(generateNum):
                state = sess.run(stackCell.zero_state(1, tf.float32))
                x = np.array([[wordToID['[']]]) # init start sign
                probs1, state = sess.run([probs, finalState], feed_dict={gtX: x, initState: state})
                word = probsToWord(probs1, words)
                poem = ''
                sentenceNum = 0
                while word not in [' ', ']']:
                    poem += word
                    if word in ['。', '？', '！', '，']:
                        sentenceNum += 1
                        if sentenceNum % 2 == 0:
                            poem += '\n'
                    x = np.array([[wordToID[word]]])
                    #print(word)
                    probs2, state = sess.run([probs, finalState], feed_dict={gtX: x, initState: state})
                    word = probsToWord(probs2, words)
                print(poem)
                poems.append(poem)
            return poems

In [6]:
test()

genrating...
INFO:tensorflow:Restoring parameters from ./checkpoints/poetrySong/poetrySong-158999
restored ./checkpoints/poetrySong/poetrySong-158999
戚行岁晚湘林叶，那幸兹风继道吾。
更看台台寻相看，从今幽绝契寒青。

鹓茸如洗素玉心，满体未从人苦余。
盥底汨岌初下日，插虹银面萦罗纹。
昨朝翠秀千里重，今秋实是十年春。
子孙柏草十二日，南望老翁竞老翁。
我今结父良不自，却作田家架上第。
见人戏尽未知死，满饭尚令马颠倒。

书样金钱不朅梅，嶙张孟乐得亲嘉。
三梦清笼赖今日，不炊挈酒可遗味。

老掖惟春算，桑绩失偶陪。
静从归上担，新栏佐来难。
直笑千钱陋，恩休日愈经。
三年历辞子，一刻着僧书。

竹外云堆城彩低，观鱼父老著翁翁。
交成想是春应好，笑趁双鱼石鬭嗤。



['戚行岁晚湘林叶，那幸兹风继道吾。\n更看台台寻相看，从今幽绝契寒青。\n',
 '鹓茸如洗素玉心，满体未从人苦余。\n盥底汨岌初下日，插虹银面萦罗纹。\n昨朝翠秀千里重，今秋实是十年春。\n子孙柏草十二日，南望老翁竞老翁。\n我今结父良不自，却作田家架上第。\n见人戏尽未知死，满饭尚令马颠倒。\n',
 '书样金钱不朅梅，嶙张孟乐得亲嘉。\n三梦清笼赖今日，不炊挈酒可遗味。\n',
 '老掖惟春算，桑绩失偶陪。\n静从归上担，新栏佐来难。\n直笑千钱陋，恩休日愈经。\n三年历辞子，一刻着僧书。\n',
 '竹外云堆城彩低，观鱼父老著翁翁。\n交成想是春应好，笑趁双鱼石鬭嗤。\n']