In [None]:
import tensorflow as tf
import numpy as np
import argparse
import os
import random
import time
import collections

batchSize = 64
learningRateBase = 0.001
learningRateDecayStep = 1000
learningRateDecayRate = 0.95

epochNum = 10                    # train epoch
generateNum = 5                   # number of generated poems per time

type = "poetrySong"                   # dataset to use, shijing, songci, etc
trainPoems = "./dataset/" + type + "/" + type + ".txt" # training file location
checkpointsPath = "./checkpoints/" + type # checkpoints location
saveStep = 1000                   # save model every savestep

# evaluate
trainRatio = 0.8                    # train percentage
evaluateCheckpointsPath = "./checkpoints/evaluate"

In [None]:
isEvaluate=False
poems = []
file = open(trainPoems, "r")
for line in file:  #every line is a poem
    title, author, poem = line.strip().split("::")  #get title and poem
    poem = poem.replace(' ','')
    if len(poem) < 10 or len(poem) > 512:  #filter poem
        continue
    if '_' in poem or '《' in poem or '[' in poem or '(' in poem or '（' in poem:
        continue
    poem = '[' + poem + ']' #add start and end signs
    poems.append(poem)
    #print(title, author, poem)

#counting words
wordFreq = collections.Counter()
for poem in poems:
    wordFreq.update(poem)
wordFreq[" "] = -1
# print(wordFreq)

# erase words which are not common
#--------------------bug-------------------------
# word num less than original num, which causes nan value in loss function
# erase = []
# for key in wordFreq:
#     if wordFreq[key] < 2:
#         erase.append(key)
# for key in erase:
#     del wordFreq[key]
# print(wordFreq)


wordPairs = sorted(wordFreq.items(), key = lambda x: -x[1])
# print(wordPairs)

words, freq = zip(*wordPairs)
# print(words, freq)

wordNum = len(words)

wordToID = dict(zip(words, range(wordNum))) #word to ID
# print(wordToID)

poemsVector = [([wordToID[word] for word in poem]) for poem in poems] # poem to vector
if isEvaluate: #evaluating need divide dataset into test set and train set
    trainVector = poemsVector[:int(len(poemsVector) * trainRatio)]
    testVector = poemsVector[int(len(poemsVector) * trainRatio):]
else:
    trainVector = poemsVector
    testVector = []
# print(trainVector[0:100])

print("訓練樣本總數： %d" % len(trainVector))
print("測試樣本總數： %d" % len(testVector))

In [None]:
def generateBatch(isTrain=True):
        #padding length to batchMaxLength
        if isTrain:
            poemsVector = trainVector
        else:
            poemsVector = testVector

        random.shuffle(poemsVector)
        
        batchNum = (len(poemsVector) - 1) // batchSize
        X = []
        Y = []
        #create batch
        for i in range(batchNum):
            batch = poemsVector[i * batchSize: (i + 1) * batchSize]
            maxLength = max([len(vector) for vector in batch])
            temp = np.full((batchSize, maxLength), wordToID[" "], np.int32) # padding space
            for j in range(batchSize):
                temp[j, :len(batch[j])] = batch[j]
            X.append(temp)
            temp2 = np.copy(temp) #copy!!!!!!
            temp2[:, :-1] = temp[:, 1:]
            Y.append(temp2)
        return X, Y

In [None]:
 def buildModel(wordNum, gtX, hidden_units = 128, layers = 2):
        """build rnn"""
        with tf.variable_scope("embedding"): #embedding
            embedding = tf.get_variable("embedding", [wordNum, hidden_units], dtype = tf.float32)
            inputbatch = tf.nn.embedding_lookup(embedding, gtX)

        basicCell = tf.contrib.rnn.BasicLSTMCell(hidden_units, state_is_tuple = True)
        stackCell = tf.contrib.rnn.MultiRNNCell([basicCell] * layers)
        initState = stackCell.zero_state(np.shape(gtX)[0], tf.float32)
        outputs, finalState = tf.nn.dynamic_rnn(stackCell, inputbatch, initial_state = initState)
        outputs = tf.reshape(outputs, [-1, hidden_units])

        with tf.variable_scope("softmax"):
            w = tf.get_variable("w", [hidden_units, wordNum])
            b = tf.get_variable("b", [wordNum])
            logits = tf.matmul(outputs, w) + b

        probs = tf.nn.softmax(logits)
        return logits, probs, stackCell, initState, finalState

In [None]:
def train(reload=True):
        """train model"""
        print("training...")
        gtX = tf.placeholder(tf.int32, shape=[batchSize, None])  # input
        gtY = tf.placeholder(tf.int32, shape=[batchSize, None])  # output

        logits, probs, a, b, c = buildModel(wordNum, gtX)

        targets = tf.reshape(gtY, [-1])

        #loss
        loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example([logits], [targets],
                                                                  [tf.ones_like(targets, dtype=tf.float32)])
        globalStep = tf.Variable(0, trainable=False)
        addGlobalStep = globalStep.assign_add(1)

        cost = tf.reduce_mean(loss)
        trainableVariables = tf.trainable_variables()
        grads, a = tf.clip_by_global_norm(tf.gradients(cost, trainableVariables), 5) 
        # prevent loss divergence caused by gradient explosion
        learningRate = tf.train.exponential_decay(learningRateBase, global_step=globalStep,
                                                  decay_steps=learningRateDecayStep, decay_rate=learningRateDecayRate)
        optimizer = tf.train.AdamOptimizer(learningRate)
        trainOP = optimizer.apply_gradients(zip(grads, trainableVariables))


        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()

            if not os.path.exists(checkpointsPath):
                os.mkdir(checkpointsPath)

            if reload:
                checkPoint = tf.train.get_checkpoint_state(checkpointsPath)
                # if have checkPoint, restore checkPoint
                if checkPoint and checkPoint.model_checkpoint_path:
                    saver.restore(sess, checkPoint.model_checkpoint_path)
                    print("restored %s" % checkPoint.model_checkpoint_path)
                else:
                    print("no checkpoint found!")

            for epoch in range(epochNum):
                X, Y = generateBatch()
                epochSteps = len(X) # equal to batch
                for step, (x, y) in enumerate(zip(X, Y)):
                    a, loss, gStep = sess.run([trainOP, cost, addGlobalStep], feed_dict = {gtX:x, gtY:y})
                    print("epoch: %d, steps: %d/%d, loss: %3f" % (epoch + 1, step + 1, epochSteps, loss))
                    if gStep % saveStep == saveStep - 1: # prevent save at the beginning
                        print("save model")
                        saver.save(sess, os.path.join(checkpointsPath, type), global_step=gStep)

In [None]:
train()