In [1]:
# 基于序列标注的思路对对联。

In [11]:
import codecs 
from keras.models import Model
from keras.layers import *
from keras.callbacks import Callback
from keras import backend as K

In [2]:
min_count = 2
maxlen = 16
batch_size = 64
char_size = 128

In [3]:
def read_data(txt_name):
    text = codecs.open(txt_name, encoding='utf-8').read()
    text = text.strip().split('\n')
    text = [l.strip().split(' ') for l in text]
    text = [l for l in text if len(l) <= maxlen] # 删除过长的对联
    return text

In [4]:
couplet_base_path="/Users/zhouwencheng/Desktop/Grass/data/txt/couplet"
train_x_path=couplet_base_path+"/train/in.txt"
train_y_path=couplet_base_path+"/train/out.txt"
test_x_path=couplet_base_path+"/test/in.txt"
test_y_path=couplet_base_path+"/test/out.txt"

x_train_txt = read_data(train_x_path)
y_train_txt = read_data(train_y_path)
x_test_txt = read_data(test_x_path)
y_test_txt = read_data(test_y_path)

In [5]:
chars = {}
for txt in [x_train_txt, y_train_txt, x_test_txt, y_test_txt]:
    for l in txt:
        for w in l:
            chars[w] = chars.get(w, 0) + 1
chars = {i:j for i,j in chars.items() if j >= min_count}
id2char = {i+1:j for i,j in enumerate(chars)}
char2id = {j:i for i,j in id2char.items()}

def string2id(s):
    # 0: <unk>
    return [char2id.get(c, 0) for c in s]

x_train = list(map(string2id, x_train_txt))
y_train = list(map(string2id, y_train_txt))
x_test = list(map(string2id, x_test_txt))
y_test = list(map(string2id, y_test_txt))

In [7]:
import numpy as np

# 按字数分组存放
train_dict = {}
test_dict = {}

for i,x in enumerate(x_train):
    j = len(x)
    if j not in train_dict:
        train_dict[j] = [[], []]
    train_dict[j][0].append(x)
    train_dict[j][1].append(y_train[i])
    
for i,x in enumerate(x_test):
    j = len(x)
    if j not in test_dict:
        test_dict[j] = [[], []]
    test_dict[j][0].append(x)
    test_dict[j][1].append(y_test[i])

for j in train_dict:
    train_dict[j][0] = np.array(train_dict[j][0])
    train_dict[j][1] = np.array(train_dict[j][1])

for j in test_dict:
    test_dict[j][0] = np.array(test_dict[j][0])
    test_dict[j][1] = np.array(test_dict[j][1])

In [8]:
def data_generator(data):
    data_p = [float(len(i[0])) for i in data.values()]
    data_p = np.array(data_p) / sum(data_p)
    while True: # 随机选一个字数，然后随机选样本，生成字数一样的一个batch
        idx = np.random.choice(len(data_p), p=data_p) + 1
        size = min(batch_size, len(data[idx][0]))
        idxs = np.random.choice(len(data[idx][0]), size=size)
        np.random.shuffle(idxs)
        yield data[idx][0][idxs], np.expand_dims(data[idx][1][idxs], 2)

In [9]:
def gated_resnet(x, ksize=3):
    # 门卷积 + 残差
    x_dim = K.int_shape(x)[-1]
    xo = Conv1D(x_dim*2, ksize, padding='same')(x)
    return Lambda(lambda x: x[0] * K.sigmoid(x[1][..., :x_dim]) \
                            + x[1][..., x_dim:] * K.sigmoid(-x[1][..., :x_dim]))([x, xo])


In [12]:
x_in = Input(shape=(None,))
x = x_in
x = Embedding(len(chars)+1, char_size)(x)
x = Dropout(0.25)(x)

x = gated_resnet(x)
x = gated_resnet(x)
x = gated_resnet(x)
x = gated_resnet(x)
x = gated_resnet(x)
x = gated_resnet(x)

x = Dense(len(chars)+1, activation='softmax')(x)

model = Model(x_in, x)
model.compile(loss='sparse_categorical_crossentropy',
              optimizer='adam')
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            (None, None)         0                                            
__________________________________________________________________________________________________
embedding_2 (Embedding)         (None, None, 128)    945408      input_2[0][0]                    
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, None, 128)    0           embedding_2[0][0]                
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               (None, None, 256)    98560       dropout_2[0][0]                  
____________________________________________________________________________________________

In [13]:
import logging
from logging import handlers
 
class Logger(object):
    level_relations = {
        'debug':logging.DEBUG,
        'info':logging.INFO,
        'warning':logging.WARNING,
        'error':logging.ERROR,
        'crit':logging.CRITICAL
    }#日志级别关系映射
 
    def __init__(self,filename,level='info',when='D',backCount=3,fmt='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'):
        self.logger = logging.getLogger(filename)
        format_str = logging.Formatter(fmt)#设置日志格式
        self.logger.setLevel(self.level_relations.get(level))#设置日志级别 
        th = handlers.TimedRotatingFileHandler(filename=filename,when=when,backupCount=backCount,encoding='utf-8')#往文件里写入#指定间隔时间自动生成文件的处理器
        #实例化TimedRotatingFileHandler
        #interval是时间间隔，backupCount是备份文件的个数，如果超过这个个数，就会自动删除，when是间隔的时间单位，单位有以下几种：
        # S 秒 M 分  H 小时、 D 天、 W 每星期（interval==0时代表星期一）
        # midnight 每天凌晨
        th.setFormatter(format_str)#设置文件里写入的格式 
        self.logger.addHandler(th)
# if __name__ == '__main__':
#     log = Logger('all.log',level='debug')
#     log.logger.debug('debug')
#     log.logger.info('info')
#     log.logger.warning('警告')
#     log.logger.error('报错')
#     log.logger.critical('严重') 

log = Logger('log/501_couplet.log',level='debug')

In [14]:
def couplet_match(s):
    # 输出对联
    # 先验知识：跟上联同一位置的字不能一样
    x = np.array([string2id(s)])
    y = model.predict(x)[0]
    for i,j in enumerate(x[0]):
        y[i, j] = 0.
    y = y[:, 1:].argmax(axis=1) + 1
    r = ''.join([id2char[i] for i in y])
    print(u'上联：%s，下联：%s' % (s, r))
    return r

In [15]:
save_path="/Users/zhouwencheng/Desktop/Grass/data/model/102_s2s_keras/501_couplet"+"/c_d_couplet.h5"
import logging
# 通过下面的方式进行简单配置输出方式与日志级别
logging.basicConfig(filename='log/501_couplet.log', level=logging.INFO)

class Evaluate(Callback):
    def __init__(self):
        self.lowest = 1e10
    def on_epoch_end(self, epoch, logs=None):
        log.logger.info(couplet_match(u'晚风摇树树还挺'))
        log.logger.info(couplet_match(u'今天天气不错'))
        log.logger.info(couplet_match(u'鱼跃此时海'))
        log.logger.info(couplet_match(u'只有香如故')) 
        # 保存最优结果
        if logs['val_loss'] <= self.lowest:
            self.lowest = logs['val_loss'] 
            log.logger.info("val_loss:"+str(self.lowest)) 
            model.save(save_path)
            log.logger.info("save model at:"+save_path) 

In [None]:
import os
from keras.models import load_model
if os.path.exists(save_path):
    model = load_model(save_path)

evaluator = Evaluate()
model.fit_generator(data_generator(train_dict),
                    steps_per_epoch=1000,
                    epochs=100,
                    validation_data=data_generator(test_dict),
                    validation_steps=100,
                    callbacks=[evaluator])

  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


Epoch 1/100
上联：晚风摇树树还挺，下联：明雨醉风风不飞
上联：今天天气不错，下联：大月国人无明
上联：鱼跃此时海，下联：鸟飞一处人
上联：只有香如故，下联：能无月似春
Epoch 2/100
上联：晚风摇树树还挺，下联：春雨落花花不浓
上联：今天天气不错，下联：古海人人无明
上联：鱼跃此时海，下联：花开其处人
上联：只有香如故，下联：无无道不新
Epoch 3/100
上联：晚风摇树树还挺，下联：秋雨照花山不闲
上联：今天天气不错，下联：大地国人无明
上联：鱼跃此时海，下联：花开一处人
上联：只有香如故，下联：方无月不新
Epoch 4/100
上联：晚风摇树树还挺，下联：寒雨照花花不深
上联：今天天气不错，下联：此地人人无清
上联：鱼跃此时海，下联：马开万处人
上联：只有香如故，下联：何无月若天
Epoch 5/100
上联：晚风摇树树还挺，下联：春雨落花花不浓
上联：今天天气不错，下联：大地月人无清
上联：鱼跃此时海，下联：鸟开一处人
上联：只有香如故，下联：何无月不春
Epoch 6/100

In [64]:
import logging
from logging import handlers
 
class Logger(object):
    level_relations = {
        'debug':logging.DEBUG,
        'info':logging.INFO,
        'warning':logging.WARNING,
        'error':logging.ERROR,
        'crit':logging.CRITICAL
    }#日志级别关系映射
 
    def __init__(self,filename,level='info',when='D',backCount=3,fmt='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'):
        self.logger = logging.getLogger(filename)
        format_str = logging.Formatter(fmt)#设置日志格式
        self.logger.setLevel(self.level_relations.get(level))#设置日志级别
        sh = logging.StreamHandler()#往屏幕上输出
        sh.setFormatter(format_str) #设置屏幕上显示的格式
        th = handlers.TimedRotatingFileHandler(filename=filename,when=when,backupCount=backCount,encoding='utf-8')#往文件里写入#指定间隔时间自动生成文件的处理器
        #实例化TimedRotatingFileHandler
        #interval是时间间隔，backupCount是备份文件的个数，如果超过这个个数，就会自动删除，when是间隔的时间单位，单位有以下几种：
        # S 秒 M 分  H 小时、 D 天、 W 每星期（interval==0时代表星期一）
        # midnight 每天凌晨
        th.setFormatter(format_str)#设置文件里写入的格式
        self.logger.addHandler(sh) #把对象加到logger里
        self.logger.addHandler(th)
if __name__ == '__main__':
    log = Logger('all.log',level='debug')
    log.logger.debug('debug')
    log.logger.info('info')
    log.logger.warning('警告')
    log.logger.error('报错')
    log.logger.critical('严重') 

# log = Logger('log/501_couplet.log',level='debug')


In [68]:
logging.exception("5")

E0814 11:16:40.479709 4321588096 <ipython-input-68-c2744b8de12c>:1] 我面试35
NoneType: None


In [69]:
#     on_epoch_end: logs include `acc` and `loss`, and
#         optionally include `val_loss`
#         (if validation is enabled in `fit`), and `val_acc`
#         (if validation and accuracy monitoring are enabled).
#     on_batch_begin: logs include `size`,
#         the number of samples in the current batch.
#     on_batch_end: logs include `loss`, and optionally `acc`
#         (if accuracy monitoring is enabled).

2019-08-14 11:19:26,251 - <ipython-input-69-6c8b667f265e>[line:33] - DEBUG: debug
I0814 11:19:26.251430 4321588096 <ipython-input-69-6c8b667f265e>:33] debug
2019-08-14 11:19:26,265 - <ipython-input-69-6c8b667f265e>[line:34] - INFO: info
I0814 11:19:26.265645 4321588096 <ipython-input-69-6c8b667f265e>:34] info
W0814 11:19:26.270969 4321588096 <ipython-input-69-6c8b667f265e>:35] 警告
2019-08-14 11:19:26,275 - <ipython-input-69-6c8b667f265e>[line:36] - ERROR: 报错
E0814 11:19:26.275985 4321588096 <ipython-input-69-6c8b667f265e>:36] 报错
2019-08-14 11:19:26,280 - <ipython-input-69-6c8b667f265e>[line:37] - CRITICAL: 严重
E0814 11:19:26.280014 4321588096 <ipython-input-69-6c8b667f265e>:37] CRITICAL - 严重
2019-08-14 11:19:26,285 - <ipython-input-69-6c8b667f265e>[line:38] - ERROR: error
E0814 11:19:26.285053 4321588096 <ipython-input-69-6c8b667f265e>:38] error
