In [None]:
# 重要!! 請安裝以下資料
# !git clone https://github.com/huggingface/transformers
# !cd transformers
# !pip install ./transformers/.
# !gsutil -m cp -R gs://tbrain-tsmc/Model_BIO . #此落無法透過指令安裝 可從gs上面下載 gs://tbrain-tsmc/training_data/Model_BIO

# config.json: https://storage.googleapis.com/tbrain-tsmc/Model_BIO/config.json
# https://storage.googleapis.com/tbrain-tsmc/Model_BIO/tf_model.h5

In [None]:
import os, re
import numpy as np
import pandas as pd
import keras
import keras_bert
from keras_bert import load_trained_model_from_checkpoint, Tokenizer
from keras.layers import *
from keras.models import Model
import keras.backend as K
from keras.callbacks import Callback
from keras.optimizers import Adam
from flask import Flask
from flask import request
from flask import jsonify
import datetime
import hashlib
from transformers import *
import tensorflow as tf

def extract_name(text, label):
    l_label = np.argmax(label, -1)[0]
    answers = set()
    answer = ''
    
    for i in np.where(l_label>0)[0].tolist():
        name = tokenizer.decode([text.numpy()[i]])

        if i in np.where(l_label==2)[0].tolist():
            if answer != "" and len(answer) > 1:
                answers.add(answer)
                
            answer = name
        else:
            answer += name

    return answers

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
config = BertConfig.from_pretrained('Model_BIO', num_labels=3)
model = TFBertForTokenClassification.from_pretrained('Model_BIO', config=config)

app = Flask(__name__)
####### PUT YOUR INFORMATION HERE #######
CAPTAIN_EMAIL = 'sammo147@gmail.com'    #
SALT = 'my_salt'                        #
#########################################

def generate_server_uuid(input_string):
    """ Create your own server_uuid
    @param input_string (str): information to be encoded as server_uuid
    @returns server_uuid (str): your unique server_uuid
    """
    s = hashlib.sha256()
    data = (input_string+SALT).encode("utf-8")
    s.update(data)
    server_uuid = s.hexdigest()
    return server_uuid

def predict(article):
    """ Predict your model result
    @param article (str): a news article
    @returns prediction (list): a list of name
    """
    MAX_LEN = 256
    part = len(article) // MAX_LEN
    i = 0
    answers = set()

    while i <= part:
        if i == part:
            inputs = tokenizer(article[i*MAX_LEN: ], return_tensors="tf")
        else:
            inputs = tokenizer(article[i*MAX_LEN: (i+1)*MAX_LEN], return_tensors="tf")

        input_ids = inputs["input_ids"]
        inputs["labels"] = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
        outputs = model(inputs)
        loss, scores = outputs[:2]
        answers.update(extract_name(input_ids[0], scores))

        i += 1
        
    prediction = _check_datatype_to_list(list(answers))
    
    return prediction

def _check_datatype_to_list(prediction):
    """ Check if your prediction is in list type or not. 
        And then convert your prediction to list type or raise error.
        
    @param prediction (list / numpy array / pandas DataFrame): your prediction
    @returns prediction (list): your prediction in list type
    """
    if isinstance(prediction, np.ndarray):
        _check_datatype_to_list(prediction.tolist())
    elif isinstance(prediction, pd.core.frame.DataFrame):
        _check_datatype_to_list(prediction.values)
    elif isinstance(prediction, list):
        return prediction
    raise ValueError('Prediction is not in list type.')

@app.route('/healthcheck', methods=['POST'])
def healthcheck():
    """ API for health check """
    data = request.get_json(force=True)  
    t = datetime.datetime.now()  
    ts = str(int(t.utcnow().timestamp()))
    server_uuid = generate_server_uuid(CAPTAIN_EMAIL+ts)
    server_timestamp = t.strftime("%Y-%m-%d %H:%M:%S")
    return jsonify({'esun_uuid': data['esun_uuid'], 'server_uuid': server_uuid, 'captain_email': CAPTAIN_EMAIL, 'server_timestamp': server_timestamp})

@app.route('/inference', methods=['POST'])
def inference():
    """ API that return your model predictions when E.SUN calls this API """
    data = request.get_json(force=True)  
    esun_timestamp = data['esun_timestamp'] #自行取用
    
    t = datetime.datetime.now()  
    ts = str(int(t.utcnow().timestamp()))
    server_uuid = generate_server_uuid(CAPTAIN_EMAIL+ts)
    
    try:
        answer = predict(data['news'])
    except:
        raise ValueError('Model error.')        
    server_timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    return jsonify({'esun_timestamp': data['esun_timestamp'], 'server_uuid': server_uuid, 'answer': answer, 'server_timestamp': server_timestamp, 'esun_uuid': data['esun_uuid']})

if __name__ == "__main__":    
    app.run(host='0.0.0.0', port=8030)