# MLM实现对联模型

In [1]:
train_in = "couplet/train/in.txt"
train_out = "couplet/train/out.txt"
test_in = "couplet/test/in.txt"
test_out = "couplet/test/out.txt"

In [2]:
# define load data func
from typing import List

def load_data(filename: str) -> List[str]:
    with open(filename) as fd:
        return fd.read().split('\n')
    
train_in = load_data(train_in)
train_out = load_data(train_out)

In [3]:
# load tokenizer
import os
from bert4keras.tokenizers import Tokenizer, load_vocab

config_path = 'bert_models/albert_base_google_zh_additional_36k_steps/albert_config.json'
check_point_path = 'bert_models/albert_base_google_zh_additional_36k_steps/albert_model.ckpt'
vocab_path = 'bert_models/albert_base_google_zh_additional_36k_steps/vocab.txt'

token_dict = load_vocab(vocab_path)
tokenizer = Tokenizer(token_dict=token_dict)

Using TensorFlow backend.


In [4]:
from bert4keras.snippets import DataGenerator, sequence_padding
from tensorflow.keras.utils import to_categorical

MAXLEN = 50 # 编码的最大长度

class CoupletData(DataGenerator):
    def __iter__(self, random=False):
        batch_token_ids, batch_segment_ids, batch_label = [], [], []
        for is_end, data in self.sample(random=random):
            x, y = data
            token_id, segment_id = tokenizer.encode(x, maxlen=MAXLEN)
            token_id_label, _ = tokenizer.encode(y, maxlen=MAXLEN)
            
            batch_token_ids.append(token_id)
            batch_segment_ids.append(segment_id)
            batch_label.append(token_id_label)
            
            if len(batch_segment_ids) == self.batch_size or is_end:
                batch_token_ids = sequence_padding(batch_token_ids)
                batch_segment_ids = sequence_padding(batch_segment_ids)
                batch_label = sequence_padding(batch_label)
                yield [batch_token_ids, batch_segment_ids], to_categorical(batch_label, num_classes=len(token_dict))
                batch_token_ids, batch_segment_ids, batch_label = [], [], []

In [5]:
train_in[:10]

['晚 风 摇 树 树 还 挺 ',
 '愿 景 天 成 无 墨 迹 ',
 '丹 枫 江 冷 人 初 去 ',
 '忽 忽 几 晨 昏 ， 离 别 间 之 ， 疾 病 间 之 ， 不 及 终 年 同 静 好 ',
 '闲 来 野 钓 人 稀 处 ',
 '毋 人 负 我 ， 毋 我 负 人 ， 柳 下 虽 和 有 介 称 ， 先 生 字 此 ， 可 以 谥 此 ',
 '投 石 向 天 跟 命 斗 ',
 '深 院 落 滕 花 ， 石 不 点 头 龙 不 语 ',
 '不 畏 鸿 门 传 汉 祚 ',
 '新 居 落 成 创 业 始 ']

In [6]:
for data in CoupletData(zip(train_in[:100], train_out[:100]), batch_size=32):
    (token_id, segment_id), label = data
    print(token_id.shape)
    print(segment_id.shape)
    print(label.shape)

(32, 29)
(32, 29)
(32, 29, 21128)
(32, 22)
(32, 22)
(32, 22, 21128)
(32, 25)
(32, 25)
(32, 25, 21128)
(4, 18)
(4, 18)
(4, 18, 21128)


In [7]:
from bert4keras.models import build_transformer_model

model = build_transformer_model(config_path=config_path, checkpoint_path=check_point_path, model='albert', with_mlm=True)

In [8]:
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
Input-Token (InputLayer)        (None, None)         0                                            
__________________________________________________________________________________________________
Input-Segment (InputLayer)      (None, None)         0                                            
__________________________________________________________________________________________________
Embedding-Token (Embedding)     multiple             16226304    Input-Token[0][0]                
                                                                 MLM-Norm[0][0]                   
__________________________________________________________________________________________________
Embedding-Segment (Embedding)   (None, None, 768)    1536        Input-Segment[0][0]        

In [54]:
from keras.callbacks import Callback

from bert4keras.snippets import to_array

def next_couplet(text: str):
    """对下联接口"""
    token_id, segment_id = tokenizer.encode(text, maxlen=50)
    token_id, segment_id = to_array([token_id], [segment_id])
    y_pred = model.predict([token_id, segment_id])[0]
    return tokenizer.decode(y_pred.argmax(-1))

class EvalCallback(Callback):
    def __init__(self):
        self.lowest = 1e8
        
    def on_epoch_end(self, epoch, logs=None):
        if logs['loss'] < self.lowest:
            logs['loss'] = self.lowest
            model.save_weights('weights/couplet-albert-mlm-best.weights')
        
        self.just_show()
        
    def just_show(self):
        first = ['今日天气多云多美丽', 
                 '珍藏惟有诗三卷', 
                 '狂笔一挥天地动', 
                 '推窗问月诗何在',
                 '彩屏如画，望秀美崤函，花团锦簇']
        
        for each in first:
            print(" -", each)
            print("--", next_couplet(each))
            print()
            

In [10]:
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy

model.compile(optimizer=Adam(learning_rate=1e-5), loss=CategoricalCrossentropy())

In [None]:
train_data = CoupletData(data=zip(train_in, train_out))

In [59]:
model.fit(train_data.forfit(), epochs=20, steps_per_epoch=1000, callbacks=[EvalCallback()])

Epoch 1/20
 - 今日天气多云多美丽
-- 今年春光有月有和谐

 - 珍藏惟有诗三卷
-- 珍藏不无酒一分

 - 狂笔一挥天地动
-- 大心百载古今行

 - 推窗问月诗何在
-- 倚月吟风梦自来

 - 彩屏如画，望秀美崤函，花团锦簇
-- 春水如诗，看和明画业，画韵辉流

Epoch 2/20
 - 今日天气多云多美丽
-- 今年人光有月有和谐

 - 珍藏惟有诗三卷
-- 喜乐不为酒一行

 - 狂笔一挥天地动
-- 清风千载古今行

 - 推窗问月诗何在
-- 对酒吟风梦自来

 - 彩屏如画，望秀美崤函，花团锦簇
-- 春韵似诗，看和谐画业，画韵辉流

Epoch 3/20
 - 今日天气多云多美丽
-- 今年风光有月有和明

 - 珍藏惟有诗三卷
-- 不藏不无酒一行

 - 狂笔一挥天地动
-- 大心三载世今行

 - 推窗问月诗何在
-- 入月吟风梦自来

 - 彩屏如画，望秀美崤函，花团锦簇
-- 春韵如云，看和明画业，画韵辉流

Epoch 4/20
 - 今日天气多云多美丽
-- 今年人风有月有和谐

 - 珍藏惟有诗三卷
-- 雅雅不为酒一杯

 - 狂笔一挥天地动
-- 狂风三载日今欢

 - 推窗问月诗何在
-- 对酒吟花酒自来

 - 彩屏如画，望秀美崤函，花团锦簇
-- 彩月如诗，看和明画月，月韵春流

Epoch 5/20
 - 今日天气多云多美丽
-- 今年人光有月有和谐

 - 珍藏惟有诗三卷
-- 苦醉不无酒一杯

 - 狂笔一挥天地动
-- 大心千步古今飞

 - 推窗问月诗何在
-- 入笔吟风酒自来

 - 彩屏如画，望秀美崤函，花团锦簇
-- 彩笔如诗，看和明大业，花舞春流

Epoch 6/20
 - 今日天气多云多美丽
-- 今年人光有月少和谐

 - 珍藏惟有诗三卷
-- 珍乐当为酒一杯

 - 狂笔一挥天地动
-- 清心千载古今新

 - 推窗问月诗何在
-- 对笔吟花梦自来

 - 彩屏如画，望秀美崤函，花团锦簇
-- 彩韵如诗，看和谐文业，气韵辉流

Epoch 7/20
 - 今日天气多云多美丽
-- 今朝人光有月尽和明

 - 珍藏惟有诗三卷
-- 喜乐当无酒一杯

 - 狂笔一挥天地动
-- 高心千卷古今流

 - 推窗问月诗何在
-- 入月观风梦自来

 - 彩屏如画，望秀美崤函，花团锦

<keras.callbacks.callbacks.History at 0x7f048dbc3f70>

In [61]:
%timeit next_couplet('彩屏如画，望秀美崤函，花团锦簇')

7.97 ms ± 403 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [62]:
%time next_couplet('彩屏如画，望秀美崤函，花团锦簇')

CPU times: user 9.42 ms, sys: 29 ms, total: 38.5 ms
Wall time: 37.5 ms


'彩气似春，看和谐华业，鸟舞龙腾'