In [1]:
import random
import os

import keras
import numpy as np
from keras.callbacks import LambdaCallback
from keras.models import Input, Model, load_model
from keras.layers import LSTM, Dropout, Dense
from keras.optimizers import Adam

from data_utils import *


Using TensorFlow backend.


In [2]:
class PoetryModel(object):
    def __init__(self, config):
        self.model = None
        self.do_train = True
        self.loaded_model = True
        self.config = config

        # 文件预处理
        self.word2numF, self.num2word, self.words, self.files_content = preprocess_file(self.config)
        
        # 诗的list
        self.poems = self.files_content.split(']')
        # 诗的总数量
        self.poems_num = len(self.poems)
        
        # 如果模型文件存在则直接加载模型，否则开始训练
        if os.path.exists(self.config.weight_file) and self.loaded_model:
            self.model = load_model(self.config.weight_file)
        else:
            self.train()

    def build_model(self):
        '''建立模型'''
        print('building model')

        # 输入的dimension
        input_tensor = Input(shape=(self.config.max_len, len(self.words)))
        lstm = LSTM(512, return_sequences=True)(input_tensor)
        dropout = Dropout(0.6)(lstm)
        lstm = LSTM(256)(dropout)
        dropout = Dropout(0.6)(lstm)
        dense = Dense(len(self.words), activation='softmax')(dropout)
        self.model = Model(inputs=input_tensor, outputs=dense)
        optimizer = Adam(lr=self.config.learning_rate)
        self.model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])

    def sample(self, preds, temperature=1.0):
        '''
        当temperature=1.0时，模型输出正常
        当temperature=0.5时，模型输出比较open
        当temperature=1.5时，模型输出比较保守
        在训练的过程中可以看到temperature不同，结果也不同
        就是一个概率分布变换的问题，保守的时候概率大的值变得更大，选择的可能性也更大
        '''
        preds = np.asarray(preds).astype('float64')
        exp_preds = np.power(preds,1./temperature)
        preds = exp_preds / np.sum(exp_preds)
        pro = np.random.choice(range(len(preds)),1,p=preds)
        return int(pro.squeeze())
    
    def generate_sample_result(self, epoch, logs):
        '''训练过程中，每4个epoch打印出当前的学习情况'''
        if epoch % 4 != 0:
            return
        
        with open('out/out.txt', 'a',encoding='utf-8') as f:
            f.write('==================Epoch {}=====================\n'.format(epoch))
                
        print("\n==================Epoch {}=====================".format(epoch))
        for diversity in [0.7, 1.0, 1.3]:
            print("------------Diversity {}--------------".format(diversity))
            generate = self.predict_random(temperature=diversity)
            print(generate)
            
            # 训练时的预测结果写入txt
            with open('out/out.txt', 'a',encoding='utf-8') as f:
                f.write(generate+'\n')
    
    def predict_random(self,temperature = 1):
        '''随机从库中选取一句开头的诗句，生成五言绝句'''
        if not self.model:
            print('model not loaded')
            return
        
        index = random.randint(0, self.poems_num)
        sentence = self.poems[index][: self.config.max_len]
        generate = self.predict_sen(sentence,temperature=temperature)
        return generate
    
    def predict_first(self, char,temperature =1):
        '''根据给出的首个文字，生成五言绝句'''
        if not self.model:
            print('model not loaded')
            return
        
        index = random.randint(0, self.poems_num)
        #选取随机一首诗的最后max_len字符+给出的首个文字作为初始输入
        sentence = self.poems[index][1-self.config.max_len:] + char
        generate = str(char)
#         print('first line = ',sentence)
        # 直接预测后面23个字符
        generate += self._preds(sentence,length=23,temperature=temperature)
        return generate
    
    def predict_sen(self, text,temperature =1):
        '''根据给出的前max_len个字，生成诗句'''
        '''此例中，即根据给出的第一句诗句（含逗号），来生成古诗'''
        if not self.model:
            return
        max_len = self.config.max_len
        if len(text)<max_len:
            print('length should not be less than ',max_len)
            return

        sentence = text[-max_len:]
        print('the first line:',sentence)
        generate = str(sentence)
        generate += self._preds(sentence,length = 24-max_len,temperature=temperature)
        return generate
    
    def predict_hide(self, text,temperature = 1):
        '''根据给4个字，生成藏头诗五言绝句'''
        if not self.model:
            print('model not loaded')
            return
        if len(text)!=4:
            print('藏头诗的输入必须是4个字！')
            return
        
        index = random.randint(0, self.poems_num)
        #选取随机一首诗的最后max_len字符+给出的首个文字作为初始输入
        sentence = self.poems[index][1-self.config.max_len:] + text[0]
        generate = str(text[0])
        print('first line = ',sentence)
        
        for i in range(5):
            next_char = self._pred(sentence,temperature)           
            sentence = sentence[1:] + next_char
            generate+= next_char
        
        for i in range(3):
            generate += text[i+1]
            sentence = sentence[1:] + text[i+1]
            for i in range(5):
                next_char = self._pred(sentence,temperature)           
                sentence = sentence[1:] + next_char
                generate+= next_char

        return generate
    
    
    def _preds(self,sentence,length = 23,temperature =1):
        '''
        sentence:预测输入值
        lenth:预测出的字符串长度
        供类内部调用，输入max_len长度字符串，返回length长度的预测值字符串
        '''
        sentence = sentence[:self.config.max_len]
        generate = ''
        for i in range(length):
            pred = self._pred(sentence,temperature)
            generate += pred
            sentence = sentence[1:]+pred
        return generate
        
        
    def _pred(self,sentence,temperature =1):
        '''内部使用方法，根据一串输入，返回单个预测字符'''
        if len(sentence) < self.config.max_len:
            print('in def _pred,length error ')
            return
        
        sentence = sentence[-self.config.max_len:]
        x_pred = np.zeros((1, self.config.max_len, len(self.words)))
        for t, char in enumerate(sentence):
            x_pred[0, t, self.word2numF(char)] = 1.
        preds = self.model.predict(x_pred, verbose=0)[0]
        next_index = self.sample(preds,temperature=temperature)
        next_char = self.num2word[next_index]
        
        return next_char

    def data_generator(self):
        '''生成器生成数据'''
        i = 0
        while 1:
            x = self.files_content[i: i + self.config.max_len]
            y = self.files_content[i + self.config.max_len]

            if ']' in x or ']' in y:
                i += 1
                continue

            y_vec = np.zeros(
                shape=(1, len(self.words)),
                dtype=np.bool
            )
            y_vec[0, self.word2numF(y)] = 1.0

            x_vec = np.zeros(
                shape=(1, self.config.max_len, len(self.words)),
                dtype=np.bool
            )

            for t, char in enumerate(x):
                x_vec[0, t, self.word2numF(char)] = 1.0

            yield x_vec, y_vec
            i += 1

    def train(self):
        '''训练模型'''
        print('training')
        number_of_epoch = len(self.files_content)-(self.config.max_len + 1)*self.poems_num
        number_of_epoch /= self.config.batch_size 
        number_of_epoch = int(number_of_epoch / 1.5)
        print('epoches = ',number_of_epoch)
        print('poems_num = ',self.poems_num)
        print('len(self.files_content) = ',len(self.files_content))

        if not self.model:
            self.build_model()

        self.model.fit_generator(
            generator=self.data_generator(),
            verbose=True,
            steps_per_epoch=self.config.batch_size,
            epochs=number_of_epoch,
            callbacks=[
                keras.callbacks.ModelCheckpoint(self.config.weight_file, save_weights_only=False),
                LambdaCallback(on_epoch_end=self.generate_sample_result)
            ]
        )



In [3]:
from config import Config
model = PoetryModel(Config)

print('model loaded')

training
epoches =  34858
poems_num =  24027
len(self.files_content) =  1841397
building model


Epoch 1/34858


 1/32 [..............................] - ETA: 1:54 - loss: 8.6236 - acc: 0.0000e+00

 2/32 [>.............................] - ETA: 1:01 - loss: 8.6221 - acc: 0.0000e+00

 3/32 [=>............................] - ETA: 43s - loss: 8.6244 - acc: 0.0000e+00 

 4/32 [==>...........................] - ETA: 34s - loss: 8.6234 - acc: 0.0000e+00

 5/32 [===>..........................] - ETA: 28s - loss: 8.6231 - acc: 0.0000e+00

 6/32 [====>.........................] - ETA: 24s - loss: 8.6234 - acc: 0.0000e+00

 7/32 [=====>........................] - ETA: 22s - loss: 8.6228 - acc: 0.0000e+00




















































------------Diversity 0.7--------------
the first line: 孟月摄提贞，


孟月摄提贞，挥愍告汤葭届齑纛睨鷕匡呶醐潭秤袒裔穗
------------Diversity 1.0--------------
the first line: 萧萧度阊阖，


萧萧度阊阖，啻惚满暾楷抨，堕凹梗瞩纂褷匄螺狂枸抱
------------Diversity 1.3--------------
the first line: 南过三湘去，


南过三湘去，春挤蠖逢漓溃信蒐咤措侪页劳暌踌策颢郸
Epoch 2/34858


 1/32 [..............................] - ETA: 16s - loss: 8.5532 - acc: 0.0000e+00

 2/32 [>.............................] - ETA: 14s - loss: 8.5840 - acc: 0.0000e+00

 3/32 [=>............................] - ETA: 13s - loss: 8.5968 - acc: 0.0000e+00

 4/32 [==>...........................] - ETA: 12s - loss: 8.5690 - acc: 0.2500    

 5/32 [===>..........................] - ETA: 11s - loss: 8.5796 - acc: 0.2000

 6/32 [====>.........................] - ETA: 11s - loss: 8.5878 - acc: 0.1667

 7/32 [=====>........................] - ETA: 10s - loss: 8.5955 - acc: 0.1429



















































Epoch 3/34858


 1/32 [..............................] - ETA: 13s - loss: 9.2363 - acc: 0.0000e+00

 2/32 [>.............................] - ETA: 12s - loss: 5.7837 - acc: 0.0000e+00

 3/32 [=>............................] - ETA: 11s - loss: 6.7843 - acc: 0.0000e+00

 4/32 [==>...........................] - ETA: 11s - loss: 6.5340 - acc: 0.0000e+00

 5/32 [===>..........................] - ETA: 11s - loss: 6.6011 - acc: 0.0000e+00

 6/32 [====>.........................] - ETA: 10s - loss: 6.9615 - acc: 0.0000e+00

 7/32 [=====>........................] - ETA: 10s - loss: 7.2894 - acc: 0.0000e+00



















































Epoch 4/34858


 1/32 [..............................] - ETA: 11s - loss: 9.2351 - acc: 0.0000e+00

 2/32 [>.............................] - ETA: 11s - loss: 9.4414 - acc: 0.0000e+00

 3/32 [=>............................] - ETA: 11s - loss: 9.5090 - acc: 0.0000e+00

 4/32 [==>...........................] - ETA: 11s - loss: 9.4205 - acc: 0.0000e+00

 5/32 [===>..........................] - ETA: 10s - loss: 9.4928 - acc: 0.0000e+00

 6/32 [====>.........................] - ETA: 10s - loss: 8.2015 - acc: 0.1667    

 7/32 [=====>........................] - ETA: 10s - loss: 8.3304 - acc: 0.1429



















































Epoch 5/34858


 1/32 [..............................] - ETA: 13s - loss: 9.4139 - acc: 0.0000e+00

 2/32 [>.............................] - ETA: 12s - loss: 9.2006 - acc: 0.0000e+00

 3/32 [=>............................] - ETA: 12s - loss: 9.2132 - acc: 0.0000e+00

 4/32 [==>...........................] - ETA: 12s - loss: 7.6089 - acc: 0.0000e+00

 5/32 [===>..........................] - ETA: 11s - loss: 7.9426 - acc: 0.0000e+00

 6/32 [====>.........................] - ETA: 11s - loss: 8.0778 - acc: 0.0000e+00

 7/32 [=====>........................] - ETA: 11s - loss: 8.2296 - acc: 0.0000e+00



KeyboardInterrupt: 

In [9]:
for i in range(3):
    #藏头诗
    sen = model.predict_hide('争云日夏')
    print(sen)

first line =  翁夜往还。争
争音常开台，云来清子恩。日天扉青家，夏作浮音为。
first line =  啄江海隅。争
争空谁上尽，云云中林翠。日落危西烟，夏更无长塞。
first line =  珠坠还结。争
争独望云落，云华北山山。日远仙入还，夏红游长无。


In [10]:
for i in range(3):
    #给出第一句话进行预测
    sen = model.predict_sen('山为斜好几，')
    print(sen)

the first line: 山为斜好几，
山为斜好几，风外风玉正。东云水赏叶，先松句断采。
the first line: 山为斜好几，
山为斜好几，隐公帝碧自。开夜知孤满，下且露落鸟。
the first line: 山为斜好几，
山为斜好几，六池如中田。阙露奇雪前，然十盛空不。


In [11]:
for i in range(3):
    #给出第一个字进行预测
    sen = model.predict_first('山')
    print(sen)

山家光出观，隐黄戎识移。愿传兰重弦，飞方来凤为。
山迹几星道，寒行极幽直。方朝蝉家复，人经识子木。
山溪二屡正，归飞情尽宅。山未子华帝，花云新酒三。


In [12]:
for temp in [0.5,1,1.5]:
    #随机抽取第一句话进行预测
    sen = model.predict_random(temperature=temp)
    print(sen)

the first line: 十载别仙峰，
十载别仙峰，不春幽思入。山不春兰知，光三落台平。
the first line: 已沐识坚贞，
已沐识坚贞，薄欢月坐终。旗国去向仙，采成赠金露。
the first line: 水尔何如此，
水尔何如此，良不枝愿宁。中鹤四刺疑，境暮衣可独。
