In [1]:
import tensorflow as tf
import matplotlib
from tensorflow import keras
from bert.tokenization import FullTokenizer
from tensorflow.keras.layers import Lambda
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import multiprocessing
for model in [tf,keras,matplotlib,np,pd]:
    print(model.__name__,model.__version__)

tensorflow 2.0.0-alpha0
tensorflow.python.keras.api._v2.keras 2.2.4-tf
matplotlib 3.1.1
numpy 1.16.4
pandas 0.24.2


In [2]:
#导入数据
def read_tsv(path):
    with open(path, "r", encoding="utf8") as file:
        data = [d.strip().split("\t") for d in file.readlines()]
    return data[1:]
def split_x_y(data):
    x_train = [d[1] for d in data]
    y_train = [int(d[0]) for d in data]
    return x_train, y_train
def load_data():
    data_dir = "./data"
    train_data = read_tsv(os.path.join(data_dir, 'train.tsv'))[:2000]
    test_data = read_tsv(os.path.join(data_dir, 'test.tsv'))[:500]
    dev_data = read_tsv(os.path.join(data_dir, 'dev.tsv'))[:300]
    x_train, y_train = split_x_y(train_data)
    x_test, y_test = split_x_y(test_data)
    x_dev, y_dev = split_x_y(dev_data)
    assert len(x_train) == len(y_train)
    assert len(x_test) == len(y_test)
    assert len(x_dev) == len(y_dev)
    return x_train, y_train, x_test, y_test, x_dev, y_dev
x_train, y_train, x_test, y_test, x_dev, y_dev = load_data()

In [3]:
x_train[:5]

['选择珠江花园的原因就是方便，有电动扶梯直接到达海边，周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般，但还算整洁。 泳池在大堂的屋顶，因此很小，不过女儿倒是喜欢。 包的早餐是西式的，还算丰富。 服务吗，一般',
 '15.4寸笔记本的键盘确实爽，基本跟台式机差不多了，蛮喜欢数字小键盘，输数字特方便，样子也很美观，做工也相当不错',
 '房间太小。其他的都一般。。。。。。。。。',
 '1.接电源没有几分钟,电源适配器热的不行. 2.摄像头用不起来. 3.机盖的钢琴漆，手不能摸，一摸一个印. 4.硬盘分区不好办.',
 '今天才知道这书还有第6卷,真有点郁闷:为什么同一套书有两种版本呢?当当网是不是该跟出版社商量商量,单独出个第6卷,让我们的孩子不会有所遗憾。']

In [4]:
y_train[:5]

[1, 1, 0, 0, 1]

In [5]:
def tokenize_data(input_str_batch, max_seq_len, model_dir):
    tokenizer = FullTokenizer(vocab_file=os.path.join(model_dir, "vocab.txt"), do_lower_case=True)
    input_ids_batch = []
    token_type_ids_batch = []
    for input_str in input_str_batch:
        input_tokens = tokenizer.tokenize(input_str)
        input_tokens = ["[CLS]"] + input_tokens + ["[SEP]"]

        print("input_tokens len:", len(input_tokens))

        input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
        if len(input_tokens) > max_seq_len:
            input_ids = input_ids[:max_seq_len]
        else:
            input_ids = input_ids + [0] * (max_seq_len - len(input_tokens))
        # token_type_ids = [0] * len(input_tokens) + [0] * (max_seq_len - len(input_tokens))
        token_type_ids = [0] * max_seq_len
        input_ids_batch.append(input_ids)
        token_type_ids_batch.append(token_type_ids)
    return input_ids_batch, token_type_ids_batch

In [6]:
#embedding sentence
max_seq_len = len([d for d in x_train if len(d) > 256])
model_dir = r"./data/models/chinese_L-12_H-768_A-12"
# pool = multiprocessing.Pool(3)
# results = []
# for data in [x_train, x_test, x_dev]:
#     result = pool.apply_async(tokenize_data, args=[data, max_seq_len, model_dir])
#     results.append(result.get())
# pool.close()
# pool.join()
# input_train_ids_batch, token_train_type_ids_batch = results[0]
# input_test_ids_batch, token_test_type_ids_batch = results[1]
# input_dev_ids_batch, token_dev_type_ids_batch = results[2]

input_train_ids_batch, token_train_type_ids_batch = tokenize_data(x_train,max_seq_len,model_dir)
input_test_ids_batch, token_test_type_ids_batch = tokenize_data(x_test,max_seq_len,model_dir)
input_dev_ids_batch, token_dev_type_ids_batch = tokenize_data(x_dev,max_seq_len,model_dir)

input_train_ids = np.array(input_train_ids_batch, dtype=np.int32)
token_train_type_ids = np.array(token_train_type_ids_batch, dtype=np.int32)
input_test_ids = np.array(input_test_ids_batch, dtype=np.int32)
token_test_type_ids = np.array(token_test_type_ids_batch, dtype=np.int32)
input_dev_ids = np.array(input_dev_ids_batch, dtype=np.int32)
token_dev_type_ids = np.array(token_dev_type_ids_batch, dtype=np.int32)

input_tokens len: 105
input_tokens len: 57
input_tokens len: 22
input_tokens len: 63
input_tokens len: 72
input_tokens len: 44
input_tokens len: 61
input_tokens len: 73
input_tokens len: 59
input_tokens len: 126
input_tokens len: 62
input_tokens len: 59
input_tokens len: 38
input_tokens len: 97
input_tokens len: 53
input_tokens len: 58
input_tokens len: 163
input_tokens len: 24
input_tokens len: 62
input_tokens len: 91
input_tokens len: 31
input_tokens len: 193
input_tokens len: 46
input_tokens len: 70
input_tokens len: 97
input_tokens len: 189
input_tokens len: 28
input_tokens len: 28
input_tokens len: 365
input_tokens len: 184
input_tokens len: 100
input_tokens len: 61
input_tokens len: 45
input_tokens len: 74
input_tokens len: 48
input_tokens len: 181
input_tokens len: 223
input_tokens len: 195
input_tokens len: 49
input_tokens len: 64
input_tokens len: 65
input_tokens len: 33
input_tokens len: 44
input_tokens len: 117
input_tokens len: 37
input_tokens len: 47
input_tokens len: 26
i

input_tokens len: 494
input_tokens len: 77
input_tokens len: 164
input_tokens len: 1062
input_tokens len: 81
input_tokens len: 41
input_tokens len: 112
input_tokens len: 66
input_tokens len: 36
input_tokens len: 102
input_tokens len: 83
input_tokens len: 66
input_tokens len: 26
input_tokens len: 53
input_tokens len: 42
input_tokens len: 69
input_tokens len: 53
input_tokens len: 51
input_tokens len: 155
input_tokens len: 141
input_tokens len: 65
input_tokens len: 45
input_tokens len: 73
input_tokens len: 191
input_tokens len: 25
input_tokens len: 175
input_tokens len: 23
input_tokens len: 74
input_tokens len: 64
input_tokens len: 20
input_tokens len: 40
input_tokens len: 60
input_tokens len: 182
input_tokens len: 89
input_tokens len: 116
input_tokens len: 60
input_tokens len: 62
input_tokens len: 809
input_tokens len: 40
input_tokens len: 50
input_tokens len: 193
input_tokens len: 121
input_tokens len: 46
input_tokens len: 101
input_tokens len: 112
input_tokens len: 212
input_tokens len

input_tokens len: 38
input_tokens len: 102
input_tokens len: 86
input_tokens len: 90
input_tokens len: 20
input_tokens len: 52
input_tokens len: 80
input_tokens len: 25
input_tokens len: 33
input_tokens len: 45
input_tokens len: 38
input_tokens len: 40
input_tokens len: 54
input_tokens len: 168
input_tokens len: 178
input_tokens len: 33
input_tokens len: 167
input_tokens len: 184
input_tokens len: 172
input_tokens len: 156
input_tokens len: 172
input_tokens len: 53
input_tokens len: 102
input_tokens len: 69
input_tokens len: 124
input_tokens len: 38
input_tokens len: 61
input_tokens len: 181
input_tokens len: 37
input_tokens len: 40
input_tokens len: 100
input_tokens len: 93
input_tokens len: 91
input_tokens len: 63
input_tokens len: 96
input_tokens len: 119
input_tokens len: 332
input_tokens len: 50
input_tokens len: 170
input_tokens len: 60
input_tokens len: 95
input_tokens len: 86
input_tokens len: 175
input_tokens len: 47
input_tokens len: 196
input_tokens len: 55
input_tokens len:

input_tokens len: 515
input_tokens len: 65
input_tokens len: 76
input_tokens len: 565
input_tokens len: 195
input_tokens len: 70
input_tokens len: 106
input_tokens len: 31
input_tokens len: 35
input_tokens len: 50
input_tokens len: 78
input_tokens len: 101
input_tokens len: 54
input_tokens len: 51
input_tokens len: 191
input_tokens len: 45
input_tokens len: 167
input_tokens len: 29
input_tokens len: 39
input_tokens len: 61
input_tokens len: 95
input_tokens len: 110
input_tokens len: 107
input_tokens len: 56
input_tokens len: 124
input_tokens len: 53
input_tokens len: 137
input_tokens len: 46
input_tokens len: 76
input_tokens len: 166
input_tokens len: 27
input_tokens len: 171
input_tokens len: 31
input_tokens len: 919
input_tokens len: 41
input_tokens len: 64
input_tokens len: 161
input_tokens len: 185
input_tokens len: 44
input_tokens len: 70
input_tokens len: 62
input_tokens len: 103
input_tokens len: 59
input_tokens len: 42
input_tokens len: 59
input_tokens len: 28
input_tokens len:

input_tokens len: 184
input_tokens len: 137
input_tokens len: 38
input_tokens len: 68
input_tokens len: 47
input_tokens len: 37
input_tokens len: 71
input_tokens len: 225
input_tokens len: 77
input_tokens len: 62
input_tokens len: 107
input_tokens len: 28
input_tokens len: 88
input_tokens len: 209
input_tokens len: 66
input_tokens len: 69
input_tokens len: 178
input_tokens len: 60
input_tokens len: 203
input_tokens len: 35
input_tokens len: 38
input_tokens len: 134
input_tokens len: 43
input_tokens len: 55
input_tokens len: 61
input_tokens len: 109
input_tokens len: 58
input_tokens len: 78
input_tokens len: 71
input_tokens len: 89
input_tokens len: 57
input_tokens len: 62
input_tokens len: 89
input_tokens len: 180
input_tokens len: 44
input_tokens len: 200
input_tokens len: 178
input_tokens len: 78
input_tokens len: 39
input_tokens len: 70
input_tokens len: 53
input_tokens len: 63
input_tokens len: 62
input_tokens len: 392
input_tokens len: 40
input_tokens len: 28
input_tokens len: 151

input_tokens len: 576
input_tokens len: 100
input_tokens len: 30
input_tokens len: 235
input_tokens len: 44
input_tokens len: 71
input_tokens len: 182
input_tokens len: 96
input_tokens len: 68
input_tokens len: 189
input_tokens len: 114
input_tokens len: 47
input_tokens len: 80
input_tokens len: 170
input_tokens len: 96
input_tokens len: 124
input_tokens len: 183
input_tokens len: 237
input_tokens len: 62
input_tokens len: 135
input_tokens len: 37
input_tokens len: 90
input_tokens len: 95
input_tokens len: 177
input_tokens len: 44
input_tokens len: 67
input_tokens len: 50
input_tokens len: 40
input_tokens len: 35
input_tokens len: 76
input_tokens len: 61
input_tokens len: 44
input_tokens len: 310
input_tokens len: 38
input_tokens len: 32
input_tokens len: 136
input_tokens len: 429
input_tokens len: 69
input_tokens len: 45
input_tokens len: 37
input_tokens len: 141
input_tokens len: 23
input_tokens len: 89
input_tokens len: 179
input_tokens len: 122
input_tokens len: 43
input_tokens len

In [7]:
def load_keras_model(model_dir, max_seq_len):
    # keras 加载BERT
    from bert import BertModelLayer
    from bert.loader import StockBertConfig, load_stock_weights

    bert_config_file = os.path.join(model_dir, "bert_config.json")
    bert_ckpt_file = os.path.join(model_dir, "bert_model.ckpt")

    with tf.io.gfile.GFile(bert_config_file, "r") as reader:
        bc = StockBertConfig.from_json_string(reader.read())
        l_bert = BertModelLayer.from_params(bc.to_bert_model_layer_params(), name="bert")

    l_input_ids = keras.layers.Input(shape=(max_seq_len,), dtype='int32', name="input_ids")
    l_token_type_ids = keras.layers.Input(shape=(max_seq_len,), dtype='int32', name="token_type_ids")

    l = l_bert([l_input_ids, l_token_type_ids])
    l = Lambda(lambda x: x[:, 0])(l)
    output = keras.layers.Dense(1, activation=keras.activations.sigmoid)(l)

    model = keras.Model(inputs=[l_input_ids, l_token_type_ids], outputs=output)

    model.build(input_shape=[(None, max_seq_len),
                             (None, max_seq_len)])

    load_stock_weights(l_bert, bert_ckpt_file)

    return model
model = load_keras_model(model_dir, max_seq_len)

W0722 09:25:41.856828 13968 deprecation.py:323] From c:\users\xiaoi\desktop\tensorflow2.0\venv\lib\site-packages\bert\loader.py:113: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.


Done loading 197 BERT weights from: ./data/models/chinese_L-12_H-768_A-12\bert_model.ckpt into <bert.model.BertModelLayer object at 0x00000219ADBF92B0> (prefix:bert)


In [8]:
model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-5),
                  loss=keras.losses.binary_crossentropy,
                  metrics=['accuracy'])
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_ids (InputLayer)          [(None, 95)]         0                                            
__________________________________________________________________________________________________
token_type_ids (InputLayer)     [(None, 95)]         0                                            
__________________________________________________________________________________________________
bert (BertModelLayer)           (None, 95, 768)      101677056   input_ids[0][0]                  
                                                                 token_type_ids[0][0]             
__________________________________________________________________________________________________
lambda (Lambda)                 (None, 768)          0           bert[0][0]                   

In [9]:
logdir = r'.\data\models\bert_model'
if not os.path.exists(logdir):
    os.mkdir(logdir)

_callbacks = [
    keras.callbacks.TensorBoard(logdir),#tensotBoard 路径
    keras.callbacks.ModelCheckpoint(os.path.join(logdir,"bert_classify_model.h5"),
                                   save_best_only=True), #保存最好的模型
    keras.callbacks.EarlyStopping(patience=5)
]

history = model.fit(x=(input_train_ids, token_train_type_ids),
              y=y_train,
              batch_size=16,
              epochs=10,
              validation_data=((input_dev_ids, token_dev_type_ids), y_dev),
            callbacks = _callbacks)

Train on 2000 samples, validate on 300 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10

KeyboardInterrupt: 

In [None]:
def plot_learning_curver(history):
    pd.DataFrame(history.history).plot(figsize = (8,5))
    plt.grid(True)
    plt.gca().set_ylim(0,1)
    plt.show()
plot_learning_curver(history)

In [None]:
model.evaluate((input_test_ids, token_test_type_ids))