In [1]:
# 挂载Google Drive
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
# 安装bert4keras
! pip3 install bert4keras
! pip3 install gsutil
! gsutil cp -r gs://t5-data/pretrained_models/mt5/small .
! gsutil cp -r gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model .
! pip3 install sentencepiece

In [None]:
# tensorflow2.x才能用GPU!!!
! pip3 install tensorflow==2.4.1
! pip3 install keras==2.3.1

In [None]:
# 下载Bert模型
! wget https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip
! unzip chinese_L-12_H-768_A-12.zip

In [None]:
# 解压数据集和权重

! unzip /content/t5_in_bert4keras-main.zip

In [None]:
! pip3 install flask-ngrok

In [None]:
#! -*- coding: utf-8 -*-
# bert做Seq2Seq任务，采用UNILM方案
# 介绍链接：https://kexue.fm/archives/6933

from __future__ import print_function
import os
os.environ['TF_KERAS'] = '1'
import json
import numpy as np
import flask
from flask import Flask, request, render_template
from flask_ngrok import run_with_ngrok
from tqdm import tqdm
from bert4keras.backend import keras, K
from bert4keras.layers import Loss
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer, load_vocab, SpTokenizer
from bert4keras.optimizers import Adam
from bert4keras.snippets import sequence_padding, open
from bert4keras.snippets import DataGenerator, AutoRegressiveDecoder
from keras.models import Model

# t5基本参数
max_c_len = 256
# UniLM基本参数
maxlen = 256

# bert配置
UniLM_config_path = '/content/chinese_L-12_H-768_A-12/bert_config.json'
UniLM_checkpoint_path = '/content/chinese_L-12_H-768_A-12/bert_model.ckpt'
dict_path = '/content/chinese_L-12_H-768_A-12/vocab.txt'
# t5配置
mt5_config_path = '/content/small/t5_small.json'
mt5_checkpoint_path = '/content/small/model.ckpt-1000000'
mt5_spm_path = '/content/t5_in_bert4keras-main/tokenizer/sentencepiece_cn.model'
mt5_keep_tokens_path = '/content/t5_in_bert4keras-main/tokenizer/sentencepiece_cn_keep_tokens.json'

# 加载t5分词器
mt5_tokenizer = SpTokenizer(mt5_spm_path, token_start=None, token_end='</s>')
mt5_keep_tokens = json.load(open(mt5_keep_tokens_path))
# 加载UniLM分词器
token_dict, UniLM_keep_tokens = load_vocab(
    dict_path=dict_path,
    simplified=True,
    startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]'],
)
UniLM_tokenizer = Tokenizer(token_dict, do_lower_case=True)

class mt5_CrossEntropy(Loss):
    """交叉熵作为loss，并mask掉输入部分
    """
    def compute_loss(self, inputs, mask=None):
        y_true, y_pred = inputs
        y_true = y_true[:, 1:]  # 目标token_ids
        y_mask = K.cast(mask[1], K.floatx())[:, :-1]  # 解码器自带mask
        y_pred = y_pred[:, :-1]  # 预测序列，错开一位
        loss = K.sparse_categorical_crossentropy(y_true, y_pred)
        loss = K.sum(loss * y_mask) / K.sum(y_mask)
        return loss
# mt5_model配置
t5 = build_transformer_model(
    config_path=mt5_config_path,
    checkpoint_path=mt5_checkpoint_path,
    keep_tokens=mt5_keep_tokens,
    model='t5.1.1',
    return_keras_model=False,
    name='T5',
)
encoder = t5.encoder
decoder = t5.decoder
mt5_model = t5.model
mt5_model.summary()
output = mt5_CrossEntropy(1)([mt5_model.inputs[1], mt5_model.outputs[0]])
mt5_model = Model(mt5_model.inputs, output)
mt5_model.compile(optimizer=Adam(2e-4))

class UniLM_CrossEntropy(Loss):
    """交叉熵作为loss，并mask掉输入部分
    """
    def compute_loss(self, inputs, mask=None):
        y_true, y_mask, y_pred = inputs
        y_true = y_true[:, 1:]  # 目标token_ids
        y_mask = y_mask[:, 1:]  # segment_ids，刚好指示了要预测的部分
        y_pred = y_pred[:, :-1]  # 预测序列，错开一位
        loss = K.sparse_categorical_crossentropy(y_true, y_pred)
        loss = K.sum(loss * y_mask) / K.sum(y_mask)
        return loss

# UniLM_model配置
UniLM_model = build_transformer_model(
    UniLM_config_path, # 模型的配置文件
    UniLM_checkpoint_path, # 模型的预训练权重
    application='unilm', # 模型的用途
    keep_tokens=UniLM_keep_tokens,  # 只保留keep_tokens中的字，精简原字表
)
output = UniLM_CrossEntropy(2)(UniLM_model.inputs + UniLM_model.outputs)
UniLM_model = Model(UniLM_model.inputs, output)
UniLM_model.compile(optimizer=Adam(1e-5))
UniLM_model.summary()

# flask展示
app = Flask(__name__)
run_with_ngrok(app)

class mt5_AutoTitle(AutoRegressiveDecoder):
    """seq2seq解码器
    """
    @AutoRegressiveDecoder.wraps(default_rtype='probas')
    def predict(self, inputs, output_ids, states):
        c_encoded = inputs[0]
        return decoder.predict([c_encoded, output_ids])[:, -1]

    def generate(self, text, topk=1):
        c_token_ids, _ = mt5_tokenizer.encode(text, maxlen=max_c_len)
        c_encoded = encoder.predict(np.array([c_token_ids]))[0]
        output_ids = self.beam_search([c_encoded], topk)  # 基于beam search
        return mt5_tokenizer.decode([int(i) for i in output_ids])
# T5有一个很让人不解的设置，它的<bos>标记id是0，即<bos>和<pad>其实都是0
mt5_autotitle = mt5_AutoTitle(start_id=0, end_id=mt5_tokenizer._token_end_id, maxlen=128)

class UniLM_AutoTitle(AutoRegressiveDecoder):
    """seq2seq解码器
    """
    @AutoRegressiveDecoder.wraps(default_rtype='probas')
    def predict(self, inputs, output_ids, states):
        token_ids, segment_ids = inputs
        token_ids = np.concatenate([token_ids, output_ids], 1)
        segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)
        return self.last_token(UniLM_model).predict([token_ids, segment_ids])

    def generate(self, text, topk=1):
        max_c_len = maxlen - self.maxlen
        token_ids, segment_ids = UniLM_tokenizer.encode(text, maxlen=max_c_len)
        output_ids = self.beam_search([token_ids, segment_ids],
                                      topk=topk)  # 基于beam search
        return UniLM_tokenizer.decode(output_ids)
# UniLM设置
UniLM_autotitle = UniLM_AutoTitle(start_id=None, end_id=UniLM_tokenizer._token_end_id, maxlen=128)

@app.route('/')
def index():
    return render_template('index.html')


@app.route('/predict', methods=['POST'])
def predict():
    try:
        sentence = request.json['input_text']
        model = request.json['model']
        if sentence != '':
            if model.lower() == 'unilm':
                output = UniLM_autotitle.generate(sentence)
            else:
                output = mt5_autotitle.generate(sentence)
            response = {}
            response['response'] = {
                'summary': str(output),
                'model': model.lower()
            }
            return flask.jsonify(response)
        else:
            res = dict({'message': 'Empty input'})
            return app.response_class(response=json.dumps(res), status=500, mimetype='application/json')
    except Exception as ex:
        res = dict({'message': str(ex)})
        print(res)
        return app.response_class(response=json.dumps(res), status=500, mimetype='application/json')

if __name__ == '__main__':
    UniLM_model.load_weights('./gdrive/MyDrive/UniLM_Bert4Keras/best_model.weights')
    mt5_model.load_weights('./gdrive/MyDrive/T5_Bert4Keras/best_model.weights')
    app.run()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
Encoder-Input-Token (InputLayer [(None, None)]       0                                            
__________________________________________________________________________________________________
Embedding-Token (Embedding)     (None, None, 512)    16690176    Encoder-Input-Token[0][0]        
__________________________________________________________________________________________________
Encoder-Embedding-Dropout (Drop (None, None, 512)    0           Embedding-Token[0][0]            
__________________________________________________________________________________________________
Encoder-Transformer-0-MultiHead (None, None, 512)    512         Encoder-Embedding-Dropout[0][0]  
______________________________________________________________________________________________

 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)


 * Running on http://2c05544be289.ngrok.io
 * Traffic stats available on http://127.0.0.1:4040


127.0.0.1 - - [14/May/2021 00:46:35] "[37mGET / HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 00:46:36] "[37mGET /static/js/jquery-3.4.1.min.js HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 00:46:36] "[37mGET /static/js/bootstrap.min.js HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 00:46:36] "[37mGET /static/css/app.css HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 00:46:36] "[37mGET / HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 00:46:36] "[37mGET /static/css/bootstrap.min.css HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 00:46:36] "[37mGET /static/js/app.js HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 00:46:36] "[37mGET /static/css/app.css HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 00:46:36] "[37mGET /static/js/app.js HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 00:46:36] "[37mGET /static/js/jquery-3.4.1.min.js HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 00:46:36] "[37mGET /static/js/bootstrap.min.js HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 00:46:36



127.0.0.1 - - [14/May/2021 00:51:09] "[37mPOST /predict HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 00:52:24] "[37mPOST /predict HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 00:57:09] "[37mPOST /predict HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 01:00:09] "[37mPOST /predict HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 01:01:07] "[37mPOST /predict HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 01:01:37] "[37mPOST /predict HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 01:03:21] "[37mPOST /predict HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 01:03:30] "[37mPOST /predict HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 01:04:31] "[37mPOST /predict HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 01:05:02] "[37mPOST /predict HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 01:05:50] "[37mPOST /predict HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 01:11:23] "[37mPOST /predict HTTP/1.1[0m" 200 -
127.0.0.1 - - [14/May/2021 01:13:11] "[37mPOST /predict HTTP/1.1[0m" 200 -