- run in Google Colab and use T4 GPU

In [None]:
!pip install flask
!pip install flask_cors
!pip install flask_ngrok
!pip install pyngrok
!pip install --upgrade simpletransformers
!pip install torchvision
!pip install -U ckip-transformers
!pip install captum==0.6.0

In [5]:
import os
from flask import Flask, request, jsonify
from pyngrok import ngrok
from flask_cors import CORS  # 載入 CORS
from google.colab import drive
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from google.colab import drive
from captum.attr import IntegratedGradients
import torch.nn.functional as F
from transformers import BertConfig
import gc
import math
import threading  # Import the threading module
import time

In [None]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"  # 為了解決 Ram 空間不夠

In [None]:
ngrok.set_auth_token("Your Ngork Token")  # 替換為你的 ngrok token

In [None]:
drive.mount('/content/drive')  # 連接至 Google Drive

In [None]:
# 預設模型與解釋器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = "/content/drive/MyDrive/Colab Notebooks/complete_model1.pt"
model = torch.load(model_path, weights_only=False, map_location=device)  # 載入訓練好的模型

model = BertForSequenceClassification.from_pretrained("ckiplab/bert-base-chinese", num_labels=2)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device).eval()

model.config = BertConfig.from_pretrained('ckiplab/bert-base-chinese')
tokenizer = BertTokenizer.from_pretrained('ckiplab/bert-base-chinese')
ig = IntegratedGradients(forward_func=lambda x, mask: model(inputs_embeds=x, attention_mask=mask).logits.softmax(dim=1)[:, 1])

# 計算每個 token 的重要性（貢獻度）分數
def get_token_importances(text):
    inputs = tokenizer(text, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)

    embeddings = model.bert.embeddings(input_ids)

    baseline = torch.zeros_like(embeddings)
    attributions, _ = ig.attribute(inputs=embeddings,
                                   baselines=baseline,
                                   additional_forward_args=(attention_mask,),
                                   return_convergence_delta=True)

    token_attributions = attributions.sum(dim=-1).squeeze(0)
    token_list = tokenizer.convert_ids_to_tokens(input_ids[0])
    token_scores = token_attributions.detach().cpu().tolist()

    del inputs, input_ids, attention_mask, embeddings, attributions, token_attributions
    torch.cuda.empty_cache()
    gc.collect()

    return list(zip(token_list, token_scores))

# 預測函數
def predict(input_sentence):
    inputs = tokenizer(input_sentence, return_tensors="pt", truncation=True, padding=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    prediction = torch.argmax(logits, dim=1)  # 選擇每個樣本得分最高的索引
    probabilities = F.softmax(logits, dim=1)
    probability = probabilities[0, prediction].item()

    tokens_with_scores = get_token_importances(input_sentence)  # 計算貢獻值

    # 對分數做操作
    tokens_with_scores = [
        (tok, round(float(score) * 100, 4))  # 乘以100；四捨五入到小數點後4位
        for tok, score in tokens_with_scores
        if math.isfinite(score)  # 排除 NaN 和 Inf
    ]

    del inputs, outputs, logits, probabilities
    torch.cuda.empty_cache()
    gc.collect()

    prob_rounded = math.floor(probability * 100) / 100  # 無條件捨去到小數點第二位
    return prediction.item(), prob_rounded, tokens_with_scores

In [None]:
# 建立 Flask
app = Flask(__name__)
CORS(app)

# 啟動 ngrok
public_url = ngrok.connect(5000)
print("Flask app is publicly accessible at:", public_url)

@app.route('/process', methods=['POST'])
def process_text():
    data = request.get_json()
    input_sentence = data.get('text')
    prediction, probability, tokens_with_scores = predict(input_sentence)
    tokens_list = [{'token': tok, 'score': score} for tok, score in tokens_with_scores if tok not in ('[CLS]', '[SEP]')]
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    return jsonify({'prediction': prediction, 'probability': probability, 'tokens': tokens_list})

# ——————— 解決 RAM 空間不夠 ———————
@app.route('/restart', methods=['POST'])
def restart():
    os.kill(os.getpid(), 9)

# GPU 記憶體監控清理
def clean_memory_periodically():
    while True:
        torch.cuda.empty_cache()
        gc.collect()
        time.sleep(30)

threading.Thread(target=clean_memory_periodically, daemon=True).start()
# ———————————————————————————————

if __name__ == '__main__':
    app.run(port=5000)