In [1]:
import tensorflow as tf
from tensorflow.contrib import seq2seq
import numpy as np
from matplotlib import pylab as plt
import os
from datasets_creator import Datasets_creator

%matplotlib inline

  from ._conv import register_converters as _register_converters


Instructions for updating:
Use the retry module or similar alternatives.


In [2]:
pwd

'/Users/zhangshulin_work/Desktop/AI-Play/Git/couplets/tensorflow_imp_char'

In [3]:
lstm_units = 256
lstm_layers = 2
max_len = 30

In [4]:
MODEL_PATH = './model_save/model.ckpt'

In [5]:
creator = Datasets_creator('../datasets/all_couplets.txt', 20000, max_len)
char2index, index2char = creator.get_chars_dict()
vocabs_size = len(char2index)

In [6]:
def build_inputs():
    with tf.name_scope('input_placeholders'):
        input = tf.placeholder(shape=(None, None), dtype=tf.int32, name='input')
        label = tf.placeholder(shape=(None, None), dtype=tf.int32, name='label')
        keep_prob = tf.placeholder(dtype=tf.float32, name='keep_prob')
        
    return input, label, keep_prob

In [7]:
def build_lstm_cell(num_units, num_layers, keep_prob, batch_size):
    with tf.name_scope('lstm_cell'):
        def lstm():
            lstm = tf.nn.rnn_cell.BasicLSTMCell(num_units)
            dropout = tf.nn.rnn_cell.DropoutWrapper(lstm, output_keep_prob=keep_prob)
            return dropout
        
        cell = tf.nn.rnn_cell.MultiRNNCell([lstm() for _ in range(num_layers)])
        init_zero_state = cell.zero_state(batch_size, tf.float32)

    return cell, init_zero_state

In [8]:
def build_lstm_layer(cell, embed_input, init_state):
    with tf.name_scope('lstm_layer'):
        outputs, final_state = tf.nn.dynamic_rnn(cell, embed_input, initial_state=init_state)
        
    return outputs, final_state

In [9]:
def build_forward(cell, input, init_state):
    one_hot = tf.one_hot(input, vocabs_size, axis=-1)
        
    outputs, final_state = build_lstm_layer(cell, one_hot, init_state)
    
    logits = tf.layers.dense(outputs, vocabs_size, name='fc_layer')
    outputs = tf.nn.softmax(logits)
    
    return outputs, logits, final_state

In [10]:
def pick_char_from_top_n(preds, vocab_size, top_n=5):
    p = np.squeeze(preds)
    p[1] = 0
    p[0] = 0
    p[np.argsort(p)[:-top_n]] = 0
    p = p / np.sum(p)
    c = np.random.choice(vocab_size, 1, p=p)[0]
    
    return c

In [11]:
def sample(prime, top_n=5):
    samples = list(prime)
    
    tf.reset_default_graph()
    
    input_pl, label_pl, keep_prob_pl = build_inputs()
    cell_op, init_zero_state_op = build_lstm_cell(lstm_units, lstm_layers, 
                                                  tf.cast(tf.shape(input_pl)[0], tf.float32), 1)
    init_state_op = init_zero_state_op
    outputs_op, _, final_state_op = build_forward(cell_op, input_pl, init_state_op)
    
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    saver = tf.train.Saver()
    saver.restore(sess, MODEL_PATH) 
    
    init_state = sess.run(init_state_op)
    
    for char in prime:
        x = np.zeros((1, 1), dtype=np.int32)
        x[0, 0] = char2index.get(char, 1)
        
        feed_dict = {
            input_pl: x,
            keep_prob_pl: 1,
            init_state_op: init_state
        }
        
        outputs, final_state = sess.run([outputs_op, final_state_op], feed_dict=feed_dict)
        init_state = final_state
        
    if len(prime) != 0:
        pick_char_index = pick_char_from_top_n(outputs, vocabs_size, top_n)
        samples.append(index2char[pick_char_index])
    else:
        pick_char_index = 0
    
    while True:
        x = np.zeros((1, 1), dtype=np.int32)
        x[0, 0] = pick_char_index
        
        feed_dict = {
            input_pl: x,
            keep_prob_pl: 1,
            init_state_op: init_state
        }
        
        outputs, final_state = sess.run([outputs_op, final_state_op], feed_dict=feed_dict)
        init_state = final_state
        
        pick_char_index = pick_char_from_top_n(outputs, vocabs_size, top_n)
        pick_char = index2char[pick_char_index]
        samples.append(pick_char)
    
        if pick_char == '。':
            break
            
    sess.close()
    
    return ''.join(samples)

In [12]:
sample(prime='', top_n=3)

INFO:tensorflow:Restoring parameters from ./model_save/model.ckpt


'树上云天山水；风中天上水流。'

In [15]:
tf.logging.set_verbosity(tf.logging.ERROR)
for _ in range(10):
    print(sample(prime='', top_n=10))

匼句成诗书墨句；挥毫一片画云香。
金龙舞虎腾绵路；紫燕腾萦唱大歌。
玉兔金龙腾盛世；春风玉兔庆新门。
婧头一点花前句；小子双声水涌心。
树色无心，清风拂海；春风拂翠，花地清馨。
晾国；中国。
玉兔踏春，春风万里；黄金踏雪，春色春秋。
匼口；飞空。
匼尽心头心上老；洞为酒下水中秋。
玉兔迎春，龙腾瑞子迎辉瑞；黄莺朗瑞，春色春华玉气歌。
